mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-17 05:01:35 +03:00
Merge: Refactoring: move global variables; move initialization of periodic tasks
Close #583 * commit 'b8444ff46aff5e45194c6cb61fdf4d3e7aa798fa': * minor * dnsforward: move initialization of periodic tasks to NewServer() * move "dnsctx" to "config" * move "dnsServer" to "config" * move "dhcpServer" to "config" * move "httpServer" to "config" * move "httpsServer" to "config" * move "pidFileName" to "config" * move "versionCheckJSON" to "config" * move "client", "transport" to "config" * move "controlLock" mutex to "config" * clients: move container object to "config"
This commit is contained in:
commit
bd5162ada3
@ -41,7 +41,6 @@ type Server struct {
|
|||||||
dnsFilter *dnsfilter.Dnsfilter // DNS filter instance
|
dnsFilter *dnsfilter.Dnsfilter // DNS filter instance
|
||||||
queryLog *queryLog // Query log instance
|
queryLog *queryLog // Query log instance
|
||||||
stats *stats // General server statistics
|
stats *stats // General server statistics
|
||||||
once sync.Once
|
|
||||||
|
|
||||||
AllowedClients map[string]bool // IP addresses of whitelist clients
|
AllowedClients map[string]bool // IP addresses of whitelist clients
|
||||||
DisallowedClients map[string]bool // IP addresses of clients that should be blocked
|
DisallowedClients map[string]bool // IP addresses of clients that should be blocked
|
||||||
@ -55,11 +54,24 @@ type Server struct {
|
|||||||
|
|
||||||
// NewServer creates a new instance of the dnsforward.Server
|
// NewServer creates a new instance of the dnsforward.Server
|
||||||
// baseDir is the base directory for query logs
|
// baseDir is the base directory for query logs
|
||||||
|
// Note: this function must be called only once
|
||||||
func NewServer(baseDir string) *Server {
|
func NewServer(baseDir string) *Server {
|
||||||
return &Server{
|
s := &Server{
|
||||||
queryLog: newQueryLog(baseDir),
|
queryLog: newQueryLog(baseDir),
|
||||||
stats: newStats(),
|
stats: newStats(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Tracef("Loading stats from querylog")
|
||||||
|
err := s.queryLog.fillStatsFromQueryLog(s.stats)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to load stats from querylog: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Start DNS server periodic jobs")
|
||||||
|
go s.queryLog.periodicQueryLogRotate()
|
||||||
|
go s.queryLog.runningTop.periodicHourlyTopRotate()
|
||||||
|
go s.stats.statsRotator()
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
// FilteringConfig represents the DNS filtering configuration of AdGuard Home
|
||||||
@ -169,33 +181,11 @@ func (s *Server) startInternal(config *ServerConfig) error {
|
|||||||
return errors.New("DNS server is already started")
|
return errors.New("DNS server is already started")
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.queryLog == nil {
|
|
||||||
s.queryLog = newQueryLog(".")
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.stats == nil {
|
|
||||||
s.stats = newStats()
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.initDNSFilter()
|
err := s.initDNSFilter()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Loading stats from querylog")
|
|
||||||
err = s.queryLog.fillStatsFromQueryLog(s.stats)
|
|
||||||
if err != nil {
|
|
||||||
return errorx.Decorate(err, "failed to load stats from querylog")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Think about reworking this, the current approach won't work properly if AG Home is restarted periodically
|
|
||||||
s.once.Do(func() {
|
|
||||||
log.Printf("Start DNS server periodic jobs")
|
|
||||||
go s.queryLog.periodicQueryLogRotate()
|
|
||||||
go s.queryLog.runningTop.periodicHourlyTopRotate()
|
|
||||||
go s.stats.statsRotator()
|
|
||||||
})
|
|
||||||
|
|
||||||
proxyConfig := proxy.Config{
|
proxyConfig := proxy.Config{
|
||||||
UDPListenAddr: s.conf.UDPListenAddr,
|
UDPListenAddr: s.conf.UDPListenAddr,
|
||||||
TCPListenAddr: s.conf.TCPListenAddr,
|
TCPListenAddr: s.conf.TCPListenAddr,
|
||||||
|
@ -66,10 +66,9 @@ type clientsContainer struct {
|
|||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
var clients clientsContainer
|
// Init initializes clients container
|
||||||
|
// Note: this function must be called only once
|
||||||
// Initialize clients container
|
func (clients *clientsContainer) Init() {
|
||||||
func clientsInit() {
|
|
||||||
if clients.list != nil {
|
if clients.list != nil {
|
||||||
log.Fatal("clients.list != nil")
|
log.Fatal("clients.list != nil")
|
||||||
}
|
}
|
||||||
@ -77,22 +76,24 @@ func clientsInit() {
|
|||||||
clients.ipIndex = make(map[string]*Client)
|
clients.ipIndex = make(map[string]*Client)
|
||||||
clients.ipHost = make(map[string]ClientHost)
|
clients.ipHost = make(map[string]ClientHost)
|
||||||
|
|
||||||
go periodicClientsUpdate()
|
go clients.periodicUpdate()
|
||||||
}
|
}
|
||||||
|
|
||||||
func periodicClientsUpdate() {
|
func (clients *clientsContainer) periodicUpdate() {
|
||||||
for {
|
for {
|
||||||
clientsAddFromHostsFile()
|
clients.addFromHostsFile()
|
||||||
clientsAddFromSystemARP()
|
clients.addFromSystemARP()
|
||||||
time.Sleep(clientsUpdatePeriod)
|
time.Sleep(clientsUpdatePeriod)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func clientsGetList() map[string]*Client {
|
// GetList returns the pointer to clients list
|
||||||
|
func (clients *clientsContainer) GetList() map[string]*Client {
|
||||||
return clients.list
|
return clients.list
|
||||||
}
|
}
|
||||||
|
|
||||||
func clientExists(ip string) bool {
|
// Exists checks if client with this IP already exists
|
||||||
|
func (clients *clientsContainer) Exists(ip string) bool {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
@ -105,8 +106,8 @@ func clientExists(ip string) bool {
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search for a client by IP
|
// Find searches for a client by IP
|
||||||
func clientFind(ip string) (Client, bool) {
|
func (clients *clientsContainer) Find(ip string) (Client, bool) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
@ -121,7 +122,7 @@ func clientFind(ip string) (Client, bool) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ipAddr := dhcpServer.FindIPbyMAC(mac)
|
ipAddr := config.dhcpServer.FindIPbyMAC(mac)
|
||||||
if ipAddr == nil {
|
if ipAddr == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -135,7 +136,7 @@ func clientFind(ip string) (Client, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if Client object's fields are correct
|
// Check if Client object's fields are correct
|
||||||
func clientCheck(c *Client) error {
|
func (c *Client) check() error {
|
||||||
if len(c.Name) == 0 {
|
if len(c.Name) == 0 {
|
||||||
return fmt.Errorf("Invalid Name")
|
return fmt.Errorf("Invalid Name")
|
||||||
}
|
}
|
||||||
@ -162,8 +163,8 @@ func clientCheck(c *Client) error {
|
|||||||
|
|
||||||
// Add a new client object
|
// Add a new client object
|
||||||
// Return true: success; false: client exists.
|
// Return true: success; false: client exists.
|
||||||
func clientAdd(c Client) (bool, error) {
|
func (clients *clientsContainer) Add(c Client) (bool, error) {
|
||||||
e := clientCheck(&c)
|
e := c.check()
|
||||||
if e != nil {
|
if e != nil {
|
||||||
return false, e
|
return false, e
|
||||||
}
|
}
|
||||||
@ -194,8 +195,8 @@ func clientAdd(c Client) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove a client
|
// Del removes a client
|
||||||
func clientDel(name string) bool {
|
func (clients *clientsContainer) Del(name string) bool {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
@ -210,8 +211,8 @@ func clientDel(name string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update a client
|
// Update a client
|
||||||
func clientUpdate(name string, c Client) error {
|
func (clients *clientsContainer) Update(name string, c Client) error {
|
||||||
err := clientCheck(&c)
|
err := c.check()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -257,10 +258,10 @@ func clientUpdate(name string, c Client) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new IP -> Host pair
|
// AddHost adds new IP -> Host pair
|
||||||
// Use priority of the source (etc/hosts > ARP > rDNS)
|
// Use priority of the source (etc/hosts > ARP > rDNS)
|
||||||
// so we overwrite existing entries with an equal or higher priority
|
// so we overwrite existing entries with an equal or higher priority
|
||||||
func clientAddHost(ip, host string, source clientSource) (bool, error) {
|
func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
@ -279,7 +280,7 @@ func clientAddHost(ip, host string, source clientSource) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse system 'hosts' file and fill clients array
|
// Parse system 'hosts' file and fill clients array
|
||||||
func clientsAddFromHostsFile() {
|
func (clients *clientsContainer) addFromHostsFile() {
|
||||||
hostsFn := "/etc/hosts"
|
hostsFn := "/etc/hosts"
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
|
hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts")
|
||||||
@ -304,7 +305,7 @@ func clientsAddFromHostsFile() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, e := clientAddHost(fields[0], fields[1], ClientSourceHostsFile)
|
ok, e := clients.AddHost(fields[0], fields[1], ClientSourceHostsFile)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
log.Tracef("%s", e)
|
log.Tracef("%s", e)
|
||||||
}
|
}
|
||||||
@ -319,7 +320,7 @@ func clientsAddFromHostsFile() {
|
|||||||
// Add IP -> Host pairs from the system's `arp -a` command output
|
// Add IP -> Host pairs from the system's `arp -a` command output
|
||||||
// The command's output is:
|
// The command's output is:
|
||||||
// HOST (IP) at MAC on IFACE
|
// HOST (IP) at MAC on IFACE
|
||||||
func clientsAddFromSystemARP() {
|
func (clients *clientsContainer) addFromSystemARP() {
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
return
|
return
|
||||||
@ -350,7 +351,7 @@ func clientsAddFromSystemARP() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, e := clientAddHost(ip, host, ClientSourceARP)
|
ok, e := clients.AddHost(ip, host, ClientSourceARP)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
log.Tracef("%s", e)
|
log.Tracef("%s", e)
|
||||||
}
|
}
|
||||||
@ -379,8 +380,8 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
data := clientListJSON{}
|
data := clientListJSON{}
|
||||||
|
|
||||||
clients.lock.Lock()
|
config.clients.lock.Lock()
|
||||||
for _, c := range clients.list {
|
for _, c := range config.clients.list {
|
||||||
cj := clientJSON{
|
cj := clientJSON{
|
||||||
IP: c.IP,
|
IP: c.IP,
|
||||||
MAC: c.MAC,
|
MAC: c.MAC,
|
||||||
@ -394,7 +395,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if len(c.MAC) != 0 {
|
if len(c.MAC) != 0 {
|
||||||
hwAddr, _ := net.ParseMAC(c.MAC)
|
hwAddr, _ := net.ParseMAC(c.MAC)
|
||||||
ipAddr := dhcpServer.FindIPbyMAC(hwAddr)
|
ipAddr := config.dhcpServer.FindIPbyMAC(hwAddr)
|
||||||
if ipAddr != nil {
|
if ipAddr != nil {
|
||||||
cj.IP = ipAddr.String()
|
cj.IP = ipAddr.String()
|
||||||
}
|
}
|
||||||
@ -402,7 +403,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
data.Clients = append(data.Clients, cj)
|
data.Clients = append(data.Clients, cj)
|
||||||
}
|
}
|
||||||
for ip, ch := range clients.ipHost {
|
for ip, ch := range config.clients.ipHost {
|
||||||
cj := clientHostJSON{
|
cj := clientHostJSON{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Name: ch.Host,
|
Name: ch.Host,
|
||||||
@ -416,7 +417,7 @@ func handleGetClients(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
data.AutoClients = append(data.AutoClients, cj)
|
data.AutoClients = append(data.AutoClients, cj)
|
||||||
}
|
}
|
||||||
clients.lock.Unlock()
|
config.clients.lock.Unlock()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
e := json.NewEncoder(w).Encode(data)
|
e := json.NewEncoder(w).Encode(data)
|
||||||
@ -462,7 +463,7 @@ func handleAddClient(w http.ResponseWriter, r *http.Request) {
|
|||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
httpError(w, http.StatusBadRequest, "%s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ok, err := clientAdd(*c)
|
ok, err := config.clients.Add(*c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
httpError(w, http.StatusBadRequest, "%s", err)
|
||||||
return
|
return
|
||||||
@ -492,7 +493,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !clientDel(cj.Name) {
|
if !config.clients.Del(cj.Name) {
|
||||||
httpError(w, http.StatusBadRequest, "Client not found")
|
httpError(w, http.StatusBadRequest, "Client not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -501,7 +502,7 @@ func handleDelClient(w http.ResponseWriter, r *http.Request) {
|
|||||||
returnOK(w)
|
returnOK(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientUpdateJSON struct {
|
type updateJSON struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Data clientJSON `json:"data"`
|
Data clientJSON `json:"data"`
|
||||||
}
|
}
|
||||||
@ -515,7 +516,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var dj clientUpdateJSON
|
var dj updateJSON
|
||||||
err = json.Unmarshal(body, &dj)
|
err = json.Unmarshal(body, &dj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
httpError(w, http.StatusBadRequest, "JSON parse: %s", err)
|
||||||
@ -532,7 +533,7 @@ func handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = clientUpdate(dj.Name, *c)
|
err = config.clients.Update(dj.Name, *c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
httpError(w, http.StatusBadRequest, "%s", err)
|
||||||
return
|
return
|
||||||
|
@ -6,17 +6,18 @@ func TestClients(t *testing.T) {
|
|||||||
var c Client
|
var c Client
|
||||||
var e error
|
var e error
|
||||||
var b bool
|
var b bool
|
||||||
|
clients := clientsContainer{}
|
||||||
|
|
||||||
clientsInit()
|
clients.Init()
|
||||||
|
|
||||||
// add
|
// add
|
||||||
c = Client{
|
c = Client{
|
||||||
IP: "1.1.1.1",
|
IP: "1.1.1.1",
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
}
|
}
|
||||||
b, e = clientAdd(c)
|
b, e = clients.Add(c)
|
||||||
if !b || e != nil {
|
if !b || e != nil {
|
||||||
t.Fatalf("clientAdd #1")
|
t.Fatalf("Add #1")
|
||||||
}
|
}
|
||||||
|
|
||||||
// add #2
|
// add #2
|
||||||
@ -24,19 +25,19 @@ func TestClients(t *testing.T) {
|
|||||||
IP: "2.2.2.2",
|
IP: "2.2.2.2",
|
||||||
Name: "client2",
|
Name: "client2",
|
||||||
}
|
}
|
||||||
b, e = clientAdd(c)
|
b, e = clients.Add(c)
|
||||||
if !b || e != nil {
|
if !b || e != nil {
|
||||||
t.Fatalf("clientAdd #2")
|
t.Fatalf("Add #2")
|
||||||
}
|
}
|
||||||
|
|
||||||
c, b = clientFind("1.1.1.1")
|
c, b = clients.Find("1.1.1.1")
|
||||||
if !b || c.Name != "client1" {
|
if !b || c.Name != "client1" {
|
||||||
t.Fatalf("clientFind #1")
|
t.Fatalf("Find #1")
|
||||||
}
|
}
|
||||||
|
|
||||||
c, b = clientFind("2.2.2.2")
|
c, b = clients.Find("2.2.2.2")
|
||||||
if !b || c.Name != "client2" {
|
if !b || c.Name != "client2" {
|
||||||
t.Fatalf("clientFind #2")
|
t.Fatalf("Find #2")
|
||||||
}
|
}
|
||||||
|
|
||||||
// failed add - name in use
|
// failed add - name in use
|
||||||
@ -44,9 +45,9 @@ func TestClients(t *testing.T) {
|
|||||||
IP: "1.2.3.5",
|
IP: "1.2.3.5",
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
}
|
}
|
||||||
b, _ = clientAdd(c)
|
b, _ = clients.Add(c)
|
||||||
if b {
|
if b {
|
||||||
t.Fatalf("clientAdd - name in use")
|
t.Fatalf("Add - name in use")
|
||||||
}
|
}
|
||||||
|
|
||||||
// failed add - ip in use
|
// failed add - ip in use
|
||||||
@ -54,91 +55,91 @@ func TestClients(t *testing.T) {
|
|||||||
IP: "2.2.2.2",
|
IP: "2.2.2.2",
|
||||||
Name: "client3",
|
Name: "client3",
|
||||||
}
|
}
|
||||||
b, e = clientAdd(c)
|
b, e = clients.Add(c)
|
||||||
if b || e == nil {
|
if b || e == nil {
|
||||||
t.Fatalf("clientAdd - ip in use")
|
t.Fatalf("Add - ip in use")
|
||||||
}
|
}
|
||||||
|
|
||||||
// get
|
// get
|
||||||
if clientExists("1.2.3.4") {
|
if clients.Exists("1.2.3.4") {
|
||||||
t.Fatalf("clientExists")
|
t.Fatalf("Exists")
|
||||||
}
|
}
|
||||||
if !clientExists("1.1.1.1") {
|
if !clients.Exists("1.1.1.1") {
|
||||||
t.Fatalf("clientExists #1")
|
t.Fatalf("Exists #1")
|
||||||
}
|
}
|
||||||
if !clientExists("2.2.2.2") {
|
if !clients.Exists("2.2.2.2") {
|
||||||
t.Fatalf("clientExists #2")
|
t.Fatalf("Exists #2")
|
||||||
}
|
}
|
||||||
|
|
||||||
// failed update - no such name
|
// failed update - no such name
|
||||||
c.IP = "1.2.3.0"
|
c.IP = "1.2.3.0"
|
||||||
c.Name = "client3"
|
c.Name = "client3"
|
||||||
if clientUpdate("client3", c) == nil {
|
if clients.Update("client3", c) == nil {
|
||||||
t.Fatalf("clientUpdate")
|
t.Fatalf("Update")
|
||||||
}
|
}
|
||||||
|
|
||||||
// failed update - name in use
|
// failed update - name in use
|
||||||
c.IP = "1.2.3.0"
|
c.IP = "1.2.3.0"
|
||||||
c.Name = "client2"
|
c.Name = "client2"
|
||||||
if clientUpdate("client1", c) == nil {
|
if clients.Update("client1", c) == nil {
|
||||||
t.Fatalf("clientUpdate - name in use")
|
t.Fatalf("Update - name in use")
|
||||||
}
|
}
|
||||||
|
|
||||||
// failed update - ip in use
|
// failed update - ip in use
|
||||||
c.IP = "2.2.2.2"
|
c.IP = "2.2.2.2"
|
||||||
c.Name = "client1"
|
c.Name = "client1"
|
||||||
if clientUpdate("client1", c) == nil {
|
if clients.Update("client1", c) == nil {
|
||||||
t.Fatalf("clientUpdate - ip in use")
|
t.Fatalf("Update - ip in use")
|
||||||
}
|
}
|
||||||
|
|
||||||
// update
|
// update
|
||||||
c.IP = "1.1.1.2"
|
c.IP = "1.1.1.2"
|
||||||
c.Name = "client1"
|
c.Name = "client1"
|
||||||
if clientUpdate("client1", c) != nil {
|
if clients.Update("client1", c) != nil {
|
||||||
t.Fatalf("clientUpdate")
|
t.Fatalf("Update")
|
||||||
}
|
}
|
||||||
|
|
||||||
// get after update
|
// get after update
|
||||||
if clientExists("1.1.1.1") || !clientExists("1.1.1.2") {
|
if clients.Exists("1.1.1.1") || !clients.Exists("1.1.1.2") {
|
||||||
t.Fatalf("clientExists - get after update")
|
t.Fatalf("Exists - get after update")
|
||||||
}
|
}
|
||||||
|
|
||||||
// failed remove - no such name
|
// failed remove - no such name
|
||||||
if clientDel("client3") {
|
if clients.Del("client3") {
|
||||||
t.Fatalf("clientDel - no such name")
|
t.Fatalf("Del - no such name")
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove
|
// remove
|
||||||
if !clientDel("client1") || clientExists("1.1.1.2") {
|
if !clients.Del("client1") || clients.Exists("1.1.1.2") {
|
||||||
t.Fatalf("clientDel")
|
t.Fatalf("Del")
|
||||||
}
|
}
|
||||||
|
|
||||||
// add host client
|
// add host client
|
||||||
b, e = clientAddHost("1.1.1.1", "host", ClientSourceARP)
|
b, e = clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||||
if !b || e != nil {
|
if !b || e != nil {
|
||||||
t.Fatalf("clientAddHost")
|
t.Fatalf("clientAddHost")
|
||||||
}
|
}
|
||||||
|
|
||||||
// failed add - ip exists
|
// failed add - ip exists
|
||||||
b, e = clientAddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
b, e = clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||||
if b || e != nil {
|
if b || e != nil {
|
||||||
t.Fatalf("clientAddHost - ip exists")
|
t.Fatalf("clientAddHost - ip exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
// overwrite with new data
|
// overwrite with new data
|
||||||
b, e = clientAddHost("1.1.1.1", "host2", ClientSourceARP)
|
b, e = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||||
if !b || e != nil {
|
if !b || e != nil {
|
||||||
t.Fatalf("clientAddHost - overwrite with new data")
|
t.Fatalf("clientAddHost - overwrite with new data")
|
||||||
}
|
}
|
||||||
|
|
||||||
// overwrite with new data (higher priority)
|
// overwrite with new data (higher priority)
|
||||||
b, e = clientAddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
b, e = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||||
if !b || e != nil {
|
if !b || e != nil {
|
||||||
t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
|
t.Fatalf("clientAddHost - overwrite with new data (higher priority)")
|
||||||
}
|
}
|
||||||
|
|
||||||
// get
|
// get
|
||||||
if !clientExists("1.1.1.1") {
|
if !clients.Exists("1.1.1.1") {
|
||||||
t.Fatalf("clientAddHost")
|
t.Fatalf("clientAddHost")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,7 @@ package home
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -38,6 +39,13 @@ type clientObject struct {
|
|||||||
SafeBrowsingEnabled bool `yaml:"safesearch_enabled"`
|
SafeBrowsingEnabled bool `yaml:"safesearch_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HTTPSServer struct {
|
||||||
|
server *http.Server
|
||||||
|
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
||||||
|
sync.Mutex // protects config.TLS
|
||||||
|
shutdown bool // if TRUE, don't restart the server
|
||||||
|
}
|
||||||
|
|
||||||
// configuration is loaded from YAML
|
// configuration is loaded from YAML
|
||||||
// field ordering is important -- yaml fields will mirror ordering from here
|
// field ordering is important -- yaml fields will mirror ordering from here
|
||||||
type configuration struct {
|
type configuration struct {
|
||||||
@ -48,10 +56,25 @@ type configuration struct {
|
|||||||
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
|
ourConfigFilename string // Config filename (can be overridden via the command line arguments)
|
||||||
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
|
ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||||
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
|
firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html
|
||||||
|
pidFileName string // PID file name. Empty if no PID file was created.
|
||||||
// runningAsService flag is set to true when options are passed from the service runner
|
// runningAsService flag is set to true when options are passed from the service runner
|
||||||
runningAsService bool
|
runningAsService bool
|
||||||
disableUpdate bool // If set, don't check for updates
|
disableUpdate bool // If set, don't check for updates
|
||||||
appSignalChannel chan os.Signal
|
appSignalChannel chan os.Signal
|
||||||
|
clients clientsContainer
|
||||||
|
controlLock sync.Mutex
|
||||||
|
transport *http.Transport
|
||||||
|
client *http.Client
|
||||||
|
|
||||||
|
// cached version.json to avoid hammering github.io for each page reload
|
||||||
|
versionCheckJSON []byte
|
||||||
|
versionCheckLastTime time.Time
|
||||||
|
|
||||||
|
dnsctx dnsContext
|
||||||
|
dnsServer *dnsforward.Server
|
||||||
|
dhcpServer dhcpd.Server
|
||||||
|
httpServer *http.Server
|
||||||
|
httpsServer HTTPSServer
|
||||||
|
|
||||||
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
|
||||||
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
|
||||||
@ -127,7 +150,6 @@ type tlsConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// initialize to default values, will be changed later when reading config or parsing command line
|
// initialize to default values, will be changed later when reading config or parsing command line
|
||||||
// TODO: Get rid of global variables
|
|
||||||
var config = configuration{
|
var config = configuration{
|
||||||
ourConfigFilename: "AdGuardHome.yaml",
|
ourConfigFilename: "AdGuardHome.yaml",
|
||||||
BindPort: 3000,
|
BindPort: 3000,
|
||||||
@ -167,8 +189,16 @@ var config = configuration{
|
|||||||
SchemaVersion: currentSchemaVersion,
|
SchemaVersion: currentSchemaVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
// init initializes default configuration for the current OS&ARCH
|
// initConfig initializes default configuration for the current OS&ARCH
|
||||||
func init() {
|
func initConfig() {
|
||||||
|
config.transport = &http.Transport{
|
||||||
|
DialContext: customDialContext,
|
||||||
|
}
|
||||||
|
config.client = &http.Client{
|
||||||
|
Timeout: time.Minute * 5,
|
||||||
|
Transport: config.transport,
|
||||||
|
}
|
||||||
|
|
||||||
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
|
if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" {
|
||||||
// Use plain DNS on MIPS, encryption is too slow
|
// Use plain DNS on MIPS, encryption is too slow
|
||||||
defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
|
defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
|
||||||
@ -233,7 +263,7 @@ func parseConfig() error {
|
|||||||
SafeSearchEnabled: cy.SafeSearchEnabled,
|
SafeSearchEnabled: cy.SafeSearchEnabled,
|
||||||
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
|
SafeBrowsingEnabled: cy.SafeBrowsingEnabled,
|
||||||
}
|
}
|
||||||
_, err = clientAdd(cli)
|
_, err = config.clients.Add(cli)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Tracef("clientAdd: %s", err)
|
log.Tracef("clientAdd: %s", err)
|
||||||
}
|
}
|
||||||
@ -268,7 +298,7 @@ func (c *configuration) write() error {
|
|||||||
c.Lock()
|
c.Lock()
|
||||||
defer c.Unlock()
|
defer c.Unlock()
|
||||||
|
|
||||||
clientsList := clientsGetList()
|
clientsList := config.clients.GetList()
|
||||||
for _, cli := range clientsList {
|
for _, cli := range clientsList {
|
||||||
ip := cli.IP
|
ip := cli.IP
|
||||||
if len(cli.MAC) != 0 {
|
if len(cli.MAC) != 0 {
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
@ -25,23 +24,8 @@ import (
|
|||||||
|
|
||||||
const updatePeriod = time.Minute * 30
|
const updatePeriod = time.Minute * 30
|
||||||
|
|
||||||
// cached version.json to avoid hammering github.io for each page reload
|
|
||||||
var versionCheckJSON []byte
|
|
||||||
var versionCheckLastTime time.Time
|
|
||||||
|
|
||||||
var protocols = []string{"tls://", "https://", "tcp://", "sdns://"}
|
var protocols = []string{"tls://", "https://", "tcp://", "sdns://"}
|
||||||
|
|
||||||
var transport = &http.Transport{
|
|
||||||
DialContext: customDialContext,
|
|
||||||
}
|
|
||||||
|
|
||||||
var client = &http.Client{
|
|
||||||
Timeout: time.Minute * 5,
|
|
||||||
Transport: transport,
|
|
||||||
}
|
|
||||||
|
|
||||||
var controlLock sync.Mutex
|
|
||||||
|
|
||||||
// ----------------
|
// ----------------
|
||||||
// helper functions
|
// helper functions
|
||||||
// ----------------
|
// ----------------
|
||||||
@ -188,7 +172,7 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Tracef("%s %v", r.Method, r.URL)
|
log.Tracef("%s %v", r.Method, r.URL)
|
||||||
data := dnsServer.GetQueryLog()
|
data := config.dnsServer.GetQueryLog()
|
||||||
|
|
||||||
jsonVal, err := json.Marshal(data)
|
jsonVal, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -205,7 +189,7 @@ func handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
func handleStatsTop(w http.ResponseWriter, r *http.Request) {
|
func handleStatsTop(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Tracef("%s %v", r.Method, r.URL)
|
log.Tracef("%s %v", r.Method, r.URL)
|
||||||
s := dnsServer.GetStatsTop()
|
s := config.dnsServer.GetStatsTop()
|
||||||
|
|
||||||
// use manual json marshalling because we want maps to be sorted by value
|
// use manual json marshalling because we want maps to be sorted by value
|
||||||
statsJSON := bytes.Buffer{}
|
statsJSON := bytes.Buffer{}
|
||||||
@ -252,7 +236,7 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
|
|||||||
// handleStatsReset resets the stats caches
|
// handleStatsReset resets the stats caches
|
||||||
func handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
func handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Tracef("%s %v", r.Method, r.URL)
|
log.Tracef("%s %v", r.Method, r.URL)
|
||||||
dnsServer.PurgeStats()
|
config.dnsServer.PurgeStats()
|
||||||
_, err := fmt.Fprintf(w, "OK\n")
|
_, err := fmt.Fprintf(w, "OK\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err)
|
||||||
@ -262,7 +246,7 @@ func handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
|||||||
// handleStats returns aggregated stats data for the 24 hours
|
// handleStats returns aggregated stats data for the 24 hours
|
||||||
func handleStats(w http.ResponseWriter, r *http.Request) {
|
func handleStats(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Tracef("%s %v", r.Method, r.URL)
|
log.Tracef("%s %v", r.Method, r.URL)
|
||||||
summed := dnsServer.GetAggregatedStats()
|
summed := config.dnsServer.GetAggregatedStats()
|
||||||
|
|
||||||
statsJSON, err := json.Marshal(summed)
|
statsJSON, err := json.Marshal(summed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -309,7 +293,7 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
|
data, err := config.dnsServer.GetStatsHistory(timeUnit, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "Cannot get stats history: %s", err)
|
httpError(w, http.StatusBadRequest, "Cannot get stats history: %s", err)
|
||||||
return
|
return
|
||||||
@ -725,7 +709,7 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Stop DNS server:
|
// Stop DNS server:
|
||||||
// we close urlfilter object which in turn closes file descriptors to filter files.
|
// we close urlfilter object which in turn closes file descriptors to filter files.
|
||||||
// Otherwise, Windows won't allow us to remove the file which is being currently used.
|
// Otherwise, Windows won't allow us to remove the file which is being currently used.
|
||||||
_ = dnsServer.Stop()
|
_ = config.dnsServer.Stop()
|
||||||
|
|
||||||
// go through each element and delete if url matches
|
// go through each element and delete if url matches
|
||||||
config.Lock()
|
config.Lock()
|
||||||
@ -984,7 +968,7 @@ func handleDOH(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer.ServeHTTP(w, r)
|
config.dnsServer.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------
|
// ------------------------
|
||||||
|
@ -17,13 +17,13 @@ type accessListJSON struct {
|
|||||||
func handleAccessList(w http.ResponseWriter, r *http.Request) {
|
func handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Tracef("%s %v", r.Method, r.URL)
|
log.Tracef("%s %v", r.Method, r.URL)
|
||||||
|
|
||||||
controlLock.Lock()
|
config.controlLock.Lock()
|
||||||
j := accessListJSON{
|
j := accessListJSON{
|
||||||
AllowedClients: config.DNS.AllowedClients,
|
AllowedClients: config.DNS.AllowedClients,
|
||||||
DisallowedClients: config.DNS.DisallowedClients,
|
DisallowedClients: config.DNS.DisallowedClients,
|
||||||
BlockedHosts: config.DNS.BlockedHosts,
|
BlockedHosts: config.DNS.BlockedHosts,
|
||||||
}
|
}
|
||||||
controlLock.Unlock()
|
config.controlLock.Unlock()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
err := json.NewEncoder(w).Encode(j)
|
err := json.NewEncoder(w).Encode(j)
|
||||||
|
@ -264,7 +264,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
|||||||
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
// until all requests are finished, and _we_ are inside a request right now, so it will block indefinitely
|
||||||
if restartHTTP {
|
if restartHTTP {
|
||||||
go func() {
|
go func() {
|
||||||
httpServer.Shutdown(context.TODO())
|
config.httpServer.Shutdown(context.TODO())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
|||||||
// check if port is available
|
// check if port is available
|
||||||
// BUT: if we are already using this port, no need
|
// BUT: if we are already using this port, no need
|
||||||
alreadyRunning := false
|
alreadyRunning := false
|
||||||
if httpsServer.server != nil {
|
if config.httpsServer.server != nil {
|
||||||
alreadyRunning = true
|
alreadyRunning = true
|
||||||
}
|
}
|
||||||
if !alreadyRunning {
|
if !alreadyRunning {
|
||||||
@ -72,7 +72,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
|||||||
// check if port is available
|
// check if port is available
|
||||||
// BUT: if we are already using this port, no need
|
// BUT: if we are already using this port, no need
|
||||||
alreadyRunning := false
|
alreadyRunning := false
|
||||||
if httpsServer.server != nil {
|
if config.httpsServer.server != nil {
|
||||||
alreadyRunning = true
|
alreadyRunning = true
|
||||||
}
|
}
|
||||||
if !alreadyRunning {
|
if !alreadyRunning {
|
||||||
@ -101,12 +101,12 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
|||||||
if restartHTTPS {
|
if restartHTTPS {
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
time.Sleep(time.Second) // TODO: could not find a way to reliably know that data was fully sent to client by https server, so we wait a bit to let response through before closing the server
|
||||||
httpsServer.cond.L.Lock()
|
config.httpsServer.cond.L.Lock()
|
||||||
httpsServer.cond.Broadcast()
|
config.httpsServer.cond.Broadcast()
|
||||||
if httpsServer.server != nil {
|
if config.httpsServer.server != nil {
|
||||||
httpsServer.server.Shutdown(context.TODO())
|
config.httpsServer.server.Shutdown(context.TODO())
|
||||||
}
|
}
|
||||||
httpsServer.cond.L.Unlock()
|
config.httpsServer.cond.L.Unlock()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,10 +73,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if !req.RecheckNow {
|
if !req.RecheckNow {
|
||||||
controlLock.Lock()
|
config.controlLock.Lock()
|
||||||
cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0
|
cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0
|
||||||
data := versionCheckJSON
|
data := config.versionCheckJSON
|
||||||
controlLock.Unlock()
|
config.controlLock.Unlock()
|
||||||
|
|
||||||
if cached {
|
if cached {
|
||||||
log.Tracef("Returning cached data")
|
log.Tracef("Returning cached data")
|
||||||
@ -87,7 +87,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Downloading data from %s", versionCheckURL)
|
log.Tracef("Downloading data from %s", versionCheckURL)
|
||||||
resp, err := client.Get(versionCheckURL)
|
resp, err := config.client.Get(versionCheckURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
|
httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err)
|
||||||
return
|
return
|
||||||
@ -103,10 +103,10 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
controlLock.Lock()
|
config.controlLock.Lock()
|
||||||
versionCheckLastTime = now
|
config.versionCheckLastTime = now
|
||||||
versionCheckJSON = body
|
config.versionCheckJSON = body
|
||||||
controlLock.Unlock()
|
config.controlLock.Unlock()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
_, err = w.Write(getVersionResp(body))
|
_, err = w.Write(getVersionResp(body))
|
||||||
@ -349,7 +349,7 @@ func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly,
|
|||||||
|
|
||||||
// Download package file and save it to disk
|
// Download package file and save it to disk
|
||||||
func getPackageFile(u *updateInfo) error {
|
func getPackageFile(u *updateInfo) error {
|
||||||
resp, err := client.Get(u.pkgURL)
|
resp, err := config.client.Get(u.pkgURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("HTTP request failed: %s", err)
|
return fmt.Errorf("HTTP request failed: %s", err)
|
||||||
}
|
}
|
||||||
@ -501,12 +501,12 @@ func finishUpdate(u *updateInfo) {
|
|||||||
func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
func handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Tracef("%s %v", r.Method, r.URL)
|
log.Tracef("%s %v", r.Method, r.URL)
|
||||||
|
|
||||||
if len(versionCheckJSON) == 0 {
|
if len(config.versionCheckJSON) == 0 {
|
||||||
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
|
httpError(w, http.StatusBadRequest, "/update request isn't allowed now")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := getUpdateInfo(versionCheckJSON)
|
u, err := getUpdateInfo(config.versionCheckJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusInternalServerError, "%s", err)
|
httpError(w, http.StatusInternalServerError, "%s", err)
|
||||||
return
|
return
|
||||||
|
24
home/dhcp.go
24
home/dhcp.go
@ -18,8 +18,6 @@ import (
|
|||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dhcpServer = dhcpd.Server{}
|
|
||||||
|
|
||||||
// []dhcpd.Lease -> JSON
|
// []dhcpd.Lease -> JSON
|
||||||
func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string {
|
func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]string {
|
||||||
leases := []map[string]string{}
|
leases := []map[string]string{}
|
||||||
@ -41,8 +39,8 @@ func convertLeases(inputLeases []dhcpd.Lease, includeExpires bool) []map[string]
|
|||||||
|
|
||||||
func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
func handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Tracef("%s %v", r.Method, r.URL)
|
log.Tracef("%s %v", r.Method, r.URL)
|
||||||
leases := convertLeases(dhcpServer.Leases(), true)
|
leases := convertLeases(config.dhcpServer.Leases(), true)
|
||||||
staticLeases := convertLeases(dhcpServer.StaticLeases(), false)
|
staticLeases := convertLeases(config.dhcpServer.StaticLeases(), false)
|
||||||
status := map[string]interface{}{
|
status := map[string]interface{}{
|
||||||
"config": config.DHCP,
|
"config": config.DHCP,
|
||||||
"leases": leases,
|
"leases": leases,
|
||||||
@ -77,18 +75,18 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dhcpServer.CheckConfig(newconfig.ServerConfig)
|
err = config.dhcpServer.CheckConfig(newconfig.ServerConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
|
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dhcpServer.Stop()
|
err = config.dhcpServer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("failed to stop the DHCP server: %s", err)
|
log.Error("failed to stop the DHCP server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dhcpServer.Init(newconfig.ServerConfig)
|
err = config.dhcpServer.Init(newconfig.ServerConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
|
httpError(w, http.StatusBadRequest, "Invalid DHCP configuration: %s", err)
|
||||||
return
|
return
|
||||||
@ -105,7 +103,7 @@ func handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dhcpServer.Start()
|
err = config.dhcpServer.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
|
httpError(w, http.StatusBadRequest, "Failed to start DHCP server: %s", err)
|
||||||
return
|
return
|
||||||
@ -389,7 +387,7 @@ func handleDHCPAddStaticLease(w http.ResponseWriter, r *http.Request) {
|
|||||||
HWAddr: mac,
|
HWAddr: mac,
|
||||||
Hostname: lj.Hostname,
|
Hostname: lj.Hostname,
|
||||||
}
|
}
|
||||||
err = dhcpServer.AddStaticLease(lease)
|
err = config.dhcpServer.AddStaticLease(lease)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
httpError(w, http.StatusBadRequest, "%s", err)
|
||||||
return
|
return
|
||||||
@ -420,7 +418,7 @@ func handleDHCPRemoveStaticLease(w http.ResponseWriter, r *http.Request) {
|
|||||||
HWAddr: mac,
|
HWAddr: mac,
|
||||||
Hostname: lj.Hostname,
|
Hostname: lj.Hostname,
|
||||||
}
|
}
|
||||||
err = dhcpServer.RemoveStaticLease(lease)
|
err = config.dhcpServer.RemoveStaticLease(lease)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
httpError(w, http.StatusBadRequest, "%s", err)
|
||||||
return
|
return
|
||||||
@ -434,12 +432,12 @@ func startDHCPServer() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dhcpServer.Init(config.DHCP)
|
err := config.dhcpServer.Init(config.DHCP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't init DHCP server")
|
return errorx.Decorate(err, "Couldn't init DHCP server")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dhcpServer.Start()
|
err = config.dhcpServer.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start DHCP server")
|
return errorx.Decorate(err, "Couldn't start DHCP server")
|
||||||
}
|
}
|
||||||
@ -451,7 +449,7 @@ func stopDHCPServer() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dhcpServer.Stop()
|
err := config.dhcpServer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't stop DHCP server")
|
return errorx.Decorate(err, "Couldn't stop DHCP server")
|
||||||
}
|
}
|
||||||
|
50
home/dns.go
50
home/dns.go
@ -17,8 +17,6 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
var dnsServer *dnsforward.Server
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
|
rdnsTimeout = 3 * time.Second // max time to wait for rDNS response
|
||||||
)
|
)
|
||||||
@ -32,8 +30,6 @@ type dnsContext struct {
|
|||||||
upstream upstream.Upstream // Upstream object for our own DNS server
|
upstream upstream.Upstream // Upstream object for our own DNS server
|
||||||
}
|
}
|
||||||
|
|
||||||
var dnsctx dnsContext
|
|
||||||
|
|
||||||
// initDNSServer creates an instance of the dnsforward.Server
|
// initDNSServer creates an instance of the dnsforward.Server
|
||||||
// Please note that we must do it even if we don't start it
|
// Please note that we must do it even if we don't start it
|
||||||
// so that we had access to the query log and the stats
|
// so that we had access to the query log and the stats
|
||||||
@ -43,7 +39,7 @@ func initDNSServer(baseDir string) {
|
|||||||
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
|
log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer = dnsforward.NewServer(baseDir)
|
config.dnsServer = dnsforward.NewServer(baseDir)
|
||||||
|
|
||||||
bindhost := config.DNS.BindHost
|
bindhost := config.DNS.BindHost
|
||||||
if config.DNS.BindHost == "0.0.0.0" {
|
if config.DNS.BindHost == "0.0.0.0" {
|
||||||
@ -53,37 +49,37 @@ func initDNSServer(baseDir string) {
|
|||||||
opts := upstream.Options{
|
opts := upstream.Options{
|
||||||
Timeout: rdnsTimeout,
|
Timeout: rdnsTimeout,
|
||||||
}
|
}
|
||||||
dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
|
config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("upstream.AddressToUpstream: %s", err)
|
log.Error("upstream.AddressToUpstream: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dnsctx.rdnsIP = make(map[string]bool)
|
config.dnsctx.rdnsIP = make(map[string]bool)
|
||||||
dnsctx.rdnsChannel = make(chan string, 256)
|
config.dnsctx.rdnsChannel = make(chan string, 256)
|
||||||
go asyncRDNSLoop()
|
go asyncRDNSLoop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func isRunning() bool {
|
func isRunning() bool {
|
||||||
return dnsServer != nil && dnsServer.IsRunning()
|
return config.dnsServer != nil && config.dnsServer.IsRunning()
|
||||||
}
|
}
|
||||||
|
|
||||||
func beginAsyncRDNS(ip string) {
|
func beginAsyncRDNS(ip string) {
|
||||||
if clientExists(ip) {
|
if config.clients.Exists(ip) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// add IP to rdnsIP, if not exists
|
// add IP to rdnsIP, if not exists
|
||||||
dnsctx.rdnsLock.Lock()
|
config.dnsctx.rdnsLock.Lock()
|
||||||
defer dnsctx.rdnsLock.Unlock()
|
defer config.dnsctx.rdnsLock.Unlock()
|
||||||
_, ok := dnsctx.rdnsIP[ip]
|
_, ok := config.dnsctx.rdnsIP[ip]
|
||||||
if ok {
|
if ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dnsctx.rdnsIP[ip] = true
|
config.dnsctx.rdnsIP[ip] = true
|
||||||
|
|
||||||
log.Tracef("Adding %s for rDNS resolve", ip)
|
log.Tracef("Adding %s for rDNS resolve", ip)
|
||||||
select {
|
select {
|
||||||
case dnsctx.rdnsChannel <- ip:
|
case config.dnsctx.rdnsChannel <- ip:
|
||||||
//
|
//
|
||||||
default:
|
default:
|
||||||
log.Tracef("rDNS queue is full")
|
log.Tracef("rDNS queue is full")
|
||||||
@ -110,7 +106,7 @@ func resolveRDNS(ip string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := dnsctx.upstream.Exchange(&req)
|
resp, err := config.dnsctx.upstream.Exchange(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error while making an rDNS lookup for %s: %s", ip, err)
|
log.Error("Error while making an rDNS lookup for %s: %s", ip, err)
|
||||||
return ""
|
return ""
|
||||||
@ -138,18 +134,18 @@ func resolveRDNS(ip string) string {
|
|||||||
func asyncRDNSLoop() {
|
func asyncRDNSLoop() {
|
||||||
for {
|
for {
|
||||||
var ip string
|
var ip string
|
||||||
ip = <-dnsctx.rdnsChannel
|
ip = <-config.dnsctx.rdnsChannel
|
||||||
|
|
||||||
host := resolveRDNS(ip)
|
host := resolveRDNS(ip)
|
||||||
if len(host) == 0 {
|
if len(host) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsctx.rdnsLock.Lock()
|
config.dnsctx.rdnsLock.Lock()
|
||||||
delete(dnsctx.rdnsIP, ip)
|
delete(config.dnsctx.rdnsIP, ip)
|
||||||
dnsctx.rdnsLock.Unlock()
|
config.dnsctx.rdnsLock.Unlock()
|
||||||
|
|
||||||
_, _ = clientAddHost(ip, host, ClientSourceRDNS)
|
_, _ = config.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,7 +217,7 @@ func generateServerConfig() (dnsforward.ServerConfig, error) {
|
|||||||
|
|
||||||
// If a client has his own settings, apply them
|
// If a client has his own settings, apply them
|
||||||
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
func applyClientSettings(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {
|
||||||
c, ok := clientFind(clientAddr)
|
c, ok := config.clients.Find(clientAddr)
|
||||||
if !ok || !c.UseOwnSettings {
|
if !ok || !c.UseOwnSettings {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -242,12 +238,12 @@ func startDNSServer() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||||
}
|
}
|
||||||
err = dnsServer.Start(&newconfig)
|
err = config.dnsServer.Start(&newconfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||||
}
|
}
|
||||||
|
|
||||||
top := dnsServer.GetStatsTop()
|
top := config.dnsServer.GetStatsTop()
|
||||||
for k := range top.Clients {
|
for k := range top.Clients {
|
||||||
beginAsyncRDNS(k)
|
beginAsyncRDNS(k)
|
||||||
}
|
}
|
||||||
@ -256,11 +252,11 @@ func startDNSServer() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func reconfigureDNSServer() error {
|
func reconfigureDNSServer() error {
|
||||||
config, err := generateServerConfig()
|
newconfig, err := generateServerConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||||
}
|
}
|
||||||
err = dnsServer.Reconfigure(&config)
|
err = config.dnsServer.Reconfigure(&newconfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't start forwarding DNS server")
|
||||||
}
|
}
|
||||||
@ -273,7 +269,7 @@ func stopDNSServer() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := dnsServer.Stop()
|
err := config.dnsServer.Stop()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
|
return errorx.Decorate(err, "Couldn't stop forwarding DNS server")
|
||||||
}
|
}
|
||||||
|
@ -222,7 +222,7 @@ func refreshFiltersIfNecessary(force bool) int {
|
|||||||
|
|
||||||
stopped := false
|
stopped := false
|
||||||
if updateCount != 0 {
|
if updateCount != 0 {
|
||||||
_ = dnsServer.Stop()
|
_ = config.dnsServer.Stop()
|
||||||
stopped = true
|
stopped = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -308,7 +308,7 @@ func parseFilterContents(contents []byte) (int, string) {
|
|||||||
func (filter *filter) update() (bool, error) {
|
func (filter *filter) update() (bool, error) {
|
||||||
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
|
log.Tracef("Downloading update for filter %d from %s", filter.ID, filter.URL)
|
||||||
|
|
||||||
resp, err := client.Get(filter.URL)
|
resp, err := config.client.Get(filter.URL)
|
||||||
if resp != nil && resp.Body != nil {
|
if resp != nil && resp.Body != nil {
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
}
|
}
|
||||||
|
@ -35,8 +35,8 @@ func ensure(method string, handler func(http.ResponseWriter, *http.Request)) fun
|
|||||||
}
|
}
|
||||||
|
|
||||||
if method == "POST" || method == "PUT" || method == "DELETE" {
|
if method == "POST" || method == "PUT" || method == "DELETE" {
|
||||||
controlLock.Lock()
|
config.controlLock.Lock()
|
||||||
defer controlLock.Unlock()
|
defer config.controlLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
@ -148,7 +148,7 @@ func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.Res
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// enforce https?
|
// enforce https?
|
||||||
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && httpsServer.server != nil {
|
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && config.httpsServer.server != nil {
|
||||||
// yes, and we want host from host:port
|
// yes, and we want host from host:port
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
host, _, err := net.SplitHostPort(r.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
48
home/home.go
48
home/home.go
@ -25,15 +25,6 @@ import (
|
|||||||
"github.com/gobuffalo/packr"
|
"github.com/gobuffalo/packr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpServer *http.Server
|
|
||||||
var httpsServer struct {
|
|
||||||
server *http.Server
|
|
||||||
cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey
|
|
||||||
sync.Mutex // protects config.TLS
|
|
||||||
shutdown bool // if TRUE, don't restart the server
|
|
||||||
}
|
|
||||||
var pidFileName string // PID file name. Empty if no PID file was created.
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
|
// Used in config to indicate that syslog or eventlog (win) should be used for logger output
|
||||||
configSyslog = "syslog"
|
configSyslog = "syslog"
|
||||||
@ -48,7 +39,7 @@ var (
|
|||||||
|
|
||||||
const versionCheckPeriod = time.Hour * 8
|
const versionCheckPeriod = time.Hour * 8
|
||||||
|
|
||||||
// main is the entry point
|
// Main is the entry point
|
||||||
func Main(version string, channel string) {
|
func Main(version string, channel string) {
|
||||||
// Init update-related global variables
|
// Init update-related global variables
|
||||||
versionString = version
|
versionString = version
|
||||||
@ -108,7 +99,8 @@ func run(args options) {
|
|||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientsInit()
|
initConfig()
|
||||||
|
config.clients.Init()
|
||||||
|
|
||||||
if !config.firstRun {
|
if !config.firstRun {
|
||||||
// Do the upgrade if necessary
|
// Do the upgrade if necessary
|
||||||
@ -168,7 +160,7 @@ func run(args options) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) {
|
||||||
pidFileName = args.pidFile
|
config.pidFileName = args.pidFile
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update filters we've just loaded right away, don't wait for periodic update timer
|
// Update filters we've just loaded right away, don't wait for periodic update timer
|
||||||
@ -192,21 +184,21 @@ func run(args options) {
|
|||||||
registerInstallHandlers()
|
registerInstallHandlers()
|
||||||
}
|
}
|
||||||
|
|
||||||
httpsServer.cond = sync.NewCond(&httpsServer.Mutex)
|
config.httpsServer.cond = sync.NewCond(&config.httpsServer.Mutex)
|
||||||
|
|
||||||
// for https, we have a separate goroutine loop
|
// for https, we have a separate goroutine loop
|
||||||
go httpServerLoop()
|
go httpServerLoop()
|
||||||
|
|
||||||
// this loop is used as an ability to change listening host and/or port
|
// this loop is used as an ability to change listening host and/or port
|
||||||
for !httpsServer.shutdown {
|
for !config.httpsServer.shutdown {
|
||||||
printHTTPAddresses("http")
|
printHTTPAddresses("http")
|
||||||
|
|
||||||
// we need to have new instance, because after Shutdown() the Server is not usable
|
// we need to have new instance, because after Shutdown() the Server is not usable
|
||||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||||
httpServer = &http.Server{
|
config.httpServer = &http.Server{
|
||||||
Addr: address,
|
Addr: address,
|
||||||
}
|
}
|
||||||
err := httpServer.ListenAndServe()
|
err := config.httpServer.ListenAndServe()
|
||||||
if err != http.ErrServerClosed {
|
if err != http.ErrServerClosed {
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
@ -219,14 +211,14 @@ func run(args options) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func httpServerLoop() {
|
func httpServerLoop() {
|
||||||
for !httpsServer.shutdown {
|
for !config.httpsServer.shutdown {
|
||||||
httpsServer.cond.L.Lock()
|
config.httpsServer.cond.L.Lock()
|
||||||
// this mechanism doesn't let us through until all conditions are met
|
// this mechanism doesn't let us through until all conditions are met
|
||||||
for config.TLS.Enabled == false ||
|
for config.TLS.Enabled == false ||
|
||||||
config.TLS.PortHTTPS == 0 ||
|
config.TLS.PortHTTPS == 0 ||
|
||||||
config.TLS.PrivateKey == "" ||
|
config.TLS.PrivateKey == "" ||
|
||||||
config.TLS.CertificateChain == "" { // sleep until necessary data is supplied
|
config.TLS.CertificateChain == "" { // sleep until necessary data is supplied
|
||||||
httpsServer.cond.Wait()
|
config.httpsServer.cond.Wait()
|
||||||
}
|
}
|
||||||
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
|
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.TLS.PortHTTPS))
|
||||||
// validate current TLS config and update warnings (it could have been loaded from file)
|
// validate current TLS config and update warnings (it could have been loaded from file)
|
||||||
@ -250,10 +242,10 @@ func httpServerLoop() {
|
|||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
httpsServer.cond.L.Unlock()
|
config.httpsServer.cond.L.Unlock()
|
||||||
|
|
||||||
// prepare HTTPS server
|
// prepare HTTPS server
|
||||||
httpsServer.server = &http.Server{
|
config.httpsServer.server = &http.Server{
|
||||||
Addr: address,
|
Addr: address,
|
||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
@ -262,7 +254,7 @@ func httpServerLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
printHTTPAddresses("https")
|
printHTTPAddresses("https")
|
||||||
err = httpsServer.server.ListenAndServeTLS("", "")
|
err = config.httpsServer.server.ListenAndServeTLS("", "")
|
||||||
if err != http.ErrServerClosed {
|
if err != http.ErrServerClosed {
|
||||||
cleanupAlways()
|
cleanupAlways()
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
@ -399,17 +391,17 @@ func cleanup() {
|
|||||||
|
|
||||||
// Stop HTTP server, possibly waiting for all active connections to be closed
|
// Stop HTTP server, possibly waiting for all active connections to be closed
|
||||||
func stopHTTPServer() {
|
func stopHTTPServer() {
|
||||||
httpsServer.shutdown = true
|
config.httpsServer.shutdown = true
|
||||||
if httpsServer.server != nil {
|
if config.httpsServer.server != nil {
|
||||||
httpsServer.server.Shutdown(context.TODO())
|
config.httpsServer.server.Shutdown(context.TODO())
|
||||||
}
|
}
|
||||||
httpServer.Shutdown(context.TODO())
|
config.httpServer.Shutdown(context.TODO())
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function is called before application exits
|
// This function is called before application exits
|
||||||
func cleanupAlways() {
|
func cleanupAlways() {
|
||||||
if len(pidFileName) != 0 {
|
if len(config.pidFileName) != 0 {
|
||||||
os.Remove(pidFileName)
|
os.Remove(config.pidFileName)
|
||||||
}
|
}
|
||||||
log.Info("Stopped")
|
log.Info("Stopped")
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user