* refactor

1. Auth module was initialized inside dns.go - now it's moved to initWeb()

2. stopHTTPServer() wasn't called on server stop - now we do that

3. Don't use postInstall() HTTP filter where it's not necessary.
Now we register handlers after installation is complete.
This commit is contained in:
Simon Zolin 2020-02-18 19:27:09 +03:00
parent c77907694d
commit e8129f15c7
7 changed files with 75 additions and 36 deletions

View File

@ -18,6 +18,8 @@ import (
const defaultDiscoverTime = time.Second * 3 const defaultDiscoverTime = time.Second * 3
const leaseExpireStatic = 1 const leaseExpireStatic = 1
var webHandlersRegistered = false
// Lease contains the necessary information about a DHCP lease // Lease contains the necessary information about a DHCP lease
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type Lease struct { type Lease struct {
@ -121,9 +123,6 @@ func Create(config ServerConfig) *Server {
return nil return nil
} }
} }
if s.conf.HTTPRegister != nil {
s.registerHandlers()
}
// we can't delay database loading until DHCP server is started, // we can't delay database loading until DHCP server is started,
// because we need static leases functionality available beforehand // because we need static leases functionality available beforehand
@ -221,6 +220,11 @@ func (s *Server) setConfig(config ServerConfig) error {
// Start will listen on port 67 and serve DHCP requests. // Start will listen on port 67 and serve DHCP requests.
func (s *Server) Start() error { func (s *Server) Start() error {
if !webHandlersRegistered && s.conf.HTTPRegister != nil {
webHandlersRegistered = true
s.registerHandlers()
}
// TODO: don't close if interface and addresses are the same // TODO: don't close if interface and addresses are the same
if s.conn != nil { if s.conn != nil {
s.closeConn() s.closeConn()

View File

@ -24,6 +24,8 @@ const (
clientsUpdatePeriod = 1 * time.Hour clientsUpdatePeriod = 1 * time.Hour
) )
var webHandlersRegistered = false
// Client information // Client information
type Client struct { type Client struct {
IDs []string IDs []string
@ -98,13 +100,21 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.
clients.addFromConfig(objects) clients.addFromConfig(objects)
if !clients.testing { if !clients.testing {
go clients.periodicUpdate()
clients.addFromDHCP() clients.addFromDHCP()
clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged)
}
}
// Start - start the module
func (clients *clientsContainer) Start() {
if !clients.testing {
if !webHandlersRegistered {
webHandlersRegistered = true
clients.registerWebHandlers() clients.registerWebHandlers()
} }
go clients.periodicUpdate()
}
} }
// Reload - reload auto-clients // Reload - reload auto-clients

View File

@ -258,12 +258,14 @@ func preInstallHandler(handler http.Handler) http.Handler {
// it also enforces HTTPS if it is enabled and configured // it also enforces HTTPS if it is enabled and configured
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if Context.firstRun && if Context.firstRun &&
!strings.HasPrefix(r.URL.Path, "/install.") && !strings.HasPrefix(r.URL.Path, "/install.") &&
r.URL.Path != "/favicon.png" { r.URL.Path != "/favicon.png" {
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable http.Redirect(w, r, "/install.html", http.StatusFound)
return return
} }
// enforce https? // enforce https?
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil { if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
// yes, and we want host from host:port // yes, and we want host from host:port
@ -282,6 +284,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect) http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
return return
} }
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
handler(w, r) handler(w, r)
} }

View File

