Make monadic blocks lazy by defering execution of continuations with free vars

This commit is contained in:
Nicolas Abril 2024-05-30 21:07:12 +02:00
parent 3c9a532df1
commit 6182ac74fa
18 changed files with 132 additions and 68 deletions

View File

@ -5,7 +5,7 @@ license = "Apache-2.0"
version = "0.2.28" version = "0.2.28"
edition = "2021" edition = "2021"
rust-version = "1.74" rust-version = "1.74"
exclude = ["tests/snapshots/"] exclude = ["tests/"]
[lib] [lib]
name = "bend" name = "bend"

View File

@ -641,7 +641,7 @@ A Parallel Bitonic Sort
The bitonic sort is a popular algorithm that sorts a set of numbers by moving The bitonic sort is a popular algorithm that sorts a set of numbers by moving
them through a "circuit" (sorting network) and swapping as they pass through: them through a "circuit" (sorting network) and swapping as they pass through:
![bsort](https://upload.wikimedia.org/wikipedia/commons/thumb/b/bd/BitonicSort1.svg/1686px-BitonicSort1.svg.png) ![bitonic-sort](https://upload.wikimedia.org/wikipedia/commons/thumb/b/bd/BitonicSort1.svg/1686px-BitonicSort1.svg.png)
In CUDA, this can be implemented by using mutable arrays and synchronization In CUDA, this can be implemented by using mutable arrays and synchronization
primitives. This is well known. What is less known is that it can also be primitives. This is well known. What is less known is that it can also be

View File

@ -4,14 +4,12 @@
"words": [ "words": [
"anni", "anni",
"annihilations", "annihilations",
"argn",
"arities", "arities",
"arity", "arity",
"arrayvec", "arrayvec",
"behaviour", "behaviour",
"bitand", "bitand",
"Bitonic", "Bitonic",
"bsort",
"builtins", "builtins",
"callcc", "callcc",
"chumsky", "chumsky",
@ -23,14 +21,12 @@
"concat", "concat",
"ctrs", "ctrs",
"cuda", "cuda",
"Dall",
"datatypes", "datatypes",
"Deque", "Deque",
"destructures", "destructures",
"desugared", "desugared",
"desugars", "desugars",
"devs", "devs",
"dref",
"dups", "dups",
"effectful", "effectful",
"elif", "elif",
@ -39,8 +35,6 @@
"hasher", "hasher",
"hexdigit", "hexdigit",
"hvm's", "hvm's",
"hvmc",
"iexp",
"indexmap", "indexmap",
"inet", "inet",
"inets", "inets",
@ -60,7 +54,6 @@
"linearization", "linearization",
"linearizes", "linearizes",
"linearizing", "linearizing",
"lnet",
"lnil", "lnil",
"lpthread", "lpthread",
"mant", "mant",
@ -75,9 +68,6 @@
"nums", "nums",
"OOM's", "OOM's",
"oper", "oper",
"opre",
"oprune",
"oref",
"parallelizable", "parallelizable",
"peekable", "peekable",
"postcondition", "postcondition",
@ -118,12 +108,15 @@
"TSPL", "TSPL",
"tunr", "tunr",
"unbounds", "unbounds",
"undefer",
"vectorize", "vectorize",
"vectorizes", "vectorizes",
"walkdir", "walkdir",
"wopts"
], ],
"files": ["**/*.rs", "**/*.md"], "files": [
"**/*.rs",
"**/*.md"
],
"ignoreRegExpList": [ "ignoreRegExpList": [
"HexValues", "HexValues",
"/λ/g", "/λ/g",

View File

@ -362,22 +362,40 @@ with Result:
return x return x
``` ```
A monadic with block. A monadic `with` block.
Where `x <- ...` performs a monadic operation. Where `x <- ...` performs a monadic operation.
Expects `Result` to be a type defined with `type` and a function `Result/bind` to be defined. Expects `Result` to be a type defined with `type` or `object` and the function `Result/bind` to be defined.
The monadic bind function should be of type `(Result a) -> (a -> Result b) -> Result b`, like this: The monadic bind function should be of type `(Result a) -> (a -> Result b) -> Result b`, like this:
```python ```python
def Result/bind(res, nxt): def Result/bind(res, nxt):
match res: match res:
case Result/Ok: case Result/Ok:
nxt = undefer(nxt)
return nxt(res.value) return nxt(res.value)
case Result/Err: case Result/Err:
return res return res
``` ```
However, the second argument, `nxt`, is actually a deferred call to the continuation, passing any free variables as arguments.
Therefore, all `bind` functions must call the builtin function `undefer` before using the value of `nxt`, as in the example above.
This is necessary to ensure that the continuation in recursive monadic functions stays lazy and doesn't expand infinitely.
This is an example of a recursive function that would loop if passing the variable `a` to the recursive call `Result/foo(a, b)` was not deferred:
```python
def Result/foo(x, y):
with Result:
a <- Result/Ok(1)
if b:
b = Result/Err(x)
else:
b = Result/Ok(y)
b <- b
return Result/foo(a, b)
```
Other statements are allowed inside the `with` block and it can both return a value at the end and bind a variable, like branching statements do. Other statements are allowed inside the `with` block and it can both return a value at the end and bind a variable, like branching statements do.
```python ```python
@ -389,7 +407,7 @@ return y
``` ```
The name `wrap` is bound inside a `with` block as a shorthand for `Type/wrap`, The name `wrap` is bound inside a `with` block as a shorthand for `Type/wrap`,
the equivalent as a `pure` function in other functional languages: and it calls the unit function of the monad, also called `pure` in some languages:
```python ```python
def Result/wrap(x): def Result/wrap(x):
@ -968,8 +986,8 @@ match x {
### With block ### With block
```rust ```rust
Result/bind (Result/Ok val) f = (f val) Result/bind (Result/Ok val) nxt = (nxt val)
Result/bind err _ = err Result/bind err _nxt = err
div a b = switch b { div a b = switch b {
0: (Result/Err "Div by 0") 0: (Result/Err "Div by 0")
@ -991,6 +1009,23 @@ Main = with Result {
Receives a type defined with `type` and expects `Result/bind` to be defined as a monadic bind function. Receives a type defined with `type` and expects `Result/bind` to be defined as a monadic bind function.
It should be of type `(Result a) -> (a -> Result b) -> Result b`, like in the example above. It should be of type `(Result a) -> (a -> Result b) -> Result b`, like in the example above.
However, the second argument, `nxt`, is actually a deferred call to the continuation, passing any free variables as arguments.
Therefore, all `bind` functions must call the builtin function `undefer` before using the value of `nxt`, as in the example above.
This is necessary to ensure that the continuation in recursive monadic functions stays lazy and doesn't expand infinitely.
This is an example of a recursive function that would loop if passing the variable `a` to the recursive call `Result/foo(a, b)` was not deferred:
```python
Result/foo x y = with Result {
ask a = (Result/Ok 1)
ask b = if b {
(Result/Err x)
} else {
(Result/Ok y)
}
(Result/foo a b)
}
```
Inside a `with` block, you can use `ask`, to access the continuation value of the monadic operation. Inside a `with` block, you can use `ask`, to access the continuation value of the monadic operation.
```rust ```rust

View File

@ -73,6 +73,7 @@ def IO/wrap(x):
def IO/bind(a, b): def IO/bind(a, b):
match a: match a:
case IO/Done: case IO/Done:
b = undefer(b)
return b(a.expr) return b(a.expr)
case IO/Call: case IO/Call:
return IO/Call(IO/MAGIC, a.func, a.argm, lambda x: IO/bind(a.cont(x), b)) return IO/Call(IO/MAGIC, a.func, a.argm, lambda x: IO/bind(a.cont(x), b))
@ -89,3 +90,15 @@ print text = (IO/Call IO/MAGIC "PUT_TEXT" text @x (IO/Done IO/MAGIC x))
get_time = (IO/Call IO/MAGIC "GET_TIME" * @x (IO/Done IO/MAGIC x)) get_time = (IO/Call IO/MAGIC "GET_TIME" * @x (IO/Done IO/MAGIC x))
sleep hi_lo = (IO/Call IO/MAGIC "PUT_TIME" hi_lo @x (IO/Done IO/MAGIC x)) sleep hi_lo = (IO/Call IO/MAGIC "PUT_TIME" hi_lo @x (IO/Done IO/MAGIC x))
# Lazy thunks
# We can defer the evaluation of a function by wrapping it in a thunk
# Ex: @x (x @arg1 @arg2 @arg3 (f arg1 arg2 arg3) arg1 arg2 arg3)
# This is only evaluated when we call it with 'undefer' (undefer my_thunk)
# We can build a defered call directly or by by using defer and defer_arg
# The example above can be written as:
# (defer_arg (defer_arg (defer_arg (defer @arg1 @arg2 @arg3 (f arg1 arg2 arg3)) arg1) arg2) arg3)
defer val = @x (x val)
defer_arg defered arg = @x (defered x arg)
undefer defered = (defered @x x)

View File

@ -2,7 +2,6 @@ use crate::{
diagnostics::{Diagnostics, DiagnosticsConfig}, diagnostics::{Diagnostics, DiagnosticsConfig},
maybe_grow, multi_iterator, ENTRY_POINT, maybe_grow, multi_iterator, ENTRY_POINT,
}; };
// use hvmc::ast::get_typ;
use indexmap::{IndexMap, IndexSet}; use indexmap::{IndexMap, IndexSet};
use interner::global::{GlobalPool, GlobalString}; use interner::global::{GlobalPool, GlobalString};
use itertools::Itertools; use itertools::Itertools;

View File

@ -1,20 +1,20 @@
use crate::{ use crate::{
diagnostics::Diagnostics, diagnostics::Diagnostics,
fun::{Ctx, Name, Term}, fun::{Ctx, Name, Pattern, Term},
maybe_grow, maybe_grow,
}; };
use std::collections::HashSet; use std::collections::HashSet;
impl Ctx<'_> { impl Ctx<'_> {
/// Converts `ask` terms inside `do` blocks into calls to a monadic bind operation. /// Converts `ask` terms inside `do` blocks into calls to a monadic bind operation.
pub fn desugar_do_blocks(&mut self) -> Result<(), Diagnostics> { pub fn desugar_with_blocks(&mut self) -> Result<(), Diagnostics> {
self.info.start_pass(); self.info.start_pass();
let def_names = self.book.defs.keys().cloned().collect::<HashSet<_>>(); let def_names = self.book.defs.keys().cloned().collect::<HashSet<_>>();
for def in self.book.defs.values_mut() { for def in self.book.defs.values_mut() {
for rule in def.rules.iter_mut() { for rule in def.rules.iter_mut() {
if let Err(e) = rule.body.desugar_do_blocks(None, &def_names) { if let Err(e) = rule.body.desugar_with_blocks(None, &def_names) {
self.info.add_rule_error(e, def.name.clone()); self.info.add_rule_error(e, def.name.clone());
} }
} }
@ -25,14 +25,14 @@ impl Ctx<'_> {
} }
impl Term { impl Term {
pub fn desugar_do_blocks( pub fn desugar_with_blocks(
&mut self, &mut self,
cur_block: Option<&Name>, cur_block: Option<&Name>,
def_names: &HashSet<Name>, def_names: &HashSet<Name>,
) -> Result<(), String> { ) -> Result<(), String> {
maybe_grow(|| { maybe_grow(|| {
if let Term::With { typ, bod } = self { if let Term::With { typ, bod } = self {
bod.desugar_do_blocks(Some(typ), def_names)?; bod.desugar_with_blocks(Some(typ), def_names)?;
let wrap_ref = Term::r#ref(&format!("{typ}/wrap")); let wrap_ref = Term::r#ref(&format!("{typ}/wrap"));
// let wrap_ref = if def_names.contains(&wrap_nam) { // let wrap_ref = if def_names.contains(&wrap_nam) {
// Term::r#ref(&wrap_nam) // Term::r#ref(&wrap_nam)
@ -47,8 +47,9 @@ impl Term {
let bind_nam = Name::new(format!("{typ}/bind")); let bind_nam = Name::new(format!("{typ}/bind"));
if def_names.contains(&bind_nam) { if def_names.contains(&bind_nam) {
// TODO: come up with a strategy for forwarding free vars to prevent infinite recursion.
let nxt = Term::lam(*pat.clone(), std::mem::take(nxt)); let nxt = Term::lam(*pat.clone(), std::mem::take(nxt));
let nxt = nxt.defer();
*self = Term::call(Term::Ref { nam: bind_nam }, [*val.clone(), nxt]); *self = Term::call(Term::Ref { nam: bind_nam }, [*val.clone(), nxt]);
} else { } else {
return Err(format!("Could not find definition {bind_nam} for type {typ}.")); return Err(format!("Could not find definition {bind_nam} for type {typ}."));
@ -59,10 +60,25 @@ impl Term {
} }
for children in self.children_mut() { for children in self.children_mut() {
children.desugar_do_blocks(cur_block, def_names)?; children.desugar_with_blocks(cur_block, def_names)?;
} }
Ok(()) Ok(())
}) })
} }
/// Converts a term with free vars `(f x1 .. xn)` into a deferred
/// call that passes those vars to the term.
///
/// Ex: `(f x1 .. xn)` becomes `@x (x @x1 .. @xn (f x1 .. x2) x1 .. x2)`.
///
/// The user must call this lazy thunk by calling the builtin
/// `undefer` function, or by applying `@x x` to the term.
fn defer(self) -> Term {
let free_vars = self.free_vars().into_keys().collect::<Vec<_>>();
let term = Term::rfold_lams(self, free_vars.iter().cloned().map(Some));
let term = Term::call(Term::Var { nam: Name::new("%x") }, [term]);
let term = Term::call(term, free_vars.iter().cloned().map(|nam| Term::Var { nam }));
Term::lam(Pattern::Var(Some(Name::new("%x"))), term)
}
} }

View File

@ -2,11 +2,11 @@ pub mod apply_args;
pub mod definition_merge; pub mod definition_merge;
pub mod definition_pruning; pub mod definition_pruning;
pub mod desugar_bend; pub mod desugar_bend;
pub mod desugar_do_blocks;
pub mod desugar_fold; pub mod desugar_fold;
pub mod desugar_match_defs; pub mod desugar_match_defs;
pub mod desugar_open; pub mod desugar_open;
pub mod desugar_use; pub mod desugar_use;
pub mod desugar_with_blocks;
pub mod encode_adts; pub mod encode_adts;
pub mod encode_match_terms; pub mod encode_match_terms;
pub mod expand_generated; pub mod expand_generated;

View File

@ -108,7 +108,7 @@ pub fn desugar_book(
ctx.desugar_bend()?; ctx.desugar_bend()?;
ctx.desugar_fold()?; ctx.desugar_fold()?;
ctx.desugar_do_blocks()?; ctx.desugar_with_blocks()?;
ctx.check_unbound_vars()?; ctx.check_unbound_vars()?;
@ -305,10 +305,10 @@ impl OptLevel {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct CompileOpts { pub struct CompileOpts {
/// Enables [hvmc::transform::eta_reduce]. /// Enables [hvm::eta_reduce].
pub eta: bool, pub eta: bool,
/// Enables [fun::transform::definition_pruning] and [hvmc_net::prune]. /// Enables [fun::transform::definition_pruning] and [hvm::prune].
pub prune: bool, pub prune: bool,
/// Enables [fun::transform::linearize_matches]. /// Enables [fun::transform::linearize_matches].
@ -320,7 +320,7 @@ pub struct CompileOpts {
/// Enables [fun::transform::definition_merge] /// Enables [fun::transform::definition_merge]
pub merge: bool, pub merge: bool,
/// Enables [hvmc::transform::inline]. /// Enables [hvm::inline].
pub inline: bool, pub inline: bool,
/// Enables [hvm::check_net_size]. /// Enables [hvm::check_net_size].

View File

@ -49,7 +49,7 @@ enum Mode {
RunC(RunArgs), RunC(RunArgs),
/// Compiles the program and runs it with the Cuda HVM implementation. /// Compiles the program and runs it with the Cuda HVM implementation.
RunCu(RunArgs), RunCu(RunArgs),
/// Compiles the program to hvmc and prints to stdout. /// Compiles the program to hvm and prints to stdout.
GenHvm(GenArgs), GenHvm(GenArgs),
/// Compiles the program to standalone C and prints to stdout. /// Compiles the program to standalone C and prints to stdout.
GenC(GenArgs), GenC(GenArgs),

View File

@ -235,7 +235,7 @@ fn simplify_matches() {
ctx.fix_match_terms()?; ctx.fix_match_terms()?;
ctx.desugar_bend()?; ctx.desugar_bend()?;
ctx.desugar_fold()?; ctx.desugar_fold()?;
ctx.desugar_do_blocks()?; ctx.desugar_with_blocks()?;
ctx.check_unbound_vars()?; ctx.check_unbound_vars()?;
ctx.book.make_var_names_unique(); ctx.book.make_var_names_unique();
ctx.book.linearize_match_binds(); ctx.book.linearize_match_binds();
@ -284,7 +284,7 @@ fn encode_pattern_match() {
ctx.fix_match_terms()?; ctx.fix_match_terms()?;
ctx.desugar_bend()?; ctx.desugar_bend()?;
ctx.desugar_fold()?; ctx.desugar_fold()?;
ctx.desugar_do_blocks()?; ctx.desugar_with_blocks()?;
ctx.check_unbound_vars()?; ctx.check_unbound_vars()?;
ctx.book.make_var_names_unique(); ctx.book.make_var_names_unique();
ctx.book.linearize_match_binds(); ctx.book.linearize_match_binds();

View File

@ -1,6 +1,6 @@
type Result = (Ok val) | (Err val) type Result = (Ok val) | (Err val)
Result/bind (Result/Ok val) f = (f val) Result/bind (Result/Ok val) f = ((undefer f) val)
Result/bind err _ = err Result/bind err _ = err
safe_div a b = switch b { safe_div a b = switch b {

View File

@ -2,7 +2,7 @@
type Result = (Ok val) | (Err val) type Result = (Ok val) | (Err val)
Result/bind r nxt = match r { Result/bind r nxt = match r {
Result/Ok: (nxt r.val) Result/Ok: ((undefer nxt) r.val)
Result/Err: r Result/Err: r
} }

View File

@ -1,6 +1,6 @@
type Result = (Ok val) | (Err val) type Result = (Ok val) | (Err val)
Result/bind (Result/Ok val) f = (f val) Result/bind (Result/Ok val) f = ((undefer f) val)
Result/bind err _ = err Result/bind err _ = err
Bar x = (Result/Err 0) Bar x = (Result/Err 0)

View File

@ -0,0 +1,15 @@
# This will only work if we make the call to `(Result/foo a b)` lazy (by converting it to a combinator).
type Result = (Ok val) | (Err val)
Result/bind = @val @nxt match val {
Result/Ok: ((undefer nxt) val.val)
Result/Err: (Result/Err val.val)
}
Result/foo x y =
with Result {
ask a = (Result/Ok x)
ask b = switch y { 0: (Result/Err a); _: (Result/Ok y-1) }
(Result/foo a b)
}
main = (Result/foo 1 2)

View File

@ -2,13 +2,15 @@
source: tests/golden_tests.rs source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/bind_syntax.bend input_file: tests/golden_tests/desugar_file/bind_syntax.bend
--- ---
(undefer) = λa (a λb b)
(Result/bind) = λa λb (a Result/bind__C2 b) (Result/bind) = λa λb (a Result/bind__C2 b)
(safe_div) = λa λb (switch b { 0: λ* (Result/Err (String/Cons 68 (String/Cons 105 (String/Cons 118 (String/Cons 32 (String/Cons 98 (String/Cons 121 (String/Cons 32 (String/Cons 48 String/Nil))))))))); _: safe_div__C0; } a) (safe_div) = λa λb (switch b { 0: λ* (Result/Err (String/Cons 68 (String/Cons 105 (String/Cons 118 (String/Cons 32 (String/Cons 98 (String/Cons 121 (String/Cons 32 (String/Cons 48 String/Nil))))))))); _: safe_div__C0; } a)
(safe_rem) = λa λb (switch b { 0: λ* (Result/Err (String/Cons 77 (String/Cons 111 (String/Cons 100 (String/Cons 32 (String/Cons 98 (String/Cons 121 (String/Cons 32 (String/Cons 48 String/Nil))))))))); _: safe_rem__C0; } a) (safe_rem) = λa λb (switch b { 0: λ* (Result/Err (String/Cons 77 (String/Cons 111 (String/Cons 100 (String/Cons 32 (String/Cons 98 (String/Cons 121 (String/Cons 32 (String/Cons 48 String/Nil))))))))); _: safe_rem__C0; } a)
(Main) = (Result/bind Main__C1 Main__C0) (Main) = (Result/bind Main__C3 Main__C2)
(String/Nil) = λa (a String/Nil/tag) (String/Nil) = λa (a String/Nil/tag)
@ -26,11 +28,15 @@ input_file: tests/golden_tests/desugar_file/bind_syntax.bend
(Result/Err/tag) = 1 (Result/Err/tag) = 1
(Main__C0) = λa (Result/bind (safe_rem a 0) λb b) (Main__C0) = λa (a λb b)
(Main__C1) = (safe_div 3 2) (Main__C1) = λa (Result/bind (safe_rem a 0) Main__C0)
(Result/bind__C0) = λa λb (b a) (Main__C2) = λa (a Main__C1)
(Main__C3) = (safe_div 3 2)
(Result/bind__C0) = λa λb (undefer b a)
(Result/bind__C1) = λ* λa λ* (Result/Err a) (Result/bind__C1) = λ* λa λ* (Result/Err a)

View File

@ -2,30 +2,8 @@
source: tests/golden_tests.rs source: tests/golden_tests.rs
input_file: tests/golden_tests/run_file/recursive_bind.bend input_file: tests/golden_tests/run_file/recursive_bind.bend
--- ---
Errors: NumScott:
The following functions contain recursive cycles incompatible with HVM's strict evaluation: λa (a Result/Err/tag 0)
* Foo -> Foo
The greedy eager evaluation of HVM may cause infinite loops. Scott:
Refactor these functions to use lazy references instead of direct function calls. λ* λa (a 0)
A reference is strict when it's being called ('(Foo x)') or when it's used non-linearly ('let x = Foo; (x x)').
It is lazy when it's an argument ('(x Foo)') or when it's used linearly ('let x = Foo; (x 0)').
Try one of these strategies:
- Use pattern matching with 'match', 'fold', and 'bend' to automatically lift expressions to lazy references.
- Replace direct calls with combinators. For example, change:
'Foo = λa λb (b (λc (Foo a c)) a)'
to:
'Foo = λa λb (b (λc λa (Foo a c)) (λa a) a)'
which is lifted to:
'Foo = λa λb (b Foo__C1 Foo__C2 a)'
- Replace non-linear 'let' expressions with 'use' expressions. For example, change:
'Foo = λf let x = Foo; (f x x)'
to:
'Foo = λf use x = Foo; (f x x)'
which inlines to:
'Foo = λf (f Foo Foo)'
- If disabled, re-enable the default 'float-combinators' and 'linearize-matches' compiler options.
For more information, visit: https://github.com/HigherOrderCO/Bend/blob/main/docs/lazy-definitions.md.
To disable this check, use the "-Arecursion-cycle" compiler option.

View File

@ -0,0 +1,9 @@
---
source: tests/golden_tests.rs
input_file: tests/golden_tests/run_file/strict_monad_fn.bend
---
NumScott:
λa (a Result/Err/tag 1)
Scott:
λ* λa (a 1)