diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 47da99bc..bfbf43c6 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -41,7 +41,6 @@ type Server struct { dnsFilter *dnsfilter.Dnsfilter // DNS filter instance queryLog *queryLog // Query log instance stats *stats // General server statistics - once sync.Once AllowedClients map[string]bool // IP addresses of whitelist clients DisallowedClients map[string]bool // IP addresses of clients that should be blocked @@ -55,11 +54,24 @@ type Server struct { // NewServer creates a new instance of the dnsforward.Server // baseDir is the base directory for query logs +// Note: this function must be called only once func NewServer(baseDir string) *Server { - return &Server{ + s := &Server{ queryLog: newQueryLog(baseDir), stats: newStats(), } + + log.Tracef("Loading stats from querylog") + err := s.queryLog.fillStatsFromQueryLog(s.stats) + if err != nil { + log.Error("failed to load stats from querylog: %s", err) + } + + log.Printf("Start DNS server periodic jobs") + go s.queryLog.periodicQueryLogRotate() + go s.queryLog.runningTop.periodicHourlyTopRotate() + go s.stats.statsRotator() + return s } // FilteringConfig represents the DNS filtering configuration of AdGuard Home @@ -169,33 +181,11 @@ func (s *Server) startInternal(config *ServerConfig) error { return errors.New("DNS server is already started") } - if s.queryLog == nil { - s.queryLog = newQueryLog(".") - } - - if s.stats == nil { - s.stats = newStats() - } - err := s.initDNSFilter() if err != nil { return err } - log.Tracef("Loading stats from querylog") - err = s.queryLog.fillStatsFromQueryLog(s.stats) - if err != nil { - return errorx.Decorate(err, "failed to load stats from querylog") - } - - // TODO: Think about reworking this, the current approach won't work properly if AG Home is restarted periodically - s.once.Do(func() { - log.Printf("Start DNS server periodic jobs") - go s.queryLog.periodicQueryLogRotate() - go s.queryLog.runningTop.periodicHourlyTopRotate() - go s.stats.statsRotator() - }) - proxyConfig := proxy.Config{ UDPListenAddr: s.conf.UDPListenAddr, TCPListenAddr: s.conf.TCPListenAddr, diff --git a/home/clients.go b/home/clients.go index 4ae478be..0c003a56 100644 --- a/home/clients.go +++ b/home/clients.go @@ -66,10 +66,9 @@ type clientsContainer struct { lock sync.Mutex } -var clients clientsContainer - -// Initialize clients container -func clientsInit() { +// Init initializes clients container +// Note: this function must be called only once +func (clients *clientsContainer) Init() { if clients.list != nil { log.Fatal("clients.list != nil") } @@ -77,22 +76,24 @@ func clientsInit() { clients.ipIndex = make(map[string]*Client) clients.ipHost = make(map[string]ClientHost) - go periodicClientsUpdate() + go clients.periodicUpdate() } -func periodicClientsUpdate() { +func (clients *clientsContainer) periodicUpdate() { for { - clientsAddFromHostsFile() - clientsAddFromSystemARP() + clients.addFromHostsFile() + clients.addFromSystemARP() time.Sleep(clientsUpdatePeriod) } } -func clientsGetList() map[string]*Client { +// GetList returns the pointer to clients list +func (clients *clientsContainer) GetList() map[string]*Client { return clients.list } -func clientExists(ip string) bool { +// Exists checks if client with this IP already exists +func (clients *clientsContainer) Exists(ip string) bool { clients.lock.Lock() defer clients.lock.Unlock() @@ -105,8 +106,8 @@ func clientExists(ip string) bool { return ok } -// Search for a client by IP -func clientFind(ip string) (Client, bool) { +// Find searches for a client by IP +func (clients *clientsContainer) Find(ip string) (Client, bool) { clients.lock.Lock() defer clients.lock.Unlock() @@ -121,7 +122,7 @@ func clientFind(ip string) (Client, bool) { if err != nil { continue } - ipAddr := dhcpServer.FindIPbyMAC(mac) + ipAddr := config.dhcpServer.FindIPbyMAC(mac) if ipAddr == nil { continue } @@ -135,7 +136,7 @@ func clientFind(ip string) (Client, bool) { } // Check if Client object's fields are correct -func clientCheck(c *Client) error { +func (c *Client) check() error { if len(c.Name) == 0 { return fmt.Errorf("Invalid Name") } @@ -162,8 +163,8 @@ func clientCheck(c *Client) error { // Add a new client object // Return true: success; false: client exists. -func clientAdd(c Client) (bool, error) { - e := clientCheck(&c) +func (clients *clientsContainer) Add(c Client) (bool, error) { + e := c.check() if e != nil { return false, e } @@ -194,8 +195,8 @@ func clientAdd(c Client) (bool, error) { return true, nil } -// Remove a client -func clientDel(name string) bool { +// Del removes a client +func (clients *clientsContainer) Del(name string) bool { clients.lock.Lock() defer clients.lock.Unlock() @@ -210,8 +211,8 @@ func clientDel(name string) bool { } // Update a client -func clientUpdate(name string, c Client) error { - err := clientCheck(&c) +func (clients *clientsContainer) Update(name string, c Client) error { + err := c.check() if err != nil { return err } @@ -257,10 +258,10 @@ func clientUpdate(name string, c Client) error { return nil } -// Add new IP -> Host pair +// AddHost adds new IP -> Host pair // Use priority of the source (etc/hosts > ARP > rDNS) // so we overwrite existing entries with an equal or higher priority -func clientAddHost(ip, host string, source clientSource) (bool, error) { +func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) { clients.lock.Lock() defer clients.lock.Unlock() @@ -279,7 +280,7 @@ func clientAddHost(ip, host string, source clientSource) (bool, error) { } // Parse system 'hosts' file and fill clients array -func clientsAddFromHostsFile() { +func (clients *clientsContainer) addFromHostsFile() { hostsFn := "/etc/hosts" if runtime.GOOS == "windows" { hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts") @@ -304,7 +305,7 @@ func clientsAddFromHostsFile() { continue } - ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile) + ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile) if e != nil { log.Tracef("%s", e) } @@ -319,7 +320,7 @@ func clientsAddFromHostsFile() { // Add IP -> Host pairs from the system's `arp -a` command output // The command's output is: // HOST (IP) at MAC on IFACE -func clientsAddFromSystemARP() { +func (clients *clientsContainer) addFromSystemARP() { if runtime.GOOS == "windows" { return @@ -350,7 +351,7 @@ func clientsAddFromSystemARP() { continue } - ok, e := clientAddHost(ip, host, ClientSourceARP) + ok, e := clients.AddHost(ip, host, ClientSourceARP) if e != nil { log.Tracef("%s", e) } @@ -379,8 +380,8 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) { data := clientListJSON{} - clients.lock.Lock() - for _, c := range clients.list { + config.clients.lock.Lock() + for _, c := range config.clients.list { cj := clientJSON{ IP: c.IP, MAC: c.MAC, @@ -394,7 +395,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) { if len(c.MAC) != 0 { hwAddr, _ := net.ParseMAC(c.MAC) - ipAddr := dhcpServer.FindIPbyMAC(hwAddr) + ipAddr := config.dhcpServer.FindIPbyMAC(hwAddr) if ipAddr != nil { cj.IP = ipAddr.String() } @@ -402,7 +403,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) { data.Clients = append(data.Clients, cj) } - for ip, ch := range clients.ipHost { + for ip, ch := range config.clients.ipHost { cj := clientHostJSON{ IP: ip, Name: ch.Host, @@ -416,7 +417,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) { } data.AutoClients = append(data.AutoClients, cj) } - clients.lock.Unlock() + config.clients.lock.Unlock() w.Header().Set("Content-Type", "application/json") e := json.NewEncoder(w).Encode(data) @@ -462,7 +463,7 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) { httpError(w, http.StatusBadRequest, "%s", err) return } - ok, err := clientAdd(*c) + ok, err := config.clients.Add(*c) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return @@ -492,7 +493,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) { return } - if !clientDel(cj.Name) { + if !config.clients.Del(cj.Name) { httpError(w, http.StatusBadRequest, "Client not found") return } @@ -501,7 +502,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) { returnOK(w) } -type clientUpdateJSON struct { +type updateJSON struct { Name string `json:"name"` Data clientJSON `json:"data"` } @@ -515,7 +516,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) { return } - var dj clientUpdateJSON + var dj updateJSON err = json.Unmarshal(body, &dj) if err != nil { httpError(w, http.StatusBadRequest, "JSON parse: %s", err) @@ -532,7 +533,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) { return } - err = clientUpdate(dj.Name, *c) + err = config.clients.Update(dj.Name, *c) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return diff --git a/home/clients_test.go b/home/clients_test.go index 4b45cee2..d5dc5143 100644 --- a/home/clients_test.go +++ b/home/clients_test.go @@ -6,17 +6,18 @@ func TestClients(t *testing.T) { var c Client var e error var b bool + clients := clientsContainer{} - clientsInit() + clients.Init() // add c = Client{ IP: "1.1.1.1", Name: "client1", } - b, e = clientAdd(c) + b, e = clients.Add(c) if !b || e != nil { - t.Fatalf("clientAdd #1") + t.Fatalf("Add #1") } // add #2 @@ -24,19 +25,19 @@ func TestClients(t *testing.T) { IP: "2.2.2.2", Name: "client2", } - b, e = clientAdd(c) + b, e = clients.Add(c) if !b || e != nil { - t.Fatalf("clientAdd #2") + t.Fatalf("Add #2") } - c, b = clientFind("1.1.1.1") + c, b = clients.Find("1.1.1.1") if !b || c.Name != "client1" { - t.Fatalf("clientFind #1") + t.Fatalf("Find #1") } - c, b = clientFind("2.2.2.2") + c, b = clients.Find("2.2.2.2") if !b || c.Name != "client2" { - t.Fatalf("clientFind #2") + t.Fatalf("Find #2") } // failed add - name in use @@ -44,9 +45,9 @@ func TestClients(t *testing.T) { IP: "1.2.3.5", Name: "client1", } - b, _ = clientAdd(c) + b, _ = clients.Add(c) if b { - t.Fatalf("clientAdd - name in use") + t.Fatalf("Add - name in use") } // failed add - ip in use @@ -54,91 +55,91 @@ func TestClients(t *testing.T) { IP: "2.2.2.2", Name: "client3", } - b, e = clientAdd(c) + b, e = clients.Add(c) if b || e == nil { - t.Fatalf("clientAdd - ip in use") + t.Fatalf("Add - ip in use") } // get - if clientExists("1.2.3.4") { - t.Fatalf("clientExists") + if clients.Exists("1.2.3.4") { + t.Fatalf("Exists") } - if !clientExists("1.1.1.1") { - t.Fatalf("clientExists #1") + if !clients.Exists("1.1.1.1") { + t.Fatalf("Exists #1") } - if !clientExists("2.2.2.2") { - t.Fatalf("clientExists #2") + if !clients.Exists("2.2.2.2") { + t.Fatalf("Exists #2") } // failed update - no such name c.IP = "1.2.3.0" c.Name = "client3" - if clientUpdate("client3", c) == nil { - t.Fatalf("clientUpdate") + if clients.Update("client3", c) == nil { + t.Fatalf("Update") } // failed update - name in use c.IP = "1.2.3.0" c.Name = "client2" - if clientUpdate("client1", c) == nil { - t.Fatalf("clientUpdate - name in use") + if clients.Update("client1", c) == nil { + t.Fatalf("Update - name in use") } // failed update - ip in use c.IP = "2.2.2.2" c.Name = "client1" - if clientUpdate("client1", c) == nil { - t.Fatalf("clientUpdate - ip in use") + if clients.Update("client1", c) == nil { + t.Fatalf("Update - ip in use") } // update c.IP = "1.1.1.2" c.Name = "client1" - if clientUpdate("client1", c) != nil { - t.Fatalf("clientUpdate") + if clients.Update("client1", c) != nil { + t.Fatalf("Update") } // get after update - if clientExists("1.1.1.1") || !clientExists("1.1.1.2") { - t.Fatalf("clientExists - get after update") + if clients.Exists("1.1.1.1") || !clients.Exists("1.1.1.2") { + t.Fatalf("Exists - get after update") } // failed remove - no such name - if clientDel("client3") { - t.Fatalf("clientDel - no such name") + if clients.Del("client3") { + t.Fatalf("Del - no such name") } // remove - if !clientDel("client1") || clientExists("1.1.1.2") { - t.Fatalf("clientDel") + if !clients.Del("client1") || clients.Exists("1.1.1.2") { + t.Fatalf("Del") } // add host client - b, e = clientAddHost("1.1.1.1", "host", ClientSourceARP) + b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP) if !b || e != nil { t.Fatalf("clientAddHost") } // failed add - ip exists - b, e = clientAddHost("1.1.1.1", "host1", ClientSourceRDNS) + b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) if b || e != nil { t.Fatalf("clientAddHost - ip exists") } // overwrite with new data - b, e = clientAddHost("1.1.1.1", "host2", ClientSourceARP) + b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) if !b || e != nil { t.Fatalf("clientAddHost - overwrite with new data") } // overwrite with new data (higher priority) - b, e = clientAddHost("1.1.1.1", "host3", ClientSourceHostsFile) + b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) if !b || e != nil { t.Fatalf("clientAddHost - overwrite with new data (higher priority)") } // get - if !clientExists("1.1.1.1") { + if !clients.Exists("1.1.1.1") { t.Fatalf("clientAddHost") } } diff --git a/home/config.go b/home/config.go index 9a294376..079a9298 100644 --- a/home/config.go +++ b/home/config.go @@ -2,6 +2,7 @@ package home import ( "io/ioutil" + "net/http" "os" "path/filepath" "runtime" @@ -38,6 +39,13 @@ type clientObject struct { SafeBrowsingEnabled bool `yaml:"safesearch_enabled"` } +type HTTPSServer struct { + server *http.Server + cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey + sync.Mutex // protects config.TLS + shutdown bool // if TRUE, don't restart the server +} + // configuration is loaded from YAML // field ordering is important -- yaml fields will mirror ordering from here type configuration struct { @@ -48,10 +56,25 @@ type configuration struct { ourConfigFilename string // Config filename (can be overridden via the command line arguments) ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html + pidFileName string // PID file name. Empty if no PID file was created. // runningAsService flag is set to true when options are passed from the service runner runningAsService bool disableUpdate bool // If set, don't check for updates appSignalChannel chan os.Signal + clients clientsContainer + controlLock sync.Mutex + transport *http.Transport + client *http.Client + + // cached version.json to avoid hammering github.io for each page reload + versionCheckJSON []byte + versionCheckLastTime time.Time + + dnsctx dnsContext + dnsServer *dnsforward.Server + dhcpServer dhcpd.Server + httpServer *http.Server + httpsServer HTTPSServer BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server @@ -127,7 +150,6 @@ type tlsConfig struct { } // initialize to default values, will be changed later when reading config or parsing command line -// TODO: Get rid of global variables var config = configuration{ ourConfigFilename: "AdGuardHome.yaml", BindPort: 3000, @@ -167,8 +189,16 @@ var config = configuration{ SchemaVersion: currentSchemaVersion, } -// init initializes default configuration for the current OS&ARCH -func init() { +// initConfig initializes default configuration for the current OS&ARCH +func initConfig() { + config.transport = &http.Transport{ + DialContext: customDialContext, + } + config.client = &http.Client{ + Timeout: time.Minute * 5, + Transport: config.transport, + } + if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" { // Use plain DNS on MIPS, encryption is too slow defaultDNS = []string{"1.1.1.1", "1.0.0.1"} @@ -233,7 +263,7 @@ func parseConfig() error { SafeSearchEnabled: cy.SafeSearchEnabled, SafeBrowsingEnabled: cy.SafeBrowsingEnabled, } - _, err = clientAdd(cli) + _, err = config.clients.Add(cli) if err != nil { log.Tracef("clientAdd: %s", err) } @@ -268,7 +298,7 @@ func (c *configuration) write() error { c.Lock() defer c.Unlock() - clientsList := clientsGetList() + clientsList := config.clients.GetList() for _, cli := range clientsList { ip := cli.IP if len(cli.MAC) != 0 { diff --git a/home/control.go b/home/control.go index c17771d0..b012f9e9 100644 --- a/home/control.go +++ b/home/control.go @@ -11,7 +11,6 @@ import ( "sort" "strconv" "strings" - "sync" "time" "github.com/AdguardTeam/AdGuardHome/dnsforward" @@ -25,23 +24,8 @@ import ( const updatePeriod = time.Minute * 30 -// cached version.json to avoid hammering github.io for each page reload -var versionCheckJSON []byte -var versionCheckLastTime time.Time - var protocols = []string{"tls://", "https://", "tcp://", "sdns://"} -var transport = &http.Transport{ - DialContext: customDialContext, -} - -var client = &http.Client{ - Timeout: time.Minute * 5, - Transport: transport, -} - -var controlLock sync.Mutex - // ---------------- // helper functions // ---------------- @@ -188,7 +172,7 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) { func handleQueryLog(w http.ResponseWriter, r *http.Request) { log.Tracef("%s %v", r.Method, r.URL) - data := dnsServer.GetQueryLog() + data := config.dnsServer.GetQueryLog() jsonVal, err := json.Marshal(data) if err != nil { @@ -205,7 +189,7 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) { func handleStatsTop(w http.ResponseWriter, r *http.Request) { log.Tracef("%s %v", r.Method, r.URL) - s := dnsServer.GetStatsTop() + s := config.dnsServer.GetStatsTop() // use manual json marshalling because we want maps to be sorted by value statsJSON := bytes.Buffer{} @@ -252,7 +236,7 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) { // handleStatsReset resets the stats caches func handleStatsReset(w http.ResponseWriter, r *http.Request) { log.Tracef("%s %v", r.Method, r.URL) - dnsServer.PurgeStats() + config.dnsServer.PurgeStats() _, err := fmt.Fprintf(w, "OK\n") if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) @@ -262,7 +246,7 @@ func handleStatsReset(w http.ResponseWriter, r *http.Request) { // handleStats returns aggregated stats data for the 24 hours func handleStats(w http.ResponseWriter, r *http.Request) { log.Tracef("%s %v", r.Method, r.URL) - summed := dnsServer.GetAggregatedStats() + summed := config.dnsServer.GetAggregatedStats() statsJSON, err := json.Marshal(summed) if err != nil { @@ -309,7 +293,7 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) { return } - data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime) + data, err := config.dnsServer.GetStatsHistory(timeUnit, startTime, endTime) if err != nil { httpError(w, http.StatusBadRequest, "Cannot get stats history: %s", err) return @@ -725,7 +709,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { // Stop DNS server: // we close urlfilter object which in turn closes file descriptors to filter files. // Otherwise, Windows won't allow us to remove the file which is being currently used. - _ = dnsServer.Stop() + _ = config.dnsServer.Stop() // go through each element and delete if url matches config.Lock() @@ -984,7 +968,7 @@ func handleDOH(w http.ResponseWriter, r *http.Request) { return } - dnsServer.ServeHTTP(w, r) + config.dnsServer.ServeHTTP(w, r) } // ------------------------ diff --git a/home/control_access.go b/home/control_access.go index 0f0cb8bd..f33df8b4 100644 --- a/home/control_access.go +++ b/home/control_access.go @@ -17,13 +17,13 @@ type accessListJSON struct { func handleAccessList(w http.ResponseWriter, r *http.Request) { log.Tracef("%s %v", r.Method, r.URL) - controlLock.Lock() + config.controlLock.Lock() j := accessListJSON{ AllowedClients: config.DNS.AllowedClients, DisallowedClients: config.DNS.DisallowedClients, BlockedHosts: config.DNS.BlockedHosts, } - controlLock.Unlock() + config.controlLock.Unlock() w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(j) diff --git a/home/control_install.go b/home/control_install.go index f5758b99..30c48cca 100644 --- a/home/control_install.go +++ b/home/control_install.go @@ -264,7 +264,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) { // until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely if restartHTTP { go func() { - httpServer.Shutdown(context.TODO()) + config.httpServer.Shutdown(context.TODO()) }() } diff --git a/home/control_tls.go b/home/control_tls.go index cdf62843..ea83dba1 100644 --- a/home/control_tls.go +++ b/home/control_tls.go @@ -46,7 +46,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) { // check if port is available // BUT: if we are already using this port, no need alreadyRunning := false - if httpsServer.server != nil { + if config.httpsServer.server != nil { alreadyRunning = true } if !alreadyRunning { @@ -72,7 +72,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { // check if port is available // BUT: if we are already using this port, no need alreadyRunning := false - if httpsServer.server != nil { + if config.httpsServer.server != nil { alreadyRunning = true } if !alreadyRunning { @@ -101,12 +101,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) { if restartHTTPS { go func() { time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server - httpsServer.cond.L.Lock() - httpsServer.cond.Broadcast() - if httpsServer.server != nil { - httpsServer.server.Shutdown(context.TODO()) + config.httpsServer.cond.L.Lock() + config.httpsServer.cond.Broadcast() + if config.httpsServer.server != nil { + config.httpsServer.server.Shutdown(context.TODO()) } - httpsServer.cond.L.Unlock() + config.httpsServer.cond.L.Unlock() }() } } diff --git a/home/control_update.go b/home/control_update.go index 9b20ca62..3e02d71d 100644 --- a/home/control_update.go +++ b/home/control_update.go @@ -73,10 +73,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { now := time.Now() if !req.RecheckNow { - controlLock.Lock() - cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 - data := versionCheckJSON - controlLock.Unlock() + config.controlLock.Lock() + cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0 + data := config.versionCheckJSON + config.controlLock.Unlock() if cached { log.Tracef("Returning cached data") @@ -87,7 +87,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { } log.Tracef("Downloading data from %s", versionCheckURL) - resp, err := client.Get(versionCheckURL) + resp, err := config.client.Get(versionCheckURL) if err != nil { httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) return @@ -103,10 +103,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { return } - controlLock.Lock() - versionCheckLastTime = now - versionCheckJSON = body - controlLock.Unlock() + config.controlLock.Lock() + config.versionCheckLastTime = now + config.versionCheckJSON = body + config.controlLock.Unlock() w.Header().Set("Content-Type", "application/json") _, err = w.Write(getVersionResp(body)) @@ -349,7 +349,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, // Download package file and save it to disk func getPackageFile(u *updateInfo) error { - resp, err := client.Get(u.pkgURL) + resp, err := config.client.Get(u.pkgURL) if err != nil { return fmt.Errorf("HTTP request failed: %s", err) } @@ -501,12 +501,12 @@ func finishUpdate(u *updateInfo) { func handleUpdate(w http.ResponseWriter, r *http.Request) { log.Tracef("%s %v", r.Method, r.URL) - if len(versionCheckJSON) == 0 { + if len(config.versionCheckJSON) == 0 { httpError(w, http.StatusBadRequest, "/update request isn't allowed now") return } - u, err := getUpdateInfo(versionCheckJSON) + u, err := getUpdateInfo(config.versionCheckJSON) if err != nil { httpError(w, http.StatusInternalServerError, "%s", err) return diff --git a/home/dhcp.go b/home/dhcp.go index 066e6843..078647f0 100644 --- a/home/dhcp.go +++ b/home/dhcp.go @@ -18,8 +18,6 @@ import ( "github.com/joomcode/errorx" ) -var dhcpServer = dhcpd.Server{} - // []dhcpd.Lease -> JSON func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string { leases := []map[string]string{} @@ -41,8 +39,8 @@ func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string] func handleDHCPStatus(w http.ResponseWriter, r *http.Request) { log.Tracef("%s %v", r.Method, r.URL) - leases := convertLeases(dhcpServer.Leases(), true) - staticLeases := convertLeases(dhcpServer.StaticLeases(), false) + leases := convertLeases(config.dhcpServer.Leases(), true) + staticLeases := convertLeases(config.dhcpServer.StaticLeases(), false) status := map[string]interface{}{ "config": config.DHCP, "leases": leases, @@ -77,18 +75,18 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { return } - err = dhcpServer.CheckConfig(newconfig.ServerConfig) + err = config.dhcpServer.CheckConfig(newconfig.ServerConfig) if err != nil { httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err) return } - err = dhcpServer.Stop() + err = config.dhcpServer.Stop() if err != nil { log.Error("failed to stop the DHCP server: %s", err) } - err = dhcpServer.Init(newconfig.ServerConfig) + err = config.dhcpServer.Init(newconfig.ServerConfig) if err != nil { httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err) return @@ -105,7 +103,7 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { } } - err = dhcpServer.Start() + err = config.dhcpServer.Start() if err != nil { httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err) return @@ -389,7 +387,7 @@ func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) { HWAddr: mac, Hostname: lj.Hostname, } - err = dhcpServer.AddStaticLease(lease) + err = config.dhcpServer.AddStaticLease(lease) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return @@ -420,7 +418,7 @@ func handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) { HWAddr: mac, Hostname: lj.Hostname, } - err = dhcpServer.RemoveStaticLease(lease) + err = config.dhcpServer.RemoveStaticLease(lease) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return @@ -434,12 +432,12 @@ func startDHCPServer() error { return nil } - err := dhcpServer.Init(config.DHCP) + err := config.dhcpServer.Init(config.DHCP) if err != nil { return errorx.Decorate(err, "Couldn't init DHCP server") } - err = dhcpServer.Start() + err = config.dhcpServer.Start() if err != nil { return errorx.Decorate(err, "Couldn't start DHCP server") } @@ -451,7 +449,7 @@ func stopDHCPServer() error { return nil } - err := dhcpServer.Stop() + err := config.dhcpServer.Stop() if err != nil { return errorx.Decorate(err, "Couldn't stop DHCP server") } diff --git a/home/dns.go b/home/dns.go index 52f57202..b195bd40 100644 --- a/home/dns.go +++ b/home/dns.go @@ -17,8 +17,6 @@ import ( "github.com/miekg/dns" ) -var dnsServer *dnsforward.Server - const ( rdnsTimeout = 3 * time.Second // max time to wait for rDNS response ) @@ -32,8 +30,6 @@ type dnsContext struct { upstream upstream.Upstream // Upstream object for our own DNS server } -var dnsctx dnsContext - // initDNSServer creates an instance of the dnsforward.Server // Please note that we must do it even if we don't start it // so that we had access to the query log and the stats @@ -43,7 +39,7 @@ func initDNSServer(baseDir string) { log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err) } - dnsServer = dnsforward.NewServer(baseDir) + config.dnsServer = dnsforward.NewServer(baseDir) bindhost := config.DNS.BindHost if config.DNS.BindHost == "0.0.0.0" { @@ -53,37 +49,37 @@ func initDNSServer(baseDir string) { opts := upstream.Options{ Timeout: rdnsTimeout, } - dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) + config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) if err != nil { log.Error("upstream.AddressToUpstream: %s", err) return } - dnsctx.rdnsIP = make(map[string]bool) - dnsctx.rdnsChannel = make(chan string, 256) + config.dnsctx.rdnsIP = make(map[string]bool) + config.dnsctx.rdnsChannel = make(chan string, 256) go asyncRDNSLoop() } func isRunning() bool { - return dnsServer != nil && dnsServer.IsRunning() + return config.dnsServer != nil && config.dnsServer.IsRunning() } func beginAsyncRDNS(ip string) { - if clientExists(ip) { + if config.clients.Exists(ip) { return } // add IP to rdnsIP, if not exists - dnsctx.rdnsLock.Lock() - defer dnsctx.rdnsLock.Unlock() - _, ok := dnsctx.rdnsIP[ip] + config.dnsctx.rdnsLock.Lock() + defer config.dnsctx.rdnsLock.Unlock() + _, ok := config.dnsctx.rdnsIP[ip] if ok { return } - dnsctx.rdnsIP[ip] = true + config.dnsctx.rdnsIP[ip] = true log.Tracef("Adding %s for rDNS resolve", ip) select { - case dnsctx.rdnsChannel <- ip: + case config.dnsctx.rdnsChannel <- ip: // default: log.Tracef("rDNS queue is full") @@ -110,7 +106,7 @@ func resolveRDNS(ip string) string { return "" } - resp, err := dnsctx.upstream.Exchange(&req) + resp, err := config.dnsctx.upstream.Exchange(&req) if err != nil { log.Error("Error while making an rDNS lookup for %s: %s", ip, err) return "" @@ -138,18 +134,18 @@ func resolveRDNS(ip string) string { func asyncRDNSLoop() { for { var ip string - ip = <-dnsctx.rdnsChannel + ip = <-config.dnsctx.rdnsChannel host := resolveRDNS(ip) if len(host) == 0 { continue } - dnsctx.rdnsLock.Lock() - delete(dnsctx.rdnsIP, ip) - dnsctx.rdnsLock.Unlock() + config.dnsctx.rdnsLock.Lock() + delete(config.dnsctx.rdnsIP, ip) + config.dnsctx.rdnsLock.Unlock() - _, _ = clientAddHost(ip, host, ClientSourceRDNS) + _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS) } } @@ -221,7 +217,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { // If a client has his own settings, apply them func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { - c, ok := clientFind(clientAddr) + c, ok := config.clients.Find(clientAddr) if !ok || !c.UseOwnSettings { return } @@ -242,12 +238,12 @@ func startDNSServer() error { if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } - err = dnsServer.Start(&newconfig) + err = config.dnsServer.Start(&newconfig) if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } - top := dnsServer.GetStatsTop() + top := config.dnsServer.GetStatsTop() for k := range top.Clients { beginAsyncRDNS(k) } @@ -256,11 +252,11 @@ func startDNSServer() error { } func reconfigureDNSServer() error { - config, err := generateServerConfig() + newconfig, err := generateServerConfig() if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } - err = dnsServer.Reconfigure(&config) + err = config.dnsServer.Reconfigure(&newconfig) if err != nil { return errorx.Decorate(err, "Couldn't start forwarding DNS server") } @@ -273,7 +269,7 @@ func stopDNSServer() error { return nil } - err := dnsServer.Stop() + err := config.dnsServer.Stop() if err != nil { return errorx.Decorate(err, "Couldn't stop forwarding DNS server") } diff --git a/home/filter.go b/home/filter.go index 1cd39935..b03eb311 100644 --- a/home/filter.go +++ b/home/filter.go @@ -222,7 +222,7 @@ func refreshFiltersIfNecessary(force bool) int { stopped := false if updateCount != 0 { - _ = dnsServer.Stop() + _ = config.dnsServer.Stop() stopped = true } @@ -308,7 +308,7 @@ func parseFilterContents(contents []byte) (int, string) { func (filter *filter) update() (bool, error) { log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL) - resp, err := client.Get(filter.URL) + resp, err := config.client.Get(filter.URL) if resp != nil && resp.Body != nil { defer resp.Body.Close() } diff --git a/home/helpers.go b/home/helpers.go index e1500311..7ff08004 100644 --- a/home/helpers.go +++ b/home/helpers.go @@ -35,8 +35,8 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun } if method == "POST" || method == "PUT" || method == "DELETE" { - controlLock.Lock() - defer controlLock.Unlock() + config.controlLock.Lock() + defer config.controlLock.Unlock() } handler(w, r) @@ -148,7 +148,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res return } // enforce https? - if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && httpsServer.server != nil { + if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && config.httpsServer.server != nil { // yes, and we want host from host:port host, _, err := net.SplitHostPort(r.Host) if err != nil { diff --git a/home/home.go b/home/home.go index 768bd17a..4d711559 100644 --- a/home/home.go +++ b/home/home.go @@ -25,15 +25,6 @@ import ( "github.com/gobuffalo/packr" ) -var httpServer *http.Server -var httpsServer struct { - server *http.Server - cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey - sync.Mutex // protects config.TLS - shutdown bool // if TRUE, don't restart the server -} -var pidFileName string // PID file name. Empty if no PID file was created. - const ( // Used in config to indicate that syslog or eventlog (win) should be used for logger output configSyslog = "syslog" @@ -48,7 +39,7 @@ var ( const versionCheckPeriod = time.Hour * 8 -// main is the entry point +// Main is the entry point func Main(version string, channel string) { // Init update-related global variables versionString = version @@ -108,7 +99,8 @@ func run(args options) { os.Exit(0) }() - clientsInit() + initConfig() + config.clients.Init() if !config.firstRun { // Do the upgrade if necessary @@ -168,7 +160,7 @@ func run(args options) { } if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { - pidFileName = args.pidFile + config.pidFileName = args.pidFile } // Update filters we've just loaded right away, don't wait for periodic update timer @@ -192,21 +184,21 @@ func run(args options) { registerInstallHandlers() } - httpsServer.cond = sync.NewCond(&httpsServer.Mutex) + config.httpsServer.cond = sync.NewCond(&config.httpsServer.Mutex) // for https, we have a separate goroutine loop go httpServerLoop() // this loop is used as an ability to change listening host and/or port - for !httpsServer.shutdown { + for !config.httpsServer.shutdown { printHTTPAddresses("http") // we need to have new instance, because after Shutdown() the Server is not usable address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) - httpServer = &http.Server{ + config.httpServer = &http.Server{ Addr: address, } - err := httpServer.ListenAndServe() + err := config.httpServer.ListenAndServe() if err != http.ErrServerClosed { cleanupAlways() log.Fatal(err) @@ -219,14 +211,14 @@ func run(args options) { } func httpServerLoop() { - for !httpsServer.shutdown { - httpsServer.cond.L.Lock() + for !config.httpsServer.shutdown { + config.httpsServer.cond.L.Lock() // this mechanism doesn't let us through until all conditions are met for config.TLS.Enabled == false || config.TLS.PortHTTPS == 0 || config.TLS.PrivateKey == "" || config.TLS.CertificateChain == "" { // sleep until necessary data is supplied - httpsServer.cond.Wait() + config.httpsServer.cond.Wait() } address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS)) // validate current TLS config and update warnings (it could have been loaded from file) @@ -250,10 +242,10 @@ func httpServerLoop() { cleanupAlways() log.Fatal(err) } - httpsServer.cond.L.Unlock() + config.httpsServer.cond.L.Unlock() // prepare HTTPS server - httpsServer.server = &http.Server{ + config.httpsServer.server = &http.Server{ Addr: address, TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, @@ -262,7 +254,7 @@ func httpServerLoop() { } printHTTPAddresses("https") - err = httpsServer.server.ListenAndServeTLS("", "") + err = config.httpsServer.server.ListenAndServeTLS("", "") if err != http.ErrServerClosed { cleanupAlways() log.Fatal(err) @@ -399,17 +391,17 @@ func cleanup() { // Stop HTTP server, possibly waiting for all active connections to be closed func stopHTTPServer() { - httpsServer.shutdown = true - if httpsServer.server != nil { - httpsServer.server.Shutdown(context.TODO()) + config.httpsServer.shutdown = true + if config.httpsServer.server != nil { + config.httpsServer.server.Shutdown(context.TODO()) } - httpServer.Shutdown(context.TODO()) + config.httpServer.Shutdown(context.TODO()) } // This function is called before application exits func cleanupAlways() { - if len(pidFileName) != 0 { - os.Remove(pidFileName) + if len(config.pidFileName) != 0 { + os.Remove(config.pidFileName) } log.Info("Stopped") }