@ -356,6 +356,8 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
return return
} }
registerControlHandlers()
// this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block // this needs to be done in a goroutine because Shutdown() is a blocking call, and it will block
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
if restartHTTP { if restartHTTP {

View File

@ -476,7 +476,6 @@ func doUpdate(u *updateInfo) error {
func finishUpdate(u *updateInfo) { func finishUpdate(u *updateInfo) {
log.Info("Stopping all tasks") log.Info("Stopping all tasks")
cleanup() cleanup()
stopHTTPServer()
cleanupAlways() cleanupAlways()
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {

View File

@ -3,7 +3,6 @@ package home
import ( import (
"fmt" "fmt"
"net" "net"
"os"
"path/filepath" "path/filepath"
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
@ -25,13 +24,9 @@ func onConfigModified() {
// Please note that we must do it even if we don't start it // Please note that we must do it even if we don't start it
// so that we had access to the query log and the stats // so that we had access to the query log and the stats
func initDNSServer() error { func initDNSServer() error {
var err error
baseDir := Context.getDataDir() baseDir := Context.getDataDir()
err := os.MkdirAll(baseDir, 0755)
if err != nil {
return fmt.Errorf("Cannot create DNS data dir at %s: %s", baseDir, err)
}
statsConf := stats.Config{ statsConf := stats.Config{
Filename: filepath.Join(baseDir, "stats.db"), Filename: filepath.Join(baseDir, "stats.db"),
LimitDays: config.DNS.StatsInterval, LimitDays: config.DNS.StatsInterval,
@ -70,14 +65,6 @@ func initDNSServer() error {
return fmt.Errorf("dnsServer.Prepare: %s", err) return fmt.Errorf("dnsServer.Prepare: %s", err)
} }
sessFilename := filepath.Join(baseDir, "sessions.db")
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
if Context.auth == nil {
closeDNSServer()
return fmt.Errorf("Couldn't initialize Auth module")
}
config.Users = nil
Context.rdns = InitRDNS(Context.dnsServer, &Context.clients) Context.rdns = InitRDNS(Context.dnsServer, &Context.clients)
Context.whois = initWhois(&Context.clients) Context.whois = initWhois(&Context.clients)
@ -224,6 +211,8 @@ func startDNSServer() error {
enableFilters(false) enableFilters(false)
Context.clients.Start()
err := Context.dnsServer.Start() err := Context.dnsServer.Start()
if err != nil { if err != nil {
return errorx.Decorate(err, "Couldn't start forwarding DNS server") return errorx.Decorate(err, "Couldn't start forwarding DNS server")
@ -295,11 +284,6 @@ func closeDNSServer() {
Context.queryLog = nil Context.queryLog = nil
} }
if Context.auth != nil {
Context.auth.Close()
Context.auth = nil
}
Context.filters.Close() Context.filters.Close()
log.Debug("Closed all DNS modules") log.Debug("Closed all DNS modules")

View File

@ -227,6 +227,9 @@ func run(args options) {
if args.bindPort != 0 { if args.bindPort != 0 {
config.BindPort = args.bindPort config.BindPort = args.bindPort
} }
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
Context.pidFileName = args.pidFile
}
if !Context.firstRun { if !Context.firstRun {
// Save the updated config // Save the updated config
@ -234,8 +237,20 @@ func run(args options) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
}
err = initDNSServer() err := os.MkdirAll(Context.getDataDir(), 0755)
if err != nil {
log.Fatalf("Cannot create DNS data dir at %s: %s", Context.getDataDir(), err)
}
err = initWeb()
if err != nil {
log.Fatalf("%s", err)
}
if !Context.firstRun {
err := initDNSServer()
if err != nil { if err != nil {
log.Fatalf("%s", err) log.Fatalf("%s", err)
} }
@ -252,26 +267,41 @@ func run(args options) {
} }
} }
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { startWeb()
Context.pidFileName = args.pidFile
// wait indefinitely for other go-routines to complete their job
select {}
} }
// Initialize Web modules
func initWeb() error {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
Context.auth = InitAuth(sessFilename, config.Users, config.WebSessionTTLHours*60*60)
if Context.auth == nil {
return fmt.Errorf("Couldn't initialize Auth module")
}
config.Users = nil
// Initialize and run the admin Web interface // Initialize and run the admin Web interface
box := packr.NewBox("../build/static") box := packr.NewBox("../build/static")
// if not configured, redirect / to /install.html, otherwise redirect /install.html to / // if not configured, redirect / to /install.html, otherwise redirect /install.html to /
http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box))))) http.Handle("/", postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(http.FileServer(box)))))
registerControlHandlers()
// add handlers for /install paths, we only need them when we're not configured yet // add handlers for /install paths, we only need them when we're not configured yet
if Context.firstRun { if Context.firstRun {
log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ") log.Info("This is the first launch of AdGuard Home, redirecting everything to /install.html ")
http.Handle("/install.html", preInstallHandler(http.FileServer(box))) http.Handle("/install.html", preInstallHandler(http.FileServer(box)))
registerInstallHandlers() registerInstallHandlers()
} else {
registerControlHandlers()
} }
Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex) Context.httpsServer.cond = sync.NewCond(&Context.httpsServer.Mutex)
return nil
}
func startWeb() {
// for https, we have a separate goroutine loop // for https, we have a separate goroutine loop
go httpServerLoop() go httpServerLoop()
@ -291,9 +321,6 @@ func run(args options) {
} }
// We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop // We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop
} }
// wait indefinitely for other go-routines to complete their job
select {}
} }
func httpServerLoop() { func httpServerLoop() {
@ -458,6 +485,8 @@ func configureLogger(args options) {
func cleanup() { func cleanup() {
log.Info("Stopping AdGuard Home") log.Info("Stopping AdGuard Home")
stopHTTPServer()
err := stopDNSServer() err := stopDNSServer()
if err != nil { if err != nil {
log.Error("Couldn't stop DNS server: %s", err) log.Error("Couldn't stop DNS server: %s", err)
@ -475,7 +504,15 @@ func stopHTTPServer() {
if Context.httpsServer.server != nil { if Context.httpsServer.server != nil {
_ = Context.httpsServer.server.Shutdown(context.TODO()) _ = Context.httpsServer.server.Shutdown(context.TODO())
} }
if Context.httpServer != nil {
_ = Context.httpServer.Shutdown(context.TODO()) _ = Context.httpServer.Shutdown(context.TODO())
}
if Context.auth != nil {
Context.auth.Close()
Context.auth = nil
}
log.Info("Stopped HTTP server") log.Info("Stopped HTTP server")
} }