wireproto: perform chunking and compression at protocol layer (API)

Currently, the "streamres" response type is populated with a generator
of chunks with compression possibly already applied. This puts the onus
on commands to perform chunking and compression. Architecturally, I
think this is the wrong place to perform this work. I think commands
should say "here is the data" and the protocol layer should take care
of encoding the final bytes to put on the wire.

Additionally, upcoming commits will improve wire protocol support for
compression. Having a central place for performing compression in the
protocol transport layer will be easier than having to deal with
compression at the commands layer.

This commit refactors the "streamres" response type to accept either
a generator or an object with "read." Additionally, the type now
accepts a flag indicating whether the response is a "version 1
compressible" response. This basically identifies all commands
currently performing compression. I could have used a special type
for this, but a flag works just as well. The argument name
foreshadows the introduction of wire protocol changes, hence the "v1."

The code for chunking and compressing has been moved to the output
generation function for each protocol transport. Some code has been
inlined, resulting in the deletion of now unused methods.
This commit is contained in:
Gregory Szorc 2016-11-20 13:50:45 -08:00
parent 037fb7f520
commit 2112fb0fd2
4 changed files with 34 additions and 43 deletions

View File

@ -76,7 +76,7 @@ def getlfile(repo, proto, sha):
yield '%d\n' % length
for chunk in util.filechunkiter(f):
yield chunk
return wireproto.streamres(generator())
return wireproto.streamres(gen=generator())
def statlfile(repo, proto, sha):
'''Server command for checking if a largefile is present - returns '2\n' if

View File

@ -73,16 +73,6 @@ class webproto(wireproto.abstractserverproto):
self.ui.ferr, self.ui.fout = self.oldio
return val
def groupchunks(self, fh):
def getchunks():
while True:
chunk = fh.read(32768)
if not chunk:
break
yield chunk
return self.compresschunks(getchunks())
def compresschunks(self, chunks):
# Don't allow untrusted settings because disabling compression or
# setting a very high compression level could lead to flooding
@ -106,8 +96,16 @@ def call(repo, req, cmd):
req.respond(HTTP_OK, HGTYPE, body=rsp)
return []
elif isinstance(rsp, wireproto.streamres):
if rsp.reader:
gen = iter(lambda: rsp.reader.read(32768), '')
else:
gen = rsp.gen
if rsp.v1compressible:
gen = p.compresschunks(gen)
req.respond(HTTP_OK, HGTYPE)
return rsp.gen
return gen
elif isinstance(rsp, wireproto.pushres):
val = p.restore()
rsp = '%d\n%s' % (rsp.res, val)

View File

@ -68,13 +68,6 @@ class sshserver(wireproto.abstractserverproto):
def redirect(self):
pass
def groupchunks(self, fh):
return iter(lambda: fh.read(4096), '')
def compresschunks(self, chunks):
for chunk in chunks:
yield chunk
def sendresponse(self, v):
self.fout.write("%d\n" % len(v))
self.fout.write(v)
@ -82,7 +75,13 @@ class sshserver(wireproto.abstractserverproto):
def sendstream(self, source):
write = self.fout.write
for chunk in source.gen:
if source.reader:
gen = iter(lambda: source.reader.read(4096), '')
else:
gen = source.gen
for chunk in gen:
write(chunk)
self.fout.flush()

View File

@ -78,21 +78,6 @@ class abstractserverproto(object):
# """
# raise NotImplementedError()
def groupchunks(self, fh):
"""Generator of chunks to send to the client.
Some protocols may have compressed the contents.
"""
raise NotImplementedError()
def compresschunks(self, chunks):
"""Generator of possible compressed chunks to send to the client.
This is like ``groupchunks()`` except it accepts a generator as
its argument.
"""
raise NotImplementedError()
class remotebatch(peer.batcher):
'''batches the queued calls; uses as few roundtrips as possible'''
def __init__(self, remote):
@ -529,10 +514,19 @@ class streamres(object):
"""wireproto reply: binary stream
The call was successful and the result is a stream.
Iterate on the `self.gen` attribute to retrieve chunks.
Accepts either a generator or an object with a ``read(size)`` method.
``v1compressible`` indicates whether this data can be compressed to
"version 1" clients (technically: HTTP peers using
application/mercurial-0.1 media type). This flag should NOT be used on
new commands because new clients should support a more modern compression
mechanism.
"""
def __init__(self, gen):
def __init__(self, gen=None, reader=None, v1compressible=False):
self.gen = gen
self.reader = reader
self.v1compressible = v1compressible
class pushres(object):
"""wireproto reply: success with simple integer return
@ -739,14 +733,14 @@ def capabilities(repo, proto):
def changegroup(repo, proto, roots):
nodes = decodelist(roots)
cg = changegroupmod.changegroup(repo, nodes, 'serve')
return streamres(proto.groupchunks(cg))
return streamres(reader=cg, v1compressible=True)
@wireprotocommand('changegroupsubset', 'bases heads')
def changegroupsubset(repo, proto, bases, heads):
bases = decodelist(bases)
heads = decodelist(heads)
cg = changegroupmod.changegroupsubset(repo, bases, heads, 'serve')
return streamres(proto.groupchunks(cg))
return streamres(reader=cg, v1compressible=True)
@wireprotocommand('debugwireargs', 'one two *')
def debugwireargs(repo, proto, one, two, others):
@ -781,7 +775,7 @@ def getbundle(repo, proto, others):
return ooberror(bundle2required)
chunks = exchange.getbundlechunks(repo, 'serve', **opts)
return streamres(proto.compresschunks(chunks))
return streamres(gen=chunks, v1compressible=True)
@wireprotocommand('heads')
def heads(repo, proto):
@ -870,7 +864,7 @@ def stream(repo, proto):
# LockError may be raised before the first result is yielded. Don't
# emit output until we're sure we got the lock successfully.
it = streamclone.generatev1wireproto(repo)
return streamres(getstream(it))
return streamres(gen=getstream(it))
except error.LockError:
return '2\n'
@ -900,7 +894,7 @@ def unbundle(repo, proto, heads):
if util.safehasattr(r, 'addpart'):
# The return looks streamable, we are in the bundle2 case and
# should return a stream.
return streamres(r.getchunks())
return streamres(gen=r.getchunks())
return pushres(r)
finally:
@ -962,4 +956,4 @@ def unbundle(repo, proto, heads):
manargs, advargs))
except error.PushRaced as exc:
bundler.newpart('error:pushraced', [('message', str(exc))])
return streamres(bundler.getchunks())
return streamres(gen=bundler.getchunks())