diff --git a/AGHTechDoc.md b/AGHTechDoc.md index 556ef859..7564a007 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -9,6 +9,9 @@ Contents: * "Check configuration" command * Disable DNSStubListener * "Apply configuration" command +* Updating + * Get version command + * Update command * Enable DHCP server * "Check DHCP" command * "Enable DHCP" command @@ -187,6 +190,92 @@ On error, server responds with code 400 or 500. In this case UI should show err ERROR MESSAGE +## Updating + +Algorithm of an update by command: + +* UI requests the latest version information from Server +* Server requests information from Internet; stores the data in cache for several hours; sends data to UI +* If UI sees that a new version is available, it shows notification message and "Update Now" button +* When user clicks on "Update Now" button, UI sends Update command to Server +* UI shows "Please wait, AGH is being updated..." message +* Server performs an update: + * Use working directory from `--work-dir` if necessary + * Download new package for the current OS and CPU + * Unpack the package to a temporary directory `update-vXXX` + * Copy the current configuration file to the directory we unpacked new AGH to + * Check configuration compatibility by executing `./AGH --check-config`. If this command fails, we won't be able to update. + * Create `backup-vXXX` directory and copy the current configuration file there + * Stop all tasks, including DNS server, DHCP server, HTTP server + * Move the current binary file to backup directory + * Note: if power fails here, AGH won't be able to start at system boot. Administrator has to fix it manually + * Move new binary file to the current directory + * If AGH is running as a service, use service control functionality to restart + * If AGH is not running as a service, use the current process arguments to start a new process + * Exit process +* UI resends Get Status command until Server responds to it with the new version. This means that Server is successfully restarted after update. +* UI reloads itself + + +### Get version command + +On receiving this request server downloads version.json data from github and stores it in cache for several hours. + +Example of version.json data: + + { + "version": "v0.95-hotfix", + "announcement": "AdGuard Home v0.95-hotfix is now available!", + "announcement_url": "", + "download_windows_amd64": "", + "download_windows_386": "", + "download_darwin_amd64": "", + "download_linux_amd64": "", + "download_linux_386": "", + "download_linux_arm": "", + "download_linux_arm64": "", + "download_linux_mips": "", + "download_linux_mipsle": "", + "selfupdate_min_version": "v0.0" + } + +Request: + + GET /control/version.json + +Response: + + 200 OK + + { + "new_version": "v0.95", + "announcement": "AdGuard Home v0.95 is now available!", + "announcement_url": "http://...", + "can_autoupdate": true + } + +If `can_autoupdate` is true, then the server can automatically upgrade to a new version. + + +### Update command + +Perform an update procedure to the latest available version + +Request: + + POST /control/update + +Response: + + 200 OK + +Error response: + + 500 + +UI shows error message "Auto-update has failed" + + ## Enable DHCP server Algorithm: diff --git a/app.go b/app.go index 200a3bd3..1c32aa52 100644 --- a/app.go +++ b/app.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "crypto/tls" "fmt" "io" @@ -30,6 +31,7 @@ var httpsServer struct { server *http.Server cond *sync.Cond // reacts to config.TLS.Enabled, PortHTTPS, CertificateChain and PrivateKey sync.Mutex // protects config.TLS + shutdown bool // if TRUE, don't restart the server } var pidFileName string // PID file name. Empty if no PID file was created. @@ -76,6 +78,7 @@ func run(args options) { if args.runningAsService { log.Info("AdGuard Home is running as a service") } + config.runningAsService = args.runningAsService config.firstRun = detectFirstRun() if config.firstRun { @@ -91,16 +94,22 @@ func run(args options) { os.Exit(0) }() - // Do the upgrade if necessary - err := upgradeConfig() - if err != nil { - log.Fatal(err) - } + if !config.firstRun { + // Do the upgrade if necessary + err := upgradeConfig() + if err != nil { + log.Fatal(err) + } - // parse from config file - err = parseConfig() - if err != nil { - log.Fatal(err) + err = parseConfig() + if err != nil { + os.Exit(1) + } + + if args.checkConfig { + log.Info("Configuration file is OK") + os.Exit(0) + } } if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && @@ -118,10 +127,12 @@ func run(args options) { loadFilters() - // Save the updated config - err = config.write() - if err != nil { - log.Fatal(err) + if !config.firstRun { + // Save the updated config + err := config.write() + if err != nil { + log.Fatal(err) + } } // Init the DNS server instance before registering HTTP handlers @@ -129,7 +140,7 @@ func run(args options) { initDNSServer(dnsBaseDir) if !config.firstRun { - err = startDNSServer() + err := startDNSServer() if err != nil { log.Fatal(err) } @@ -171,7 +182,7 @@ func run(args options) { go httpServerLoop() // this loop is used as an ability to change listening host and/or port - for { + for !httpsServer.shutdown { printHTTPAddresses("http") // we need to have new instance, because after Shutdown() the Server is not usable @@ -186,10 +197,13 @@ func run(args options) { } // We use ErrServerClosed as a sign that we need to rebind on new address, so go back to the start of the loop } + + // wait indefinitely for other go-routines to complete their job + select {} } func httpServerLoop() { - for { + for !httpsServer.shutdown { httpsServer.cond.L.Lock() // this mechanism doesn't let us through until all conditions are met for config.TLS.Enabled == false || @@ -367,6 +381,15 @@ func cleanup() { } } +// Stop HTTP server, possibly waiting for all active connections to be closed +func stopHTTPServer() { + httpsServer.shutdown = true + if httpsServer.server != nil { + httpsServer.server.Shutdown(context.TODO()) + } + httpServer.Shutdown(context.TODO()) +} + // This function is called before application exits func cleanupAlways() { if len(pidFileName) != 0 { @@ -384,6 +407,7 @@ type options struct { bindPort int // port to serve HTTP pages on logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog pidFile string // File name to save PID to + checkConfig bool // Check configuration and exit // service control action (see service.ControlAction array + "status" command) serviceControlAction string @@ -404,25 +428,26 @@ func loadOptions() options { callbackWithValue func(value string) callbackNoValue func() }{ - {"config", "c", "path to the config file", func(value string) { o.configFilename = value }, nil}, - {"work-dir", "w", "path to the working directory", func(value string) { o.workDir = value }, nil}, - {"host", "h", "host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil}, - {"port", "p", "port to serve HTTP pages on", func(value string) { + {"config", "c", "Path to the config file", func(value string) { o.configFilename = value }, nil}, + {"work-dir", "w", "Path to the working directory", func(value string) { o.workDir = value }, nil}, + {"host", "h", "Host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil}, + {"port", "p", "Port to serve HTTP pages on", func(value string) { v, err := strconv.Atoi(value) if err != nil { panic("Got port that is not a number") } o.bindPort = v }, nil}, - {"service", "s", "service control action: status, install, uninstall, start, stop, restart", func(value string) { + {"service", "s", "Service control action: status, install, uninstall, start, stop, restart", func(value string) { o.serviceControlAction = value }, nil}, - {"logfile", "l", "path to the log file. If empty, writes to stdout, if 'syslog' -- system log", func(value string) { + {"logfile", "l", "Path to log file. If empty: write to stdout; if 'syslog': write to system log", func(value string) { o.logFile = value }, nil}, - {"pidfile", "", "File name to save PID to", func(value string) { o.pidFile = value }, nil}, - {"verbose", "v", "enable verbose output", nil, func() { o.verbose = true }}, - {"help", "", "print this help", nil, func() { + {"pidfile", "", "Path to a file where PID is stored", func(value string) { o.pidFile = value }, nil}, + {"check-config", "", "Check configuration and exit", nil, func() { o.checkConfig = true }}, + {"verbose", "v", "Enable verbose output", nil, func() { o.verbose = true }}, + {"help", "", "Print this help", nil, func() { printHelp() os.Exit(64) }}, @@ -432,10 +457,14 @@ func loadOptions() options { fmt.Printf("%s [options]\n\n", os.Args[0]) fmt.Printf("Options:\n") for _, opt := range opts { + val := "" + if opt.callbackWithValue != nil { + val = " VALUE" + } if opt.shortName != "" { - fmt.Printf(" -%s, %-30s %s\n", opt.shortName, "--"+opt.longName, opt.description) + fmt.Printf(" -%s, %-30s %s\n", opt.shortName, "--"+opt.longName+val, opt.description) } else { - fmt.Printf(" %-34s %s\n", "--"+opt.longName, opt.description) + fmt.Printf(" %-34s %s\n", "--"+opt.longName+val, opt.description) } } } diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index e0029945..e0cf54e7 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -260,5 +260,8 @@ "dns_addresses": "DNS addresses", "down": "Down", "fix": "Fix", - "dns_providers": "Here is a <0>list of known DNS providers to choose from." + "dns_providers": "Here is a <0>list of known DNS providers to choose from.", + "update_now": "Update now", + "update_failed": "Auto-update failed. Please follow the steps<\/a> to update manually.", + "processing_update": "Please wait, AdGuard Home is being updated" } \ No newline at end of file diff --git a/client/src/actions/index.js b/client/src/actions/index.js index 94830ada..070c9324 100644 --- a/client/src/actions/index.js +++ b/client/src/actions/index.js @@ -2,15 +2,17 @@ import { createAction } from 'redux-actions'; import round from 'lodash/round'; import { t } from 'i18next'; import { showLoading, hideLoading } from 'react-redux-loading-bar'; +import axios from 'axios'; import { normalizeHistory, normalizeFilteringStatus, normalizeLogs, normalizeTextarea } from '../helpers/helpers'; -import { SETTINGS_NAMES } from '../helpers/constants'; +import { SETTINGS_NAMES, CHECK_TIMEOUT } from '../helpers/constants'; import Api from '../api/Api'; const apiClient = new Api(); export const addErrorToast = createAction('ADD_ERROR_TOAST'); export const addSuccessToast = createAction('ADD_SUCCESS_TOAST'); +export const addNoticeToast = createAction('ADD_NOTICE_TOAST'); export const removeToast = createAction('REMOVE_TOAST'); export const toggleSettingStatus = createAction('SETTING_STATUS_TOGGLE'); @@ -154,6 +156,56 @@ export const getVersion = () => async (dispatch) => { } }; +export const getUpdateRequest = createAction('GET_UPDATE_REQUEST'); +export const getUpdateFailure = createAction('GET_UPDATE_FAILURE'); +export const getUpdateSuccess = createAction('GET_UPDATE_SUCCESS'); + +export const getUpdate = () => async (dispatch) => { + dispatch(getUpdateRequest()); + try { + await apiClient.getUpdate(); + + const checkUpdate = async (attempts) => { + let count = attempts || 1; + let timeout; + + if (count > 60) { + dispatch(addNoticeToast({ error: 'update_failed' })); + dispatch(getUpdateFailure()); + return false; + } + + const rmTimeout = t => t && clearTimeout(t); + const setRecursiveTimeout = (time, ...args) => setTimeout( + checkUpdate, + time, + ...args, + ); + + axios.get('control/status') + .then((response) => { + rmTimeout(timeout); + if (response) { + dispatch(getUpdateSuccess()); + window.location.reload(true); + } + timeout = setRecursiveTimeout(CHECK_TIMEOUT, count += 1); + }) + .catch(() => { + rmTimeout(timeout); + timeout = setRecursiveTimeout(CHECK_TIMEOUT, count += 1); + }); + + return false; + }; + + checkUpdate(); + } catch (error) { + dispatch(addNoticeToast({ error: 'update_failed' })); + dispatch(getUpdateFailure()); + } +}; + export const getClientsRequest = createAction('GET_CLIENTS_REQUEST'); export const getClientsFailure = createAction('GET_CLIENTS_FAILURE'); export const getClientsSuccess = createAction('GET_CLIENTS_SUCCESS'); diff --git a/client/src/api/Api.js b/client/src/api/Api.js index 6d8a2f52..1743cc06 100644 --- a/client/src/api/Api.js +++ b/client/src/api/Api.js @@ -40,6 +40,8 @@ export default class Api { GLOBAL_ENABLE_PROTECTION = { path: 'enable_protection', method: 'POST' }; GLOBAL_DISABLE_PROTECTION = { path: 'disable_protection', method: 'POST' }; GLOBAL_CLIENTS = { path: 'clients', method: 'GET' } + GLOBAL_CLIENTS = { path: 'clients', method: 'GET' }; + GLOBAL_UPDATE = { path: 'update', method: 'POST' }; restartGlobalFiltering() { const { path, method } = this.GLOBAL_RESTART; @@ -145,6 +147,11 @@ export default class Api { return this.makeRequest(path, method); } + getUpdate() { + const { path, method } = this.GLOBAL_UPDATE; + return this.makeRequest(path, method); + } + // Filtering FILTERING_STATUS = { path: 'filtering/status', method: 'GET' }; FILTERING_ENABLE = { path: 'filtering/enable', method: 'POST' }; diff --git a/client/src/components/App/index.js b/client/src/components/App/index.js index 545d5007..157a55e6 100644 --- a/client/src/components/App/index.js +++ b/client/src/components/App/index.js @@ -19,6 +19,7 @@ import Toasts from '../Toasts'; import Footer from '../ui/Footer'; import Status from '../ui/Status'; import UpdateTopline from '../ui/UpdateTopline'; +import UpdateOverlay from '../ui/UpdateOverlay'; import EncryptionTopline from '../ui/EncryptionTopline'; import i18n from '../../i18n'; @@ -37,6 +38,10 @@ class App extends Component { this.props.enableDns(); }; + handleUpdate = () => { + this.props.getUpdate(); + } + setLanguage = () => { const { processing, language } = this.props.dashboard; @@ -62,10 +67,16 @@ class App extends Component { {updateAvailable && - + + + + } {!encryption.processing && @@ -100,6 +111,7 @@ class App extends Component { App.propTypes = { getDnsStatus: PropTypes.func, + getUpdate: PropTypes.func, enableDns: PropTypes.func, dashboard: PropTypes.object, isCoreRunning: PropTypes.bool, diff --git a/client/src/components/Toasts/Toast.css b/client/src/components/Toasts/Toast.css index dd4f3c1b..c9496dba 100644 --- a/client/src/components/Toasts/Toast.css +++ b/client/src/components/Toasts/Toast.css @@ -32,6 +32,12 @@ overflow: hidden; } +.toast__content a { + font-weight: 600; + color: #fff; + text-decoration: underline; +} + .toast__dismiss { display: block; flex: 0 0 auto; diff --git a/client/src/components/Toasts/Toast.js b/client/src/components/Toasts/Toast.js index 0e951f09..ca9e1dd8 100644 --- a/client/src/components/Toasts/Toast.js +++ b/client/src/components/Toasts/Toast.js @@ -4,7 +4,7 @@ import { Trans, withNamespaces } from 'react-i18next'; class Toast extends Component { componentDidMount() { - const timeout = this.props.type === 'error' ? 30000 : 5000; + const timeout = this.props.type === 'success' ? 5000 : 30000; setTimeout(() => { this.props.removeToast(this.props.id); @@ -15,13 +15,25 @@ class Toast extends Component { return false; } + showMessage(t, type, message) { + if (type === 'notice') { + return ; + } + + return {message}; + } + render() { + const { + type, id, t, message, + } = this.props; + return ( -
+

