Support nonce checks in OIDC Provider (#967)

* Set and verify a nonce with OIDC

* Create a CSRF object to manage nonces & cookies

* Add missing generic cookie unit tests

* Add config flag to control OIDC SkipNonce

* Send hashed nonces in authentication requests

* Encrypt the CSRF cookie

* Add clarity to naming & add more helper methods

* Make CSRF an interface and keep underlying nonces private

* Add ReverseProxy scope to cookie tests

* Align to new 1.16 SameSite cookie default

* Perform SecretBytes conversion on CSRF cookie crypto

* Make state encoding signatures consistent

* Mock time in CSRF struct via Clock

* Improve InsecureSkipNonce docstring
This commit is contained in:
Nick Meves 2021-04-21 02:33:27 -07:00 committed by GitHub
parent d3423408c7
commit 7eeaea0b3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 860 additions and 170 deletions

View File

@ -4,10 +4,14 @@
## Important Notes
- [#967](https://github.com/oauth2-proxy/oauth2-proxy/pull/967) `--insecure-oidc-skip-nonce` is currently `true` by default in case
any existing OIDC Identity Providers don't support it. The default will switch to `false` in a future version.
## Breaking Changes
## Changes since v7.1.2
- [#967](https://github.com/oauth2-proxy/oauth2-proxy/pull/967) Set & verify a nonce with OIDC providers (@NickMeves)
- [#1136](https://github.com/oauth2-proxy/oauth2-proxy/pull/1136) Add clock package for better time mocking in tests (@NickMeves)
- [#947](https://github.com/oauth2-proxy/oauth2-proxy/pull/947) Multiple provider ingestion and validation in alpha options (first stage: [#926](https://github.com/oauth2-proxy/oauth2-proxy/issues/926)) (@yanasega)

View File

@ -264,6 +264,7 @@ make up the header value
| `issuerURL` | _string_ | IssuerURL is the OpenID Connect issuer URL<br/>eg: https://accounts.google.com |
| `insecureAllowUnverifiedEmail` | _bool_ | InsecureAllowUnverifiedEmail prevents failures if an email address in an id_token is not verified<br/>default set to 'false' |
| `insecureSkipIssuerVerification` | _bool_ | InsecureSkipIssuerVerification skips verification of ID token issuers. When false, ID Token Issuers must match the OIDC discovery URL<br/>default set to 'false' |
| `insecureSkipNonce` | _bool_ | InsecureSkipNonce skips verifying the ID Token's nonce claim that must match<br/>the random nonce sent in the initial OAuth flow. Otherwise, the nonce is checked<br/>after the initial OAuth redeem & subsequent token refreshes.<br/>default set to 'true'<br/>Warning: In a future release, this will change to 'false' by default for enhanced security. |
| `skipDiscovery` | _bool_ | SkipDiscovery allows to skip OIDC discovery and use manually supplied Endpoints<br/>default set to 'false' |
| `jwksURL` | _string_ | JwksURL is the OpenID Connect JWKS URL<br/>eg: https://www.googleapis.com/oauth2/v3/certs |
| `emailClaim` | _string_ | EmailClaim indicates which claim contains the user email,<br/>default set to 'email' |

View File

@ -75,6 +75,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/
| `--login-url` | string | Authentication endpoint | |
| `--insecure-oidc-allow-unverified-email` | bool | don't fail if an email address in an id_token is not verified | false |
| `--insecure-oidc-skip-issuer-verification` | bool | allow the OIDC issuer URL to differ from the expected (currently required for Azure multi-tenant compatibility) | false |
| `--insecure-oidc-skip-nonce` | bool | skip verifying the OIDC ID Token's nonce claim | true |
| `--oidc-issuer-url` | string | the OpenID Connect issuer URL, e.g. `"https://accounts.google.com"` | |
| `--oidc-jwks-url` | string | OIDC JWKS URI for token verification; required if OIDC discovery is disabled | |
| `--oidc-email-claim` | string | which OIDC claim contains the user's email | `"email"` |

View File

@ -71,6 +71,7 @@ providers:
groupsClaim: groups
emailClaim: email
userIDClaim: email
insecureSkipNonce: true
`
const testCoreConfig = `
@ -138,9 +139,10 @@ redirect_url="http://localhost:4180/oauth2/callback"
Tenant: "common",
},
OIDCConfig: options.OIDCOptions{
GroupsClaim: "groups",
EmailClaim: "email",
UserIDClaim: "email",
GroupsClaim: "groups",
EmailClaim: "email",
UserIDClaim: "email",
InsecureSkipNonce: true,
},
ApprovalPrompt: "force",
},
@ -228,7 +230,7 @@ redirect_url="http://localhost:4180/oauth2/callback"
configContent: testCoreConfig,
alphaConfigContent: testAlphaConfig + ":",
expectedOptions: func() *options.Options { return nil },
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 48: did not find expected key"),
expectedErr: errors.New("failed to load alpha options: error unmarshalling config: error converting YAML to JSON: yaml: line 49: did not find expected key"),
}),
Entry("with alpha configuration and bad core configuration", loadConfigurationTableInput{
configContent: testCoreConfig + "unknown_field=\"something\"",

View File

@ -23,8 +23,8 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
proxyhttp "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/http"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
@ -60,14 +60,8 @@ type allowedRoute struct {
// OAuthProxy is the main authentication proxy
type OAuthProxy struct {
CSRFCookieName string
CookieDomains []string
CookiePath string
CookieSecure bool
CookieHTTPOnly bool
CookieExpire time.Duration
CookieSameSite string
Validator func(string) bool
CookieOptions *options.Cookie
Validator func(string) bool
RobotsPath string
SignInPath string
@ -179,14 +173,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
}
p := &OAuthProxy{
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
CookieDomains: opts.Cookie.Domains,
CookiePath: opts.Cookie.Path,
CookieSecure: opts.Cookie.Secure,
CookieHTTPOnly: opts.Cookie.HTTPOnly,
CookieExpire: opts.Cookie.Expire,
CookieSameSite: opts.Cookie.SameSite,
Validator: validator,
CookieOptions: &opts.Cookie,
Validator: validator,
RobotsPath: "/robots.txt",
SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
@ -427,47 +415,6 @@ func buildRoutesAllowlist(opts *options.Options) ([]allowedRoute, error) {
return routes, nil
}
// MakeCSRFCookie creates a cookie for CSRF
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
}
func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)
if cookieDomain != "" {
domain := requestutil.GetRequestHost(req)
if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h
}
if !strings.HasSuffix(domain, cookieDomain) {
logger.Errorf("Warning: request host is %q but using configured cookie domain of %q", domain, cookieDomain)
}
}
return &http.Cookie{
Name: name,
Value: value,
Path: p.CookiePath,
Domain: cookieDomain,
HttpOnly: p.CookieHTTPOnly,
Secure: p.CookieSecure,
Expires: now.Add(expiration),
SameSite: cookies.ParseSameSite(p.CookieSameSite),
}
}
// ClearCSRFCookie creates a cookie to unset the CSRF cookie stored in the user's
// session
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRFCookie adds a CSRF cookie to the response
func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now()))
}
// ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error {
@ -744,21 +691,35 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
// OAuthStart starts the OAuth2 authentication flow
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
prepareNoCache(rw)
nonce, err := encryption.Nonce()
csrf, err := cookies.NewCSRF(p.CookieOptions)
if err != nil {
logger.Errorf("Error obtaining nonce: %v", err)
logger.Errorf("Error creating CSRF nonce: %v", err)
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
p.SetCSRFCookie(rw, req, nonce)
redirect, err := p.getAppRedirect(req)
appRedirect, err := p.getAppRedirect(req)
if err != nil {
logger.Errorf("Error obtaining redirect: %v", err)
logger.Errorf("Error obtaining application redirect: %v", err)
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
redirectURI := p.getOAuthRedirectURI(req)
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
callbackRedirect := p.getOAuthRedirectURI(req)
loginURL := p.provider.GetLoginURL(
callbackRedirect,
encodeState(csrf.HashOAuthState(), appRedirect),
csrf.HashOIDCNonce(),
)
if _, err := csrf.SetCookie(rw, req); err != nil {
logger.Errorf("Error setting CSRF cookie: %v", err)
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
http.Redirect(rw, req, loginURL, http.StatusFound)
}
// OAuthCallback is the OAuth2 authentication flow callback that finishes the
@ -796,29 +757,33 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return
}
state := strings.SplitN(req.Form.Get("state"), ":", 2)
if len(state) != 2 {
logger.Error("Error while parsing OAuth2 state: invalid length")
p.ErrorPage(rw, req, http.StatusInternalServerError, "State paremeter did not have expected length", "Login Failed: Invalid State after login.")
return
}
nonce := state[0]
redirect := state[1]
c, err := req.Cookie(p.CSRFCookieName)
csrf, err := cookies.LoadCSRFCookie(req, p.CookieOptions)
if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie")
p.ErrorPage(rw, req, http.StatusForbidden, err.Error(), "Login Failed: Unable to find a valid CSRF token. Please try again.")
return
}
p.ClearCSRFCookie(rw, req)
if c.Value != nonce {
csrf.ClearCookie(rw, req)
nonce, appRedirect, err := decodeState(req)
if err != nil {
logger.Errorf("Error while parsing OAuth2 state: %v", err)
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
if !csrf.CheckOAuthState(nonce) {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: CSRF token mismatch, potential attack")
p.ErrorPage(rw, req, http.StatusForbidden, "CSRF token mismatch, potential attack", "Login Failed: Unable to find a valid CSRF token. Please try again.")
return
}
if !p.IsValidRedirect(redirect) {
redirect = "/"
csrf.SetSessionNonce(session)
p.provider.ValidateSession(req.Context(), session)
if !p.IsValidRedirect(appRedirect) {
appRedirect = "/"
}
// set cookie, or deny
@ -834,7 +799,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
return
}
http.Redirect(rw, req, redirect, http.StatusFound)
http.Redirect(rw, req, appRedirect, http.StatusFound)
} else {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unauthorized")
p.ErrorPage(rw, req, http.StatusForbidden, "Invalid session: unauthorized")
@ -966,7 +931,7 @@ func (p *OAuthProxy) getOAuthRedirectURI(req *http.Request) string {
// If CookieSecure is true, return `https` no matter what
// Not all reverse proxies set X-Forwarded-Proto
if p.CookieSecure {
if p.CookieOptions.Secure {
rd.Scheme = schemeHTTPS
}
return rd.String()
@ -1207,6 +1172,22 @@ func extractAllowedGroups(req *http.Request) map[string]struct{} {
return groups
}
// encodedState builds the OAuth state param out of our nonce and
// original application redirect
func encodeState(nonce string, redirect string) string {
return fmt.Sprintf("%v:%v", nonce, redirect)
}
// decodeState splits the reflected OAuth state response back into
// the nonce and original application redirect
func decodeState(req *http.Request) (string, string, error) {
state := strings.SplitN(req.Form.Get("state"), ":", 2)
if len(state) != 2 {
return "", "", errors.New("invalid length")
}
return state[0], state[1], nil
}
// addHeadersForProxying adds the appropriate headers the request / response for proxying
func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, session *sessionsapi.SessionState) {
if session.Email == "" {

View File

@ -22,6 +22,7 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/cookie"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
@ -698,23 +699,42 @@ func (patTest *PassAccessTokenTest) Close() {
patTest.providerServer.Close()
}
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
cookie string) {
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) {
rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
strings.NewReader(""))
csrf, err := cookies.NewCSRF(patTest.proxy.CookieOptions)
if err != nil {
panic(err)
}
req, err := http.NewRequest(
http.MethodGet,
fmt.Sprintf(
"/oauth2/callback?code=callback_code&state=%s",
encodeState(csrf.HashOAuthState(), "%2F"),
),
strings.NewReader(""),
)
if err != nil {
return 0, ""
}
req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
// rw is a dummy here, we just want the csrfCookie to add to our req
csrfCookie, err := csrf.SetCookie(httptest.NewRecorder(), req)
if err != nil {
panic(err)
}
req.AddCookie(csrfCookie)
patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Header().Values("Set-Cookie")[1]
}
// getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath
// and cookie and returns body and status code.
func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoint string) (httpCode int, accessToken string) {
cookieName := patTest.opts.Cookie.Name
cookieName := patTest.proxy.CookieOptions.Name
var value string
keyPrefix := cookieName + "="
@ -983,6 +1003,9 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
}
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.Providers[0].AllowedGroups)
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
pcTest.proxy.CookieOptions.Refresh = time.Duration(0)
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pcTest.validateUser = true
@ -1104,6 +1127,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)
pcTest.proxy.CookieOptions.Refresh = time.Hour
session, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
@ -1999,7 +2023,7 @@ func TestClearSplitCookie(t *testing.T) {
t.Fatal(err)
}
p := OAuthProxy{sessionStore: store}
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)
@ -2032,7 +2056,7 @@ func TestClearSingleCookie(t *testing.T) {
t.Fatal(err)
}
p := OAuthProxy{sessionStore: store}
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)

