mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-15 03:02:07 +03:00
+ GET /control/profile
* openapi: get /profile * auth: store user names along with sessions
This commit is contained in:
parent
e8bb0fdcb7
commit
c185f6826a
@ -54,6 +54,7 @@ Contents:
|
||||
* Log-in page
|
||||
* API: Log in
|
||||
* API: Log out
|
||||
* API: Get current user info
|
||||
|
||||
|
||||
## Relations between subsystems
|
||||
@ -1207,7 +1208,7 @@ YAML configuration:
|
||||
|
||||
Session DB file:
|
||||
|
||||
session="..." expire=123456
|
||||
session="..." user=name expire=123456
|
||||
...
|
||||
|
||||
Session data is SHA(random()+name+password).
|
||||
@ -1270,3 +1271,20 @@ Response:
|
||||
302 Found
|
||||
Location: /login.html
|
||||
Set-Cookie: session=...; Expires=Thu, 01 Jan 1970 00:00:00 GMT
|
||||
|
||||
|
||||
### API: Get current user info
|
||||
|
||||
Request:
|
||||
|
||||
GET /control/profile
|
||||
|
||||
Response:
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
"name":"..."
|
||||
}
|
||||
|
||||
If no client is configured then authentication is disabled and server sends an empty response.
|
||||
|
118
home/auth.go
118
home/auth.go
@ -20,10 +20,44 @@ import (
|
||||
const cookieTTL = 365 * 24 // in hours
|
||||
const expireTime = 30 * 24 // in hours
|
||||
|
||||
type session struct {
|
||||
userName string
|
||||
expire uint32 // expiration time (in seconds)
|
||||
}
|
||||
|
||||
/*
|
||||
expire byte[4]
|
||||
name_len byte[2]
|
||||
name byte[]
|
||||
*/
|
||||
func (s *session) serialize() []byte {
|
||||
var data []byte
|
||||
data = make([]byte, 4+2+len(s.userName))
|
||||
binary.BigEndian.PutUint32(data[0:4], s.expire)
|
||||
binary.BigEndian.PutUint16(data[4:6], uint16(len(s.userName)))
|
||||
copy(data[6:], []byte(s.userName))
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *session) deserialize(data []byte) bool {
|
||||
if len(data) < 4+2 {
|
||||
return false
|
||||
}
|
||||
s.expire = binary.BigEndian.Uint32(data[0:4])
|
||||
nameLen := binary.BigEndian.Uint16(data[4:6])
|
||||
data = data[6:]
|
||||
|
||||
if len(data) < int(nameLen) {
|
||||
return false
|
||||
}
|
||||
s.userName = string(data)
|
||||
return true
|
||||
}
|
||||
|
||||
// Auth - global object
|
||||
type Auth struct {
|
||||
db *bbolt.DB
|
||||
sessions map[string]uint32 // session -> expiration time (in seconds)
|
||||
sessions map[string]*session // session name -> session data
|
||||
lock sync.Mutex
|
||||
users []User
|
||||
}
|
||||
@ -37,7 +71,7 @@ type User struct {
|
||||
// InitAuth - create a global object
|
||||
func InitAuth(dbFilename string, users []User) *Auth {
|
||||
a := Auth{}
|
||||
a.sessions = make(map[string]uint32)
|
||||
a.sessions = make(map[string]*session)
|
||||
rand.Seed(time.Now().UTC().Unix())
|
||||
var err error
|
||||
a.db, err = bbolt.Open(dbFilename, 0644, nil)
|
||||
@ -56,6 +90,10 @@ func (a *Auth) Close() {
|
||||
_ = a.db.Close()
|
||||
}
|
||||
|
||||
func bucketName() []byte {
|
||||
return []byte("sessions-2")
|
||||
}
|
||||
|
||||
// load sessions from file, remove expired sessions
|
||||
func (a *Auth) loadSessions() {
|
||||
tx, err := a.db.Begin(true)
|
||||
@ -67,16 +105,22 @@ func (a *Auth) loadSessions() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket([]byte("sessions"))
|
||||
bkt := tx.Bucket(bucketName())
|
||||
if bkt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
if tx.Bucket([]byte("sessions")) != nil {
|
||||
_ = tx.DeleteBucket([]byte("sessions"))
|
||||
removed = 1
|
||||
}
|
||||
|
||||
now := uint32(time.Now().UTC().Unix())
|
||||
forEach := func(k, v []byte) error {
|
||||
i := binary.BigEndian.Uint32(v)
|
||||
if i <= now {
|
||||
s := session{}
|
||||
if !s.deserialize(v) || s.expire <= now {
|
||||
err = bkt.Delete(k)
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Delete: %s", err)
|
||||
@ -85,7 +129,8 @@ func (a *Auth) loadSessions() {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
a.sessions[hex.EncodeToString(k)] = i
|
||||
|
||||
a.sessions[hex.EncodeToString(k)] = &s
|
||||
return nil
|
||||
}
|
||||
_ = bkt.ForEach(forEach)
|
||||
@ -99,11 +144,15 @@ func (a *Auth) loadSessions() {
|
||||
}
|
||||
|
||||
// store session data in file
|
||||
func (a *Auth) storeSession(data []byte, expire uint32) {
|
||||
func (a *Auth) addSession(data []byte, s *session) {
|
||||
a.lock.Lock()
|
||||
a.sessions[hex.EncodeToString(data)] = expire
|
||||
a.sessions[hex.EncodeToString(data)] = s
|
||||
a.lock.Unlock()
|
||||
a.storeSession(data, s)
|
||||
}
|
||||
|
||||
// store session data in file
|
||||
func (a *Auth) storeSession(data []byte, s *session) {
|
||||
tx, err := a.db.Begin(true)
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Begin: %s", err)
|
||||
@ -113,15 +162,12 @@ func (a *Auth) storeSession(data []byte, expire uint32) {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt, err := tx.CreateBucketIfNotExists([]byte("sessions"))
|
||||
bkt, err := tx.CreateBucketIfNotExists(bucketName())
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.CreateBucketIfNotExists: %s", err)
|
||||
return
|
||||
}
|
||||
var val []byte
|
||||
val = make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(val, expire)
|
||||
err = bkt.Put(data, val)
|
||||
err = bkt.Put(data, s.serialize())
|
||||
if err != nil {
|
||||
log.Error("Auth: bbolt.Put: %s", err)
|
||||
return
|
||||
@ -147,7 +193,7 @@ func (a *Auth) removeSession(sess []byte) {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
bkt := tx.Bucket([]byte("sessions"))
|
||||
bkt := tx.Bucket(bucketName())
|
||||
if bkt == nil {
|
||||
log.Error("Auth: bbolt.Bucket")
|
||||
return
|
||||
@ -174,12 +220,12 @@ func (a *Auth) CheckSession(sess string) int {
|
||||
update := false
|
||||
|
||||
a.lock.Lock()
|
||||
expire, ok := a.sessions[sess]
|
||||
s, ok := a.sessions[sess]
|
||||
if !ok {
|
||||
a.lock.Unlock()
|
||||
return -1
|
||||
}
|
||||
if expire <= now {
|
||||
if s.expire <= now {
|
||||
delete(a.sessions, sess)
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.removeSession(key)
|
||||
@ -188,17 +234,17 @@ func (a *Auth) CheckSession(sess string) int {
|
||||
}
|
||||
|
||||
newExpire := now + expireTime*60*60
|
||||
if expire/(24*60*60) != newExpire/(24*60*60) {
|
||||
if s.expire/(24*60*60) != newExpire/(24*60*60) {
|
||||
// update expiration time once a day
|
||||
update = true
|
||||
a.sessions[sess] = newExpire
|
||||
s.expire = newExpire
|
||||
}
|
||||
|
||||
a.lock.Unlock()
|
||||
|
||||
if update {
|
||||
key, _ := hex.DecodeString(sess)
|
||||
a.storeSession(key, expire)
|
||||
a.storeSession(key, s)
|
||||
}
|
||||
|
||||
return 0
|
||||
@ -238,8 +284,10 @@ func httpCookie(req loginJSON) string {
|
||||
expstr = expstr[:len(expstr)-len("UTC")] // "UTC" -> "GMT"
|
||||
expstr += "GMT"
|
||||
|
||||
expireSess := uint32(now.Unix()) + expireTime*60*60
|
||||
config.auth.storeSession(sess, expireSess)
|
||||
s := session{}
|
||||
s.userName = u.Name
|
||||
s.expire = uint32(now.Unix()) + expireTime*60*60
|
||||
config.auth.addSession(sess, &s)
|
||||
|
||||
return fmt.Sprintf("session=%s; Path=/; HttpOnly; Expires=%s", hex.EncodeToString(sess), expstr)
|
||||
}
|
||||
@ -402,6 +450,34 @@ func (a *Auth) UserFind(login string, password string) User {
|
||||
return User{}
|
||||
}
|
||||
|
||||
// GetCurrentUser - get the current user
|
||||
func (a *Auth) GetCurrentUser(r *http.Request) User {
|
||||
cookie, err := r.Cookie("session")
|
||||
if err != nil {
|
||||
// there's no Cookie, check Basic authentication
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if ok {
|
||||
u := config.auth.UserFind(user, pass)
|
||||
return u
|
||||
}
|
||||
}
|
||||
|
||||
a.lock.Lock()
|
||||
s, ok := a.sessions[cookie.Value]
|
||||
if !ok {
|
||||
a.lock.Unlock()
|
||||
return User{}
|
||||
}
|
||||
for _, u := range a.users {
|
||||
if u.Name == s.userName {
|
||||
a.lock.Unlock()
|
||||
return u
|
||||
}
|
||||
}
|
||||
a.lock.Unlock()
|
||||
return User{}
|
||||
}
|
||||
|
||||
// GetUsers - get users
|
||||
func (a *Auth) GetUsers() []User {
|
||||
a.lock.Lock()
|
||||
|
@ -28,6 +28,7 @@ func TestAuth(t *testing.T) {
|
||||
User{Name: "name", PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2"},
|
||||
}
|
||||
a := InitAuth(fn, nil)
|
||||
s := session{}
|
||||
|
||||
user := User{Name: "name"}
|
||||
a.UserAdd(&user, "password")
|
||||
@ -38,12 +39,16 @@ func TestAuth(t *testing.T) {
|
||||
sess := getSession(&users[0])
|
||||
sessStr := hex.EncodeToString(sess)
|
||||
|
||||
now := time.Now().UTC().Unix()
|
||||
// check expiration
|
||||
a.storeSession(sess, uint32(time.Now().UTC().Unix()))
|
||||
s.expire = uint32(now)
|
||||
a.addSession(sess, &s)
|
||||
assert.True(t, a.CheckSession(sessStr) == 1)
|
||||
|
||||
// add session with TTL = 2 sec
|
||||
a.storeSession(sess, uint32(time.Now().UTC().Unix()+2))
|
||||
s = session{}
|
||||
s.expire = uint32(now + 2)
|
||||
a.addSession(sess, &s)
|
||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
||||
|
||||
a.Close()
|
||||
@ -53,6 +58,9 @@ func TestAuth(t *testing.T) {
|
||||
|
||||
// the session is still alive
|
||||
assert.True(t, a.CheckSession(sessStr) == 0)
|
||||
// reset our expiration time because CheckSession() has just updated it
|
||||
s.expire = uint32(now + 2)
|
||||
a.storeSession(sess, &s)
|
||||
a.Close()
|
||||
|
||||
u := a.UserFind("name", "password")
|
||||
|
@ -377,6 +377,23 @@ func checkDNS(input string, bootstrap []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type profileJSON struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
pj := profileJSON{}
|
||||
u := config.auth.GetCurrentUser(r)
|
||||
pj.Name = u.Name
|
||||
|
||||
data, err := json.Marshal(pj)
|
||||
if err != nil {
|
||||
httpError(w, http.StatusInternalServerError, "json.Marshal: %s", err)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
// --------------
|
||||
// DNS-over-HTTPS
|
||||
// --------------
|
||||
@ -416,6 +433,7 @@ func registerControlHandlers() {
|
||||
|
||||
httpRegister(http.MethodGet, "/control/access/list", handleAccessList)
|
||||
httpRegister(http.MethodPost, "/control/access/set", handleAccessSet)
|
||||
httpRegister("GET", "/control/profile", handleGetProfile)
|
||||
|
||||
RegisterFilteringHandlers()
|
||||
RegisterTLSHandlers()
|
||||
|
@ -970,6 +970,18 @@ paths:
|
||||
302:
|
||||
description: OK
|
||||
|
||||
/profile:
|
||||
get:
|
||||
tags:
|
||||
- global
|
||||
operationId: getProfile
|
||||
summary: ""
|
||||
responses:
|
||||
200:
|
||||
description: OK
|
||||
schema:
|
||||
$ref: "#/definitions/ProfileInfo"
|
||||
|
||||
definitions:
|
||||
ServerStatus:
|
||||
type: "object"
|
||||
@ -1559,6 +1571,14 @@ definitions:
|
||||
description: "Network interfaces dictionary (key is the interface name)"
|
||||
additionalProperties:
|
||||
$ref: "#/definitions/NetInterface"
|
||||
|
||||
ProfileInfo:
|
||||
type: "object"
|
||||
description: "Information about the current user"
|
||||
properties:
|
||||
name:
|
||||
type: "string"
|
||||
|
||||
Client:
|
||||
type: "object"
|
||||
description: "Client information"
|
||||
|
Loading…
Reference in New Issue
Block a user