- {this.props.message} + {this.showMessage(t, type, message)}

-
@@ -30,6 +42,7 @@ class Toast extends Component { } Toast.propTypes = { + t: PropTypes.func.isRequired, id: PropTypes.string.isRequired, message: PropTypes.string.isRequired, type: PropTypes.string.isRequired, diff --git a/client/src/components/ui/Overlay.css b/client/src/components/ui/Overlay.css new file mode 100644 index 00000000..d12a55b7 --- /dev/null +++ b/client/src/components/ui/Overlay.css @@ -0,0 +1,40 @@ +.overlay { + display: none; + position: fixed; + top: 0; + left: 0; + z-index: 110; + width: 100%; + height: 100%; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 20px; + font-size: 28px; + font-weight: 600; + text-align: center; + background-color: rgba(255, 255, 255, 0.8); +} + +.overlay--visible { + display: flex; +} + +.overlay__loading { + width: 40px; + height: 40px; + margin-bottom: 20px; + background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20viewBox%3D%220%200%2047.6%2047.6%22%20height%3D%22100%25%22%20width%3D%22100%25%22%3E%3Cpath%20opacity%3D%22.235%22%20fill%3D%22%23979797%22%20d%3D%22M44.4%2011.9l-5.2%203c1.5%202.6%202.4%205.6%202.4%208.9%200%209.8-8%2017.8-17.8%2017.8-6.6%200-12.3-3.6-15.4-8.9l-5.2%203C7.3%2042.8%2015%2047.6%2023.8%2047.6c13.1%200%2023.8-10.7%2023.8-23.8%200-4.3-1.2-8.4-3.2-11.9z%22%2F%3E%3Cpath%20fill%3D%22%2366b574%22%20d%3D%22M3.2%2035.7C0%2030.2-.8%2023.8.8%2017.6%202.5%2011.5%206.4%206.4%2011.9%203.2%2017.4%200%2023.8-.8%2030%20.8c6.1%201.6%2011.3%205.6%2014.4%2011.1l-5.2%203c-2.4-4.1-6.2-7.1-10.8-8.3C23.8%205.4%2019%206%2014.9%208.4s-7.1%206.2-8.3%2010.8c-1.2%204.6-.6%209.4%201.8%2013.5l-5.2%203z%22%2F%3E%3C%2Fsvg%3E"); + will-change: transform; + animation: clockwise 2s linear infinite; +} + +@keyframes clockwise { + 0% { + transform: rotate(0deg); + } + + 100% { + transform: rotate(360deg); + } +} diff --git a/client/src/components/ui/UpdateOverlay.js b/client/src/components/ui/UpdateOverlay.js new file mode 100644 index 00000000..7a35264a --- /dev/null +++ b/client/src/components/ui/UpdateOverlay.js @@ -0,0 +1,26 @@ +import React from 'react'; +import PropTypes from 'prop-types'; +import { Trans, withNamespaces } from 'react-i18next'; +import classnames from 'classnames'; + +import './Overlay.css'; + +const UpdateOverlay = (props) => { + const overlayClass = classnames({ + overlay: true, + 'overlay--visible': props.processingUpdate, + }); + + return ( +
+
+ processing_update +
+ ); +}; + +UpdateOverlay.propTypes = { + processingUpdate: PropTypes.bool, +}; + +export default withNamespaces()(UpdateOverlay); diff --git a/client/src/components/ui/UpdateTopline.js b/client/src/components/ui/UpdateTopline.js index a9124666..833a833d 100644 --- a/client/src/components/ui/UpdateTopline.js +++ b/client/src/components/ui/UpdateTopline.js @@ -1,4 +1,4 @@ -import React from 'react'; +import React, { Fragment } from 'react'; import PropTypes from 'prop-types'; import { Trans, withNamespaces } from 'react-i18next'; @@ -6,22 +6,37 @@ import Topline from './Topline'; const UpdateTopline = props => ( - - Click here -
, - ]} - > - update_announcement - + + + Click here + , + ]} + > + update_announcement + + {props.canAutoUpdate && + + } + ); UpdateTopline.propTypes = { - version: PropTypes.string.isRequired, + version: PropTypes.string, url: PropTypes.string.isRequired, + canAutoUpdate: PropTypes.bool, + getUpdate: PropTypes.func, + processingUpdate: PropTypes.bool, }; export default withNamespaces()(UpdateTopline); diff --git a/client/src/reducers/index.js b/client/src/reducers/index.js index 25156688..58b16e94 100644 --- a/client/src/reducers/index.js +++ b/client/src/reducers/index.js @@ -126,12 +126,16 @@ const dashboard = handleActions({ const { version, announcement_url: announcementUrl, + new_version: newVersion, + can_autoupdate: canAutoUpdate, } = payload; const newState = { ...state, version, announcementUrl, + newVersion, + canAutoUpdate, isUpdateAvailable: true, }; return newState; @@ -140,6 +144,13 @@ const dashboard = handleActions({ return state; }, + [actions.getUpdateRequest]: state => ({ ...state, processingUpdate: true }), + [actions.getUpdateFailure]: state => ({ ...state, processingUpdate: false }), + [actions.getUpdateSuccess]: (state) => { + const newState = { ...state, processingUpdate: false }; + return newState; + }, + [actions.getFilteringRequest]: state => ({ ...state, processingFiltering: true }), [actions.getFilteringFailure]: state => ({ ...state, processingFiltering: false }), [actions.getFilteringSuccess]: (state, { payload }) => { @@ -187,6 +198,7 @@ const dashboard = handleActions({ processingVersion: true, processingFiltering: true, processingClients: true, + processingUpdate: false, upstreamDns: '', bootstrapDns: '', allServers: false, diff --git a/client/src/reducers/toasts.js b/client/src/reducers/toasts.js index c56085d3..34698480 100644 --- a/client/src/reducers/toasts.js +++ b/client/src/reducers/toasts.js @@ -1,7 +1,7 @@ import { handleActions } from 'redux-actions'; import nanoid from 'nanoid'; -import { addErrorToast, addSuccessToast, removeToast } from '../actions'; +import { addErrorToast, addSuccessToast, addNoticeToast, removeToast } from '../actions'; const toasts = handleActions({ [addErrorToast]: (state, { payload }) => { @@ -24,6 +24,16 @@ const toasts = handleActions({ const newState = { ...state, notices: [...state.notices, successToast] }; return newState; }, + [addNoticeToast]: (state, { payload }) => { + const noticeToast = { + id: nanoid(), + message: payload.error.toString(), + type: 'notice', + }; + + const newState = { ...state, notices: [...state.notices, noticeToast] }; + return newState; + }, [removeToast]: (state, { payload }) => { const filtered = state.notices.filter(notice => notice.id !== payload); const newState = { ...state, notices: filtered }; diff --git a/config.go b/config.go index 24ee8605..49ffe93a 100644 --- a/config.go +++ b/config.go @@ -30,9 +30,15 @@ type logSettings struct { // configuration is loaded from YAML // field ordering is important -- yaml fields will mirror ordering from here type configuration struct { + // Raw file data to avoid re-reading of configuration file + // It's reset after config is parsed + fileData []byte + ourConfigFilename string // Config filename (can be overridden via the command line arguments) ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html + // runningAsService flag is set to true when options are passed from the service runner + runningAsService bool BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server @@ -113,10 +119,10 @@ var config = configuration{ BindHost: "0.0.0.0", Port: 53, FilteringConfig: dnsforward.FilteringConfig{ - ProtectionEnabled: true, // whether or not use any of dnsfilter features - FilteringEnabled: true, // whether or not use filter lists + ProtectionEnabled: true, // whether or not use any of dnsfilter features + FilteringEnabled: true, // whether or not use filter lists BlockingMode: "nxdomain", // mode how to answer filtered requests - BlockedResponseTTL: 10, // in seconds + BlockedResponseTTL: 10, // in seconds QueryLogEnabled: true, Ratelimit: 20, RefuseAny: true, @@ -174,7 +180,7 @@ func (c *configuration) getConfigFilename() string { func getLogSettings() logSettings { l := logSettings{} yamlFile, err := readConfigFile() - if err != nil || yamlFile == nil { + if err != nil { return l } err = yaml.Unmarshal(yamlFile, &l) @@ -190,13 +196,9 @@ func parseConfig() error { log.Debug("Reading config file: %s", configFile) yamlFile, err := readConfigFile() if err != nil { - log.Error("Couldn't read config file: %s", err) return err } - if yamlFile == nil { - log.Error("YAML file doesn't exist, skipping it") - return nil - } + config.fileData = nil err = yaml.Unmarshal(yamlFile, &config) if err != nil { log.Error("Couldn't parse config file: %s", err) @@ -213,22 +215,23 @@ func parseConfig() error { // readConfigFile reads config file contents if it exists func readConfigFile() ([]byte, error) { - configFile := config.getConfigFilename() - if _, err := os.Stat(configFile); os.IsNotExist(err) { - // do nothing, file doesn't exist - return nil, nil + if len(config.fileData) != 0 { + return config.fileData, nil } - return ioutil.ReadFile(configFile) + + configFile := config.getConfigFilename() + d, err := ioutil.ReadFile(configFile) + if err != nil { + log.Error("Couldn't read config file %s: %s", configFile, err) + return nil, err + } + return d, nil } // Saves configuration to the YAML file and also saves the user filter contents to a file func (c *configuration) write() error { c.Lock() defer c.Unlock() - if config.firstRun { - log.Debug("Silently refusing to write config because first run and not configured yet") - return nil - } configFile := config.getConfigFilename() log.Debug("Writing YAML file: %s", configFile) yamlText, err := yaml.Marshal(&config) diff --git a/control.go b/control.go index 35214d82..4646b221 100644 --- a/control.go +++ b/control.go @@ -557,42 +557,6 @@ func checkDNS(input string, bootstrap []string) error { return nil } -func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { - log.Tracef("%s %v", r.Method, r.URL) - now := time.Now() - if now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 { - // return cached copy - w.Header().Set("Content-Type", "application/json") - w.Write(versionCheckJSON) - return - } - - resp, err := client.Get(versionCheckURL) - if err != nil { - httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) - return - } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } - - // read the body entirely - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(body) - if err != nil { - httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) - } - - versionCheckLastTime = now - versionCheckJSON = body -} - // --------- // filtering // --------- @@ -1006,6 +970,7 @@ func registerControlHandlers() { http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(handleStatsHistory)))) http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(handleStatsReset)))) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) + http.HandleFunc("/control/update", postInstall(optionalAuth(ensurePOST(handleUpdate)))) http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable)))) http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable)))) http.HandleFunc("/control/filtering/add_url", postInstall(optionalAuth(ensurePOST(handleFilteringAddURL)))) diff --git a/control_update.go b/control_update.go new file mode 100644 index 00000000..247eb043 --- /dev/null +++ b/control_update.go @@ -0,0 +1,371 @@ +package main + +import ( + "archive/zip" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "syscall" + "time" + + "github.com/AdguardTeam/golibs/log" +) + +// Convert version.json data to our JSON response +func getVersionResp(data []byte) []byte { + versionJSON := make(map[string]interface{}) + err := json.Unmarshal(data, &versionJSON) + if err != nil { + log.Error("version.json: %s", err) + return []byte{} + } + + ret := make(map[string]interface{}) + ret["can_autoupdate"] = false + + var ok1, ok2, ok3 bool + ret["new_version"], ok1 = versionJSON["version"].(string) + ret["announcement"], ok2 = versionJSON["announcement"].(string) + ret["announcement_url"], ok3 = versionJSON["announcement_url"].(string) + if !ok1 || !ok2 || !ok3 { + log.Error("version.json: invalid data") + return []byte{} + } + + _, ok := versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)] + if ok && ret["new_version"] != VersionString { + ret["can_autoupdate"] = true + } + + d, _ := json.Marshal(ret) + return d +} + +// Get the latest available version from the Internet +func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { + log.Tracef("%s %v", r.Method, r.URL) + + now := time.Now() + controlLock.Lock() + cached := now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 + data := versionCheckJSON + controlLock.Unlock() + + if cached { + // return cached copy + w.Header().Set("Content-Type", "application/json") + w.Write(getVersionResp(data)) + return + } + + resp, err := client.Get(versionCheckURL) + if err != nil { + httpError(w, http.StatusBadGateway, "Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) + return + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + // read the body entirely + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err) + return + } + + controlLock.Lock() + versionCheckLastTime = now + versionCheckJSON = body + controlLock.Unlock() + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(getVersionResp(body)) + if err != nil { + httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) + } +} + +// Copy file on disk +func copyFile(src, dst string) error { + d, e := ioutil.ReadFile(src) + if e != nil { + return e + } + e = ioutil.WriteFile(dst, d, 0644) + if e != nil { + return e + } + return nil +} + +type updateInfo struct { + pkgURL string // URL for the new package + pkgName string // Full path to package file + newVer string // New version string + updateDir string // Full path to the directory containing unpacked files from the new package + backupDir string // Full path to backup directory + configName string // Full path to the current configuration file + updateConfigName string // Full path to the configuration file to check by the new binary + curBinName string // Full path to the current executable file + bkpBinName string // Full path to the current executable file in backup directory + newBinName string // Full path to the new executable file +} + +// Fill in updateInfo object +func getUpdateInfo(jsonData []byte) (*updateInfo, error) { + var u updateInfo + + workDir := config.ourWorkingDir + + versionJSON := make(map[string]interface{}) + err := json.Unmarshal(jsonData, &versionJSON) + if err != nil { + return nil, fmt.Errorf("JSON parse: %s", err) + } + + u.pkgURL = versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)].(string) + u.newVer = versionJSON["version"].(string) + if len(u.pkgURL) == 0 || len(u.newVer) == 0 { + return nil, fmt.Errorf("Invalid JSON") + } + + if u.newVer == VersionString { + return nil, fmt.Errorf("No need to update") + } + + _, pkgFileName := filepath.Split(u.pkgURL) + if len(pkgFileName) == 0 { + return nil, fmt.Errorf("Invalid JSON") + } + u.pkgName = filepath.Join(workDir, pkgFileName) + + u.updateDir = filepath.Join(workDir, fmt.Sprintf("update-%s", u.newVer)) + u.backupDir = filepath.Join(workDir, fmt.Sprintf("backup-%s", VersionString)) + u.configName = config.getConfigFilename() + u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome", "AdGuardHome.yaml") + if strings.HasSuffix(pkgFileName, ".zip") { + u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome.yaml") + } + + binName := "AdGuardHome" + if runtime.GOOS == "windows" { + binName = "AdGuardHome.exe" + } + u.curBinName = filepath.Join(workDir, binName) + u.bkpBinName = filepath.Join(u.backupDir, binName) + u.newBinName = filepath.Join(u.updateDir, "AdGuardHome", binName) + if strings.HasSuffix(pkgFileName, ".zip") { + u.newBinName = filepath.Join(u.updateDir, binName) + } + + return &u, nil +} + +// Unpack all files from .zip file to the specified directory +func zipFileUnpack(zipfile, outdir string) error { + r, err := zip.OpenReader(zipfile) + if err != nil { + return fmt.Errorf("zip.OpenReader(): %s", err) + } + defer r.Close() + + for _, zf := range r.File { + zr, err := zf.Open() + if err != nil { + return fmt.Errorf("zip file Open(): %s", err) + } + fi := zf.FileInfo() + fn := filepath.Join(outdir, fi.Name()) + + if fi.IsDir() { + err = os.Mkdir(fn, fi.Mode()) + if err != nil { + return fmt.Errorf("zip file Read(): %s", err) + } + continue + } + + f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) + if err != nil { + zr.Close() + return fmt.Errorf("os.OpenFile(): %s", err) + } + _, err = io.Copy(f, zr) + if err != nil { + zr.Close() + return fmt.Errorf("io.Copy(): %s", err) + } + zr.Close() + } + return nil +} + +// Unpack all files from .tar.gz file to the specified directory +func targzFileUnpack(tarfile, outdir string) error { + cmd := exec.Command("tar", "zxf", tarfile, "-C", outdir) + log.Tracef("Unpacking: %v", cmd.Args) + _, err := cmd.Output() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return fmt.Errorf("exec.Command() failed: %s", err) + } + return nil +} + +// Perform an update procedure +func doUpdate(u *updateInfo) error { + log.Info("Updating from %s to %s. URL:%s Package:%s", + VersionString, u.newVer, u.pkgURL, u.pkgName) + + resp, err := client.Get(u.pkgURL) + if err != nil { + return fmt.Errorf("HTTP request failed: %s", err) + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + log.Tracef("Reading HTTP body") + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("ioutil.ReadAll() failed: %s", err) + } + + log.Tracef("Saving package to file") + err = ioutil.WriteFile(u.pkgName, body, 0644) + if err != nil { + return fmt.Errorf("ioutil.WriteFile() failed: %s", err) + } + + log.Tracef("Unpacking the package") + _ = os.Mkdir(u.updateDir, 0755) + _, file := filepath.Split(u.pkgName) + if strings.HasSuffix(file, ".zip") { + err = zipFileUnpack(u.pkgName, u.updateDir) + if err != nil { + return fmt.Errorf("zipFileUnpack() failed: %s", err) + } + } else if strings.HasSuffix(file, ".tar.gz") { + err = targzFileUnpack(u.pkgName, u.updateDir) + if err != nil { + return fmt.Errorf("zipFileUnpack() failed: %s", err) + } + } else { + return fmt.Errorf("Unknown package extension") + } + + log.Tracef("Checking configuration") + err = copyFile(u.configName, u.updateConfigName) + if err != nil { + return fmt.Errorf("copyFile() failed: %s", err) + } + cmd := exec.Command(u.newBinName, "--check-config") + err = cmd.Run() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode()) + } + + log.Tracef("Backing up the current configuration") + _ = os.Mkdir(u.backupDir, 0755) + err = copyFile(u.configName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) + if err != nil { + return fmt.Errorf("copyFile() failed: %s", err) + } + + log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) + err = os.Rename(u.curBinName, u.bkpBinName) + if err != nil { + return err + } + if runtime.GOOS == "windows" { + // rename fails with "File in use" error + err = copyFile(u.newBinName, u.curBinName) + } else { + err = os.Rename(u.newBinName, u.curBinName) + } + if err != nil { + return err + } + log.Tracef("Renamed: %s -> %s", u.newBinName, u.curBinName) + + _ = os.Remove(u.pkgName) + // _ = os.RemoveAll(u.updateDir) + return nil +} + +// Complete an update procedure +func finishUpdate(u *updateInfo) { + log.Info("Stopping all tasks") + cleanup() + stopHTTPServer() + cleanupAlways() + + if runtime.GOOS == "windows" { + + if config.runningAsService { + // Note: + // we can't restart the service via "kardianos/service" package - it kills the process first + // we can't start a new instance - Windows doesn't allow it + cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome") + err := cmd.Start() + if err != nil { + log.Fatalf("exec.Command() failed: %s", err) + } + os.Exit(0) + } + + cmd := exec.Command(u.curBinName, os.Args[1:]...) + log.Info("Restarting: %v", cmd.Args) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + log.Fatalf("exec.Command() failed: %s", err) + } + os.Exit(0) + + } else { + + log.Info("Restarting: %v", os.Args) + err := syscall.Exec(u.curBinName, os.Args, os.Environ()) + if err != nil { + log.Fatalf("syscall.Exec() failed: %s", err) + } + // Unreachable code + } +} + +// Perform an update procedure to the latest available version +func handleUpdate(w http.ResponseWriter, r *http.Request) { + log.Tracef("%s %v", r.Method, r.URL) + + if len(versionCheckJSON) == 0 { + httpError(w, http.StatusBadRequest, "/update request isn't allowed now") + return + } + + u, err := getUpdateInfo(versionCheckJSON) + if err != nil { + httpError(w, http.StatusInternalServerError, "%s", err) + return + } + + err = doUpdate(u) + if err != nil { + httpError(w, http.StatusInternalServerError, "%s", err) + return + } + + returnOK(w) + + time.Sleep(time.Second) // wait (hopefully) until response is sent (not sure whether it's really necessary) + go finishUpdate(u) +} diff --git a/control_update_test.go b/control_update_test.go new file mode 100644 index 00000000..346b98f3 --- /dev/null +++ b/control_update_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "os" + "testing" +) + +func testDoUpdate(t *testing.T) { + config.DNS.Port = 0 + u := updateInfo{ + pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.95/AdGuardHome_v0.95_linux_amd64.tar.gz", + pkgName: "./AdGuardHome_v0.95_linux_amd64.tar.gz", + newVer: "v0.95", + updateDir: "./update-v0.95", + backupDir: "./backup-v0.94", + configName: "./AdGuardHome.yaml", + updateConfigName: "./update-v0.95/AdGuardHome/AdGuardHome.yaml", + curBinName: "./AdGuardHome", + bkpBinName: "./backup-v0.94/AdGuardHome", + newBinName: "./update-v0.95/AdGuardHome/AdGuardHome", + } + e := doUpdate(&u) + if e != nil { + t.Fatalf("FAILED: %s", e) + } + os.RemoveAll(u.backupDir) + os.RemoveAll(u.updateDir) +} + +func testZipFileUnpack(t *testing.T) { + fn := "./dist/AdGuardHome_v0.95_Windows_amd64.zip" + outdir := "./test-unpack" + _ = os.Mkdir(outdir, 0755) + e := zipFileUnpack(fn, outdir) + if e != nil { + t.Fatalf("FAILED: %s", e) + } + os.RemoveAll(outdir) +} diff --git a/helpers.go b/helpers.go index 33cf7a62..f94b2007 100644 --- a/helpers.go +++ b/helpers.go @@ -318,7 +318,7 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err Timeout: time.Minute * 5, } - if net.ParseIP(host) != nil { + if net.ParseIP(host) != nil || config.DNS.Port == 0 { con, err := dialer.DialContext(ctx, network, addr) return con, err } diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index e5692606..6fd9503f 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -151,6 +151,17 @@ paths: description: 'Cannot write answer' 502: description: 'Cannot retrieve the version.json file contents' + /update: + post: + tags: + - global + operationId: beginUpdate + summary: 'Begin auto-upgrade procedure' + responses: + 200: + description: OK + 500: + description: Failed # -------------------------------------------------- # Query log methods @@ -906,17 +917,8 @@ definitions: VersionInfo: type: "object" description: "Information about the latest available version of AdGuard Home" - required: - - "version" - - "announcement" - - "announcement_url" - - "download_darwin_amd64" - - "download_linux_amd64" - - "download_linux_386" - - "download_linux_arm" - - "selfupdate_min_version" properties: - version: + new_version: type: "string" example: "v0.9" announcement: @@ -925,21 +927,8 @@ definitions: announcement_url: type: "string" example: "https://github.com/AdguardTeam/AdGuardHome/releases/tag/v0.9" - download_darwin_amd64: - type: "string" - example: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.9/AdGuardHome_v0.9_MacOS.zip" - download_linux_amd64: - type: "string" - example: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.9/AdGuardHome_v0.9_linux_amd64.tar.gz" - download_linux_386: - type: "string" - example: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.9/AdGuardHome_v0.9_linux_386.tar.gz" - download_linux_arm: - type: "string" - example: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.9/AdGuardHome_v0.9_linux_arm.tar.gz" - selfupdate_min_version: - type: "string" - example: "v0.0" + can_autoupdate: + type: "boolean" Stats: type: "object" description: "General server stats for the last 24 hours" diff --git a/upgrade.go b/upgrade.go index e730d34b..c2de9f68 100644 --- a/upgrade.go +++ b/upgrade.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "io/ioutil" "os" "path/filepath" @@ -16,21 +15,15 @@ const currentSchemaVersion = 3 // used for upgrading from old configs to new con // Performs necessary upgrade operations if needed func upgradeConfig() error { // read a config file into an interface map, so we can manipulate values without losing any - configFile := config.getConfigFilename() - if _, err := os.Stat(configFile); os.IsNotExist(err) { - log.Printf("config file %s does not exist, nothing to upgrade", configFile) - return nil - } diskConfig := map[string]interface{}{} - body, err := ioutil.ReadFile(configFile) + body, err := readConfigFile() if err != nil { - log.Printf("Couldn't read config file '%s': %s", configFile, err) return err } err = yaml.Unmarshal(body, &diskConfig) if err != nil { - log.Printf("Couldn't parse config file '%s': %s", configFile, err) + log.Printf("Couldn't parse config file: %s", err) return err } @@ -87,6 +80,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err return err } + config.fileData = body err = file.SafeWrite(configFile, body) if err != nil { log.Printf("Couldn't save YAML config: %s", err)