# Portions Copyright (c) Facebook, Inc. and its affiliates. # # This software may be used and distributed according to the terms of the # GNU General Public License version 2. # sshpeer.py - ssh repository proxy class for mercurial # # Copyright 2005, 2006 Matt Mackall # # This software may be used and distributed according to the terms of the # GNU General Public License version 2 or any later version. from __future__ import absolute_import import re import threading from typing import Any from . import error, progress, pycompat, util, wireproto from .i18n import _ from .pycompat import decodeutf8, encodeutf8 # Record of the bytes sent and received to SSH peers. This records the # cumulative total bytes sent to all peers for the life of the process. _totalbytessent = 0 _totalbytesreceived = 0 def _serverquote(s): if not s: return s """quote a string for the remote shell ... which we assume is sh""" if re.match("[a-zA-Z0-9@%_+=:,./-]*$", s): return s return "'%s'" % s.replace("'", "'\\''") def _writessherror(ui, s): # type: (Any, bytes) -> None if s and not ui.quiet: for l in s.splitlines(): if l.startswith(b"ssh:"): prefix = "" else: prefix = _("remote: ") ui.write_err(prefix, decodeutf8(l, errors="replace"), "\n") class countingpipe(object): """Wraps a pipe that count the number of bytes read/written to it """ def __init__(self, ui, pipe): self._ui = ui self._pipe = pipe self._totalbytes = 0 def write(self, data): assert isinstance(data, bytes) self._totalbytes += len(data) self._ui.metrics.gauge("ssh_write_bytes", len(data)) return self._pipe.write(data) def read(self, size): # type: (int) -> bytes r = self._pipe.read(size) bufs = [r] # In Python 3 _pipe is a FileIO and is not guaranteed to return size # bytes. So let's loop until we get the bytes, or we get 0 bytes, # indicating the end of the pipe. if len(r) < size: totalread = len(r) while totalread < size and len(r) != 0: r = self._pipe.read(size - totalread) totalread += len(r) bufs.append(r) r = b"".join(bufs) self._totalbytes += len(r) self._ui.metrics.gauge("ssh_read_bytes", len(r)) return r def readline(self): r = self._pipe.readline() self._totalbytes += len(r) self._ui.metrics.gauge("ssh_read_bytes", len(r)) return r def close(self): return self._pipe.close() def flush(self): return self._pipe.flush() class threadedstderr(object): def __init__(self, ui, stderr): self._ui = ui self._stderr = stderr self._thread = None def start(self): # type: () -> None thread = threading.Thread(target=self.run) thread.daemon = True thread.start() self._thread = thread def run(self): # type: () -> None while True: try: buf = self._stderr.readline() except (Exception, KeyboardInterrupt): # Not fatal. Treat it as if the stderr stream has ended. break if len(buf) == 0: break _writessherror(self._ui, buf) # Close the pipe. It's likely already closed on the other end. # Note: during readline(), close() will raise an IOError. So there is # no "close" method that can be used by the main thread. self._stderr.close() def join(self, timeout): if self._thread: self._thread.join(timeout) class sshpeer(wireproto.wirepeer): def __init__(self, ui, path, create=False): self._url = path self._ui = ui self._pipeo = self._pipei = self._pipee = None u = util.url(path, parsequery=False, parsefragment=False) if u.scheme != "ssh" or not u.host or u.path is None: self._abort(error.RepoError(_("couldn't parse location %s") % path)) util.checksafessh(path) if u.passwd is not None: self._abort(error.RepoError(_("password in URL not supported"))) self._user = u.user self._host = u.host self._port = u.port self._path = u.path or "." sshcmd = self.ui.config("ui", "ssh") remotecmd = self.ui.config("ui", "remotecmd") sshaddenv = dict(self.ui.configitems("sshenv")) sshenv = util.shellenviron(sshaddenv) args = util.sshargs(sshcmd, self._host, self._user, self._port) if create: cmd = "%s %s %s" % ( sshcmd, args, util.shellquote( "%s init %s" % (_serverquote(remotecmd), _serverquote(self._path)) ), ) ui.debug("running %s\n" % cmd) res = ui.system(cmd, blockedtag="sshpeer", environ=sshenv) if res != 0: self._abort(error.RepoError(_("could not create remote repo"))) with self.ui.timeblockedsection("sshsetup"), progress.suspend(), util.traced( "ssh_setup", cat="blocked" ): self._validaterepo(sshcmd, args, remotecmd, sshenv) # Begin of _basepeer interface. @util.propertycache def ui(self): return self._ui def url(self): return self._url def local(self): return None def peer(self): return self def canpush(self): return True def close(self): self._cleanup() # End of _basepeer interface. # Begin of _basewirecommands interface. def capabilities(self): return self._caps # End of _basewirecommands interface. def _validaterepo(self, sshcmd, args, remotecmd, sshenv=None): # cleanup up previous run self._cleanup() cmd = "%s %s %s" % ( sshcmd, args, util.shellquote( "%s -R %s serve --stdio" % (_serverquote(remotecmd), _serverquote(self._path)) ), ) self.ui.debug("running %s\n" % cmd) cmd = util.quotecommand(cmd) # while self._subprocess isn't used, having it allows the subprocess to # to clean up correctly later sub = util.popen4(cmd, bufsize=0, env=sshenv) pipeo, pipei, pipee, self._subprocess = sub self._pipee = threadedstderr(self.ui, pipee) self._pipee.start() self._pipei = countingpipe(self.ui, pipei) self._pipeo = countingpipe(self.ui, pipeo) self.ui.metrics.gauge("ssh_connections") def badresponse(errortext): msg = _("no suitable response from remote hg") if errortext: msg += ": '%s'" % errortext hint = self.ui.config("ui", "ssherrorhint") self._abort(error.BadResponseError(msg, hint=hint)) timer = None try: def timeout(): self.ui.warn( _("timed out establishing the ssh connection, killing ssh\n") ) self._subprocess.kill() sshsetuptimeout = self.ui.configint("ui", "sshsetuptimeout") if sshsetuptimeout: timer = threading.Timer(sshsetuptimeout, timeout) timer.start() try: # skip any noise generated by remote shell self._callstream("hello") r = self._callstream("between", pairs=("%s-%s" % ("0" * 40, "0" * 40))) except IOError as ex: badresponse(str(ex)) lines = ["", "dummy"] max_noise = 500 while lines[-1] and max_noise: try: l = decodeutf8(r.readline()) if lines[-1] == "1\n" and l == "\n": break if l: self.ui.debug("remote: ", l) lines.append(l) max_noise -= 1 except IOError as ex: badresponse(str(ex)) else: badresponse("".join(lines[2:])) finally: if timer: timer.cancel() self._caps = set() for l in reversed(lines): if l.startswith("capabilities:"): self._caps.update(l[:-1].split(":")[1].split()) break def _abort(self, exception): self._cleanup() raise exception def _cleanup(self): global _totalbytessent, _totalbytesreceived if self._pipeo is None: return # Close the pipe connecting to the stdin of the remote ssh process. # This means if the remote process tries to read its stdin, it will get # an empty buffer that indicates EOF. The remote process should then # exit, which will close its stdout and stderr so the background stderr # reader thread will notice that it reaches EOF and becomes joinable. self._pipeo.close() _totalbytessent += self._pipeo._totalbytes # Clear the pipe to indicate this has already been cleaned up. self._pipeo = None # Wait for the stderr thread to complete reading all stderr text from # the remote ssh process (i.e. hitting EOF). # # This must be after pipeo.close(). Otherwise the remote process might # still wait for stdin and does not close its stderr. # # This is better before pipei.close(). Otherwise the remote process # might nondeterministically get EPIPE when writing to its stdout, # which can trigger different code paths nondeterministically that # might affect stderr. In other words, moving this after pipei.close() # can potentially increase test flakiness. if util.istest(): # In the test environment, we control all remote processes. They # are expected to exit after getting EOF from stdin. Wait # indefinitely to make sure all stderr messages are received. # # If this line hangs forever, that indicates a bug in the remote # process, not here. self._pipee.join(None) else: # In real world environment, remote processes might mis-behave. # Therefore be inpatient on waiting. self._pipee.join(1) # Close the pipe connected to the stdout of the remote process. # The remote end of the pipe is likely already closed since we waited # the pipee thread. If not, the remote process will get EPIPE or # SIGPIPE if it writes a bit more to its stdout. self._pipei.close() _totalbytesreceived += self._pipei._totalbytes self.ui.log( "sshbytes", "", sshbytessent=_totalbytessent, sshbytesreceived=_totalbytesreceived, ) __del__ = _cleanup def _submitbatch(self, req): rsp = self._callstream("batch", cmds=wireproto.encodebatchcmds(req)) available = self._getamount() # TODO this response parsing is probably suboptimal for large # batches with large responses. toread = min(available, 1024) work = rsp.read(toread) available -= toread chunk = work while chunk: while b";" in work: one, work = work.split(b";", 1) yield wireproto.unescapebytearg(one) toread = min(available, 1024) chunk = rsp.read(toread) available -= toread work += chunk yield wireproto.unescapebytearg(work) def _callstream(self, cmd, **args): args = args self.ui.debug("sending %s command\n" % cmd) self._pipeo.write(encodeutf8("%s\n" % cmd)) _func, names = wireproto.commands[cmd] keys = names.split() wireargs = {} for k in keys: if k == "*": wireargs["*"] = args break else: wireargs[k] = args[k] del args[k] for k, v in sorted(pycompat.iteritems(wireargs)): k = encodeutf8(k) if isinstance(v, str): v = encodeutf8(v) self._pipeo.write(b"%s %d\n" % (k, len(v))) if isinstance(v, dict): for dk, dv in pycompat.iteritems(v): if isinstance(dk, str): dk = encodeutf8(dk) if isinstance(dv, str): dv = encodeutf8(dv) self._pipeo.write(b"%s %d\n" % (dk, len(dv))) self._pipeo.write(dv) else: self._pipeo.write(v) self._pipeo.flush() return self._pipei def _callcompressable(self, cmd, **args): return self._callstream(cmd, **args) def _call(self, cmd, **args): self._callstream(cmd, **args) return self._recv() def _callpush(self, cmd, fp, **args): r = self._call(cmd, **args) if r: return b"", r for d in iter(lambda: fp.read(4096), b""): self._send(d) self._send(b"", flush=True) r = self._recv() if r: return b"", r return self._recv(), b"" def _calltwowaystream(self, cmd, fp, **args): r = self._call(cmd, **args) if r: # XXX needs to be made better raise error.Abort(_("unexpected remote reply: %s") % r) for d in iter(lambda: fp.read(4096), b""): self._send(d) self._send(b"", flush=True) return self._pipei def _getamount(self): l = self._pipei.readline() if l == "\n": msg = _("check previous remote output") self._abort(error.OutOfBandError(hint=msg)) try: return int(l) except ValueError: self._abort(error.ResponseError(_("unexpected response:"), l)) def _recv(self): return self._pipei.read(self._getamount()) def _send(self, data, flush=False): self._pipeo.write(b"%d\n" % len(data)) if data: self._pipeo.write(data) if flush: self._pipeo.flush() instance = sshpeer