mirror of
https://github.com/digital-asset/daml.git
synced 2024-09-17 07:47:14 +03:00
trigger service auth /auth endpoint (#7545)
* Factor out request bodies Addressing review comment https://github.com/digital-asset/daml/pull/7519#discussion_r497321689 * Implement /auth endpoint changelog_begin changelog_end * /auth check the required claims * Factor out middlewareUri * fmt * less implicit variables Co-authored-by: Andreas Herrmann <andreas.herrmann@tweag.io>
This commit is contained in:
parent
2f325349ef
commit
c5abcece56
@ -112,7 +112,9 @@ da_scala_test(
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/resources",
|
||||
"@maven//:com_typesafe_akka_akka_actor_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_http_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_http_core_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_http_spray_json_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_parsing_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_stream_2_12",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging_2_12",
|
||||
|
@ -4,10 +4,21 @@
|
||||
package com.daml.oauth.middleware
|
||||
|
||||
import akka.http.scaladsl.model.Uri
|
||||
import spray.json.{DefaultJsonProtocol, JsString, JsValue, JsonFormat, deserializationError}
|
||||
import spray.json.{
|
||||
DefaultJsonProtocol,
|
||||
JsString,
|
||||
JsValue,
|
||||
JsonFormat,
|
||||
RootJsonFormat,
|
||||
deserializationError
|
||||
}
|
||||
|
||||
object Request {
|
||||
|
||||
/** Auth endpoint query parameters
|
||||
*/
|
||||
case class Auth(claims: String) // TODO[AH] parse ledger claims
|
||||
|
||||
/** Login endpoint query parameters
|
||||
*
|
||||
* @param redirectUri Redirect target after the login flow completed. I.e. the original request URI on the trigger service.
|
||||
@ -17,7 +28,11 @@ object Request {
|
||||
|
||||
}
|
||||
|
||||
object Response {}
|
||||
object Response {
|
||||
|
||||
case class Authorize(accessToken: String, refreshToken: Option[String])
|
||||
|
||||
}
|
||||
|
||||
object JsonProtocol extends DefaultJsonProtocol {
|
||||
implicit object UriFormat extends JsonFormat[Uri] {
|
||||
@ -27,4 +42,6 @@ object JsonProtocol extends DefaultJsonProtocol {
|
||||
}
|
||||
def write(uri: Uri) = JsString(uri.toString)
|
||||
}
|
||||
implicit val responseAuthorizeFormat: RootJsonFormat[Response.Authorize] =
|
||||
jsonFormat(Response.Authorize, "access_token", "refresh_token")
|
||||
}
|
||||
|
@ -9,12 +9,12 @@ import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.Http.ServerBinding
|
||||
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.scaladsl.model.headers.HttpCookie
|
||||
import akka.http.scaladsl.model.headers.{HttpCookie, HttpCookiePair}
|
||||
import akka.http.scaladsl.server.Directive1
|
||||
import akka.http.scaladsl.server.Directives._
|
||||
import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller}
|
||||
import com.daml.oauth.server.{Request => OAuthRequest, Response => OAuthResponse}
|
||||
import com.typesafe.scalalogging.StrictLogging
|
||||
import java.util.Base64
|
||||
import java.util.UUID
|
||||
import scala.collection.concurrent.TrieMap
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
@ -25,7 +25,9 @@ import spray.json._
|
||||
// This is an implementation of the trigger service authentication middleware
|
||||
// for OAuth2 as specified in `/triggers/service/authentication.md`
|
||||
object Server extends StrictLogging {
|
||||
import JsonProtocol._
|
||||
import com.daml.oauth.server.JsonProtocol._
|
||||
implicit private val unmarshal: Unmarshaller[String, Uri] = Unmarshaller.strict(Uri(_))
|
||||
|
||||
// TODO[AH] Make the redirect URI configurable, especially the authority. E.g. when running behind nginx.
|
||||
private def toRedirectUri(uri: Uri) =
|
||||
@ -36,82 +38,12 @@ object Server extends StrictLogging {
|
||||
|
||||
def start(
|
||||
config: Config)(implicit system: ActorSystem, ec: ExecutionContext): Future[ServerBinding] = {
|
||||
implicit val unmarshal: Unmarshaller[String, Uri] = Unmarshaller.strict(Uri(_))
|
||||
// TODO[AH] Make sure this is bounded in size - or avoid state altogether.
|
||||
val requests = TrieMap[UUID, Uri]()
|
||||
val requests: TrieMap[UUID, Uri] = TrieMap()
|
||||
val route = concat(
|
||||
path("auth") {
|
||||
get {
|
||||
complete((StatusCodes.NotImplemented, "The /auth endpoint is not implemented yet"))
|
||||
}
|
||||
},
|
||||
path("login") {
|
||||
get {
|
||||
parameters(('redirect_uri.as[Uri], 'claims))
|
||||
.as[Request.Login](Request.Login) {
|
||||
login =>
|
||||
extractRequest {
|
||||
request =>
|
||||
val requestId = UUID.randomUUID
|
||||
requests += (requestId -> login.redirectUri)
|
||||
val authorize = OAuthRequest.Authorize(
|
||||
responseType = "code",
|
||||
clientId = config.clientId,
|
||||
redirectUri = toRedirectUri(request.uri),
|
||||
scope = Some(login.claims),
|
||||
state = Some(requestId.toString))
|
||||
redirect(
|
||||
config.oauthUri
|
||||
.withPath(Uri.Path./("authorize"))
|
||||
.withQuery(authorize.toQuery),
|
||||
StatusCodes.Found)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
path("cb") {
|
||||
get {
|
||||
parameters(('code, 'state ?))
|
||||
.as[OAuthResponse.Authorize](OAuthResponse.Authorize) {
|
||||
authorize =>
|
||||
extractRequest {
|
||||
request =>
|
||||
val redirectUri = for {
|
||||
state <- authorize.state
|
||||
requestId <- Try(UUID.fromString(state)).toOption
|
||||
redirectUri <- requests.remove(requestId)
|
||||
} yield redirectUri
|
||||
redirectUri match {
|
||||
case None =>
|
||||
complete(StatusCodes.NotFound)
|
||||
case Some(redirectUri) =>
|
||||
val body = OAuthRequest.Token(
|
||||
grantType = "authorization_code",
|
||||
code = authorize.code,
|
||||
redirectUri = toRedirectUri(request.uri),
|
||||
clientId = config.clientId,
|
||||
clientSecret = config.clientSecret)
|
||||
val req = HttpRequest(
|
||||
uri = config.oauthUri.withPath(Uri.Path./("token")),
|
||||
entity =
|
||||
HttpEntity(MediaTypes.`application/json`, body.toJson.compactPrint),
|
||||
method = HttpMethods.POST)
|
||||
val tokenRequest = for {
|
||||
resp <- Http().singleRequest(req)
|
||||
tokenResp <- Unmarshal(resp).to[OAuthResponse.Token]
|
||||
} yield tokenResp
|
||||
onSuccess(tokenRequest) { token =>
|
||||
val encoder = Base64.getUrlEncoder()
|
||||
val content = encoder.encodeToString(token.toJson.compactPrint.getBytes)
|
||||
setCookie(HttpCookie("daml-ledger-token", content)) {
|
||||
redirect(redirectUri, StatusCodes.Found)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
path("auth") { get { auth } },
|
||||
path("login") { get { login(config, requests) } },
|
||||
path("cb") { get { loginCallback(config, requests) } },
|
||||
path("refresh") {
|
||||
get {
|
||||
complete((StatusCodes.NotImplemented, "The /refresh endpoint is not implemented yet"))
|
||||
@ -121,6 +53,86 @@ object Server extends StrictLogging {
|
||||
|
||||
Http().bindAndHandle(route, "localhost", config.port.value)
|
||||
}
|
||||
|
||||
def stop(f: Future[ServerBinding])(implicit ec: ExecutionContext): Future[Done] =
|
||||
f.flatMap(_.unbind())
|
||||
|
||||
private val cookieName = "daml-ledger-token"
|
||||
|
||||
private def optionalToken: Directive1[Option[OAuthResponse.Token]] = {
|
||||
def f(x: HttpCookiePair) = OAuthResponse.Token.fromCookieValue(x.value)
|
||||
optionalCookie(cookieName).map(_.flatMap(f))
|
||||
}
|
||||
|
||||
private def auth =
|
||||
parameters(('claims))
|
||||
.as[Request.Auth](Request.Auth) { auth =>
|
||||
optionalToken {
|
||||
// TODO[AH] Implement mapping from scope to claims
|
||||
// TODO[AH] Check whether granted scope subsumes requested claims
|
||||
case Some(token) if token.scope == Some(auth.claims) =>
|
||||
complete(
|
||||
Response
|
||||
.Authorize(accessToken = token.accessToken, refreshToken = token.refreshToken))
|
||||
case _ => complete(StatusCodes.Unauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
private def login(config: Config, requests: TrieMap[UUID, Uri]) =
|
||||
parameters(('redirect_uri.as[Uri], 'claims))
|
||||
.as[Request.Login](Request.Login) { login =>
|
||||
extractRequest { request =>
|
||||
val requestId = UUID.randomUUID
|
||||
requests += (requestId -> login.redirectUri)
|
||||
val authorize = OAuthRequest.Authorize(
|
||||
responseType = "code",
|
||||
clientId = config.clientId,
|
||||
redirectUri = toRedirectUri(request.uri),
|
||||
scope = Some(login.claims),
|
||||
state = Some(requestId.toString))
|
||||
redirect(
|
||||
config.oauthUri
|
||||
.withPath(Uri.Path./("authorize"))
|
||||
.withQuery(authorize.toQuery),
|
||||
StatusCodes.Found)
|
||||
}
|
||||
}
|
||||
|
||||
private def loginCallback(config: Config, requests: TrieMap[UUID, Uri])(
|
||||
implicit system: ActorSystem,
|
||||
ec: ExecutionContext) =
|
||||
parameters(('code, 'state ?))
|
||||
.as[OAuthResponse.Authorize](OAuthResponse.Authorize) { authorize =>
|
||||
extractRequest { request =>
|
||||
val redirectUri = for {
|
||||
state <- authorize.state
|
||||
requestId <- Try(UUID.fromString(state)).toOption
|
||||
redirectUri <- requests.remove(requestId)
|
||||
} yield redirectUri
|
||||
redirectUri match {
|
||||
case None =>
|
||||
complete(StatusCodes.NotFound)
|
||||
case Some(redirectUri) =>
|
||||
val body = OAuthRequest.Token(
|
||||
grantType = "authorization_code",
|
||||
code = authorize.code,
|
||||
redirectUri = toRedirectUri(request.uri),
|
||||
clientId = config.clientId,
|
||||
clientSecret = config.clientSecret)
|
||||
val req = HttpRequest(
|
||||
uri = config.oauthUri.withPath(Uri.Path./("token")),
|
||||
entity = HttpEntity(MediaTypes.`application/json`, body.toJson.compactPrint),
|
||||
method = HttpMethods.POST)
|
||||
val tokenRequest = for {
|
||||
resp <- Http().singleRequest(req)
|
||||
tokenResp <- Unmarshal(resp).to[OAuthResponse.Token]
|
||||
} yield tokenResp
|
||||
onSuccess(tokenRequest) { token =>
|
||||
setCookie(HttpCookie(cookieName, token.toCookieValue)) {
|
||||
redirect(redirectUri, StatusCodes.Found)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,16 +3,13 @@
|
||||
|
||||
package com.daml.oauth.server
|
||||
|
||||
import java.util.Base64
|
||||
|
||||
import akka.http.scaladsl.model.Uri
|
||||
import akka.http.scaladsl.model.Uri.Query
|
||||
import spray.json.{
|
||||
DefaultJsonProtocol,
|
||||
JsString,
|
||||
JsValue,
|
||||
JsonFormat,
|
||||
RootJsonFormat,
|
||||
deserializationError
|
||||
}
|
||||
import spray.json._
|
||||
|
||||
import scala.util.Try
|
||||
|
||||
object Request {
|
||||
|
||||
@ -65,7 +62,23 @@ object Response {
|
||||
tokenType: String,
|
||||
expiresIn: Option[String],
|
||||
refreshToken: Option[String],
|
||||
scope: Option[String])
|
||||
scope: Option[String]) {
|
||||
def toCookieValue: String = {
|
||||
import JsonProtocol._
|
||||
Base64.getUrlEncoder().encodeToString(this.toJson.compactPrint.getBytes)
|
||||
}
|
||||
}
|
||||
|
||||
object Token {
|
||||
def fromCookieValue(s: String): Option[Token] = {
|
||||
import JsonProtocol._
|
||||
for {
|
||||
bytes <- Try(Base64.getUrlDecoder().decode(s))
|
||||
json <- Try(new String(bytes).parseJson)
|
||||
token <- Try(json.convertTo[Token])
|
||||
} yield token
|
||||
}.toOption
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -78,7 +91,19 @@ object JsonProtocol extends DefaultJsonProtocol {
|
||||
def write(uri: Uri) = JsString(uri.toString)
|
||||
}
|
||||
implicit val tokenReqFormat: RootJsonFormat[Request.Token] =
|
||||
jsonFormat(Request.Token, "grant_type", "code", "redirect_uri", "client_id", "client_secret")
|
||||
jsonFormat(
|
||||
Request.Token.apply,
|
||||
"grant_type",
|
||||
"code",
|
||||
"redirect_uri",
|
||||
"client_id",
|
||||
"client_secret")
|
||||
implicit val tokenRespFormat: RootJsonFormat[Response.Token] =
|
||||
jsonFormat(Response.Token, "access_token", "token_type", "expires_in", "refresh_token", "scope")
|
||||
jsonFormat(
|
||||
Response.Token.apply,
|
||||
"access_token",
|
||||
"token_type",
|
||||
"expires_in",
|
||||
"refresh_token",
|
||||
"scope")
|
||||
}
|
||||
|
@ -4,24 +4,83 @@
|
||||
package com.daml.oauth.middleware
|
||||
|
||||
import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import akka.http.scaladsl.model.Uri.{Path, Query}
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.scaladsl.model.headers.{Location, `Set-Cookie`}
|
||||
import akka.http.scaladsl.model.headers.{Cookie, Location, `Set-Cookie`}
|
||||
import akka.http.scaladsl.unmarshalling.Unmarshal
|
||||
import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
|
||||
import com.daml.oauth.server.{Response => OAuthResponse}
|
||||
import java.util.Base64
|
||||
import org.scalatest.AsyncWordSpec
|
||||
import spray.json._
|
||||
|
||||
class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAroundAll {
|
||||
import com.daml.oauth.server.JsonProtocol._
|
||||
"the middleware" should {
|
||||
"redirect and set cookie after login" in {
|
||||
lazy val middlewareBinding = suiteResource.value._2.localAddress
|
||||
lazy val middlewareUri =
|
||||
Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(middlewareBinding.getHostString, middlewareBinding.getPort)
|
||||
import JsonProtocol._
|
||||
lazy private val middlewareUri = {
|
||||
lazy val middlewareBinding = suiteResource.value._2.localAddress
|
||||
Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(middlewareBinding.getHostString, middlewareBinding.getPort)
|
||||
}
|
||||
"the /auth endpoint" should {
|
||||
"return unauthorized without cookie" in {
|
||||
val claims = "actAs:Alice"
|
||||
val req = HttpRequest(
|
||||
uri = middlewareUri
|
||||
.withPath(Path./("auth"))
|
||||
.withQuery(Query(("claims", claims))))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
} yield {
|
||||
assert(resp.status == StatusCodes.Unauthorized)
|
||||
}
|
||||
}
|
||||
"return the token from a cookie" in {
|
||||
val claims = "actAs:Alice"
|
||||
val token = OAuthResponse.Token(
|
||||
accessToken = "access",
|
||||
tokenType = "bearer",
|
||||
expiresIn = None,
|
||||
refreshToken = Some("refresh"),
|
||||
scope = Some(claims))
|
||||
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
|
||||
val req = HttpRequest(
|
||||
uri = middlewareUri
|
||||
.withPath(Path./("auth"))
|
||||
.withQuery(Query(("claims", claims))),
|
||||
headers = List(cookieHeader))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
auth <- {
|
||||
assert(resp.status == StatusCodes.OK)
|
||||
Unmarshal(resp).to[Response.Authorize]
|
||||
}
|
||||
} yield {
|
||||
assert(auth.accessToken == token.accessToken)
|
||||
assert(auth.refreshToken == token.refreshToken)
|
||||
}
|
||||
}
|
||||
"return unauthorized on insufficient claims" in {
|
||||
val token = OAuthResponse.Token(
|
||||
accessToken = "access",
|
||||
tokenType = "bearer",
|
||||
expiresIn = None,
|
||||
refreshToken = Some("refresh"),
|
||||
scope = Some("actAs:Alice"))
|
||||
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
|
||||
val req = HttpRequest(
|
||||
uri = middlewareUri
|
||||
.withPath(Path./("auth"))
|
||||
.withQuery(Query(("claims", "actAs:Bob"))),
|
||||
headers = List(cookieHeader))
|
||||
for {
|
||||
resp <- Http().singleRequest(req)
|
||||
} yield {
|
||||
assert(resp.status == StatusCodes.Unauthorized)
|
||||
}
|
||||
}
|
||||
}
|
||||
"the /login endpoint" should {
|
||||
"redirect and set cookie" in {
|
||||
val claims = "actAs:Alice"
|
||||
val req = HttpRequest(
|
||||
uri = middlewareUri
|
||||
@ -48,9 +107,7 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr
|
||||
// Store token in cookie
|
||||
val cookie = resp.header[`Set-Cookie`].get.cookie
|
||||
assert(cookie.name == "daml-ledger-token")
|
||||
val decoder = Base64.getUrlDecoder()
|
||||
val token =
|
||||
new String(decoder.decode(cookie.value)).parseJson.convertTo[OAuthResponse.Token]
|
||||
val token = OAuthResponse.Token.fromCookieValue(cookie.value).get
|
||||
assert(token.tokenType == "bearer")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user