add synchronous, single-connection interface

This commit is contained in:
Matthew Yacavone 2021-08-26 16:51:11 -04:00
parent 0eea2cb7e9
commit cff0c1600a
3 changed files with 292 additions and 55 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import base64
from enum import Enum
from dataclasses import dataclass
from typing import Any, List, Optional, Union
from typing import Any, Tuple, List, Dict, Optional, Union
from typing_extensions import Literal
import argo_client.interaction as argo
@ -20,7 +20,9 @@ def extend_hex(string : str) -> str:
else:
return string
def from_cryptol_arg(val : Any) -> Any:
CryptolPython = Union[bool, int, BV, Tuple, List, Dict, OpaqueValue]
def from_cryptol_arg(val : Any) -> CryptolPython:
"""Return the canonical Python value for a Cryptol JSON value."""
if isinstance(val, bool):
return val
@ -80,27 +82,36 @@ class CryptolExtendSearchPath(argo.Command):
return res
class CryptolEvalExpr(argo.Command):
class CryptolEvalExprRaw(argo.Command):
def __init__(self, connection : HasProtocolState, expr : Any) -> None:
super(CryptolEvalExpr, self).__init__(
super(CryptolEvalExprRaw, self).__init__(
'evaluate expression',
{'expression': expr},
connection
)
def process_result(self, res : Any) -> Any:
return from_cryptol_arg(res['value'])
return res['value']
class CryptolCall(argo.Command):
class CryptolEvalExpr(CryptolEvalExprRaw):
def process_result(self, res : Any) -> Any:
return from_cryptol_arg(super(CryptolEvalExpr, self).process_result(res))
class CryptolCallRaw(argo.Command):
def __init__(self, connection : HasProtocolState, fun : str, args : List[Any]) -> None:
super(CryptolCall, self).__init__(
super(CryptolCallRaw, self).__init__(
'call',
{'function': fun, 'arguments': args},
connection
)
def process_result(self, res : Any) -> Any:
return from_cryptol_arg(res['value'])
return res['value']
class CryptolCall(CryptolCallRaw):
def process_result(self, res : Any) -> Any:
return from_cryptol_arg(super(CryptolCall, self).process_result(res))
@dataclass
@ -112,42 +123,49 @@ class CheckReport:
tests_run: int
tests_possible: Optional[int]
class CryptolCheck(argo.Command):
def to_check_report(res : Any) -> CheckReport:
if res['result'] == 'pass':
return CheckReport(
success=True,
args=[],
error_msg = None,
tests_run=res['tests run'],
tests_possible=res['tests possible'])
elif res['result'] == 'fail':
return CheckReport(
success=False,
args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']],
error_msg = None,
tests_run=res['tests run'],
tests_possible=res['tests possible'])
elif res['result'] == 'error':
return CheckReport(
success=False,
args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']],
error_msg = res['error message'],
tests_run=res['tests run'],
tests_possible=res['tests possible'])
else:
raise ValueError("Unknown check result " + str(res))
class CryptolCheckRaw(argo.Command):
def __init__(self, connection : HasProtocolState, expr : Any, num_tests : Union[Literal['all'],int, None]) -> None:
if num_tests:
args = {'expression': expr, 'number of tests':num_tests}
else:
args = {'expression': expr}
super(CryptolCheck, self).__init__(
super(CryptolCheckRaw, self).__init__(
'check',
args,
connection
)
def process_result(self, res : Any) -> Any:
return res
class CryptolCheck(CryptolCheckRaw):
def process_result(self, res : Any) -> 'CheckReport':
if res['result'] == 'pass':
return CheckReport(
success=True,
args=[],
error_msg = None,
tests_run=res['tests run'],
tests_possible=res['tests possible'])
elif res['result'] == 'fail':
return CheckReport(
success=False,
args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']],
error_msg = None,
tests_run=res['tests run'],
tests_possible=res['tests possible'])
elif res['result'] == 'error':
return CheckReport(
success=False,
args=[from_cryptol_arg(arg['expr']) for arg in res['arguments']],
error_msg = res['error message'],
tests_run=res['tests run'],
tests_possible=res['tests possible'])
else:
raise ValueError("Unknown check result " + str(res))
return to_check_report(super(CryptolCheck, self).process_result(res))
class CryptolCheckType(argo.Command):
@ -161,14 +179,37 @@ class CryptolCheckType(argo.Command):
def process_result(self, res : Any) -> Any:
return res['type schema']
class SmtQueryType(str, Enum):
PROVE = 'prove'
SAFE = 'safe'
SAT = 'sat'
class CryptolProveSat(argo.Command):
SmtQueryResult = Union[bool, List[CryptolPython], OfflineSmtQuery]
def to_smt_query_result(qtype : SmtQueryType, res : Any) -> SmtQueryResult:
if res['result'] == 'unsatisfiable':
if qtype == SmtQueryType.SAT:
return False
elif qtype == SmtQueryType.PROVE or qtype == SmtQueryType.SAFE:
return True
else:
raise ValueError("Unknown SMT query type: " + qtype)
elif res['result'] == 'invalid':
return [from_cryptol_arg(arg['expr'])
for arg in res['counterexample']]
elif res['result'] == 'satisfied':
return [from_cryptol_arg(arg['expr'])
for m in res['models']
for arg in m]
elif res['result'] == 'offline':
return OfflineSmtQuery(content=res['query'])
else:
raise ValueError("Unknown SMT result: " + str(res))
class CryptolProveSatRaw(argo.Command):
def __init__(self, connection : HasProtocolState, qtype : SmtQueryType, expr : Any, solver : Solver, count : Optional[int]) -> None:
super(CryptolProveSat, self).__init__(
super(CryptolProveSatRaw, self).__init__(
'prove or satisfy',
{'query type': qtype,
'expression': expr,
@ -180,24 +221,11 @@ class CryptolProveSat(argo.Command):
self.qtype = qtype
def process_result(self, res : Any) -> Any:
if res['result'] == 'unsatisfiable':
if self.qtype == SmtQueryType.SAT:
return False
elif self.qtype == SmtQueryType.PROVE or self.qtype == SmtQueryType.SAFE:
return True
else:
raise ValueError("Unknown SMT query type: " + self.qtype)
elif res['result'] == 'invalid':
return [from_cryptol_arg(arg['expr'])
for arg in res['counterexample']]
elif res['result'] == 'satisfied':
return [from_cryptol_arg(arg['expr'])
for m in res['models']
for arg in m]
elif res['result'] == 'offline':
return OfflineSmtQuery(content=res['query'])
else:
raise ValueError("Unknown SMT result: " + str(res))
return res
class CryptolProveSat(CryptolProveSatRaw):
def process_result(self, res : Any) -> Any:
return to_smt_query_result(self.qtype, super(CryptolProveSat, self).process_result(res))
class CryptolProve(CryptolProveSat):
def __init__(self, connection : HasProtocolState, expr : Any, solver : Solver) -> None:
@ -211,6 +239,7 @@ class CryptolSafe(CryptolProveSat):
def __init__(self, connection : HasProtocolState, expr : Any, solver : Solver) -> None:
super(CryptolSafe, self).__init__(connection, SmtQueryType.SAFE, expr, solver, 1)
class CryptolNames(argo.Command):
def __init__(self, connection : HasProtocolState) -> None:
super(CryptolNames, self).__init__('visible names', {}, connection)
@ -218,6 +247,7 @@ class CryptolNames(argo.Command):
def process_result(self, res : Any) -> Any:
return res
class CryptolFocusedModule(argo.Command):
def __init__(self, connection : HasProtocolState) -> None:
super(CryptolFocusedModule, self).__init__(
@ -229,6 +259,7 @@ class CryptolFocusedModule(argo.Command):
def process_result(self, res : Any) -> Any:
return res
class CryptolReset(argo.Notification):
def __init__(self, connection : HasProtocolState) -> None:
super(CryptolReset, self).__init__(
@ -237,6 +268,7 @@ class CryptolReset(argo.Notification):
connection
)
class CryptolResetServer(argo.Notification):
def __init__(self, connection : HasProtocolState) -> None:
super(CryptolResetServer, self).__init__(

View File

@ -175,6 +175,13 @@ class CryptolConnection:
self.most_recent_result = CryptolLoadModule(self, module_name)
return self.most_recent_result
def eval_raw(self, expression : Any) -> argo.Command:
"""Like the member method ``eval``, but does not call
``from_cryptol_arg`` on the ``.result()``.
"""
self.most_recent_result = CryptolEvalExprRaw(self, expression)
return self.most_recent_result
def eval(self, expression : Any) -> argo.Command:
"""Evaluate a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing
@ -189,15 +196,33 @@ class CryptolConnection:
return self.eval(expression)
def extend_search_path(self, *dir : str) -> argo.Command:
"""Load a Cryptol module, like ``:module`` at the Cryptol REPL."""
"""Extend the search path for loading Cryptol modules."""
self.most_recent_result = CryptolExtendSearchPath(self, list(dir))
return self.most_recent_result
def call_raw(self, fun : str, *args : List[Any]) -> argo.Command:
"""Like the member method ``call``, but does not call
``from_cryptol_arg`` on the ``.result()``.
"""
encoded_args = [cryptoltypes.CryptolType().from_python(a) for a in args]
self.most_recent_result = CryptolCallRaw(self, fun, encoded_args)
return self.most_recent_result
def call(self, fun : str, *args : List[Any]) -> argo.Command:
encoded_args = [cryptoltypes.CryptolType().from_python(a) for a in args]
self.most_recent_result = CryptolCall(self, fun, encoded_args)
return self.most_recent_result
def check_raw(self, expr : Any, *, num_tests : Union[Literal['all'], int, None] = None) -> argo.Command:
"""Like the member method ``check``, but does not call
`to_check_report` on the ``.result()``.
"""
if num_tests == "all" or isinstance(num_tests, int) or num_tests is None:
self.most_recent_result = CryptolCheckRaw(self, expr, num_tests)
return self.most_recent_result
else:
raise ValueError('``num_tests`` must be an integer, ``None``, or the string literall ``"all"``')
def check(self, expr : Any, *, num_tests : Union[Literal['all'], int, None] = None) -> argo.Command:
"""Tests the validity of a Cryptol expression with random inputs. The expression must be a function with
return type ``Bit``.
@ -212,7 +237,6 @@ class CryptolConnection:
else:
raise ValueError('``num_tests`` must be an integer, ``None``, or the string literall ``"all"``')
def check_type(self, code : Any) -> argo.Command:
"""Check the type of a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing for
@ -221,6 +245,14 @@ class CryptolConnection:
self.most_recent_result = CryptolCheckType(self, code)
return self.most_recent_result
def prove_sat_raw(self, expr : Any, qtype : SmtQueryType, solver : solver.Solver, count : Optional[int]) -> argo.Command:
"""A generalization of the member methods ``sat``, ``prove``, and
``check`, but additionally does not call `to_smt_query_result` on the
``.result()``.
"""
self.most_recent_result = CryptolProveSatRaw(self, expr, qtype, solver, count)
return self.most_recent_result
def sat(self, expr : Any, solver : solver.Solver = solver.Z3, count : int = 1) -> argo.Command:
"""Check the satisfiability of a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing for

View File

@ -0,0 +1,173 @@
"""A synchronous, single-connection interface for the Cryptol bindings"""
from __future__ import annotations
from typing import cast, Any, Optional, Union, List, Dict
from typing_extensions import Literal
from . import solver
from . import connection
from . import cryptoltypes
from .commands import *
__designated_connection = None # type: Optional[connection.CryptolConnection]
def __get_designated_connection() -> connection.CryptolConnection:
global __designated_connection
if __designated_connection is None:
raise ValueError("There is no current synchronous connection (see `connect_sync`).")
else:
return __designated_connection
def __set_designated_connection(conn: connection.CryptolConnection) -> None:
global __designated_connection
if __designated_connection is None:
__designated_connection = conn
else:
raise ValueError("There is already a current synchronous connection."
" Did you call `connect_sync()` more than once?")
def sync_connected() -> bool:
"""Return true iff there is a current synchronous connection."""
global __designated_connection
return __designated_connection is not None
def connect_sync(command : Optional[str]=None,
*,
cryptol_path : Optional[str] = None,
url : Optional[str] = None,
reset_server : bool = False) -> None:
"""
Connect to a (possibly new) Cryptol server process synchronously.
:param command: A command to launch a new Cryptol server in socket mode (if provided).
:param cryptol_path: A replacement for the contents of
the ``CRYPTOLPATH`` environment variable (if provided).
:param url: A URL at which to connect to an already running Cryptol
HTTP server.
:param reset_server: If ``True``, the server that is connected to will be
reset. (This ensures any states from previous server usages have been
cleared.)
If no ``command`` or ``url`` parameters are provided, the following are attempted in order:
1. If the environment variable ``CRYPTOL_SERVER`` is set and referse to an executable,
it is assumed to be a Cryptol server and will be used for a new ``socket`` connection.
2. If the environment variable ``CRYPTOL_SERVER_URL`` is set, it is assumed to be
the URL for a running Cryptol server in ``http`` mode and will be connected to.
3. If an executable ``cryptol-remote-api`` is available on the ``PATH``
it is assumed to be a Cryptol server and will be used for a new ``socket`` connection.
"""
global __designated_connection
# Set the designated connection by starting a server process
if __designated_connection is None:
__designated_connection = connection.connect(
command=command,
cryptol_path=cryptol_path,
url=url,
reset_server=reset_server)
elif reset_server:
__designated_connection.reset_server()
else:
raise ValueError("There is already a current synchronous connection."
" Did you call `connect_sync()` more than once?")
def connect_sync_stdio(command : str, cryptol_path : Optional[str] = None) -> None:
"""Start a new synchronous connection to a new Cryptol server process.
:param command: The command to launch the Cryptol server.
:param cryptol_path: An optional replacement for the contents of
the ``CRYPTOLPATH`` environment variable.
"""
__set_designated_connection(connection.connect_stdio(
command=command,
cryptol_path=cryptol_path))
def load_file(filename : str) -> None:
"""Load a filename as a Cryptol module, like ``:load`` at the Cryptol
REPL.
"""
__get_designated_connection().load_file(filename).result()
def load_module(module_name : str) -> None:
"""Load a Cryptol module, like ``:module`` at the Cryptol REPL."""
__get_designated_connection().load_module(module_name).result()
def evalCry(expression : Any) -> CryptolPython:
"""Evaluate a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing
for their JSON equivalents.
"""
return from_cryptol_arg(__get_designated_connection().eval_raw(expression).result())
def evaluate_expression(expression : Any) -> CryptolPython:
"""Synonym for ``evalCry``. """
return evalCry(expression)
def extend_search_path(*dir : str) -> None:
"""Extend the search path for loading Cryptol modules."""
__get_designated_connection().extend_search_path(*dir).result()
def call(fun : str, *args : List[Any]) -> CryptolPython:
return from_cryptol_arg(__get_designated_connection().call_raw(fun, *args).result())
def check(expr : Any, *, num_tests : Union[Literal['all'], int, None] = None) -> CheckReport:
"""Tests the validity of a Cryptol expression with random inputs. The expression must be a function with
return type ``Bit``.
If ``num_tests`` is ``"all"`` then the expression is tested exhaustively (i.e., against all possible inputs).
If ``num_tests`` is omitted, Cryptol defaults to running 100 tests.
"""
return to_check_report(__get_designated_connection().check_raw(expr, num_tests=num_tests).result())
def check_type(code : Any) -> cryptoltypes.CryptolType:
"""Check the type of a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing for
their JSON equivalents.
"""
return cryptoltypes.to_type(__get_designated_connection().check_type(code).result()['type'])
def sat(expr : Any, solver : solver.Solver = solver.Z3, count : int = 1) -> SmtQueryResult:
"""Check the satisfiability of a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing for
their JSON equivalents. Use the solver named `solver`, and return up to
`count` solutions.
"""
return to_smt_query_result(SmtQueryType.SAT, __get_designated_connection().prove_sat_raw(expr, SmtQueryType.SAT, solver, count).result())
def prove(expr : Any, solver : solver.Solver = solver.Z3) -> SmtQueryResult:
"""Check the validity of a Cryptol expression, represented according to
:ref:`cryptol-json-expression`, with Python datatypes standing for
their JSON equivalents. Use the solver named `solver`.
"""
return to_smt_query_result(SmtQueryType.PROVE, __get_designated_connection().prove_sat_raw(expr, SmtQueryType.PROVE, solver, 1).result())
def safe(expr : Any, solver : solver.Solver = solver.Z3) -> SmtQueryResult:
"""Check via an external SMT solver that the given term is safe for all inputs,
which means it cannot encounter a run-time error.
"""
return to_smt_query_result(SmtQueryType.SAFE, __get_designated_connection().prove_sat_raw(expr, SmtQueryType.SAFE, solver, 1).result())
def names() -> List[Dict[str,Any]]:
"""Discover the list of names currently in scope in the current context."""
res = __get_designated_connection().names().result()
if isinstance(res, list) and all(isinstance(d, dict) and all(isinstance(k, str) for k in d.keys()) for d in res):
return res
else:
raise ValueError("Panic! Result of `names()` is malformed: " + str(res))
def focused_module() -> Dict[str,Any]:
"""Return the name of the currently-focused module."""
res = __get_designated_connection().focused_module().result()
if isinstance(res, dict) and all(isinstance(k, str) for k in res.keys()):
return res
else:
raise ValueError("Panic! Result of `focused_module()` is malformed: " + str(res))
def reset() -> None:
"""Resets the connection, causing its unique state on the server to be freed (if applicable).
After a reset a connection may be treated as if it were a fresh connection with the server if desired."""
__get_designated_connection().reset()
def reset_server() -> None:
"""Resets the Cryptol server, clearing all states."""
__get_designated_connection().reset_server()