Refactor CLN stubs to allow for mocking

This commit is contained in:
Reckless_Satoshi 2023-11-08 14:56:36 +00:00 committed by Reckless_Satoshi
parent 68b1186277
commit bece7c7d4e

View File

@ -48,10 +48,6 @@ class CLNNode:
hold_channel = grpc.secure_channel(CLN_GRPC_HOLD_HOST, creds)
node_channel = grpc.secure_channel(CLN_GRPC_HOST, creds)
# Create the gRPC stub
hstub = hold_pb2_grpc.HoldStub(hold_channel)
nstub = node_pb2_grpc.NodeStub(node_channel)
payment_failure_context = {
-1: "Catchall nonspecific error.",
201: "Already paid with this hash using different amount or destination.",
@ -65,9 +61,9 @@ class CLNNode:
@classmethod
def get_version(cls):
try:
nstub = node_pb2_grpc.NodeStub(cls.node_channel)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
request = node_pb2.GetinfoRequest()
response = nstub.Getinfo(request)
response = nodestub.Getinfo(request)
return response.version
except Exception as e:
print(f"Cannot get CLN version: {e}")
@ -77,8 +73,8 @@ class CLNNode:
def decode_payreq(cls, invoice):
"""Decodes a lightning payment request (invoice)"""
request = hold_pb2.DecodeBolt11Request(bolt11=invoice)
response = cls.hstub.DecodeBolt11(request)
holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.DecodeBolt11(request)
return response
@classmethod
@ -86,8 +82,8 @@ class CLNNode:
"""Returns estimated fee for onchain payouts"""
# feerate estimaes work a bit differently in cln see https://lightning.readthedocs.io/lightning-feerates.7.html
request = node_pb2.FeeratesRequest(style="PERKB")
response = cls.nstub.Feerates(request)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.Feerates(request)
# "opening" -> ~12 block target
return {
@ -102,8 +98,8 @@ class CLNNode:
def wallet_balance(cls):
"""Returns onchain balance"""
request = node_pb2.ListfundsRequest()
response = cls.nstub.ListFunds(request)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.ListFunds(request)
unconfirmed_balance = 0
confirmed_balance = 0
@ -136,8 +132,8 @@ class CLNNode:
def channel_balance(cls):
"""Returns channels balance"""
request = node_pb2.ListpeerchannelsRequest()
response = cls.nstub.ListPeerChannels(request)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.ListPeerChannels(request)
local_balance_sat = 0
remote_balance_sat = 0
@ -199,7 +195,8 @@ class CLNNode:
# Changing the state to "MEMPO" should be atomic with SendCoins.
onchainpayment.status = on_mempool_code
onchainpayment.save(update_fields=["status"])
response = cls.nstub.Withdraw(request)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.Withdraw(request)
if response.txid:
onchainpayment.txid = response.txid.hex()
@ -217,7 +214,8 @@ class CLNNode:
request = hold_pb2.HoldInvoiceCancelRequest(
payment_hash=bytes.fromhex(payment_hash)
)
response = cls.hstub.HoldInvoiceCancel(request)
holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceCancel(request)
return response.state == hold_pb2.HoldInvoiceCancelResponse.Holdstate.CANCELED
@ -227,7 +225,8 @@ class CLNNode:
request = hold_pb2.HoldInvoiceSettleRequest(
payment_hash=hashlib.sha256(bytes.fromhex(preimage)).digest()
)
response = cls.hstub.HoldInvoiceSettle(request)
holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceSettle(request)
return response.state == hold_pb2.HoldInvoiceSettleResponse.Holdstate.SETTLED
@ -260,7 +259,8 @@ class CLNNode:
cltv=cltv_expiry_blocks,
preimage=preimage, # preimage is actually optional in cln, as cln would generate one by default
)
response = cls.hstub.HoldInvoice(request)
holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoice(request)
hold_payment["invoice"] = response.bolt11
payreq_decoded = cls.decode_payreq(hold_payment["invoice"])
@ -284,7 +284,8 @@ class CLNNode:
request = hold_pb2.HoldInvoiceLookupRequest(
payment_hash=bytes.fromhex(lnpayment.payment_hash)
)
response = cls.hstub.HoldInvoiceLookup(request)
holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceLookup(request)
# Will fail if 'unable to locate invoice'. Happens if invoice expiry
# time has passed (but these are 15% padded at the moment). Should catch it
@ -324,7 +325,8 @@ class CLNNode:
request = hold_pb2.HoldInvoiceLookupRequest(
payment_hash=bytes.fromhex(lnpayment.payment_hash)
)
response = cls.hstub.HoldInvoiceLookup(request)
holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceLookup(request)
status = cln_response_state_to_lnpayment_status[response.state]
@ -345,7 +347,8 @@ class CLNNode:
payment_hash=bytes.fromhex(lnpayment.payment_hash)
)
try:
response2 = cls.nstub.ListInvoices(request2).invoices
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response2 = nodestub.ListInvoices(request2).invoices
except Exception as e:
print(str(e))
@ -482,7 +485,8 @@ class CLNNode:
)
try:
response = cls.nstub.Pay(request)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.Pay(request)
if response.status == node_pb2.PayResponse.PayStatus.COMPLETE:
lnpayment.status = LNPayment.Status.SUCCED
@ -540,7 +544,8 @@ class CLNNode:
)
while True:
try:
response_listpays = cls.nstub.ListPays(request_listpays)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response_listpays = nodestub.ListPays(request_listpays)
except Exception as e:
print(str(e))
time.sleep(2)
@ -562,8 +567,8 @@ class CLNNode:
lnpayment.save(update_fields=["in_flight", "status"])
order.update_status(Order.Status.PAY)
response = cls.nstub.Pay(request)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.Pay(request)
if response.status == node_pb2.PayResponse.PayStatus.PENDING:
print(f"Order: {order.id} IN_FLIGHT. Hash {hash}")
@ -758,9 +763,11 @@ class CLNNode:
)
)
if sign:
self_pubkey = cls.nstub.Getinfo(node_pb2.GetinfoRequest()).id
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
self_pubkey = nodestub.Getinfo(node_pb2.GetinfoRequest()).id
timestamp = struct.pack(">i", int(time.time()))
signature = cls.nstub.SignMessage(
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
signature = nodestub.SignMessage(
node_pb2.SignmessageRequest(
message=(
bytes.fromhex(self_pubkey)
@ -791,7 +798,8 @@ class CLNNode:
retry_for=timeout,
amount_msat=primitives__pb2.Amount(msat=num_satoshis * 1000),
)
response = cls.nstub.KeySend(request)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
response = nodestub.KeySend(request)
keysend_payment["preimage"] = response.payment_preimage.hex()
keysend_payment["payment_hash"] = response.payment_hash.hex()
@ -800,7 +808,8 @@ class CLNNode:
payment_hash=response.payment_hash, timeout=timeout
)
try:
waitresp = cls.nstub.WaitSendPay(waitreq)
nodestub = node_pb2_grpc.NodeStub(cls.node_channel)
waitresp = nodestub.WaitSendPay(waitreq)
keysend_payment["fee"] = (
float(waitresp.amount_sent_msat.msat - waitresp.amount_msat.msat)
/ 1000
@ -833,7 +842,8 @@ class CLNNode:
payment_hash=bytes.fromhex(payment_hash)
)
try:
response = cls.hstub.HoldInvoiceLookup(request)
holdstub = hold_pb2_grpc.HoldStub(cls.hold_channel)
response = holdstub.HoldInvoiceLookup(request)
except Exception as e:
if "Timed out" in str(e):
return False