oauth2-proxy/providers/adfs_test.go
2021-12-01 19:16:42 -08:00

293 lines
7.9 KiB
Go

package providers
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
. "github.com/onsi/ginkgo"
. "github.com/onsi/ginkgo/extensions/table"
. "github.com/onsi/gomega"
)
type fakeADFSJwks struct{}
func (fakeADFSJwks) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
decodeString, err := base64.RawURLEncoding.DecodeString(strings.Split(jwt, ".")[1])
if err != nil {
return nil, err
}
return decodeString, nil
}
type adfsClaims struct {
UPN string `json:"upn,omitempty"`
idTokenClaims
}
func newSignedTestADFSToken(tokenClaims adfsClaims) (string, error) {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
standardClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, tokenClaims)
return standardClaims.SignedString(key)
}
func testADFSProvider(hostname string) *ADFSProvider {
o := oidc.NewVerifier(
"https://issuer.example.com",
fakeADFSJwks{},
&oidc.Config{ClientID: "https://test.myapp.com"},
)
p := NewADFSProvider(&ProviderData{
ProviderName: "",
LoginURL: &url.URL{},
RedeemURL: &url.URL{},
ProfileURL: &url.URL{},
ValidateURL: &url.URL{},
Scope: "",
Verifier: o,
EmailClaim: OIDCEmailClaim,
})
if hostname != "" {
updateURL(p.Data().LoginURL, hostname)
updateURL(p.Data().RedeemURL, hostname)
updateURL(p.Data().ProfileURL, hostname)
updateURL(p.Data().ValidateURL, hostname)
}
return p
}
func testADFSBackend() *httptest.Server {
authResponse := `
{
"access_token": "my_access_token",
"id_token": "my_id_token",
"refresh_token": "my_refresh_token"
}
`
userInfo := `
{
"email": "samiracho@email.com"
}
`
refreshResponse := `{ "access_token": "new_some_access_token", "refresh_token": "new_some_refresh_token", "expires_in": "32693148245", "id_token": "new_some_id_token" }`
authHeader := "Bearer adfs_access_token"
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/adfs/oauth2/authorize":
w.WriteHeader(200)
w.Write([]byte(authResponse))
case "/adfs/oauth2/refresh":
w.WriteHeader(200)
w.Write([]byte(refreshResponse))
case "/adfs/oauth2/userinfo":
if r.Header["Authorization"][0] == authHeader {
w.WriteHeader(200)
w.Write([]byte(userInfo))
} else {
w.WriteHeader(401)
}
default:
w.WriteHeader(200)
}
}))
}
var _ = Describe("ADFS Provider Tests", func() {
var p *ADFSProvider
var b *httptest.Server
BeforeEach(func() {
b = testADFSBackend()
bURL, err := url.Parse(b.URL)
Expect(err).To(BeNil())
p = testADFSProvider(bURL.Host)
})
AfterEach(func() {
b.Close()
})
Context("New Provider Init", func() {
It("uses defaults", func() {
providerData := NewADFSProvider(&ProviderData{}).Data()
Expect(providerData.ProviderName).To(Equal("ADFS"))
Expect(providerData.Scope).To(Equal("openid email profile"))
})
})
Context("with bad token", func() {
It("should trigger an error", func() {
session := &sessions.SessionState{AccessToken: "unexpected_adfs_access_token", IDToken: "malformed_token"}
err := p.EnrichSession(context.Background(), session)
Expect(err).NotTo(BeNil())
})
})
Context("with valid token", func() {
It("should not throw an error", func() {
rawIDToken, _ := newSignedTestIDToken(defaultIDToken)
idToken, err := p.Verifier.Verify(context.Background(), rawIDToken)
Expect(err).To(BeNil())
session, err := p.buildSessionFromClaims(idToken)
Expect(err).To(BeNil())
session.IDToken = rawIDToken
err = p.EnrichSession(context.Background(), session)
Expect(session.Email).To(Equal("janed@me.com"))
Expect(err).To(BeNil())
})
})
Context("with skipScope enabled", func() {
It("should not include parameter scope", func() {
resource, _ := url.Parse("http://example.com")
p := NewADFSProvider(&ProviderData{
ProtectedResource: resource,
Scope: "",
})
p.skipScope = true
result := p.GetLoginURL("https://example.com/adfs/oauth2/", "", "")
Expect(result).NotTo(ContainSubstring("scope="))
})
})
Context("With resource parameter", func() {
type scopeTableInput struct {
resource string
scope string
expectedScope string
}
DescribeTable("should return expected results",
func(in scopeTableInput) {
resource, _ := url.Parse(in.resource)
p := NewADFSProvider(&ProviderData{
ProtectedResource: resource,
Scope: in.scope,
})
Expect(p.Data().Scope).To(Equal(in.expectedScope))
result := p.GetLoginURL("https://example.com/adfs/oauth2/", "", "")
Expect(result).To(ContainSubstring("scope=" + url.QueryEscape(in.expectedScope)))
},
Entry("should add slash", scopeTableInput{
resource: "http://resource.com",
scope: "openid",
expectedScope: "http://resource.com/openid",
}),
Entry("shouldn't add extra slash", scopeTableInput{
resource: "http://resource.com/",
scope: "openid",
expectedScope: "http://resource.com/openid",
}),
Entry("should add default scopes with resource", scopeTableInput{
resource: "http://resource.com/",
scope: "",
expectedScope: "http://resource.com/openid email profile",
}),
Entry("should add default scopes", scopeTableInput{
resource: "",
scope: "",
expectedScope: "openid email profile",
}),
Entry("shouldn't add resource if already in scopes", scopeTableInput{
resource: "http://resource.com",
scope: "http://resource.com/openid",
expectedScope: "http://resource.com/openid",
}),
)
})
Context("UPN Fallback", func() {
var idToken string
var session *sessions.SessionState
BeforeEach(func() {
var err error
idToken, err = newSignedTestADFSToken(adfsClaims{
UPN: "upn@company.com",
idTokenClaims: minimalIDToken,
})
Expect(err).ToNot(HaveOccurred())
session = &sessions.SessionState{
IDToken: idToken,
}
})
Describe("EnrichSession", func() {
It("uses email claim if present", func() {
p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error {
s.Email = "person@company.com"
return nil
}
err := p.EnrichSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("person@company.com"))
})
It("falls back to UPN claim if Email is missing", func() {
p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error {
return nil
}
err := p.EnrichSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("upn@company.com"))
})
It("falls back to UPN claim on errors", func() {
p.oidcEnrichFunc = func(_ context.Context, s *sessions.SessionState) error {
return errors.New("neither the id_token nor the profileURL set an email")
}
err := p.EnrichSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("upn@company.com"))
})
})
Describe("RefreshSession", func() {
It("uses email claim if present", func() {
p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) {
s.Email = "person@company.com"
return true, nil
}
_, err := p.RefreshSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("person@company.com"))
})
It("falls back to UPN claim if Email is missing", func() {
p.oidcRefreshFunc = func(_ context.Context, s *sessions.SessionState) (bool, error) {
return true, nil
}
_, err := p.RefreshSession(context.Background(), session)
Expect(err).ToNot(HaveOccurred())
Expect(session.Email).To(Equal("upn@company.com"))
})
})
})
})