Integrate sessions middlewares

This commit is contained in:
Joel Speed 2020-07-18 00:42:51 +01:00
parent 034f057b60
commit eb234011eb
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
4 changed files with 53 additions and 309 deletions

View File

@ -15,7 +15,9 @@ import (
"time"
"github.com/coreos/go-oidc"
"github.com/justinas/alice"
ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic"
@ -23,6 +25,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
"github.com/oauth2-proxy/oauth2-proxy/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream"
"github.com/oauth2-proxy/oauth2-proxy/providers"
@ -98,6 +101,8 @@ type OAuthProxy struct {
trustedIPs *ip.NetSet
Banner string
Footer string
sessionChain alice.Chain
}
// NewOAuthProxy creates a new instance of OAuthProxy from the options provided
@ -156,6 +161,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
}
}
sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator)
return &OAuthProxy{
CookieName: opts.Cookie.Name,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
@ -209,9 +216,45 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
basicAuthValidator: basicAuthValidator,
displayHtpasswdForm: basicAuthValidator != nil,
sessionChain: sessionChain,
}, nil
}
func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain {
chain := alice.New(middleware.NewScope())
if opts.SkipJwtBearerTokens {
sessionLoaders := []middlewareapi.TokenToSessionLoader{}
if opts.GetOIDCVerifier() != nil {
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
Verifier: opts.GetOIDCVerifier(),
TokenToSession: opts.GetProvider().CreateSessionStateFromBearerToken,
})
}
for _, verifier := range opts.GetJWTBearerVerifiers() {
sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{
Verifier: verifier,
})
}
chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders))
}
if validator != nil {
chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator))
}
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded,
ValidateSessionState: opts.GetProvider().ValidateSessionState,
}))
return chain
}
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
// redirect clients to once authenticated
func (p *OAuthProxy) GetRedirectURI(host string) string {
@ -780,86 +823,20 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
// Set-Cookie headers may be set on the response as a side-effect of calling this method.
func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) {
var session *sessionsapi.SessionState
var err error
var saveSession, clearSession, revalidated bool
if p.skipJwtBearerTokens && req.Header.Get("Authorization") != "" {
session, err = p.GetJwtSession(req)
if err != nil {
logger.Printf("Error retrieving session from token in Authorization header: %s", err)
}
if session != nil {
saveSession = false
}
}
getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
session = middleware.GetRequestScope(req).Session
}))
getSession.ServeHTTP(rw, req)
remoteAddr := ip.GetClientString(p.realClientIPParser, req, true)
if session == nil {
session, err = p.LoadCookiedSession(req)
if err != nil {
logger.Printf("Error loading cookied session: %s", err)
}
if session != nil {
if session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh)
saveSession = true
}
if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil {
logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
clearSession = true
session = nil
} else if ok {
saveSession = true
revalidated = true
}
}
}
if session != nil && session.IsExpired() {
logger.Printf("Removing session: token expired %s", session)
session = nil
saveSession = false
clearSession = true
}
if saveSession && !revalidated && session != nil && session.AccessToken != "" {
if !p.provider.ValidateSessionState(req.Context(), session) {
logger.Printf("Removing session: error validating %s", session)
saveSession = false
session = nil
clearSession = true
}
return nil, ErrNeedsLogin
}
if session != nil && session.Email != "" && !p.Validator(session.Email) {
logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session)
session = nil
saveSession = false
clearSession = true
}
if saveSession && session != nil {
err = p.SaveSession(rw, req, session)
if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthError, "Save session error %s", err)
return nil, err
}
}
if clearSession {
// Invalid session, clear it
p.ClearSessionCookie(rw, req)
}
if session == nil {
session, err = p.CheckBasicAuth(req)
if err != nil {
logger.Printf("Error during basic auth validation: %s", err)
}
}
if session == nil {
return nil, ErrNeedsLogin
}
@ -997,36 +974,6 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) {
}
}
// CheckBasicAuth checks the requests Authorization header for basic auth
// credentials and authenticates these against the proxies HtpasswdFile
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) {
if p.basicAuthValidator == nil {
return nil, nil
}
auth := req.Header.Get("Authorization")
if auth == "" {
return nil, nil
}
s := strings.SplitN(auth, " ", 2)
if len(s) != 2 || s[0] != "Basic" {
return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization"))
}
b, err := b64.StdEncoding.DecodeString(s[1])
if err != nil {
return nil, err
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return nil, fmt.Errorf("invalid format %s", b)
}
if p.basicAuthValidator.Validate(pair[0], pair[1]) {
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
return &sessionsapi.SessionState{User: pair[0]}, nil
}
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
return nil, nil
}
// isAjax checks if a request is an ajax request
func isAjax(req *http.Request) bool {
acceptValues := req.Header.Values("Accept")
@ -1044,74 +991,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) {
rw.Header().Set("Content-Type", applicationJSON)
rw.WriteHeader(code)
}
// GetJwtSession loads a session based on a JWT token in the authorization header.
// (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers)
func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState, error) {
rawBearerToken, err := p.findBearerToken(req)
if err != nil {
return nil, err
}
// If we are using an oidc provider, go ahead and try that provider first with its Verifier
// and Bearer Token -> Session converter
if p.mainJwtBearerVerifier != nil {
bearerToken, err := p.mainJwtBearerVerifier.Verify(req.Context(), rawBearerToken)
if err == nil {
return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken)
}
}
// Otherwise, attempt to verify against the extra JWT issuers and use a more generic
// Bearer Token -> Session converter
for _, verifier := range p.extraJwtBearerVerifiers {
bearerToken, err := verifier.Verify(req.Context(), rawBearerToken)
if err != nil {
continue
}
return (*providers.ProviderData)(nil).CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken)
}
return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization"))
}
// findBearerToken finds a valid JWT token from the Authorization header of a given request.
func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) {
auth := req.Header.Get("Authorization")
s := strings.SplitN(auth, " ", 2)
if len(s) != 2 {
return "", fmt.Errorf("invalid authorization header %s", auth)
}
jwtRegex := regexp.MustCompile(`^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`)
var rawBearerToken string
if s[0] == "Bearer" && jwtRegex.MatchString(s[1]) {
rawBearerToken = s[1]
} else if s[0] == "Basic" {
// Check if we have a Bearer token masquerading in Basic
b, err := b64.StdEncoding.DecodeString(s[1])
if err != nil {
return "", err
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return "", fmt.Errorf("invalid format %s", b)
}
user, password := pair[0], pair[1]
// check user, user+password, or just password for a token
if jwtRegex.MatchString(user) {
// Support blank passwords or magic `x-oauth-basic` passwords - nothing else
if password == "" || password == "x-oauth-basic" {
rawBearerToken = user
}
} else if jwtRegex.MatchString(password) {
// support passwords and ignore user
rawBearerToken = password
}
}
if rawBearerToken == "" {
return "", fmt.Errorf("no valid bearer token found in authorization header")
}
return rawBearerToken, nil
}

