mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-11-29 05:18:11 +03:00
Some more work on boilerplate
This commit is contained in:
parent
8d4b98ed30
commit
52dfcba527
2
.gitignore
vendored
2
.gitignore
vendored
@ -21,6 +21,8 @@ pip-delete-this-directory.txt
|
||||
.cache
|
||||
.pytest_cache/
|
||||
|
||||
# PoSpace plots
|
||||
**/*.dat
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
from typing import List
|
||||
from chiapos import DiskPlotter, DiskProver
|
||||
from blspy import PrivateKey
|
||||
from hashlib import sha256
|
||||
import secrets
|
||||
from blspy import PrivateKey
|
||||
from .util.api_decorators import api_request
|
||||
from chiapos import DiskPlotter, DiskProver
|
||||
from .types import plotter_api
|
||||
|
||||
|
||||
class Plotter:
|
||||
@ -12,7 +14,8 @@ class Plotter:
|
||||
def find_plotfiles(self, directories=[]) -> List[str]:
|
||||
return []
|
||||
|
||||
def create_plot(self, k: int, filename: str, pool_pk):
|
||||
@api_request(create_plot=plotter_api.CreatePlot.from_bin)
|
||||
def create_plot(self, create_plot: plotter_api.CreatePlot):
|
||||
# TODO: Check if we have enough disk space
|
||||
|
||||
# Uses python secure random number generation
|
||||
@ -23,20 +26,22 @@ class Plotter:
|
||||
|
||||
# TODO: store the private key and plot id on disk
|
||||
public_key_ser = private_key.get_public_key().serialize()
|
||||
plot_seed: bytes = sha256(pool_pk + public_key_ser).digest()
|
||||
plot_seed: bytes = sha256(create_plot.pool_pubkey.serialize() + public_key_ser).digest()
|
||||
plotter = DiskPlotter()
|
||||
plotter.create_plot_disk(filename, k, bytes([]), plot_seed)
|
||||
self.plots_[plot_seed] = (private_key, DiskProver(filename))
|
||||
plotter.create_plot_disk(create_plot.filename, create_plot.size, bytes([]), plot_seed)
|
||||
self.plots_[plot_seed] = (private_key, DiskProver(create_plot.filename))
|
||||
|
||||
def new_challenge(self, challenge_hash: bytes):
|
||||
@api_request(new_challenge=plotter_api.NewChallenge.from_bin)
|
||||
def new_challenge(self, new_challenge: plotter_api.NewChallenge):
|
||||
# TODO: Create an ID based on plot id and index
|
||||
all_qualities = []
|
||||
for _, (_, prover) in self.plots_.items():
|
||||
qualities = prover.get_qualities_for_challenge(challenge_hash)
|
||||
qualities = prover.get_qualities_for_challenge(new_challenge.challenge_hash)
|
||||
for index, quality in enumerate(qualities):
|
||||
all_qualities.append((index, quality))
|
||||
return all_qualities
|
||||
|
||||
def request_proof_of_space(self, challenge_hash: bytes, block_hash: bytes):
|
||||
@api_request(request=plotter_api.RequestProofOfSpace.from_bin)
|
||||
def request_proof_of_space(self, request: plotter_api.RequestProofOfSpace):
|
||||
# TODO: Lookup private key, plot id
|
||||
pass
|
||||
|
@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from protocol import ChiaProtocol
|
||||
from blspy import PrivateKey
|
||||
from src.protocol.protocol import ChiaProtocol
|
||||
from src.types.plotter_api import CreatePlot, NewChallenge
|
||||
|
||||
|
||||
async def main():
|
||||
# Get a reference to the event loop as we plan to use
|
||||
@ -13,10 +15,10 @@ async def main():
|
||||
lambda: ChiaProtocol(on_con_lost, loop, lambda x: x),
|
||||
'127.0.0.1', 8888)
|
||||
|
||||
ppk = PrivateKey.from_seed(b"123").get_public_key().serialize()
|
||||
protocol.send("create_plot", 16, "myplot1.dat", ppk)
|
||||
protocol.send("create_plot", 17, "myplot2.dat", ppk)
|
||||
protocol.send("new_challenge", bytes([2]*32))
|
||||
ppk = PrivateKey.from_seed(b"123").get_public_key()
|
||||
await protocol.send("create_plot", CreatePlot(16, ppk, b"myplot_1.dat"))
|
||||
# protocol.send("create_plot", CreatePlot(17, ppk, b"myplot_2.dat"))
|
||||
await protocol.send("new_challenge", NewChallenge(bytes([77]*32)))
|
||||
|
||||
# Wait until the protocol signals that the connection
|
||||
# is lost and close the transport.
|
||||
@ -26,4 +28,4 @@ async def main():
|
||||
transport.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
@ -4,6 +4,23 @@ from cbor2 import dumps, loads
|
||||
LENGTH_BYTES: int = 5
|
||||
|
||||
|
||||
def transform_to_streamable(d):
|
||||
"""
|
||||
Drill down through dictionaries and lists and transform objects with "as_bin" to bytes.
|
||||
"""
|
||||
print("?")
|
||||
if hasattr(d, "as_bin"):
|
||||
return d.as_bin()
|
||||
if isinstance(d, (str, bytes, int)):
|
||||
return d
|
||||
if isinstance(d, dict):
|
||||
new_d = {}
|
||||
for k, v in d.items():
|
||||
new_d[transform_to_streamable(k)] = transform_to_streamable(v)
|
||||
return new_d
|
||||
return [transform_to_streamable(_) for _ in d]
|
||||
|
||||
|
||||
class ChiaProtocol(asyncio.Protocol):
|
||||
def __init__(self, on_con_lost, loop, api):
|
||||
self.loop_ = loop
|
||||
@ -23,18 +40,17 @@ class ChiaProtocol(asyncio.Protocol):
|
||||
else:
|
||||
print(f'Connection lost to {peername} exception {exc}')
|
||||
|
||||
def send(self, function, data):
|
||||
encoded = dumps({"function": function, "data": data})
|
||||
self.transport_.write(len(encoded).to_bytes(5, "big") + encoded)
|
||||
async def send(self, function, data):
|
||||
encoded = dumps({"function": function, "data": transform_to_streamable(data)})
|
||||
await self.transport_.write(len(encoded).to_bytes(LENGTH_BYTES, "big") + encoded)
|
||||
|
||||
def data_received(self, data):
|
||||
peername = self.transport_.get_extra_info('peername')
|
||||
print(f'Received data: {data} from {peername}')
|
||||
if data is not None:
|
||||
self.message_ += data
|
||||
full_message_length = 0
|
||||
if len(self.message_) >= LENGTH_BYTES:
|
||||
ful_message_length = int.from_bytes(self.message_[:LENGTH_BYTES], "big")
|
||||
full_message_length = int.from_bytes(self.message_[:LENGTH_BYTES], "big")
|
||||
if len(self.message_) - LENGTH_BYTES < full_message_length:
|
||||
return
|
||||
else:
|
||||
|
@ -1,20 +1,45 @@
|
||||
import asyncio
|
||||
from protocol import ChiaProtocol
|
||||
import cbor2
|
||||
from src.plotter import Plotter
|
||||
|
||||
|
||||
async def main(api_cls):
|
||||
loop = asyncio.get_running_loop()
|
||||
LENGTH_BYTES: int = 5
|
||||
|
||||
on_con_lost = loop.create_future()
|
||||
class ChiaConnection:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
server = await loop.create_server(
|
||||
lambda: ChiaProtocol(on_con_lost, loop, api_cls()),
|
||||
'127.0.0.1', 8888, start_serving=False)
|
||||
|
||||
print(f'Starting {api_cls.__name__} server')
|
||||
async def new_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
peername = writer.get_extra_info('peername')
|
||||
print(f'Connection from {peername}')
|
||||
size = await reader.read(LENGTH_BYTES)
|
||||
full_message_length = int.from_bytes(size, "big")
|
||||
full_message = await reader.read(full_message_length)
|
||||
|
||||
decoded = cbor2.loads(full_message)
|
||||
function: str = decoded["function"]
|
||||
function_data: bytes = decoded["data"]
|
||||
|
||||
f = getattr(self.api_, function)
|
||||
if f is not None:
|
||||
print(f'Message of size {full_message_length}: {function}({function_data[:100]}) from {peername}')
|
||||
f(function_data)
|
||||
else:
|
||||
print(f'Invalid message: {function} from {peername}')
|
||||
|
||||
|
||||
async def main(api):
|
||||
|
||||
async def new_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
server = await asyncio.start_server(
|
||||
lambda x, y: ChiaConnection(x, y, api), '127.0.0.1', 8888)
|
||||
|
||||
addr = server.sockets[0].getsockname()
|
||||
print(f'Serving on {addr}')
|
||||
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
|
||||
asyncio.run(main(Plotter))
|
||||
asyncio.run(main(Plotter()))
|
||||
# # TODO: run other servers (farmer, full node, timelord)
|
10
src/types/challenge.py
Normal file
10
src/types/challenge.py
Normal file
@ -0,0 +1,10 @@
|
||||
from ..util.streamable import streamable
|
||||
from .sized_bytes import bytes32
|
||||
|
||||
|
||||
@streamable
|
||||
class Challenge:
|
||||
proof_of_time_output_hash: bytes32
|
||||
proof_of_space_hash: bytes32
|
||||
height: int
|
||||
total_weight: int
|
8
src/types/classgroup.py
Normal file
8
src/types/classgroup.py
Normal file
@ -0,0 +1,8 @@
|
||||
from ..util.streamable import streamable
|
||||
from ..util.ints import uint1024
|
||||
|
||||
|
||||
@streamable
|
||||
class ClassgroupElement:
|
||||
a: uint1024
|
||||
b: uint1024
|
51
src/types/plotter_api.py
Normal file
51
src/types/plotter_api.py
Normal file
@ -0,0 +1,51 @@
|
||||
from blspy import PublicKey, PrependSignature
|
||||
from ..util.streamable import streamable
|
||||
from .sized_bytes import bytes32
|
||||
from ..util.ints import uint8, uint32
|
||||
from .proof_of_space import ProofOfSpace
|
||||
|
||||
|
||||
@streamable
|
||||
class CreatePlot:
|
||||
size: uint8
|
||||
pool_pubkey: PublicKey
|
||||
filename: bytes
|
||||
|
||||
@classmethod
|
||||
def parse(cls, f):
|
||||
return cls(uint8.parse(f), PublicKey.from_bytes(f.read(PublicKey.PUBLIC_KEY_SIZE)), f.read())
|
||||
|
||||
|
||||
@streamable
|
||||
class PlotterHandshake:
|
||||
pool_pk: PublicKey
|
||||
|
||||
@classmethod
|
||||
def parse(cls, f):
|
||||
return cls(PublicKey.from_bytes(f.read(PublicKey.PUBLIC_KEY_SIZE)))
|
||||
|
||||
|
||||
@streamable
|
||||
class NewChallenge:
|
||||
challenge_hash: bytes32
|
||||
|
||||
|
||||
@streamable
|
||||
class ChallengeResponse:
|
||||
challenge_hash: bytes32
|
||||
response_id: uint32
|
||||
quality: bytes
|
||||
|
||||
|
||||
@streamable
|
||||
class RequestProofOfSpace:
|
||||
challenge_hash: bytes32
|
||||
block_hash: bytes32
|
||||
|
||||
|
||||
@streamable
|
||||
class ProofOfSpaceResponse:
|
||||
proof: ProofOfSpace
|
||||
block_hash: bytes32
|
||||
block_hash_signature: PrependSignature
|
||||
proof_of_possession: PrependSignature
|
18
src/types/proof_of_space.py
Normal file
18
src/types/proof_of_space.py
Normal file
@ -0,0 +1,18 @@
|
||||
from blspy import PublicKey
|
||||
from ..util.streamable import streamable
|
||||
from ..util.ints import uint8
|
||||
|
||||
|
||||
@streamable
|
||||
class ProofOfSpace:
|
||||
pool_pubkey: PublicKey
|
||||
plot_pubkey: PublicKey
|
||||
size: uint8
|
||||
proof: bytes
|
||||
|
||||
@classmethod
|
||||
def parse(cls, f):
|
||||
return cls(PublicKey.from_bytes(f.read(PublicKey.PUBLIC_KEY_SIZE)),
|
||||
PublicKey.from_bytes(f.read(PublicKey.PUBLIC_KEY_SIZE)),
|
||||
uint8.parse(f),
|
||||
f.read())
|
10
src/types/proof_of_time_output.py
Normal file
10
src/types/proof_of_time_output.py
Normal file
@ -0,0 +1,10 @@
|
||||
from ..util.streamable import streamable
|
||||
from .sized_bytes import bytes32
|
||||
from .classgroup import ClassgroupElement
|
||||
|
||||
|
||||
@streamable
|
||||
class ProofOfTimeOutput:
|
||||
challenge_hash: bytes32
|
||||
number_of_iterations: int
|
||||
output: ClassgroupElement
|
6
src/types/sized_bytes.py
Normal file
6
src/types/sized_bytes.py
Normal file
@ -0,0 +1,6 @@
|
||||
from ..util.byte_types import make_sized_bytes
|
||||
|
||||
|
||||
bytes32 = make_sized_bytes(32)
|
||||
bytes48 = make_sized_bytes(48)
|
||||
bytes96 = make_sized_bytes(96)
|
31
src/util/api_decorators.py
Normal file
31
src/util/api_decorators.py
Normal file
@ -0,0 +1,31 @@
|
||||
import functools
|
||||
from inspect import signature
|
||||
|
||||
|
||||
def transform_args(kwarg_transformers, message):
|
||||
if not isinstance(message, dict):
|
||||
return message
|
||||
new_message = dict(message)
|
||||
for k, v in kwarg_transformers.items():
|
||||
new_message[k] = v(message[k])
|
||||
return new_message
|
||||
|
||||
|
||||
def api_request(**kwarg_transformers):
|
||||
"""
|
||||
This decorator will transform the values for the given keywords by the corresponding
|
||||
function.
|
||||
@api_request(block=Block.from_blob)
|
||||
def accept_block(block):
|
||||
# do some stuff with block as Block rather than bytes
|
||||
"""
|
||||
def inner(f):
|
||||
@functools.wraps(f)
|
||||
def f_substitute(*args, **kwargs):
|
||||
sig = signature(f)
|
||||
binding = sig.bind(*args, **kwargs)
|
||||
binding.apply_defaults()
|
||||
inter = transform_args(kwarg_transformers, dict(binding.arguments))
|
||||
return f(**inter)
|
||||
return f_substitute
|
||||
return inner
|
18
src/util/bin_methods.py
Normal file
18
src/util/bin_methods.py
Normal file
@ -0,0 +1,18 @@
|
||||
import io
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class bin_methods:
|
||||
"""
|
||||
Create "from_bin" and "as_bin" methods in terms of "parse" and "stream" methods.
|
||||
"""
|
||||
@classmethod
|
||||
def from_bin(cls, blob: bytes) -> Any:
|
||||
f = io.BytesIO(blob)
|
||||
return cls.parse(f)
|
||||
|
||||
def as_bin(self) -> bytes:
|
||||
f = io.BytesIO()
|
||||
self.stream(f)
|
||||
return bytes(f.getvalue())
|
40
src/util/byte_types.py
Normal file
40
src/util/byte_types.py
Normal file
@ -0,0 +1,40 @@
|
||||
import binascii
|
||||
|
||||
from typing import Any, BinaryIO
|
||||
|
||||
from .bin_methods import bin_methods
|
||||
|
||||
|
||||
def make_sized_bytes(size):
|
||||
"""
|
||||
Create a streamable type that subclasses "hexbytes" but requires instances
|
||||
to be a certain, fixed size.
|
||||
"""
|
||||
name = "bytes%d" % size
|
||||
|
||||
def __new__(self, v):
|
||||
v = bytes(v)
|
||||
if not isinstance(v, bytes) or len(v) != size:
|
||||
raise ValueError("bad %s initializer %s" % (name, v))
|
||||
return bytes.__new__(self, v)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, f: BinaryIO) -> Any:
|
||||
b = f.read(size)
|
||||
assert len(b) == size
|
||||
return cls(b)
|
||||
|
||||
def stream(self, f):
|
||||
f.write(self)
|
||||
|
||||
def __str__(self):
|
||||
return binascii.hexlify(self).decode("utf8")
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s>" % (self.__class__.__name__, str(self))
|
||||
|
||||
namespace = dict(__new__=__new__, parse=parse, stream=stream, __str__=__str__, __repr__=__repr__)
|
||||
|
||||
cls = type(name, (bytes, bin_methods), namespace)
|
||||
|
||||
return cls
|
43
src/util/ints.py
Normal file
43
src/util/ints.py
Normal file
@ -0,0 +1,43 @@
|
||||
from .struct_stream import struct_stream
|
||||
from typing import Any, BinaryIO
|
||||
|
||||
|
||||
class int8(int, struct_stream):
|
||||
PACK = "!b"
|
||||
|
||||
|
||||
class uint8(int, struct_stream):
|
||||
PACK = "!B"
|
||||
|
||||
|
||||
class int16(int, struct_stream):
|
||||
PACK = "!h"
|
||||
|
||||
|
||||
class uint16(int, struct_stream):
|
||||
PACK = "!H"
|
||||
|
||||
|
||||
class int32(int, struct_stream):
|
||||
PACK = "!l"
|
||||
|
||||
|
||||
class uint32(int, struct_stream):
|
||||
PACK = "!L"
|
||||
|
||||
|
||||
class int64(int, struct_stream):
|
||||
PACK = "!q"
|
||||
|
||||
|
||||
class uint64(int, struct_stream):
|
||||
PACK = "!Q"
|
||||
|
||||
|
||||
class uint1024(int):
|
||||
@classmethod
|
||||
def parse(cls, f: BinaryIO) -> Any:
|
||||
return cls(int.from_bytes(f.read(1024), "big"))
|
||||
|
||||
def stream(self, f):
|
||||
f.write(self.to_bytes(1024, "big"))
|
55
src/util/streamable.py
Normal file
55
src/util/streamable.py
Normal file
@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
|
||||
from typing import Type, BinaryIO, get_type_hints
|
||||
|
||||
from .bin_methods import bin_methods
|
||||
|
||||
|
||||
def streamable(cls):
|
||||
"""
|
||||
This is a decorator for class definitions. It applies the dataclasses.dataclass
|
||||
decorator, and also allows fields to be cast to their expected type. The resulting
|
||||
class also gets parse and stream for free, as long as all its constituent elements
|
||||
have it.
|
||||
"""
|
||||
|
||||
class _local:
|
||||
def __init__(self, *args):
|
||||
fields = get_type_hints(self)
|
||||
la, lf = len(args), len(fields)
|
||||
if la != lf:
|
||||
raise ValueError("got %d and expected %d args" % (la, lf))
|
||||
for a, (f_name, f_type) in zip(args, fields.items()):
|
||||
if not isinstance(a, f_type):
|
||||
a = f_type(a)
|
||||
if not isinstance(a, f_type):
|
||||
raise ValueError("wrong type for %s" % f_name)
|
||||
object.__setattr__(self, f_name, a)
|
||||
|
||||
@classmethod
|
||||
def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__:
|
||||
values = []
|
||||
for f_name, f_type in get_type_hints(cls).items():
|
||||
if hasattr(f_type, "parse"):
|
||||
values.append(f_type.parse(f))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return cls(*values)
|
||||
|
||||
def stream(self, f: BinaryIO) -> None:
|
||||
for f_name, f_type in get_type_hints(self).items():
|
||||
v = getattr(self, f_name)
|
||||
if hasattr(f_type, "stream"):
|
||||
v.stream(f)
|
||||
elif hasattr(f_type, "serialize"):
|
||||
to_write = v.serialize()
|
||||
f.write(to_write)
|
||||
elif isinstance(v, bytes):
|
||||
f.write(v)
|
||||
else:
|
||||
raise NotImplementedError("can't stream %s: %s" % (v, f_name))
|
||||
|
||||
cls1 = dataclasses.dataclass(_cls=cls, frozen=True, init=False)
|
||||
|
||||
cls2 = type(cls.__name__, (cls1, bin_methods, _local), {})
|
||||
return cls2
|
17
src/util/struct_stream.py
Normal file
17
src/util/struct_stream.py
Normal file
@ -0,0 +1,17 @@
|
||||
import struct
|
||||
|
||||
from typing import Any, BinaryIO
|
||||
|
||||
from .bin_methods import bin_methods
|
||||
|
||||
|
||||
class struct_stream(bin_methods):
|
||||
"""
|
||||
Create a class that can parse and stream itself based on a struct.pack template string.
|
||||
"""
|
||||
@classmethod
|
||||
def parse(cls, f: BinaryIO) -> Any:
|
||||
return cls(*struct.unpack(cls.PACK, f.read(struct.calcsize(cls.PACK))))
|
||||
|
||||
def stream(self, f):
|
||||
f.write(struct.pack(self.PACK, self))
|
Loading…
Reference in New Issue
Block a user