Some more work on boilerplate

This commit is contained in:
Mariano Sorgente 2019-07-22 22:01:23 +09:00
parent 8d4b98ed30
commit 52dfcba527
17 changed files with 386 additions and 29 deletions

2
.gitignore vendored
View File

@ -21,6 +21,8 @@ pip-delete-this-directory.txt
.cache
.pytest_cache/
# PoSpace plots
**/*.dat
*.mo
*.pot

View File

@ -1,8 +1,10 @@
from typing import List
from chiapos import DiskPlotter, DiskProver
from blspy import PrivateKey
from hashlib import sha256
import secrets
from blspy import PrivateKey
from .util.api_decorators import api_request
from chiapos import DiskPlotter, DiskProver
from .types import plotter_api
class Plotter:
@ -12,7 +14,8 @@ class Plotter:
def find_plotfiles(self, directories=[]) -> List[str]:
return []
def create_plot(self, k: int, filename: str, pool_pk):
@api_request(create_plot=plotter_api.CreatePlot.from_bin)
def create_plot(self, create_plot: plotter_api.CreatePlot):
# TODO: Check if we have enough disk space
# Uses python secure random number generation
@ -23,20 +26,22 @@ class Plotter:
# TODO: store the private key and plot id on disk
public_key_ser = private_key.get_public_key().serialize()
plot_seed: bytes = sha256(pool_pk + public_key_ser).digest()
plot_seed: bytes = sha256(create_plot.pool_pubkey.serialize() + public_key_ser).digest()
plotter = DiskPlotter()
plotter.create_plot_disk(filename, k, bytes([]), plot_seed)
self.plots_[plot_seed] = (private_key, DiskProver(filename))
plotter.create_plot_disk(create_plot.filename, create_plot.size, bytes([]), plot_seed)
self.plots_[plot_seed] = (private_key, DiskProver(create_plot.filename))
def new_challenge(self, challenge_hash: bytes):
@api_request(new_challenge=plotter_api.NewChallenge.from_bin)
def new_challenge(self, new_challenge: plotter_api.NewChallenge):
# TODO: Create an ID based on plot id and index
all_qualities = []
for _, (_, prover) in self.plots_.items():
qualities = prover.get_qualities_for_challenge(challenge_hash)
qualities = prover.get_qualities_for_challenge(new_challenge.challenge_hash)
for index, quality in enumerate(qualities):
all_qualities.append((index, quality))
return all_qualities
def request_proof_of_space(self, challenge_hash: bytes, block_hash: bytes):
@api_request(request=plotter_api.RequestProofOfSpace.from_bin)
def request_proof_of_space(self, request: plotter_api.RequestProofOfSpace):
# TODO: Lookup private key, plot id
pass

View File

@ -1,6 +1,8 @@
import asyncio
from protocol import ChiaProtocol
from blspy import PrivateKey
from src.protocol.protocol import ChiaProtocol
from src.types.plotter_api import CreatePlot, NewChallenge
async def main():
# Get a reference to the event loop as we plan to use
@ -13,10 +15,10 @@ async def main():
lambda: ChiaProtocol(on_con_lost, loop, lambda x: x),
'127.0.0.1', 8888)
ppk = PrivateKey.from_seed(b"123").get_public_key().serialize()
protocol.send("create_plot", 16, "myplot1.dat", ppk)
protocol.send("create_plot", 17, "myplot2.dat", ppk)
protocol.send("new_challenge", bytes([2]*32))
ppk = PrivateKey.from_seed(b"123").get_public_key()
await protocol.send("create_plot", CreatePlot(16, ppk, b"myplot_1.dat"))
# protocol.send("create_plot", CreatePlot(17, ppk, b"myplot_2.dat"))
await protocol.send("new_challenge", NewChallenge(bytes([77]*32)))
# Wait until the protocol signals that the connection
# is lost and close the transport.
@ -26,4 +28,4 @@ async def main():
transport.close()
asyncio.run(main())
asyncio.run(main())

View File

@ -4,6 +4,23 @@ from cbor2 import dumps, loads
LENGTH_BYTES: int = 5
def transform_to_streamable(d):
"""
Drill down through dictionaries and lists and transform objects with "as_bin" to bytes.
"""
print("?")
if hasattr(d, "as_bin"):
return d.as_bin()
if isinstance(d, (str, bytes, int)):
return d
if isinstance(d, dict):
new_d = {}
for k, v in d.items():
new_d[transform_to_streamable(k)] = transform_to_streamable(v)
return new_d
return [transform_to_streamable(_) for _ in d]
class ChiaProtocol(asyncio.Protocol):
def __init__(self, on_con_lost, loop, api):
self.loop_ = loop
@ -23,18 +40,17 @@ class ChiaProtocol(asyncio.Protocol):
else:
print(f'Connection lost to {peername} exception {exc}')
def send(self, function, data):
encoded = dumps({"function": function, "data": data})
self.transport_.write(len(encoded).to_bytes(5, "big") + encoded)
async def send(self, function, data):
encoded = dumps({"function": function, "data": transform_to_streamable(data)})
await self.transport_.write(len(encoded).to_bytes(LENGTH_BYTES, "big") + encoded)
def data_received(self, data):
peername = self.transport_.get_extra_info('peername')
print(f'Received data: {data} from {peername}')
if data is not None:
self.message_ += data
full_message_length = 0
if len(self.message_) >= LENGTH_BYTES:
ful_message_length = int.from_bytes(self.message_[:LENGTH_BYTES], "big")
full_message_length = int.from_bytes(self.message_[:LENGTH_BYTES], "big")
if len(self.message_) - LENGTH_BYTES < full_message_length:
return
else:

View File

@ -1,20 +1,45 @@
import asyncio
from protocol import ChiaProtocol
import cbor2
from src.plotter import Plotter
async def main(api_cls):
loop = asyncio.get_running_loop()
LENGTH_BYTES: int = 5
on_con_lost = loop.create_future()
class ChiaConnection:
def __init__(self):
pass
server = await loop.create_server(
lambda: ChiaProtocol(on_con_lost, loop, api_cls()),
'127.0.0.1', 8888, start_serving=False)
print(f'Starting {api_cls.__name__} server')
async def new_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
peername = writer.get_extra_info('peername')
print(f'Connection from {peername}')
size = await reader.read(LENGTH_BYTES)
full_message_length = int.from_bytes(size, "big")
full_message = await reader.read(full_message_length)
decoded = cbor2.loads(full_message)
function: str = decoded["function"]
function_data: bytes = decoded["data"]
f = getattr(self.api_, function)
if f is not None:
print(f'Message of size {full_message_length}: {function}({function_data[:100]}) from {peername}')
f(function_data)
else:
print(f'Invalid message: {function} from {peername}')
async def main(api):
async def new_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
server = await asyncio.start_server(
lambda x, y: ChiaConnection(x, y, api), '127.0.0.1', 8888)
addr = server.sockets[0].getsockname()
print(f'Serving on {addr}')
async with server:
await server.serve_forever()
asyncio.run(main(Plotter))
asyncio.run(main(Plotter()))
# # TODO: run other servers (farmer, full node, timelord)

10
src/types/challenge.py Normal file
View File

@ -0,0 +1,10 @@
from ..util.streamable import streamable
from .sized_bytes import bytes32
@streamable
class Challenge:
proof_of_time_output_hash: bytes32
proof_of_space_hash: bytes32
height: int
total_weight: int

8
src/types/classgroup.py Normal file
View File

@ -0,0 +1,8 @@
from ..util.streamable import streamable
from ..util.ints import uint1024
@streamable
class ClassgroupElement:
a: uint1024
b: uint1024

51
src/types/plotter_api.py Normal file
View File

@ -0,0 +1,51 @@
from blspy import PublicKey, PrependSignature
from ..util.streamable import streamable
from .sized_bytes import bytes32
from ..util.ints import uint8, uint32
from .proof_of_space import ProofOfSpace
@streamable
class CreatePlot:
size: uint8
pool_pubkey: PublicKey
filename: bytes
@classmethod
def parse(cls, f):
return cls(uint8.parse(f), PublicKey.from_bytes(f.read(PublicKey.PUBLIC_KEY_SIZE)), f.read())
@streamable
class PlotterHandshake:
pool_pk: PublicKey
@classmethod
def parse(cls, f):
return cls(PublicKey.from_bytes(f.read(PublicKey.PUBLIC_KEY_SIZE)))
@streamable
class NewChallenge:
challenge_hash: bytes32
@streamable
class ChallengeResponse:
challenge_hash: bytes32
response_id: uint32
quality: bytes
@streamable
class RequestProofOfSpace:
challenge_hash: bytes32
block_hash: bytes32
@streamable
class ProofOfSpaceResponse:
proof: ProofOfSpace
block_hash: bytes32
block_hash_signature: PrependSignature
proof_of_possession: PrependSignature

View File

@ -0,0 +1,18 @@
from blspy import PublicKey
from ..util.streamable import streamable
from ..util.ints import uint8
@streamable
class ProofOfSpace:
pool_pubkey: PublicKey
plot_pubkey: PublicKey
size: uint8
proof: bytes
@classmethod
def parse(cls, f):
return cls(PublicKey.from_bytes(f.read(PublicKey.PUBLIC_KEY_SIZE)),
PublicKey.from_bytes(f.read(PublicKey.PUBLIC_KEY_SIZE)),
uint8.parse(f),
f.read())

View File

@ -0,0 +1,10 @@
from ..util.streamable import streamable
from .sized_bytes import bytes32
from .classgroup import ClassgroupElement
@streamable
class ProofOfTimeOutput:
challenge_hash: bytes32
number_of_iterations: int
output: ClassgroupElement

6
src/types/sized_bytes.py Normal file
View File

@ -0,0 +1,6 @@
from ..util.byte_types import make_sized_bytes
bytes32 = make_sized_bytes(32)
bytes48 = make_sized_bytes(48)
bytes96 = make_sized_bytes(96)

View File

@ -0,0 +1,31 @@
import functools
from inspect import signature
def transform_args(kwarg_transformers, message):
if not isinstance(message, dict):
return message
new_message = dict(message)
for k, v in kwarg_transformers.items():
new_message[k] = v(message[k])
return new_message
def api_request(**kwarg_transformers):
"""
This decorator will transform the values for the given keywords by the corresponding
function.
@api_request(block=Block.from_blob)
def accept_block(block):
# do some stuff with block as Block rather than bytes
"""
def inner(f):
@functools.wraps(f)
def f_substitute(*args, **kwargs):
sig = signature(f)
binding = sig.bind(*args, **kwargs)
binding.apply_defaults()
inter = transform_args(kwarg_transformers, dict(binding.arguments))
return f(**inter)
return f_substitute
return inner

18
src/util/bin_methods.py Normal file
View File

@ -0,0 +1,18 @@
import io
from typing import Any
class bin_methods:
"""
Create "from_bin" and "as_bin" methods in terms of "parse" and "stream" methods.
"""
@classmethod
def from_bin(cls, blob: bytes) -> Any:
f = io.BytesIO(blob)
return cls.parse(f)
def as_bin(self) -> bytes:
f = io.BytesIO()
self.stream(f)
return bytes(f.getvalue())

40
src/util/byte_types.py Normal file
View File

@ -0,0 +1,40 @@
import binascii
from typing import Any, BinaryIO
from .bin_methods import bin_methods
def make_sized_bytes(size):
"""
Create a streamable type that subclasses "hexbytes" but requires instances
to be a certain, fixed size.
"""
name = "bytes%d" % size
def __new__(self, v):
v = bytes(v)
if not isinstance(v, bytes) or len(v) != size:
raise ValueError("bad %s initializer %s" % (name, v))
return bytes.__new__(self, v)
@classmethod
def parse(cls, f: BinaryIO) -> Any:
b = f.read(size)
assert len(b) == size
return cls(b)
def stream(self, f):
f.write(self)
def __str__(self):
return binascii.hexlify(self).decode("utf8")
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, str(self))
namespace = dict(__new__=__new__, parse=parse, stream=stream, __str__=__str__, __repr__=__repr__)
cls = type(name, (bytes, bin_methods), namespace)
return cls

43
src/util/ints.py Normal file
View File

@ -0,0 +1,43 @@
from .struct_stream import struct_stream
from typing import Any, BinaryIO
class int8(int, struct_stream):
PACK = "!b"
class uint8(int, struct_stream):
PACK = "!B"
class int16(int, struct_stream):
PACK = "!h"
class uint16(int, struct_stream):
PACK = "!H"
class int32(int, struct_stream):
PACK = "!l"
class uint32(int, struct_stream):
PACK = "!L"
class int64(int, struct_stream):
PACK = "!q"
class uint64(int, struct_stream):
PACK = "!Q"
class uint1024(int):
@classmethod
def parse(cls, f: BinaryIO) -> Any:
return cls(int.from_bytes(f.read(1024), "big"))
def stream(self, f):
f.write(self.to_bytes(1024, "big"))

55
src/util/streamable.py Normal file
View File

@ -0,0 +1,55 @@
import dataclasses
from typing import Type, BinaryIO, get_type_hints
from .bin_methods import bin_methods
def streamable(cls):
"""
This is a decorator for class definitions. It applies the dataclasses.dataclass
decorator, and also allows fields to be cast to their expected type. The resulting
class also gets parse and stream for free, as long as all its constituent elements
have it.
"""
class _local:
def __init__(self, *args):
fields = get_type_hints(self)
la, lf = len(args), len(fields)
if la != lf:
raise ValueError("got %d and expected %d args" % (la, lf))
for a, (f_name, f_type) in zip(args, fields.items()):
if not isinstance(a, f_type):
a = f_type(a)
if not isinstance(a, f_type):
raise ValueError("wrong type for %s" % f_name)
object.__setattr__(self, f_name, a)
@classmethod
def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__:
values = []
for f_name, f_type in get_type_hints(cls).items():
if hasattr(f_type, "parse"):
values.append(f_type.parse(f))
else:
raise NotImplementedError
return cls(*values)
def stream(self, f: BinaryIO) -> None:
for f_name, f_type in get_type_hints(self).items():
v = getattr(self, f_name)
if hasattr(f_type, "stream"):
v.stream(f)
elif hasattr(f_type, "serialize"):
to_write = v.serialize()
f.write(to_write)
elif isinstance(v, bytes):
f.write(v)
else:
raise NotImplementedError("can't stream %s: %s" % (v, f_name))
cls1 = dataclasses.dataclass(_cls=cls, frozen=True, init=False)
cls2 = type(cls.__name__, (cls1, bin_methods, _local), {})
return cls2

17
src/util/struct_stream.py Normal file
View File

@ -0,0 +1,17 @@
import struct
from typing import Any, BinaryIO
from .bin_methods import bin_methods
class struct_stream(bin_methods):
"""
Create a class that can parse and stream itself based on a struct.pack template string.
"""
@classmethod
def parse(cls, f: BinaryIO) -> Any:
return cls(*struct.unpack(cls.PACK, f.read(struct.calcsize(cls.PACK))))
def stream(self, f):
f.write(struct.pack(self.PACK, self))