View File

@ -1889,7 +1889,7 @@ func TestGetJwtSession(t *testing.T) {
// Bearer
expires := time.Unix(1912151821, 0)
session, err := test.proxy.GetJwtSession(test.req)
session, err := test.proxy.getAuthenticatedSession(test.rw, test.req)
assert.NoError(t, err)
assert.Equal(t, session.User, "1234567890")
assert.Equal(t, session.Email, "john@example.com")
@ -1912,70 +1912,6 @@ func TestGetJwtSession(t *testing.T) {
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
}
func TestFindJwtBearerToken(t *testing.T) {
p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}}
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}
validToken := "eyJfoobar.eyJfoobar.12345asdf"
var token string
// Bearer
getReq.Header = map[string][]string{
"Authorization": {fmt.Sprintf("Bearer %s", validToken)},
}
token, err := p.findBearerToken(getReq)
assert.NoError(t, err)
assert.Equal(t, validToken, token)
// Basic - no password
getReq.SetBasicAuth(token, "")
token, err = p.findBearerToken(getReq)
assert.NoError(t, err)
assert.Equal(t, validToken, token)
// Basic - sentinel password
getReq.SetBasicAuth(token, "x-oauth-basic")
token, err = p.findBearerToken(getReq)
assert.NoError(t, err)
assert.Equal(t, validToken, token)
// Basic - any username, password matching jwt pattern
getReq.SetBasicAuth("any-username-you-could-wish-for", token)
token, err = p.findBearerToken(getReq)
assert.NoError(t, err)
assert.Equal(t, validToken, token)
failures := []string{
// Too many parts
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
// Not enough parts
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA",
// Invalid encrypted key
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA",
// Invalid IV
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA",
// Invalid ciphertext
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA",
// Invalid tag
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////",
// Invalid header
"W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
// Invalid header
"######.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
// Missing alc/enc params
"e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
}
for _, failure := range failures {
getReq.Header = map[string][]string{
"Authorization": {fmt.Sprintf("Bearer %s", failure)},
}
_, err := p.findBearerToken(getReq)
assert.Error(t, err)
}
}
func Test_prepareNoCache(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
prepareNoCache(w)

View File

@ -126,35 +126,8 @@ func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S
return false, nil
}
// CreateSessionStateFromBearerToken should be implemented to allow providers
// to convert ID tokens into sessions
func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) {
var claims struct {
Subject string `json:"sub"`
Email string `json:"email"`
Verified *bool `json:"email_verified"`
PreferredUsername string `json:"preferred_username"`
}
if err := idToken.Claims(&claims); err != nil {
return nil, fmt.Errorf("failed to parse bearer token claims: %v", err)
}
if claims.Email == "" {
claims.Email = claims.Subject
}
if claims.Verified != nil && !*claims.Verified {
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
newSession := &sessions.SessionState{
Email: claims.Email,
User: claims.Subject,
PreferredUsername: claims.PreferredUsername,
AccessToken: rawIDToken,
IDToken: rawIDToken,
RefreshToken: "",
ExpiresOn: &idToken.Expiry,
}
return newSession, nil
return nil, errors.New("not implemented")
}

View File

@ -2,15 +2,10 @@ package providers
import (
"context"
"crypto/rand"
"crypto/rsa"
"net/url"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/dgrijalva/jwt-go"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
@ -52,39 +47,3 @@ func TestAcrValuesConfigured(t *testing.T) {
result := p.GetLoginURL("https://my.test.app/oauth", "")
assert.Contains(t, result, "acr_values=testValue")
}
func TestCreateSessionStateFromBearerToken(t *testing.T) {
minimalIDToken := jwt.StandardClaims{
Audience: "asdf1234",
ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(),
Id: "id-some-id",
IssuedAt: time.Now().Unix(),
Issuer: "https://issuer.example.com",
NotBefore: 0,
Subject: "123456789",
}
// From oidc_test.go
verifier := oidc.NewVerifier(
"https://issuer.example.com",
fakeKeySetStub{},
&oidc.Config{ClientID: "asdf1234"},
)
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, minimalIDToken).SignedString(key)
assert.NoError(t, err)
// Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below
idToken, err := verifier.Verify(context.Background(), rawIDToken)
assert.NoError(t, err)
session, err := (*ProviderData)(nil).CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken)
assert.NoError(t, err)
assert.Equal(t, rawIDToken, session.AccessToken)
assert.Equal(t, rawIDToken, session.IDToken)
assert.Equal(t, "123456789", session.Email)
assert.Equal(t, "123456789", session.User)
assert.Empty(t, session.RefreshToken)
assert.Empty(t, session.PreferredUsername)
}