From 4efc59d41608fb77b52f14ce9f068f294d836b85 Mon Sep 17 00:00:00 2001 From: Reckless_Satoshi Date: Wed, 8 Nov 2023 14:25:34 +0000 Subject: [PATCH] Refactor gRPC and mocks. Add coordinator info test. --- api/lightning/cln.py | 105 ++++++++-------- api/lightning/lnd.py | 215 ++++++++++++++++++--------------- api/tests/test_utils.py | 31 ++--- tests/mocks/cln.py | 58 +++++++++ tests/mocks/lnd.py | 29 +++-- tests/test_coordinator_info.py | 61 ++++++++++ tests/test_trade_pipeline.py | 13 +- 7 files changed, 330 insertions(+), 182 deletions(-) create mode 100644 tests/mocks/cln.py create mode 100644 tests/test_coordinator_info.py diff --git a/api/lightning/cln.py b/api/lightning/cln.py index 9d97d405..cb92dd15 100755 --- a/api/lightning/cln.py +++ b/api/lightning/cln.py @@ -10,10 +10,7 @@ import ring from decouple import config from django.utils import timezone -from . import hold_pb2 as holdrpc -from . import hold_pb2_grpc as holdstub -from . import node_pb2 as noderpc -from . import node_pb2_grpc as nodestub +from . import hold_pb2, hold_pb2_grpc, node_pb2, node_pb2_grpc from . import primitives_pb2 as primitives__pb2 ####### @@ -52,11 +49,8 @@ class CLNNode: node_channel = grpc.secure_channel(CLN_GRPC_HOST, creds) # Create the gRPC stub - hstub = holdstub.HoldStub(hold_channel) - nstub = nodestub.NodeStub(node_channel) - - holdrpc = holdrpc - noderpc = noderpc + hstub = hold_pb2_grpc.HoldStub(hold_channel) + nstub = node_pb2_grpc.NodeStub(node_channel) payment_failure_context = { -1: "Catchall nonspecific error.", @@ -71,19 +65,18 @@ class CLNNode: @classmethod def get_version(cls): try: - request = noderpc.GetinfoRequest() - print(request) - response = cls.nstub.Getinfo(request) - print(response) + nstub = node_pb2_grpc.NodeStub(cls.node_channel) + request = node_pb2.GetinfoRequest() + response = nstub.Getinfo(request) return response.version except Exception as e: - print(e) + print(f"Cannot get CLN version: {e}") return None @classmethod def decode_payreq(cls, invoice): """Decodes a lightning payment request (invoice)""" - request = holdrpc.DecodeBolt11Request(bolt11=invoice) + request = hold_pb2.DecodeBolt11Request(bolt11=invoice) response = cls.hstub.DecodeBolt11(request) return response @@ -92,7 +85,7 @@ class CLNNode: def estimate_fee(cls, amount_sats, target_conf=2, min_confs=1): """Returns estimated fee for onchain payouts""" # feerate estimaes work a bit differently in cln see https://lightning.readthedocs.io/lightning-feerates.7.html - request = noderpc.FeeratesRequest(style="PERKB") + request = node_pb2.FeeratesRequest(style="PERKB") response = cls.nstub.Feerates(request) @@ -108,7 +101,7 @@ class CLNNode: @classmethod def wallet_balance(cls): """Returns onchain balance""" - request = noderpc.ListfundsRequest() + request = node_pb2.ListfundsRequest() response = cls.nstub.ListFunds(request) @@ -119,13 +112,13 @@ class CLNNode: if not utxo.reserved: if ( utxo.status - == noderpc.ListfundsOutputs.ListfundsOutputsStatus.UNCONFIRMED + == node_pb2.ListfundsOutputs.ListfundsOutputsStatus.UNCONFIRMED ): unconfirmed_balance += utxo.amount_msat.msat // 1_000 total_balance += utxo.amount_msat.msat // 1_000 elif ( utxo.status - == noderpc.ListfundsOutputs.ListfundsOutputsStatus.CONFIRMED + == node_pb2.ListfundsOutputs.ListfundsOutputsStatus.CONFIRMED ): confirmed_balance += utxo.amount_msat.msat // 1_000 total_balance += utxo.amount_msat.msat // 1_000 @@ -142,7 +135,7 @@ class CLNNode: @classmethod def channel_balance(cls): """Returns channels balance""" - request = noderpc.ListpeerchannelsRequest() + request = node_pb2.ListpeerchannelsRequest() response = cls.nstub.ListPeerChannels(request) @@ -153,7 +146,7 @@ class CLNNode: for channel in response.channels: if ( channel.state - == noderpc.ListpeerchannelsChannels.ListpeerchannelsChannelsState.CHANNELD_NORMAL + == node_pb2.ListpeerchannelsChannels.ListpeerchannelsChannelsState.CHANNELD_NORMAL ): local_balance_sat += channel.to_us_msat.msat // 1_000 remote_balance_sat += ( @@ -162,12 +155,12 @@ class CLNNode: for htlc in channel.htlcs: if ( htlc.direction - == noderpc.ListpeerchannelsChannelsHtlcs.ListpeerchannelsChannelsHtlcsDirection.IN + == node_pb2.ListpeerchannelsChannelsHtlcs.ListpeerchannelsChannelsHtlcsDirection.IN ): unsettled_local_balance += htlc.amount_msat.msat // 1_000 elif ( htlc.direction - == noderpc.ListpeerchannelsChannelsHtlcs.ListpeerchannelsChannelsHtlcsDirection.OUT + == node_pb2.ListpeerchannelsChannelsHtlcs.ListpeerchannelsChannelsHtlcsDirection.OUT ): unsettled_remote_balance += htlc.amount_msat.msat // 1_000 @@ -185,7 +178,7 @@ class CLNNode: if DISABLE_ONCHAIN or onchainpayment.sent_satoshis > MAX_SWAP_AMOUNT: return False - request = noderpc.WithdrawRequest( + request = node_pb2.WithdrawRequest( destination=onchainpayment.address, satoshi=primitives__pb2.AmountOrAll( amount=primitives__pb2.Amount(msat=onchainpayment.sent_satoshis * 1_000) @@ -221,22 +214,22 @@ class CLNNode: @classmethod def cancel_return_hold_invoice(cls, payment_hash): """Cancels or returns a hold invoice""" - request = holdrpc.HoldInvoiceCancelRequest( + request = hold_pb2.HoldInvoiceCancelRequest( payment_hash=bytes.fromhex(payment_hash) ) response = cls.hstub.HoldInvoiceCancel(request) - return response.state == holdrpc.HoldInvoiceCancelResponse.Holdstate.CANCELED + return response.state == hold_pb2.HoldInvoiceCancelResponse.Holdstate.CANCELED @classmethod def settle_hold_invoice(cls, preimage): """settles a hold invoice""" - request = holdrpc.HoldInvoiceSettleRequest( + request = hold_pb2.HoldInvoiceSettleRequest( payment_hash=hashlib.sha256(bytes.fromhex(preimage)).digest() ) response = cls.hstub.HoldInvoiceSettle(request) - return response.state == holdrpc.HoldInvoiceSettleResponse.Holdstate.SETTLED + return response.state == hold_pb2.HoldInvoiceSettleResponse.Holdstate.SETTLED @classmethod def gen_hold_invoice( @@ -259,7 +252,7 @@ class CLNNode: # The preimage is a random hash of 256 bits entropy preimage = hashlib.sha256(secrets.token_bytes(nbytes=32)).digest() - request = holdrpc.HoldInvoiceRequest( + request = hold_pb2.HoldInvoiceRequest( description=description, amount_msat=primitives__pb2.Amount(msat=num_satoshis * 1_000), label=f"Order:{order_id}-{lnpayment_concept}-{time}", @@ -288,7 +281,7 @@ class CLNNode: """Checks if hold invoice is locked""" from api.models import LNPayment - request = holdrpc.HoldInvoiceLookupRequest( + request = hold_pb2.HoldInvoiceLookupRequest( payment_hash=bytes.fromhex(lnpayment.payment_hash) ) response = cls.hstub.HoldInvoiceLookup(request) @@ -296,13 +289,13 @@ class CLNNode: # Will fail if 'unable to locate invoice'. Happens if invoice expiry # time has passed (but these are 15% padded at the moment). Should catch it # and report back that the invoice has expired (better robustness) - if response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.OPEN: + if response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.OPEN: pass - if response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.SETTLED: + if response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.SETTLED: pass - if response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.CANCELED: + if response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.CANCELED: pass - if response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.ACCEPTED: + if response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.ACCEPTED: lnpayment.expiry_height = response.htlc_expiry lnpayment.status = LNPayment.Status.LOCKED lnpayment.save(update_fields=["expiry_height", "status"]) @@ -328,7 +321,7 @@ class CLNNode: try: # this is similar to LNNnode.validate_hold_invoice_locked - request = holdrpc.HoldInvoiceLookupRequest( + request = hold_pb2.HoldInvoiceLookupRequest( payment_hash=bytes.fromhex(lnpayment.payment_hash) ) response = cls.hstub.HoldInvoiceLookup(request) @@ -348,7 +341,7 @@ class CLNNode: # (cln-grpc-hodl has separate state for hodl-invoices, which it forgets after an invoice expired more than an hour ago) if "empty result for listdatastore_state" in str(e): print(str(e)) - request2 = noderpc.ListinvoicesRequest( + request2 = node_pb2.ListinvoicesRequest( payment_hash=bytes.fromhex(lnpayment.payment_hash) ) try: @@ -358,12 +351,12 @@ class CLNNode: if ( response2[0].status - == noderpc.ListinvoicesInvoices.ListinvoicesInvoicesStatus.PAID + == node_pb2.ListinvoicesInvoices.ListinvoicesInvoicesStatus.PAID ): status = LNPayment.Status.SETLED elif ( response2[0].status - == noderpc.ListinvoicesInvoices.ListinvoicesInvoicesStatus.EXPIRED + == node_pb2.ListinvoicesInvoices.ListinvoicesInvoicesStatus.EXPIRED ): status = LNPayment.Status.CANCEL else: @@ -482,7 +475,7 @@ class CLNNode: ) ) # 200 ppm or 10 sats timeout_seconds = int(config("REWARDS_TIMEOUT_SECONDS")) - request = noderpc.PayRequest( + request = node_pb2.PayRequest( bolt11=lnpayment.invoice, maxfee=primitives__pb2.Amount(msat=fee_limit_sat * 1_000), retry_for=timeout_seconds, @@ -491,7 +484,7 @@ class CLNNode: try: response = cls.nstub.Pay(request) - if response.status == noderpc.PayResponse.PayStatus.COMPLETE: + if response.status == node_pb2.PayResponse.PayStatus.COMPLETE: lnpayment.status = LNPayment.Status.SUCCED lnpayment.fee = ( float(response.amount_sent_msat.msat - response.amount_msat.msat) @@ -500,13 +493,13 @@ class CLNNode: lnpayment.preimage = response.payment_preimage.hex() lnpayment.save(update_fields=["fee", "status", "preimage"]) return True, None - elif response.status == noderpc.PayResponse.PayStatus.PENDING: + elif response.status == node_pb2.PayResponse.PayStatus.PENDING: failure_reason = "Payment isn't failed (yet)" lnpayment.failure_reason = LNPayment.FailureReason.NOTYETF lnpayment.status = LNPayment.Status.FLIGHT lnpayment.save(update_fields=["failure_reason", "status"]) return False, failure_reason - else: # response.status == noderpc.PayResponse.PayStatus.FAILED + else: # response.status == node_pb2.PayResponse.PayStatus.FAILED failure_reason = "All possible routes were tried and failed permanently. Or were no routes to the destination at all." lnpayment.failure_reason = LNPayment.FailureReason.NOROUTE lnpayment.status = LNPayment.Status.FAILRO @@ -530,7 +523,7 @@ class CLNNode: # retry_for is not quite the same as a timeout. Pay can still take SIGNIFICANTLY longer to return if htlcs are stuck! # allow_self_payment=True, No such thing in pay command and self_payments do not work with pay! - request = noderpc.PayRequest( + request = node_pb2.PayRequest( bolt11=lnpayment.invoice, maxfee=primitives__pb2.Amount(msat=fee_limit_sat * 1_000), retry_for=timeout_seconds, @@ -542,7 +535,9 @@ class CLNNode: return def watchpayment(): - request_listpays = noderpc.ListpaysRequest(payment_hash=bytes.fromhex(hash)) + request_listpays = node_pb2.ListpaysRequest( + payment_hash=bytes.fromhex(hash) + ) while True: try: response_listpays = cls.nstub.ListPays(request_listpays) @@ -554,7 +549,7 @@ class CLNNode: if ( len(response_listpays.pays) == 0 or response_listpays.pays[0].status - != noderpc.ListpaysPays.ListpaysPaysStatus.PENDING + != node_pb2.ListpaysPays.ListpaysPaysStatus.PENDING ): return response_listpays else: @@ -570,14 +565,14 @@ class CLNNode: response = cls.nstub.Pay(request) - if response.status == noderpc.PayResponse.PayStatus.PENDING: + if response.status == node_pb2.PayResponse.PayStatus.PENDING: print(f"Order: {order.id} IN_FLIGHT. Hash {hash}") watchpayment() handle_response() - if response.status == noderpc.PayResponse.PayStatus.FAILED: + if response.status == node_pb2.PayResponse.PayStatus.FAILED: lnpayment.status = LNPayment.Status.FAILRO lnpayment.last_routing_time = timezone.now() lnpayment.routing_attempts += 1 @@ -614,7 +609,7 @@ class CLNNode: "context": f"payment failure reason: {cls.payment_failure_context[-1]}", } - if response.status == noderpc.PayResponse.PayStatus.COMPLETE: + if response.status == node_pb2.PayResponse.PayStatus.COMPLETE: print(f"Order: {order.id} SUCCEEDED. Hash: {hash}") lnpayment.status = LNPayment.Status.SUCCED lnpayment.fee = ( @@ -702,7 +697,7 @@ class CLNNode: if ( len(last_payresponse.pays) > 0 and last_payresponse.pays[0].status - == noderpc.ListpaysPays.ListpaysPaysStatus.COMPLETE + == node_pb2.ListpaysPays.ListpaysPaysStatus.COMPLETE ): handle_response() else: @@ -763,10 +758,10 @@ class CLNNode: ) ) if sign: - self_pubkey = cls.nstub.GetInfo(noderpc.GetinfoRequest()).id + self_pubkey = cls.nstub.Getinfo(node_pb2.GetinfoRequest()).id timestamp = struct.pack(">i", int(time.time())) signature = cls.nstub.SignMessage( - noderpc.SignmessageRequest( + node_pb2.SignmessageRequest( message=( bytes.fromhex(self_pubkey) + bytes.fromhex(target_pubkey) @@ -789,7 +784,7 @@ class CLNNode: # no maxfee for Keysend maxfeepercent = (routing_budget_sats / num_satoshis) * 100 - request = noderpc.KeysendRequest( + request = node_pb2.KeysendRequest( destination=bytes.fromhex(target_pubkey), extratlvs=primitives__pb2.TlvStream(entries=custom_records), maxfeepercent=maxfeepercent, @@ -801,7 +796,7 @@ class CLNNode: keysend_payment["preimage"] = response.payment_preimage.hex() keysend_payment["payment_hash"] = response.payment_hash.hex() - waitreq = noderpc.WaitsendpayRequest( + waitreq = node_pb2.WaitsendpayRequest( payment_hash=response.payment_hash, timeout=timeout ) try: @@ -834,7 +829,7 @@ class CLNNode: @classmethod def double_check_htlc_is_settled(cls, payment_hash): """Just as it sounds. Better safe than sorry!""" - request = holdrpc.HoldInvoiceLookupRequest( + request = hold_pb2.HoldInvoiceLookupRequest( payment_hash=bytes.fromhex(payment_hash) ) try: @@ -845,4 +840,4 @@ class CLNNode: else: raise e - return response.state == holdrpc.HoldInvoiceLookupResponse.Holdstate.SETTLED + return response.state == hold_pb2.HoldInvoiceLookupResponse.Holdstate.SETTLED diff --git a/api/lightning/lnd.py b/api/lightning/lnd.py index fdda1093..8a6ac59f 100644 --- a/api/lightning/lnd.py +++ b/api/lightning/lnd.py @@ -11,16 +11,18 @@ import ring from decouple import config from django.utils import timezone -from . import invoices_pb2 as invoicesrpc -from . import invoices_pb2_grpc as invoicesstub -from . import lightning_pb2 as lnrpc -from . import lightning_pb2_grpc as lightningstub -from . import router_pb2 as routerrpc -from . import router_pb2_grpc as routerstub -from . import signer_pb2 as signerrpc -from . import signer_pb2_grpc as signerstub -from . import verrpc_pb2 as verrpc -from . import verrpc_pb2_grpc as verstub +from . import ( + invoices_pb2, + invoices_pb2_grpc, + lightning_pb2, + lightning_pb2_grpc, + router_pb2, + router_pb2_grpc, + signer_pb2, + signer_pb2_grpc, + verrpc_pb2, + verrpc_pb2_grpc, +) ####### # Works with LND (c-lightning in the future for multi-vendor resilience) @@ -67,12 +69,6 @@ class LNDNode: combined_creds = grpc.composite_channel_credentials(ssl_creds, auth_creds) channel = grpc.secure_channel(LND_GRPC_HOST, combined_creds) - lightningstub = lightningstub.LightningStub(channel) - invoicesstub = invoicesstub.InvoicesStub(channel) - routerstub = routerstub.RouterStub(channel) - signerstub = signerstub.SignerStub(channel) - verstub = verstub.VersionerStub(channel) - payment_failure_context = { 0: "Payment isn't failed (yet)", 1: "There are more routes to try, but the payment timeout was exceeded.", @@ -85,8 +81,9 @@ class LNDNode: @classmethod def get_version(cls): try: - request = verrpc.VersionRequest() - response = cls.verstub.GetVersion(request) + request = verrpc_pb2.VersionRequest() + verstub = verrpc_pb2_grpc.VersionerStub(cls.channel) + response = verstub.GetVersion(request) log("verstub.GetVersion", request, response) return "v" + response.version except Exception as e: @@ -96,33 +93,35 @@ class LNDNode: @classmethod def decode_payreq(cls, invoice): """Decodes a lightning payment request (invoice)""" - request = lnrpc.PayReqString(pay_req=invoice) - response = cls.lightningstub.DecodePayReq(request) - log("lightningstub.DecodePayReq", request, response) + lightningstub = lightning_pb2_grpc.LightningStub(cls.channel) + request = lightning_pb2.PayReqString(pay_req=invoice) + response = lightningstub.DecodePayReq(request) + log("lightning_pb2_grpc.DecodePayReq", request, response) return response @classmethod def estimate_fee(cls, amount_sats, target_conf=2, min_confs=1): """Returns estimated fee for onchain payouts""" - - request = lnrpc.GetInfoRequest() + lightningstub = lightning_pb2_grpc.LightningStub(cls.channel) + request = lightning_pb2.GetInfoRequest() response = lightningstub.GetInfo(request) - log("lightningstub.GetInfo", request, response) + log("lightning_pb2_grpc.GetInfo", request, response) if response.testnet: dummy_address = "tb1qehyqhruxwl2p5pt52k6nxj4v8wwc3f3pg7377x" else: dummy_address = "bc1qgxwaqe4m9mypd7ltww53yv3lyxhcfnhzzvy5j3" # We assume segwit. Use hardcoded address as shortcut so there is no need of user inputs yet. - request = lnrpc.EstimateFeeRequest( + request = lightning_pb2.EstimateFeeRequest( AddrToAmount={dummy_address: amount_sats}, target_conf=target_conf, min_confs=min_confs, spend_unconfirmed=False, ) - response = cls.lightningstub.EstimateFee(request) - log("lightningstub.EstimateFee", request, response) + lightningstub = lightning_pb2_grpc.LightningStub(cls.channel) + response = lightningstub.EstimateFee(request) + log("lightning_pb2_grpc.EstimateFee", request, response) return { "mining_fee_sats": response.fee_sat, @@ -135,9 +134,10 @@ class LNDNode: @classmethod def wallet_balance(cls): """Returns onchain balance""" - request = lnrpc.WalletBalanceRequest() - response = cls.lightningstub.WalletBalance(request) - log("lightningstub.WalletBalance", request, response) + lightningstub = lightning_pb2_grpc.LightningStub(cls.channel) + request = lightning_pb2.WalletBalanceRequest() + response = lightningstub.WalletBalance(request) + log("lightning_pb2_grpc.WalletBalance", request, response) return { "total_balance": response.total_balance, @@ -151,9 +151,10 @@ class LNDNode: @classmethod def channel_balance(cls): """Returns channels balance""" - request = lnrpc.ChannelBalanceRequest() - response = cls.lightningstub.ChannelBalance(request) - log("lightningstub.ChannelBalance", request, response) + lightningstub = lightning_pb2_grpc.LightningStub(cls.channel) + request = lightning_pb2.ChannelBalanceRequest() + response = lightningstub.ChannelBalance(request) + log("lightning_pb2_grpc.ChannelBalance", request, response) return { "local_balance": response.local_balance.sat, @@ -169,7 +170,7 @@ class LNDNode: if DISABLE_ONCHAIN or onchainpayment.sent_satoshis > MAX_SWAP_AMOUNT: return False - request = lnrpc.SendCoinsRequest( + request = lightning_pb2.SendCoinsRequest( addr=onchainpayment.address, amount=int(onchainpayment.sent_satoshis), sat_per_vbyte=int(onchainpayment.mining_fee_rate), @@ -187,8 +188,9 @@ class LNDNode: # Changing the state to "MEMPO" should be atomic with SendCoins. onchainpayment.status = on_mempool_code onchainpayment.save(update_fields=["status"]) - response = cls.lightningstub.SendCoins(request) - log("lightningstub.SendCoins", request, response) + lightningstub = lightning_pb2_grpc.LightningStub(cls.channel) + response = lightningstub.SendCoins(request) + log("lightning_pb2_grpc.SendCoins", request, response) if response.txid: onchainpayment.txid = response.txid @@ -210,18 +212,22 @@ class LNDNode: @classmethod def cancel_return_hold_invoice(cls, payment_hash): """Cancels or returns a hold invoice""" - request = invoicesrpc.CancelInvoiceMsg(payment_hash=bytes.fromhex(payment_hash)) - response = cls.invoicesstub.CancelInvoice(request) - log("invoicesstub.CancelInvoice", request, response) + request = invoices_pb2.CancelInvoiceMsg( + payment_hash=bytes.fromhex(payment_hash) + ) + invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel) + response = invoicesstub.CancelInvoice(request) + log("invoices_pb2_grpc.CancelInvoice", request, response) # Fix this: tricky because canceling sucessfully an invoice has no response. TODO return str(response) == "" # True if no response, false otherwise. @classmethod def settle_hold_invoice(cls, preimage): """settles a hold invoice""" - request = invoicesrpc.SettleInvoiceMsg(preimage=bytes.fromhex(preimage)) - response = cls.invoicesstub.SettleInvoice(request) - log("invoicesstub.SettleInvoice", request, response) + request = invoices_pb2.SettleInvoiceMsg(preimage=bytes.fromhex(preimage)) + invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel) + response = invoicesstub.SettleInvoice(request) + log("invoices_pb2_grpc.SettleInvoice", request, response) # Fix this: tricky because settling sucessfully an invoice has None response. TODO return str(response) == "" # True if no response, false otherwise. @@ -244,7 +250,7 @@ class LNDNode: # Its hash is used to generate the hold invoice r_hash = hashlib.sha256(preimage).digest() - request = invoicesrpc.AddHoldInvoiceRequest( + request = invoices_pb2.AddHoldInvoiceRequest( memo=description, value=num_satoshis, hash=r_hash, @@ -253,8 +259,9 @@ class LNDNode: ), # actual expiry is padded by 50%, if tight, wrong client system clock will say invoice is expired. cltv_expiry=cltv_expiry_blocks, ) - response = cls.invoicesstub.AddHoldInvoice(request) - log("invoicesstub.AddHoldInvoice", request, response) + invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel) + response = invoicesstub.AddHoldInvoice(request) + log("invoices_pb2_grpc.AddHoldInvoice", request, response) hold_payment["invoice"] = response.payment_request payreq_decoded = cls.decode_payreq(hold_payment["invoice"]) @@ -275,22 +282,25 @@ class LNDNode: """Checks if hold invoice is locked""" from api.models import LNPayment - request = invoicesrpc.LookupInvoiceMsg( + request = invoices_pb2.LookupInvoiceMsg( payment_hash=bytes.fromhex(lnpayment.payment_hash) ) - response = cls.invoicesstub.LookupInvoiceV2(request) - log("invoicesstub.LookupInvoiceV2", request, response) + invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel) + response = invoicesstub.LookupInvoiceV2(request) + log("invoices_pb2_grpc.LookupInvoiceV2", request, response) # Will fail if 'unable to locate invoice'. Happens if invoice expiry # time has passed (but these are 15% padded at the moment). Should catch it # and report back that the invoice has expired (better robustness) - if response.state == lnrpc.Invoice.InvoiceState.OPEN: # OPEN + if response.state == lightning_pb2.Invoice.InvoiceState.OPEN: # OPEN pass - if response.state == lnrpc.Invoice.InvoiceState.SETTLED: # SETTLED + if response.state == lightning_pb2.Invoice.InvoiceState.SETTLED: # SETTLED pass - if response.state == lnrpc.Invoice.InvoiceState.CANCELED: # CANCELED + if response.state == lightning_pb2.Invoice.InvoiceState.CANCELED: # CANCELED pass - if response.state == lnrpc.Invoice.InvoiceState.ACCEPTED: # ACCEPTED (LOCKED) + if ( + response.state == lightning_pb2.Invoice.InvoiceState.ACCEPTED + ): # ACCEPTED (LOCKED) lnpayment.expiry_height = response.htlcs[0].expiry_height lnpayment.status = LNPayment.Status.LOCKED lnpayment.save(update_fields=["expiry_height", "status"]) @@ -316,11 +326,12 @@ class LNDNode: try: # this is similar to LNNnode.validate_hold_invoice_locked - request = invoicesrpc.LookupInvoiceMsg( + request = invoices_pb2.LookupInvoiceMsg( payment_hash=bytes.fromhex(lnpayment.payment_hash) ) - response = cls.invoicesstub.LookupInvoiceV2(request) - log("invoicesstub.LookupInvoiceV2", request, response) + invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel) + response = invoicesstub.LookupInvoiceV2(request) + log("invoices_pb2_grpc.LookupInvoiceV2", request, response) status = lnd_response_state_to_lnpayment_status[response.state] @@ -351,8 +362,9 @@ class LNDNode: @classmethod def resetmc(cls): - request = routerrpc.ResetMissionControlRequest() - _ = cls.routerstub.ResetMissionControl(request) + routerstub = router_pb2_grpc.RouterStub(cls.channel) + request = router_pb2.ResetMissionControlRequest() + _ = routerstub.ResetMissionControl(request) return True @classmethod @@ -459,27 +471,28 @@ class LNDNode: ) ) # 200 ppm or 10 sats timeout_seconds = int(config("REWARDS_TIMEOUT_SECONDS")) - request = routerrpc.SendPaymentRequest( + request = router_pb2.SendPaymentRequest( payment_request=lnpayment.invoice, fee_limit_sat=fee_limit_sat, timeout_seconds=timeout_seconds, ) - for response in cls.routerstub.SendPaymentV2(request): - log("routerstub.SendPaymentV2", request, response) + routerstub = router_pb2_grpc.RouterStub(cls.channel) + for response in routerstub.SendPaymentV2(request): + log("router_pb2_grpc.SendPaymentV2", request, response) if ( - response.status == lnrpc.Payment.PaymentStatus.UNKNOWN + response.status == lightning_pb2.Payment.PaymentStatus.UNKNOWN ): # Status 0 'UNKNOWN' # Not sure when this status happens pass if ( - response.status == lnrpc.Payment.PaymentStatus.IN_FLIGHT + response.status == lightning_pb2.Payment.PaymentStatus.IN_FLIGHT ): # Status 1 'IN_FLIGHT' pass if ( - response.status == lnrpc.Payment.PaymentStatus.FAILED + response.status == lightning_pb2.Payment.PaymentStatus.FAILED ): # Status 3 'FAILED' """0 Payment isn't failed (yet). 1 There are more routes to try, but the payment timeout was exceeded. @@ -495,7 +508,7 @@ class LNDNode: return False, failure_reason if ( - response.status == lnrpc.Payment.PaymentStatus.SUCCEEDED + response.status == lightning_pb2.Payment.PaymentStatus.SUCCEEDED ): # STATUS 'SUCCEEDED' lnpayment.status = LNPayment.Status.SUCCED lnpayment.fee = float(response.fee_msat) / 1000 @@ -515,7 +528,7 @@ class LNDNode: hash = lnpayment.payment_hash - request = routerrpc.SendPaymentRequest( + request = router_pb2.SendPaymentRequest( payment_request=lnpayment.invoice, fee_limit_sat=fee_limit_sat, timeout_seconds=timeout_seconds, @@ -535,7 +548,7 @@ class LNDNode: order.save(update_fields=["status"]) if ( - response.status == lnrpc.Payment.PaymentStatus.UNKNOWN + response.status == lightning_pb2.Payment.PaymentStatus.UNKNOWN ): # Status 0 'UNKNOWN' # Not sure when this status happens print(f"Order: {order.id} UNKNOWN. Hash {hash}") @@ -543,7 +556,7 @@ class LNDNode: lnpayment.save(update_fields=["in_flight"]) if ( - response.status == lnrpc.Payment.PaymentStatus.IN_FLIGHT + response.status == lightning_pb2.Payment.PaymentStatus.IN_FLIGHT ): # Status 1 'IN_FLIGHT' print(f"Order: {order.id} IN_FLIGHT. Hash {hash}") @@ -556,7 +569,7 @@ class LNDNode: lnpayment.save(update_fields=["last_routing_time"]) if ( - response.status == lnrpc.Payment.PaymentStatus.FAILED + response.status == lightning_pb2.Payment.PaymentStatus.FAILED ): # Status 3 'FAILED' lnpayment.status = LNPayment.Status.FAILRO lnpayment.last_routing_time = timezone.now() @@ -599,7 +612,7 @@ class LNDNode: } if ( - response.status == lnrpc.Payment.PaymentStatus.SUCCEEDED + response.status == lightning_pb2.Payment.PaymentStatus.SUCCEEDED ): # Status 2 'SUCCEEDED' print(f"Order: {order.id} SUCCEEDED. Hash: {hash}") lnpayment.status = LNPayment.Status.SUCCED @@ -621,8 +634,9 @@ class LNDNode: return results try: - for response in cls.routerstub.SendPaymentV2(request): - log("routerstub.SendPaymentV2", request, response) + routerstub = router_pb2_grpc.RouterStub(cls.channel) + for response in routerstub.SendPaymentV2(request): + log("router_pb2_grpc.SendPaymentV2", request, response) handle_response(response) except Exception as e: @@ -630,12 +644,13 @@ class LNDNode: print(f"Order: {order.id}. INVOICE EXPIRED. Hash: {hash}") # An expired invoice can already be in-flight. Check. try: - request = routerrpc.TrackPaymentRequest( + request = router_pb2.TrackPaymentRequest( payment_hash=bytes.fromhex(hash) ) - for response in cls.routerstub.TrackPaymentV2(request): - log("routerstub.TrackPaymentV2", request, response) + routerstub = router_pb2_grpc.RouterStub(cls.channel) + for response in routerstub.TrackPaymentV2(request): + log("router_pb2_grpc.TrackPaymentV2", request, response) handle_response(response, was_in_transit=True) except Exception as e: @@ -670,23 +685,25 @@ class LNDNode: elif "payment is in transition" in str(e): print(f"Order: {order.id} ALREADY IN TRANSITION. Hash: {hash}.") - request = routerrpc.TrackPaymentRequest( + request = router_pb2.TrackPaymentRequest( payment_hash=bytes.fromhex(hash) ) - for response in cls.routerstub.TrackPaymentV2(request): - log("routerstub.TrackPaymentV2", request, response) + routerstub = router_pb2_grpc.RouterStub(cls.channel) + for response in routerstub.TrackPaymentV2(request): + log("router_pb2_grpc.TrackPaymentV2", request, response) handle_response(response, was_in_transit=True) elif "invoice is already paid" in str(e): print(f"Order: {order.id} ALREADY PAID. Hash: {hash}.") - request = routerrpc.TrackPaymentRequest( + request = router_pb2.TrackPaymentRequest( payment_hash=bytes.fromhex(hash) ) - for response in cls.routerstub.TrackPaymentV2(request): - log("routerstub.TrackPaymentV2", request, response) + routerstub = router_pb2_grpc.RouterStub(cls.channel) + for response in routerstub.TrackPaymentV2(request): + log("router_pb2_grpc.TrackPaymentV2", request, response) handle_response(response) else: @@ -721,26 +738,28 @@ class LNDNode: (34349334, bytes.fromhex(msg.encode("utf-8").hex())) ) if sign: - self_pubkey = cls.lightningstub.GetInfo( - lnrpc.GetInfoRequest() + lightningstub = lightning_pb2_grpc.LightningStub(cls.channel) + self_pubkey = lightningstub.GetInfo( + lightning_pb2.GetInfoRequest() ).identity_pubkey timestamp = struct.pack(">i", int(time.time())) - signature = cls.signerstub.SignMessage( - signerrpc.SignMessageReq( + signerstub = signer_pb2_grpc.SignerStub(cls.channel) + signature = signerstub.SignMessage( + signer_pb2.SignMessageReq( msg=( bytes.fromhex(self_pubkey) + bytes.fromhex(target_pubkey) + timestamp + bytes.fromhex(msg.encode("utf-8").hex()) ), - key_loc=signerrpc.KeyLocator(key_family=6, key_index=0), + key_loc=signer_pb2.KeyLocator(key_family=6, key_index=0), ) ).signature custom_records.append((34349337, signature)) custom_records.append((34349339, bytes.fromhex(self_pubkey))) custom_records.append((34349343, timestamp)) - request = routerrpc.SendPaymentRequest( + request = router_pb2.SendPaymentRequest( dest=bytes.fromhex(target_pubkey), dest_custom_records=custom_records, fee_limit_sat=routing_budget_sats, @@ -749,17 +768,18 @@ class LNDNode: payment_hash=bytes.fromhex(hashed_secret), allow_self_payment=ALLOW_SELF_KEYSEND, ) - for response in cls.routerstub.SendPaymentV2(request): - log("routerstub.SendPaymentV2", request, response) - if response.status == lnrpc.Payment.PaymentStatus.IN_FLIGHT: + routerstub = router_pb2_grpc.RouterStub(cls.channel) + for response in routerstub.SendPaymentV2(request): + log("router_pb2_grpc.SendPaymentV2", request, response) + if response.status == lightning_pb2.Payment.PaymentStatus.IN_FLIGHT: keysend_payment["status"] = LNPayment.Status.FLIGHT - if response.status == lnrpc.Payment.PaymentStatus.SUCCEEDED: + if response.status == lightning_pb2.Payment.PaymentStatus.SUCCEEDED: keysend_payment["fee"] = float(response.fee_msat) / 1000 keysend_payment["status"] = LNPayment.Status.SUCCED - if response.status == lnrpc.Payment.PaymentStatus.FAILED: + if response.status == lightning_pb2.Payment.PaymentStatus.FAILED: keysend_payment["status"] = LNPayment.Status.FAILRO keysend_payment["failure_reason"] = response.failure_reason - if response.status == lnrpc.Payment.PaymentStatus.UNKNOWN: + if response.status == lightning_pb2.Payment.PaymentStatus.UNKNOWN: print("Unknown Error") except Exception as e: if "self-payments not allowed" in str(e): @@ -772,10 +792,13 @@ class LNDNode: @classmethod def double_check_htlc_is_settled(cls, payment_hash): """Just as it sounds. Better safe than sorry!""" - request = invoicesrpc.LookupInvoiceMsg(payment_hash=bytes.fromhex(payment_hash)) - response = cls.invoicesstub.LookupInvoiceV2(request) - log("invoicesstub.LookupInvoiceV2", request, response) + request = invoices_pb2.LookupInvoiceMsg( + payment_hash=bytes.fromhex(payment_hash) + ) + invoicesstub = invoices_pb2_grpc.InvoicesStub(cls.channel) + response = invoicesstub.LookupInvoiceV2(request) + log("invoices_pb2_grpc.LookupInvoiceV2", request, response) return ( - response.state == lnrpc.Invoice.InvoiceState.SETTLED + response.state == lightning_pb2.Invoice.InvoiceState.SETTLED ) # LND states: 0 OPEN, 1 SETTLED, 3 ACCEPTED, GRPC_ERROR status 5 when CANCELED/returned diff --git a/api/tests/test_utils.py b/api/tests/test_utils.py index b1d9d913..5280b4bf 100644 --- a/api/tests/test_utils.py +++ b/api/tests/test_utils.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, Mock, mock_open, patch import numpy as np -from decouple import config from django.test import TestCase from api.models import Order @@ -22,6 +21,8 @@ from api.utils import ( verify_signed_message, weighted_median, ) +from tests.mocks.cln import MockNodeStub +from tests.mocks.lnd import MockVersionerStub class TestUtils(TestCase): @@ -95,25 +96,19 @@ class TestUtils(TestCase): mock_response_blockchain.json.assert_called_once() mock_response_yadio.json.assert_called_once() - LNVENDOR = config("LNVENDOR", cast=str, default="LND") + @patch("api.lightning.lnd.verrpc_pb2_grpc.VersionerStub", MockVersionerStub) + def test_get_lnd_version(self): + version = get_lnd_version() + self.assertEqual(version, "v0.17.0-beta") - if LNVENDOR == "LND": + @patch("api.lightning.cln.node_pb2_grpc.NodeStub", MockNodeStub) + def test_get_cln_version(self): + version = get_cln_version() + self.assertEqual(version, "v23.08") - @patch("api.lightning.lnd.LNDNode.get_version") - def test_get_lnd_version(self, mock_get_version): - mock_get_version.return_value = "v0.17.0-beta" - version = get_lnd_version() - self.assertEqual(version, "v0.17.0-beta") - - elif LNVENDOR == "CLN": - - @patch("api.lightning.cln.CLNNode.get_version") - def test_get_cln_version(self, mock_get_version): - mock_get_version.return_value = "v23.08.1" - version = get_cln_version() - self.assertEqual(version, "v23.08.1") - - @patch("builtins.open", new_callable=mock_open, read_data="test_commit_hash") + @patch( + "builtins.open", new_callable=mock_open, read_data="00000000000000000000 dev" + ) def test_get_robosats_commit(self, mock_file): # Call the get_robosats_commit function commit_hash = get_robosats_commit() diff --git a/tests/mocks/cln.py b/tests/mocks/cln.py new file mode 100644 index 00000000..4aaad91c --- /dev/null +++ b/tests/mocks/cln.py @@ -0,0 +1,58 @@ +from unittest.mock import MagicMock + +# Mock up of CLN gRPC responses + + +class MockNodeStub: + def __init__(channel, other): + pass + + def Getinfo(self, request): + response = MagicMock() + response.id = b"\002\202Y\300\330\2564\005\357\263\221;\300\266\326F\010}\370/\252&!v\221iM\251\241V\241\034\034" + response.alias = "ROBOSATS-TEST-CLN-v23.08" + response.color = "\002\202Y" + response.num_peers = 1 + response.num_active_channels = 1 + response.version = "v23.08" + response.lightning_dir = "/root/.lightning/testnet" + response.our_features.init = b"\010\240\000\n\002i\242" + response.our_features.node = b"\210\240\000\n\002i\242" + response.our_features.invoice = b"\002\000\000\002\002A\000" + response.blockheight = 2100000 + response.network = "testnet" + response.fees_collected_msat.msat: 21000 + response.address.item_type = "TORV3" + response.address.port = 19735 + response.address.address = ( + "21000000gwfmvmig5xlzc2yzm6uzisode5vhs7kyegwstu5hflhx5fid.onion" + ) + response.binding.item_type = "IPV6" + response.binding.address = "127.0.0.1" + response.binding.port = 9736 + return response + + +class MockHoldStub: + def __init__(channel, other): + pass + + def HoldInvoiceLookup(self, request): + response = MagicMock() + return response + + def HoldInvoice(self, request): + response = MagicMock() + return response + + def HoldInvoiceSettle(self, request): + response = MagicMock() + return response + + def HoldInvoiceCancel(self, request): + response = MagicMock() + return response + + def DecodeBolt11(self, request): + response = MagicMock() + return response diff --git a/tests/mocks/lnd.py b/tests/mocks/lnd.py index 9286da9f..53f622b5 100644 --- a/tests/mocks/lnd.py +++ b/tests/mocks/lnd.py @@ -17,20 +17,30 @@ class MockLightningStub: def DecodePayReq(self, request): response = MagicMock() - if request.pay_req == "lntb17314....x": - response.destination = "00000000" - response.payment_hash = "00000000" + if ( + request.pay_req + == "lntb17310n1pj552mdpp50p2utzh7mpsf3uq7u7cws4a96tj3kyq54hchdkpw8zecamx9klrqd2j2pshjmt9de6zqun9vejhyetwvdjn5gphxs6nsvfe893z6wphvfsj6dryvymj6wp5xvuz6wp5xcukvdec8yukgcf49cs9g6rfwvs8qcted4jkuapq2ay5cnpqgefy2326g5syjn3qt984253q2aq5cnz92skzqcmgv43kkgr0dcs9ymmzdafkzarnyp5kvgr5dpjjqmr0vd4jqampwvs8xatrvdjhxumxw4kzugzfwss8w6tvdssxyefqw4hxcmmrddjkggpgveskjmpfyp6kumr9wdejq7t0w5sxx6r9v96zqmmjyp3kzmnrv4kzqatwd9kxzar9wfskcmre9ccqz52xqzwzsp5hkzegrhn6kegr33z8qfxtcudaklugygdrakgyy7va0wt2qs7drfq9qyyssqc6rztchzl4m7mlulrhlcajszcl9fan8908k9n5x7gmz8g8d6ht5pj4l8r0dushq6j5s8x7yv9a5klz0kfxwy8v6ze6adyrrp4wu0q0sq3t604x" + ): + response.destination = ( + "033b58d7681fe5dd2fb21fd741996cda5449616f77317dd1156b80128d6a71b807" + ) + response.payment_hash = ( + "7855c58afed86098f01ee7b0e857a5d2e51b1014adf176d82e38b38eecc5b7c6" + ) response.num_satoshis = 1731 response.timestamp = 1699359597 response.expiry = 450 - response.description = "Payment reference: xxxxxxxxxxxxxxxxxxxxxxx. This payment WILL FREEZE IN YOUR WALLET, check on RoboSats if the lock was successful. It will be unlocked (fail) unless you cheat or cancel unilaterally." + response.description = "Payment reference: 7458199b-87ba-4da7-8438-8469f7899da5. This payment WILL FREEZE IN YOUR WALLET, check on RoboSats if the lock was successful. It will be unlocked (fail) unless you cheat or cancel unilaterally." response.cltv_expiry = 650 - response.payment_addr = "\275\205\224\002\036h\322" + response.payment_addr = '\275\205\224\016\363\325\262\201\306"8\022e\343\215\355\277\304\021\r\037l\202\023\314\353\334\265\002\036h\322' response.num_msat = 1731000 def CancelInvoice(self, request): response = MagicMock() - if request == b"xU\305\212\306": + if ( + request + == b"xU\305\212\376\330`\230\360\036\347\260\350W\245\322\345\033\020\024\255\361v\330.8\263\216\354\305\267\306" + ): response = {} return response @@ -69,9 +79,9 @@ class MockInvoicesStub: def AddHoldInvoice(self, request): response = MagicMock() if request.value == 1731: - response.payment_request = "lntb17314....x" + response.payment_request = "lntb17310n1pj552mdpp50p2utzh7mpsf3uq7u7cws4a96tj3kyq54hchdkpw8zecamx9klrqd2j2pshjmt9de6zqun9vejhyetwvdjn5gphxs6nsvfe893z6wphvfsj6dryvymj6wp5xvuz6wp5xcukvdec8yukgcf49cs9g6rfwvs8qcted4jkuapq2ay5cnpqgefy2326g5syjn3qt984253q2aq5cnz92skzqcmgv43kkgr0dcs9ymmzdafkzarnyp5kvgr5dpjjqmr0vd4jqampwvs8xatrvdjhxumxw4kzugzfwss8w6tvdssxyefqw4hxcmmrddjkggpgveskjmpfyp6kumr9wdejq7t0w5sxx6r9v96zqmmjyp3kzmnrv4kzqatwd9kxzar9wfskcmre9ccqz52xqzwzsp5hkzegrhn6kegr33z8qfxtcudaklugygdrakgyy7va0wt2qs7drfq9qyyssqc6rztchzl4m7mlulrhlcajszcl9fan8908k9n5x7gmz8g8d6ht5pj4l8r0dushq6j5s8x7yv9a5klz0kfxwy8v6ze6adyrrp4wu0q0sq3t604x" response.add_index = 1 - response.payment_addr = b"\275\205\322" + response.payment_addr = b'\275\205\224\016\363\325\262\201\306"8\022e\343\215\355\277\304\021\r\037l\202\023\314\353\334\265\002\036h\322' def CancelInvoice(self, request): response = MagicMock() @@ -107,6 +117,9 @@ class MockSignerStub: class MockVersionerStub: + def __init__(channel, other): + pass + def GetVersion(self, request): response = MagicMock() response.commit = "v0.17.0-beta" diff --git a/tests/test_coordinator_info.py b/tests/test_coordinator_info.py new file mode 100644 index 00000000..2abf984a --- /dev/null +++ b/tests/test_coordinator_info.py @@ -0,0 +1,61 @@ +import json +from unittest.mock import patch + +from decouple import config +from django.conf import settings +from django.contrib.auth.models import User +from django.test import Client, TestCase + +from tests.mocks.cln import MockNodeStub +from tests.mocks.lnd import MockVersionerStub + +FEE = config("FEE", cast=float, default=0.2) +NODE_ID = config("NODE_ID", cast=str, default="033b58d7......") +MAKER_FEE = FEE * config("FEE_SPLIT", cast=float, default=0.125) +TAKER_FEE = FEE * (1 - config("FEE_SPLIT", cast=float, default=0.125)) +BOND_SIZE = config("BOND_SIZE", cast=float, default=3) +NOTICE_SEVERITY = config("NOTICE_SEVERITY", cast=str, default="none") +NOTICE_MESSAGE = config("NOTICE_MESSAGE", cast=str, default="") + + +class CoordinatorInfoTest(TestCase): + su_pass = "12345678" + su_name = config("ESCROW_USERNAME", cast=str, default="admin") + + def setUp(self): + """ + Create a superuser. The superuser is the escrow party. + """ + self.client = Client() + User.objects.create_superuser(self.su_name, "super@user.com", self.su_pass) + + @patch("api.lightning.cln.node_pb2_grpc.NodeStub", MockNodeStub) + @patch("api.lightning.lnd.verrpc_pb2_grpc.VersionerStub", MockVersionerStub) + def test_info(self): + path = "/api/info/" + + response = self.client.get(path) + data = json.loads(response.content.decode()) + + self.assertEqual(response.status_code, 200) + self.assertEqual(data["num_public_buy_orders"], 0) + self.assertEqual(data["num_public_sell_orders"], 0) + self.assertEqual(data["book_liquidity"], 0) + self.assertEqual(data["active_robots_today"], 0) + self.assertEqual(data["last_day_nonkyc_btc_premium"], 0) + self.assertEqual(data["last_day_volume"], 0) + self.assertEqual(data["lifetime_volume"], 0) + self.assertEqual(data["lnd_version"], "v0.17.0-beta") + self.assertEqual(data["cln_version"], "v23.08") + self.assertEqual( + data["robosats_running_commit_hash"], "00000000000000000000 dev" + ) + self.assertEqual(data["version"], settings.VERSION) + self.assertEqual(data["node_id"], NODE_ID) + self.assertEqual(data["network"], "testnet") + self.assertAlmostEqual(data["maker_fee"], MAKER_FEE) + self.assertAlmostEqual(data["taker_fee"], TAKER_FEE) + self.assertAlmostEqual(data["bond_size"], BOND_SIZE) + self.assertEqual(data["notice_severity"], NOTICE_SEVERITY) + self.assertEqual(data["notice_message"], NOTICE_MESSAGE) + self.assertEqual(data["current_swap_fee_rate"], 0) diff --git a/tests/test_trade_pipeline.py b/tests/test_trade_pipeline.py index 6aa11723..d3b6996e 100644 --- a/tests/test_trade_pipeline.py +++ b/tests/test_trade_pipeline.py @@ -9,6 +9,7 @@ from django.test import Client, TestCase from api.models import Currency, Order from api.tasks import cache_market +from tests.mocks.cln import MockHoldStub, MockNodeStub from tests.mocks.lnd import ( MockInvoicesStub, MockLightningStub, @@ -225,11 +226,13 @@ class TradeTest(TestCase): ) self.assertIsNone(data["taker"], "New order's taker is not null") - @patch("api.lightning.lightning_pb2_grpc.LightningStub", MockLightningStub) - @patch("api.lightning.invoices_pb2_grpc.InvoicesStub", MockInvoicesStub) - @patch("api.lightning.router_pb2_grpc.RouterStub", MockRouterStub) - @patch("api.lightning.signer_pb2_grpc.SignerStub", MockSignerStub) - @patch("api.lightning.verrpc_pb2_grpc.VersionerStub", MockVersionerStub) + @patch("api.lightning.cln.node_pb2_grpc.NodeStub", MockNodeStub) + @patch("api.lightning.cln.hold_pb2_grpc.HoldStub", MockHoldStub) + @patch("api.lightning.lnd.verrpc_pb2_grpc.VersionerStub", MockVersionerStub) + @patch("api.lightning.lnd.lightning_pb2_grpc.LightningStub", MockLightningStub) + @patch("api.lightning.lnd.invoices_pb2_grpc.InvoicesStub", MockInvoicesStub) + @patch("api.lightning.lnd.router_pb2_grpc.RouterStub", MockRouterStub) + @patch("api.lightning.lnd.signer_pb2_grpc.SignerStub", MockSignerStub) def test_maker_bond_locked(self): self.test_create_order( robot_index=1,