View File

@ -48,12 +48,13 @@ func NewLegacyOptions() *LegacyOptions {
},
LegacyProvider: LegacyProvider{
ProviderType: "google",
AzureTenant: "common",
ApprovalPrompt: "force",
UserIDClaim: "email",
OIDCEmailClaim: "email",
OIDCGroupsClaim: "groups",
ProviderType: "google",
AzureTenant: "common",
ApprovalPrompt: "force",
UserIDClaim: "email",
OIDCEmailClaim: "email",
OIDCGroupsClaim: "groups",
InsecureOIDCSkipNonce: true,
},
Options: *NewOptions(),
@ -492,6 +493,7 @@ type LegacyProvider struct {
OIDCIssuerURL string `flag:"oidc-issuer-url" cfg:"oidc_issuer_url"`
InsecureOIDCAllowUnverifiedEmail bool `flag:"insecure-oidc-allow-unverified-email" cfg:"insecure_oidc_allow_unverified_email"`
InsecureOIDCSkipIssuerVerification bool `flag:"insecure-oidc-skip-issuer-verification" cfg:"insecure_oidc_skip_issuer_verification"`
InsecureOIDCSkipNonce bool `flag:"insecure-oidc-skip-nonce" cfg:"insecure_oidc_skip_nonce"`
SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"`
OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"`
OIDCEmailClaim string `flag:"oidc-email-claim" cfg:"oidc_email_claim"`
@ -540,6 +542,7 @@ func legacyProviderFlagSet() *pflag.FlagSet {
flagSet.String("oidc-issuer-url", "", "OpenID Connect issuer URL (ie: https://accounts.google.com)")
flagSet.Bool("insecure-oidc-allow-unverified-email", false, "Don't fail if an email address in an id_token is not verified")
flagSet.Bool("insecure-oidc-skip-issuer-verification", false, "Do not verify if issuer matches OIDC discovery URL")
flagSet.Bool("insecure-oidc-skip-nonce", true, "skip verifying the OIDC ID Token's nonce claim")
flagSet.Bool("skip-oidc-discovery", false, "Skip OIDC discovery and use manually supplied Endpoints")
flagSet.String("oidc-jwks-url", "", "OpenID Connect JWKS URL (ie: https://www.googleapis.com/oauth2/v3/certs)")
flagSet.String("oidc-groups-claim", providers.OIDCGroupsClaim, "which OIDC claim contains the user groups")
@ -630,6 +633,7 @@ func (l *LegacyProvider) convert() (Providers, error) {
IssuerURL: l.OIDCIssuerURL,
InsecureAllowUnverifiedEmail: l.InsecureOIDCAllowUnverifiedEmail,
InsecureSkipIssuerVerification: l.InsecureOIDCSkipIssuerVerification,
InsecureSkipNonce: l.InsecureOIDCSkipNonce,
SkipDiscovery: l.SkipOIDCDiscovery,
JwksURL: l.OIDCJwksURL,
UserIDClaim: l.UserIDClaim,

View File

@ -113,6 +113,7 @@ var _ = Describe("Legacy Options", func() {
opts.Providers[0].ClientID = "oauth-proxy"
opts.Providers[0].ID = "google=oauth-proxy"
opts.Providers[0].OIDCConfig.InsecureSkipNonce = true
converted, err := legacyOpts.ToOptions()
Expect(err).ToNot(HaveOccurred())

View File

@ -36,12 +36,13 @@ var _ = Describe("Load", func() {
},
LegacyProvider: LegacyProvider{
ProviderType: "google",
AzureTenant: "common",
ApprovalPrompt: "force",
UserIDClaim: "email",
OIDCEmailClaim: "email",
OIDCGroupsClaim: "groups",
ProviderType: "google",
AzureTenant: "common",
ApprovalPrompt: "force",
UserIDClaim: "email",
OIDCEmailClaim: "email",
OIDCGroupsClaim: "groups",
InsecureOIDCSkipNonce: true,
},
Options: Options{

View File

@ -132,6 +132,12 @@ type OIDCOptions struct {
// InsecureSkipIssuerVerification skips verification of ID token issuers. When false, ID Token Issuers must match the OIDC discovery URL
// default set to 'false'
InsecureSkipIssuerVerification bool `json:"insecureSkipIssuerVerification,omitempty"`
// InsecureSkipNonce skips verifying the ID Token's nonce claim that must match
// the random nonce sent in the initial OAuth flow. Otherwise, the nonce is checked
// after the initial OAuth redeem & subsequent token refreshes.
// default set to 'true'
// Warning: In a future release, this will change to 'false' by default for enhanced security.
InsecureSkipNonce bool `json:"insecureSkipNonce,omitempty"`
// SkipDiscovery allows to skip OIDC discovery and use manually supplied Endpoints
// default set to 'false'
SkipDiscovery bool `json:"skipDiscovery,omitempty"`
@ -169,6 +175,7 @@ func providerDefaults() Providers {
},
OIDCConfig: OIDCOptions{
InsecureAllowUnverifiedEmail: false,
InsecureSkipNonce: true,
SkipDiscovery: false,
UserIDClaim: providers.OIDCEmailClaim, // Deprecated: Use OIDCEmailClaim
EmailClaim: providers.OIDCEmailClaim,

View File

@ -24,6 +24,8 @@ type SessionState struct {
IDToken string `msgpack:"it,omitempty"`
RefreshToken string `msgpack:"rt,omitempty"`
Nonce []byte `msgpack:"n,omitempty"`
Email string `msgpack:"e,omitempty"`
User string `msgpack:"u,omitempty"`
Groups []string `msgpack:"g,omitempty"`
@ -100,6 +102,11 @@ func (s *SessionState) GetClaim(claim string) []string {
}
}
// CheckNonce compares the Nonce against a potential hash of it
func (s *SessionState) CheckNonce(hashed string) bool {
return encryption.CheckNonce(s.Nonce, hashed)
}
// EncodeSessionState returns an encrypted, lz4 compressed, MessagePack encoded session
func (s *SessionState) EncodeSessionState(c encryption.Cipher, compress bool) ([]byte, error) {
packed, err := msgpack.Marshal(s)

View File

@ -153,6 +153,7 @@ func TestEncodeAndDecodeSessionState(t *testing.T) {
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
Nonce: []byte("abcdef1234567890abcdef1234567890"),
},
"No ExpiresOn": {
Email: "username@example.com",
@ -162,6 +163,7 @@ func TestEncodeAndDecodeSessionState(t *testing.T) {
IDToken: "IDToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
CreatedAt: &created,
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
Nonce: []byte("abcdef1234567890abcdef1234567890"),
},
"No PreferredUsername": {
Email: "username@example.com",
@ -171,6 +173,7 @@ func TestEncodeAndDecodeSessionState(t *testing.T) {
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
Nonce: []byte("abcdef1234567890abcdef1234567890"),
},
"Minimal session": {
User: "username",
@ -194,6 +197,7 @@ func TestEncodeAndDecodeSessionState(t *testing.T) {
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: "RefreshToken.12349871293847fdsaihf9238h4f91h8fr.1349f831y98fd7",
Nonce: []byte("abcdef1234567890abcdef1234567890"),
Groups: []string{"group-a", "group-b"},
},
}

View File

@ -12,46 +12,33 @@ import (
requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
)
// MakeCookie constructs a cookie from the given parameters,
// discovering the domain from the request if not specified.
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
if domain != "" {
host := requestutil.GetRequestHost(req)
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
if !strings.HasSuffix(host, domain) {
logger.Errorf("Warning: request host is %q but using configured cookie domain of %q", host, domain)
}
}
return &http.Cookie{
Name: name,
Value: value,
Path: path,
Domain: domain,
HttpOnly: httpOnly,
Secure: secure,
Expires: now.Add(expiration),
SameSite: sameSite,
}
}
// MakeCookieFromOptions constructs a cookie based on the given *options.CookieOptions,
// value and creation time
func MakeCookieFromOptions(req *http.Request, name string, value string, cookieOpts *options.Cookie, expiration time.Duration, now time.Time) *http.Cookie {
domain := GetCookieDomain(req, cookieOpts.Domains)
if domain != "" {
return MakeCookie(req, name, value, cookieOpts.Path, domain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
}
func MakeCookieFromOptions(req *http.Request, name string, value string, opts *options.Cookie, expiration time.Duration, now time.Time) *http.Cookie {
domain := GetCookieDomain(req, opts.Domains)
// If nothing matches, create the cookie with the shortest domain
defaultDomain := ""
if len(cookieOpts.Domains) > 0 {
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", requestutil.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
if domain == "" && len(opts.Domains) > 0 {
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q",
requestutil.GetRequestHost(req),
strings.Join(opts.Domains, ","),
)
domain = opts.Domains[len(opts.Domains)-1]
}
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
c := &http.Cookie{
Name: name,
Value: value,
Path: opts.Path,
Domain: domain,
Expires: now.Add(expiration),
HttpOnly: opts.HTTPOnly,
Secure: opts.Secure,
SameSite: ParseSameSite(opts.SameSite),
}
warnInvalidDomain(c, req)
return c
}
// GetCookieDomain returns the correct cookie domain given a list of domains
@ -81,3 +68,19 @@ func ParseSameSite(v string) http.SameSite {
panic(fmt.Sprintf("Invalid value for SameSite: %s", v))
}
}
// warnInvalidDomain logs a warning if the request host and cookie domain are
// mismatched.
func warnInvalidDomain(c *http.Cookie, req *http.Request) {
if c.Domain == "" {
return
}
host := requestutil.GetRequestHost(req)
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
if !strings.HasSuffix(host, c.Domain) {
logger.Errorf("Warning: request host is %q but using configured cookie domain of %q", host, c.Domain)
}
}

View File

@ -0,0 +1,35 @@
package cookies
import (
"net/http"
"testing"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const (
csrfState = "1234asdf1234asdf1234asdf"
csrfNonce = "0987lkjh0987lkjh0987lkjh"
cookieName = "cookie_test_12345"
cookieSecret = "3q48hmFH30FJ2HfJF0239UFJCVcl3kj3"
cookieDomain = "o2p.cookies.test"
cookiePath = "/cookie-tests"
nowEpoch = 1609366421
)
func TestProviderSuite(t *testing.T) {
logger.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "Cookies")
}
func testCookieExpires(exp time.Time) string {
var buf [len(http.TimeFormat)]byte
return string(exp.UTC().AppendFormat(buf[:0], http.TimeFormat))
}

View File

@ -0,0 +1,79 @@
package cookies
import (
"fmt"
"net/http"
middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
var _ = Describe("Cookie Tests", func() {
Context("GetCookieDomain", func() {
type getCookieDomainTableInput struct {
host string
xForwardedHost string
cookieDomains []string
expectedOutput string
}
DescribeTable("should return expected results",
func(in getCookieDomainTableInput) {
req, err := http.NewRequest(
http.MethodGet,
fmt.Sprintf("https://%s/%s", in.host, cookiePath),
nil,
)
Expect(err).ToNot(HaveOccurred())
if in.xForwardedHost != "" {
req.Header.Add("X-Forwarded-Host", in.xForwardedHost)
req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{
ReverseProxy: true,
})
}
Expect(GetCookieDomain(req, in.cookieDomains)).To(Equal(in.expectedOutput))
},
Entry("a single exact match for the Host header", getCookieDomainTableInput{
host: "www.cookies.test",
cookieDomains: []string{"www.cookies.test"},
expectedOutput: "www.cookies.test",
}),
Entry("a single exact match for the X-Forwarded-Host header", getCookieDomainTableInput{
host: "backend.cookies.internal",
xForwardedHost: "www.cookies.test",
cookieDomains: []string{"www.cookies.test"},
expectedOutput: "www.cookies.test",
}),
Entry("a single suffix match for the Host header", getCookieDomainTableInput{
host: "www.cookies.test",
cookieDomains: []string{".cookies.test"},
expectedOutput: ".cookies.test",
}),
Entry("a single suffix match for the X-Forwarded-Host header", getCookieDomainTableInput{
host: "backend.cookies.internal",
xForwardedHost: "www.cookies.test",
cookieDomains: []string{".cookies.test"},
expectedOutput: ".cookies.test",
}),
Entry("the first match is used", getCookieDomainTableInput{
host: "www.cookies.test",
cookieDomains: []string{"www.cookies.test", ".cookies.test"},
expectedOutput: "www.cookies.test",
}),
Entry("the only match is used", getCookieDomainTableInput{
host: "www.cookies.test",
cookieDomains: []string{".cookies.wrong", ".cookies.test"},
expectedOutput: ".cookies.test",
}),
Entry("blank is returned for no matches", getCookieDomainTableInput{
host: "www.cookies.test",
cookieDomains: []string{".cookies.wrong", ".cookies.false"},
expectedOutput: "",
}),
)
})
})

199
pkg/cookies/csrf.go Normal file
View File

@ -0,0 +1,199 @@
package cookies
import (
"errors"
"fmt"
"net/http"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/vmihailenco/msgpack/v4"
)
// CSRF manages various nonces stored in the CSRF cookie during the initial
// authentication flows.
type CSRF interface {
HashOAuthState() string
HashOIDCNonce() string
CheckOAuthState(string) bool
CheckOIDCNonce(string) bool
SetSessionNonce(s *sessions.SessionState)
SetCookie(http.ResponseWriter, *http.Request) (*http.Cookie, error)
ClearCookie(http.ResponseWriter, *http.Request)
}
type csrf struct {
// OAuthState holds the OAuth2 state parameter's nonce component set in the
// initial authentication request and mirrored back in the callback
// redirect from the IdP for CSRF protection.
OAuthState []byte `msgpack:"s,omitempty"`
// OIDCNonce holds the OIDC nonce parameter used in the initial authentication
// and then set in all subsequent OIDC ID Tokens as the nonce claim. This
// is used to mitigate replay attacks.
OIDCNonce []byte `msgpack:"n,omitempty"`
cookieOpts *options.Cookie
time clock.Clock
}
// NewCSRF creates a CSRF with random nonces
func NewCSRF(opts *options.Cookie) (CSRF, error) {
state, err := encryption.Nonce()
if err != nil {
return nil, err
}
nonce, err := encryption.Nonce()
if err != nil {
return nil, err
}
return &csrf{
OAuthState: state,
OIDCNonce: nonce,
cookieOpts: opts,
}, nil
}
// LoadCSRFCookie loads a CSRF object from a request's CSRF cookie
func LoadCSRFCookie(req *http.Request, opts *options.Cookie) (CSRF, error) {
cookie, err := req.Cookie(csrfCookieName(opts))
if err != nil {
return nil, err
}
return decodeCSRFCookie(cookie, opts)
}
// HashOAuthState returns the hash of the OAuth state nonce
func (c *csrf) HashOAuthState() string {
return encryption.HashNonce(c.OAuthState)
}
// HashOIDCNonce returns the hash of the OIDC nonce
func (c *csrf) HashOIDCNonce() string {
return encryption.HashNonce(c.OIDCNonce)
}
// CheckOAuthState compares the OAuth state nonce against a potential
// hash of it
func (c *csrf) CheckOAuthState(hashed string) bool {
return encryption.CheckNonce(c.OAuthState, hashed)
}
// CheckOIDCNonce compares the OIDC nonce against a potential hash of it
func (c *csrf) CheckOIDCNonce(hashed string) bool {
return encryption.CheckNonce(c.OIDCNonce, hashed)
}
// SetSessionNonce sets the OIDCNonce on a SessionState
func (c *csrf) SetSessionNonce(s *sessions.SessionState) {
s.Nonce = c.OIDCNonce
}
// SetCookie encodes the CSRF to a signed cookie and sets it on the ResponseWriter
func (c *csrf) SetCookie(rw http.ResponseWriter, req *http.Request) (*http.Cookie, error) {
encoded, err := c.encodeCookie()
if err != nil {
return nil, err
}
cookie := MakeCookieFromOptions(
req,
c.cookieName(),
encoded,
c.cookieOpts,
c.cookieOpts.Expire,
c.time.Now(),
)
http.SetCookie(rw, cookie)
return cookie, nil
}
// ClearCookie removes the CSRF cookie
func (c *csrf) ClearCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, MakeCookieFromOptions(
req,
c.cookieName(),
"",
c.cookieOpts,
time.Hour*-1,
c.time.Now(),
))
}
// encodeCookie MessagePack encodes and encrypts the CSRF and then creates a
// signed cookie value
func (c *csrf) encodeCookie() (string, error) {
packed, err := msgpack.Marshal(c)
if err != nil {
return "", fmt.Errorf("error marshalling CSRF to msgpack: %v", err)
}
encrypted, err := encrypt(packed, c.cookieOpts)
if err != nil {
return "", err
}
return encryption.SignedValue(c.cookieOpts.Secret, c.cookieName(), encrypted, c.time.Now())
}
// decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
// cookie into a CSRF struct
func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error) {
val, _, ok := encryption.Validate(cookie, opts.Secret, opts.Expire)
if !ok {
return nil, errors.New("CSRF cookie failed validation")
}
decrypted, err := decrypt(val, opts)
if err != nil {
return nil, err
}
// Valid cookie, Unmarshal the CSRF
csrf := &csrf{cookieOpts: opts}
err = msgpack.Unmarshal(decrypted, csrf)
if err != nil {
return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err)
}
return csrf, nil
}
// cookieName returns the CSRF cookie's name derived from the base
// session cookie name
func (c *csrf) cookieName() string {
return csrfCookieName(c.cookieOpts)
}
func csrfCookieName(opts *options.Cookie) string {
return fmt.Sprintf("%v_csrf", opts.Name)
}
func encrypt(data []byte, opts *options.Cookie) ([]byte, error) {
cipher, err := makeCipher(opts)
if err != nil {
return nil, err
}
return cipher.Encrypt(data)
}
func decrypt(data []byte, opts *options.Cookie) ([]byte, error) {
cipher, err := makeCipher(opts)
if err != nil {
return nil, err
}
return cipher.Decrypt(data)
}
func makeCipher(opts *options.Cookie) (encryption.Cipher, error) {
return encryption.NewCFBCipher(encryption.SecretBytes(opts.Secret))
}

190
pkg/cookies/csrf_test.go Normal file
View File

@ -0,0 +1,190 @@
package cookies
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"time"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("CSRF Cookie Tests", func() {
var (
cookieOpts *options.Cookie
publicCSRF CSRF
privateCSRF *csrf
)
BeforeEach(func() {
cookieOpts = &options.Cookie{
Name: cookieName,
Secret: cookieSecret,
Domains: []string{cookieDomain},
Path: cookiePath,
Expire: time.Hour,
Secure: true,
HTTPOnly: true,
}
var err error
publicCSRF, err = NewCSRF(cookieOpts)
Expect(err).ToNot(HaveOccurred())
privateCSRF = publicCSRF.(*csrf)
})
Context("NewCSRF", func() {
It("makes unique nonces for OAuth and OIDC", func() {
Expect(privateCSRF.OAuthState).ToNot(BeEmpty())
Expect(privateCSRF.OIDCNonce).ToNot(BeEmpty())
Expect(privateCSRF.OAuthState).ToNot(Equal(privateCSRF.OIDCNonce))
})
It("makes unique nonces between multiple CSRFs", func() {
other, err := NewCSRF(cookieOpts)
Expect(err).ToNot(HaveOccurred())
Expect(privateCSRF.OAuthState).ToNot(Equal(other.(*csrf).OAuthState))
Expect(privateCSRF.OIDCNonce).ToNot(Equal(other.(*csrf).OIDCNonce))
})
})
Context("CheckOAuthState and CheckOIDCNonce", func() {
It("checks that hashed versions match", func() {
privateCSRF.OAuthState = []byte(csrfState)
privateCSRF.OIDCNonce = []byte(csrfNonce)
stateHashed := encryption.HashNonce([]byte(csrfState))
nonceHashed := encryption.HashNonce([]byte(csrfNonce))
Expect(publicCSRF.CheckOAuthState(stateHashed)).To(BeTrue())
Expect(publicCSRF.CheckOIDCNonce(nonceHashed)).To(BeTrue())
Expect(publicCSRF.CheckOAuthState(csrfNonce)).To(BeFalse())
Expect(publicCSRF.CheckOIDCNonce(csrfState)).To(BeFalse())
Expect(publicCSRF.CheckOAuthState(csrfState + csrfNonce)).To(BeFalse())
Expect(publicCSRF.CheckOIDCNonce(csrfNonce + csrfState)).To(BeFalse())
Expect(publicCSRF.CheckOAuthState("")).To(BeFalse())
Expect(publicCSRF.CheckOIDCNonce("")).To(BeFalse())
})
})
Context("SetSessionNonce", func() {
It("sets the session.Nonce", func() {
session := &sessions.SessionState{}
publicCSRF.SetSessionNonce(session)
Expect(session.Nonce).To(Equal(privateCSRF.OIDCNonce))
})
})
Context("encodeCookie and decodeCSRFCookie", func() {
It("encodes and decodes to the same nonces", func() {
privateCSRF.OAuthState = []byte(csrfState)
privateCSRF.OIDCNonce = []byte(csrfNonce)
encoded, err := privateCSRF.encodeCookie()
Expect(err).ToNot(HaveOccurred())
cookie := &http.Cookie{
Name: privateCSRF.cookieName(),
Value: encoded,
}
decoded, err := decodeCSRFCookie(cookie, cookieOpts)
Expect(err).ToNot(HaveOccurred())
Expect(decoded).ToNot(BeNil())
Expect(decoded.OAuthState).To(Equal([]byte(csrfState)))
Expect(decoded.OIDCNonce).To(Equal([]byte(csrfNonce)))
})
It("signs the encoded cookie value", func() {
encoded, err := privateCSRF.encodeCookie()
Expect(err).ToNot(HaveOccurred())
cookie := &http.Cookie{
Name: privateCSRF.cookieName(),
Value: encoded,
}
_, _, valid := encryption.Validate(cookie, cookieOpts.Secret, cookieOpts.Expire)
Expect(valid).To(BeTrue())
})
})
Context("Cookie Management", func() {
var req *http.Request
testNow := time.Unix(nowEpoch, 0)
BeforeEach(func() {
privateCSRF.time.Set(testNow)
req = &http.Request{
Method: http.MethodGet,
Proto: "HTTP/1.1",
Host: cookieDomain,
URL: &url.URL{
Scheme: "https",
Host: cookieDomain,
Path: cookiePath,
},
}
})
AfterEach(func() {
privateCSRF.time.Reset()
})
Context("SetCookie", func() {
It("adds the encoded CSRF cookie to a ResponseWriter", func() {
rw := httptest.NewRecorder()
_, err := publicCSRF.SetCookie(rw, req)
Expect(err).ToNot(HaveOccurred())
Expect(rw.Header().Get("Set-Cookie")).To(ContainSubstring(
fmt.Sprintf("%s=", privateCSRF.cookieName()),
))
Expect(rw.Header().Get("Set-Cookie")).To(ContainSubstring(
fmt.Sprintf(
"; Path=%s; Domain=%s; Expires=%s; HttpOnly; Secure",
cookiePath,
cookieDomain,
testCookieExpires(testNow.Add(cookieOpts.Expire)),
),
))
})
})
Context("ClearCookie", func() {
It("sets a cookie with an empty value in the past", func() {
rw := httptest.NewRecorder()
publicCSRF.ClearCookie(rw, req)
Expect(rw.Header().Get("Set-Cookie")).To(Equal(
fmt.Sprintf(
"%s=; Path=%s; Domain=%s; Expires=%s; HttpOnly; Secure",
privateCSRF.cookieName(),
cookiePath,
cookieDomain,
testCookieExpires(testNow.Add(time.Hour*-1)),
),
))
})
})
Context("cookieName", func() {
It("has the cookie options name as a base", func() {
Expect(privateCSRF.cookieName()).To(ContainSubstring(cookieName))
})
})
})
})

View File

@ -1,17 +1,37 @@
package encryption
import (
"crypto/hmac"
"crypto/rand"
"fmt"
"encoding/base64"
"golang.org/x/crypto/blake2b"
)
// Nonce generates a random 16 byte string to be used as a nonce
func Nonce() (nonce string, err error) {
b := make([]byte, 16)
_, err = rand.Read(b)
// Nonce generates a random 32-byte slice to be used as a nonce
func Nonce() ([]byte, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return
return nil, err
}
nonce = fmt.Sprintf("%x", b)
return
return b, nil
}
// HashNonce returns the BLAKE2b 256-bit hash of a nonce
// NOTE: Error checking (G104) is purposefully skipped:
// - `blake2b.New256` has no error path with a nil signing key
// - `hash.Hash` interface's `Write` has an error signature, but
// `blake2b.digest.Write` does not use it.
/* #nosec G104 */
func HashNonce(nonce []byte) string {
hasher, _ := blake2b.New256(nil)
hasher.Write(nonce)
sum := hasher.Sum(nil)
return base64.RawURLEncoding.EncodeToString(sum)
}
// CheckNonce tests if a nonce matches the hashed version of it
func CheckNonce(nonce []byte, hashed string) bool {
return hmac.Equal([]byte(HashNonce(nonce)), []byte(hashed))
}

View File

@ -264,6 +264,7 @@ func parseProviderInfo(o *options.Options, msgs []string) []string {
p.SetTeam(o.Providers[0].BitbucketConfig.Team)
p.SetRepository(o.Providers[0].BitbucketConfig.Repository)
case *providers.OIDCProvider:
p.SkipNonce = o.Providers[0].OIDCConfig.InsecureSkipNonce
if p.Verifier == nil {
msgs = append(msgs, "oidc provider requires an oidc issuer URL")
}

View File

@ -2,6 +2,7 @@ package validation
import (
"context"
"encoding/base64"
"fmt"
"time"
@ -50,10 +51,11 @@ func validateRedisSessionStore(o *options.Options) []string {
return []string{fmt.Sprintf("unable to initialize a redis client: %v", err)}
}
nonce, err := encryption.Nonce()
n, err := encryption.Nonce()
if err != nil {
return []string{fmt.Sprintf("unable to generate a redis initialization test key: %v", err)}
}
nonce := base64.RawURLEncoding.EncodeToString(n)
key := fmt.Sprintf("%s-healthcheck-%s", o.Cookie.Name, nonce)
return sendRedisConnectionTest(client, key, nonce)

View File

@ -107,7 +107,7 @@ func overrideTenantURL(current, defaultURL *url.URL, tenant, path string) {
}
}
func (p *AzureProvider) GetLoginURL(redirectURI, state string) string {
func (p *AzureProvider) GetLoginURL(redirectURI, state, _ string) string {
extraParams := url.Values{}
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
extraParams.Add("resource", p.ProtectedResource.String())

View File

@ -336,7 +336,7 @@ func TestAzureProviderRedeem(t *testing.T) {
func TestAzureProviderProtectedResourceConfigured(t *testing.T) {
p := testAzureProvider("")
p.ProtectedResource, _ = url.Parse("http://my.resource.test")
result := p.GetLoginURL("https://my.test.app/oauth", "")
result := p.GetLoginURL("https://my.test.app/oauth", "", "")
assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test"))
}

View File

@ -228,7 +228,7 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
}
// GetLoginURL overrides GetLoginURL to add login.gov parameters
func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string {
func (p *LoginGovProvider) GetLoginURL(redirectURI, state, _ string) string {
extraParams := url.Values{}
if p.AcrValues == "" {
acr := "http://idmanagement.gov/ns/assurance/loa/1"

View File

@ -292,7 +292,7 @@ func TestLoginGovProviderBadNonce(t *testing.T) {
func TestLoginGovProviderGetLoginURL(t *testing.T) {
p, _, _ := newLoginGovProvider()
result := p.GetLoginURL("http://redirect/", "")
result := p.GetLoginURL("http://redirect/", "", "")
assert.Contains(t, result, "acr_values="+url.QueryEscape("http://idmanagement.gov/ns/assurance/loa/1"))
assert.Contains(t, result, "nonce=fakenonce")
}

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/url"
"reflect"
"time"
@ -16,16 +17,31 @@ import (
// OIDCProvider represents an OIDC based Identity Provider
type OIDCProvider struct {
*ProviderData
SkipNonce bool
}
// NewOIDCProvider initiates a new OIDCProvider
func NewOIDCProvider(p *ProviderData) *OIDCProvider {
p.ProviderName = "OpenID Connect"
return &OIDCProvider{ProviderData: p}
return &OIDCProvider{
ProviderData: p,
SkipNonce: true,
}
}
var _ Provider = (*OIDCProvider)(nil)
// GetLoginURL makes the LoginURL with optional nonce support
func (p *OIDCProvider) GetLoginURL(redirectURI, state, nonce string) string {
extraParams := url.Values{}
if !p.SkipNonce {
extraParams.Add("nonce", nonce)
}
loginURL := makeLoginURL(p.Data(), redirectURI, state, extraParams)
return loginURL.String()
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
clientSecret, err := p.GetClientSecret()
@ -109,8 +125,22 @@ func (p *OIDCProvider) enrichFromProfileURL(ctx context.Context, s *sessions.Ses
// ValidateSession checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool {
_, err := p.Verifier.Verify(ctx, s.IDToken)
return err == nil
idToken, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil {
logger.Errorf("id_token verification failed: %v", err)
return false
}
if p.SkipNonce {
return true
}
err = p.checkNonce(s, idToken)
if err != nil {
logger.Errorf("nonce verification failed: %v", err)
return false
}
return true
}
// RefreshSessionIfNeeded checks if the session has expired and uses the

View File

@ -2,8 +2,10 @@ package providers
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
@ -11,6 +13,7 @@ import (
"github.com/coreos/go-oidc"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/stretchr/testify/assert"
)
@ -23,7 +26,6 @@ type redeemTokenResponse struct {
}
func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
providerData := &ProviderData{
ProviderName: "oidc",
ClientID: oidcClientID,
@ -54,7 +56,7 @@ func newOIDCProvider(serverURL *url.URL) *OIDCProvider {
),
}
p := &OIDCProvider{ProviderData: providerData}
p := NewOIDCProvider(providerData)
return p
}
@ -74,8 +76,27 @@ func newTestOIDCSetup(body []byte) (*httptest.Server, *OIDCProvider) {
return server, provider
}
func TestOIDCProviderRedeem(t *testing.T) {
func TestOIDCProviderGetLoginURL(t *testing.T) {
serverURL := &url.URL{
Scheme: "https",
Host: "oauth2proxy.oidctest",
}
provider := newOIDCProvider(serverURL)
n, err := encryption.Nonce()
assert.NoError(t, err)
nonce := base64.RawURLEncoding.EncodeToString(n)
// SkipNonce defaults to true
skipNonce := provider.GetLoginURL("http://redirect/", "", nonce)
assert.NotContains(t, skipNonce, "nonce")
provider.SkipNonce = false
withNonce := provider.GetLoginURL("http://redirect/", "", nonce)
assert.Contains(t, withNonce, fmt.Sprintf("nonce=%s", nonce))
}
func TestOIDCProviderRedeem(t *testing.T) {
idToken, _ := newSignedTestIDToken(defaultIDToken)
body, _ := json.Marshal(redeemTokenResponse{
AccessToken: accessToken,
@ -98,7 +119,6 @@ func TestOIDCProviderRedeem(t *testing.T) {
}
func TestOIDCProviderRedeem_custom_userid(t *testing.T) {
idToken, _ := newSignedTestIDToken(defaultIDToken)
body, _ := json.Marshal(redeemTokenResponse{
AccessToken: accessToken,

View File

@ -122,6 +122,7 @@ type OIDCClaims struct {
Email string `json:"-"`
Groups []string `json:"-"`
Verified *bool `json:"email_verified"`
Nonce string `json:"nonce"`
raw map[string]interface{}
}
@ -192,6 +193,18 @@ func (p *ProviderData) getClaims(idToken *oidc.IDToken) (*OIDCClaims, error) {
return claims, nil
}
// checkNonce compares the session's nonce with the IDToken's nonce claim
func (p *ProviderData) checkNonce(s *sessions.SessionState, idToken *oidc.IDToken) error {
claims, err := p.getClaims(idToken)
if err != nil {
return fmt.Errorf("id_token claims extraction failed: %v", err)
}
if !s.CheckNonce(claims.Nonce) {
return errors.New("id_token nonce claim does not match the session nonce")
}
return nil
}
// extractGroups extracts groups from a claim to a list in a type safe manner.
// If the claim isn't present, `nil` is returned. If the groups claim is
// present but empty, `[]string{}` is returned.

View File

@ -15,6 +15,7 @@ import (
"github.com/coreos/go-oidc"
"github.com/dgrijalva/jwt-go"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
. "github.com/onsi/gomega"
"golang.org/x/oauth2"
)
@ -27,6 +28,7 @@ const (
oidcIssuer = "https://issuer.example.com"
oidcClientID = "https://test.myapp.com"
oidcSecret = "SuperSecret123456789"
oidcNonce = "abcde12345edcba09876abcde12345ff"
failureTokenID = "this-id-fails-verification"
)
@ -53,6 +55,7 @@ var (
Groups: []string{"test:a", "test:b"},
Roles: []string{"test:c", "test:d"},
Verified: &verified,
Nonce: encryption.HashNonce([]byte(oidcNonce)),
StandardClaims: standardClaims,
}
@ -96,6 +99,7 @@ type idTokenClaims struct {
Groups interface{} `json:"groups,omitempty"`
Roles interface{} `json:"roles,omitempty"`
Verified *bool `json:"email_verified,omitempty"`
Nonce string `json:"nonce,omitempty"`
jwt.StandardClaims
}
@ -348,6 +352,63 @@ func TestProviderData_buildSessionFromClaims(t *testing.T) {
}
}
func TestProviderData_checkNonce(t *testing.T) {
testCases := map[string]struct {
Session *sessions.SessionState
IDToken idTokenClaims
ExpectedError error
}{
"Nonces match": {
Session: &sessions.SessionState{
Nonce: []byte(oidcNonce),
},
IDToken: defaultIDToken,
ExpectedError: nil,
},
"Nonces do not match": {
Session: &sessions.SessionState{
Nonce: []byte("WrongWrongWrong"),
},
IDToken: defaultIDToken,
ExpectedError: errors.New("id_token nonce claim does not match the session nonce"),
},
"Missing nonce claim": {
Session: &sessions.SessionState{
Nonce: []byte(oidcNonce),
},
IDToken: minimalIDToken,
ExpectedError: errors.New("id_token nonce claim does not match the session nonce"),
},
}
for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
g := NewWithT(t)
provider := &ProviderData{
Verifier: oidc.NewVerifier(
oidcIssuer,
mockJWKS{},
&oidc.Config{ClientID: oidcClientID},
),
}
rawIDToken, err := newSignedTestIDToken(tc.IDToken)
g.Expect(err).ToNot(HaveOccurred())
idToken, err := provider.Verifier.Verify(context.Background(), rawIDToken)
g.Expect(err).ToNot(HaveOccurred())
err = provider.checkNonce(tc.Session, idToken)
if err != nil {
g.Expect(err).To(Equal(tc.ExpectedError))
} else {
g.Expect(err).ToNot(HaveOccurred())
}
})
}
}
func TestProviderData_extractGroups(t *testing.T) {
testCases := map[string]struct {
Claims map[string]interface{}

View File

@ -33,6 +33,13 @@ var (
_ Provider = (*ProviderData)(nil)
)
// GetLoginURL with typical oauth parameters
func (p *ProviderData) GetLoginURL(redirectURI, state, _ string) string {
extraParams := url.Values{}
loginURL := makeLoginURL(p, redirectURI, state, extraParams)
return loginURL.String()
}
// Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) {
if code == "" {
@ -86,13 +93,6 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s
return nil, fmt.Errorf("no access token found %s", result.Body())
}
// GetLoginURL with typical oauth parameters
func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
extraParams := url.Values{}
a := makeLoginURL(p, redirectURI, state, extraParams)
return a.String()
}
// GetEmailAddress returns the Account email address
// Deprecated: Migrate to EnrichSession
func (p *ProviderData) GetEmailAddress(_ context.Context, _ *sessions.SessionState) (string, error) {

View File

@ -31,7 +31,7 @@ func TestAcrValuesNotConfigured(t *testing.T) {
},
}
result := p.GetLoginURL("https://my.test.app/oauth", "")
result := p.GetLoginURL("https://my.test.app/oauth", "", "")
assert.NotContains(t, result, "acr_values")
}
@ -45,7 +45,7 @@ func TestAcrValuesConfigured(t *testing.T) {
AcrValues: "testValue",
}
result := p.GetLoginURL("https://my.test.app/oauth", "")
result := p.GetLoginURL("https://my.test.app/oauth", "", "")
assert.Contains(t, result, "acr_values=testValue")
}

View File

@ -11,11 +11,11 @@ type Provider interface {
Data() *ProviderData
// Deprecated: Migrate to EnrichSession
GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error)
GetLoginURL(redirectURI, state, nonce string) string
Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error)
EnrichSession(ctx context.Context, s *sessions.SessionState) error
Authorize(ctx context.Context, s *sessions.SessionState) (bool, error)
ValidateSession(ctx context.Context, s *sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error)
CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error)
}