sapling/mercurial/sshserver.py

145 lines
4.0 KiB
Python
Raw Normal View History

2006-06-04 21:32:13 +04:00
# sshserver.py - ssh protocol server support for mercurial
2006-06-04 21:26:05 +04:00
#
# Copyright 2005-2007 Matt Mackall <mpm@selenic.com>
2006-08-12 23:30:02 +04:00
# Copyright 2006 Vadim Gelfer <vadim.gelfer@gmail.com>
2006-06-04 21:26:05 +04:00
#
# This software may be used and distributed according to the terms of the
2010-01-20 07:20:08 +03:00
# GNU General Public License version 2 or any later version.
2006-06-04 21:26:05 +04:00
import util, hook, wireproto, changegroup
2010-07-16 00:06:45 +04:00
import os, sys
2006-06-04 21:26:05 +04:00
class sshserver(object):
def __init__(self, ui, repo):
self.ui = ui
self.repo = repo
self.lock = None
self.fin = ui.fin
self.fout = ui.fout
2006-06-04 21:26:05 +04:00
hook.redirect(True)
ui.fout = repo.ui.fout = ui.ferr
2006-06-04 21:26:05 +04:00
# Prevent insertion/deletion of CRs
2011-05-06 17:25:35 +04:00
util.setbinary(self.fin)
util.setbinary(self.fout)
2006-06-04 21:26:05 +04:00
def getargs(self, args):
data = {}
keys = args.split()
for n in xrange(len(keys)):
argline = self.fin.readline()[:-1]
arg, l = argline.split()
if arg not in keys:
raise util.Abort("unexpected parameter %r" % arg)
if arg == '*':
star = {}
for k in xrange(int(l)):
argline = self.fin.readline()[:-1]
arg, l = argline.split()
val = self.fin.read(int(l))
star[arg] = val
data['*'] = star
else:
val = self.fin.read(int(l))
data[arg] = val
return [data[k] for k in keys]
def getarg(self, name):
return self.getargs(name)[0]
2006-06-04 21:26:05 +04:00
def getfile(self, fpout):
self.sendresponse('')
count = int(self.fin.readline())
while count:
fpout.write(self.fin.read(count))
count = int(self.fin.readline())
def redirect(self):
pass
def groupchunks(self, changegroup):
while True:
d = changegroup.read(4096)
if not d:
break
yield d
def sendresponse(self, v):
self.fout.write("%d\n" % len(v))
self.fout.write(v)
self.fout.flush()
2010-07-15 01:19:27 +04:00
def sendstream(self, source):
for chunk in source.gen:
2010-07-15 01:19:27 +04:00
self.fout.write(chunk)
self.fout.flush()
def sendpushresponse(self, rsp):
self.sendresponse('')
self.sendresponse(str(rsp.res))
def sendpusherror(self, rsp):
self.sendresponse(rsp.res)
2006-06-04 21:26:05 +04:00
def serve_forever(self):
try:
2010-01-25 09:05:27 +03:00
while self.serve_one():
pass
finally:
if self.lock is not None:
self.lock.release()
2006-06-04 21:26:05 +04:00
sys.exit(0)
handlers = {
str: sendresponse,
wireproto.streamres: sendstream,
wireproto.pushres: sendpushresponse,
wireproto.pusherr: sendpusherror,
}
2006-06-04 21:26:05 +04:00
def serve_one(self):
cmd = self.fin.readline()[:-1]
if cmd and cmd in wireproto.commands:
rsp = wireproto.dispatch(self.repo, self, cmd)
self.handlers[rsp.__class__](self, rsp)
elif cmd:
2006-06-04 21:26:05 +04:00
impl = getattr(self, 'do_' + cmd, None)
2010-01-25 09:05:27 +03:00
if impl:
r = impl()
if r is not None:
self.sendresponse(r)
else: self.sendresponse("")
2006-06-04 21:26:05 +04:00
return cmd != ''
def do_lock(self):
'''DEPRECATED - allowing remote client to lock repo is not safe'''
2006-06-04 21:26:05 +04:00
self.lock = self.repo.lock()
return ""
2006-06-04 21:26:05 +04:00
def do_unlock(self):
'''DEPRECATED'''
2006-06-04 21:26:05 +04:00
if self.lock:
self.lock.release()
self.lock = None
return ""
2006-06-04 21:26:05 +04:00
def do_addchangegroup(self):
'''DEPRECATED'''
2006-06-04 21:26:05 +04:00
if not self.lock:
self.sendresponse("not locked")
2006-06-04 21:26:05 +04:00
return
self.sendresponse("")
cg = changegroup.unbundle10(self.fin, "UN")
r = self.repo.addchangegroup(cg, 'serve', self._client(),
lock=self.lock)
return str(r)
def _client(self):
client = os.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
return 'remote:ssh:' + client