Make monadic blocks lazy by defering execution of continuations with free vars

This commit is contained in:
Nicolas Abril 2024-05-30 21:07:12 +02:00
parent 3c9a532df1
commit 6182ac74fa
18 changed files with 132 additions and 68 deletions

View File

@ -5,7 +5,7 @@ license = "Apache-2.0"
version = "0.2.28"
edition = "2021"
rust-version = "1.74"
exclude = ["tests/snapshots/"]
exclude = ["tests/"]
[lib]
name = "bend"

View File

@ -641,7 +641,7 @@ A Parallel Bitonic Sort
The bitonic sort is a popular algorithm that sorts a set of numbers by moving
them through a "circuit" (sorting network) and swapping as they pass through:
![bsort](https://upload.wikimedia.org/wikipedia/commons/thumb/b/bd/BitonicSort1.svg/1686px-BitonicSort1.svg.png)
![bitonic-sort](https://upload.wikimedia.org/wikipedia/commons/thumb/b/bd/BitonicSort1.svg/1686px-BitonicSort1.svg.png)
In CUDA, this can be implemented by using mutable arrays and synchronization
primitives. This is well known. What is less known is that it can also be

View File

@ -4,14 +4,12 @@
"words": [
"anni",
"annihilations",
"argn",
"arities",
"arity",
"arrayvec",
"behaviour",
"bitand",
"Bitonic",
"bsort",
"builtins",
"callcc",
"chumsky",
@ -23,14 +21,12 @@
"concat",
"ctrs",
"cuda",
"Dall",
"datatypes",
"Deque",
"destructures",
"desugared",
"desugars",
"devs",
"dref",
"dups",
"effectful",
"elif",
@ -39,8 +35,6 @@
"hasher",
"hexdigit",
"hvm's",
"hvmc",
"iexp",
"indexmap",
"inet",
"inets",
@ -60,7 +54,6 @@
"linearization",
"linearizes",
"linearizing",
"lnet",
"lnil",
"lpthread",
"mant",
@ -75,9 +68,6 @@
"nums",
"OOM's",
"oper",
"opre",
"oprune",
"oref",
"parallelizable",
"peekable",
"postcondition",
@ -118,12 +108,15 @@
"TSPL",
"tunr",
"unbounds",
"undefer",
"vectorize",
"vectorizes",
"walkdir",
"wopts"
],
"files": ["**/*.rs", "**/*.md"],
"files": [
"**/*.rs",
"**/*.md"
],
"ignoreRegExpList": [
"HexValues",
"/λ/g",

View File

@ -362,22 +362,40 @@ with Result:
return x
```
A monadic with block.
A monadic `with` block.
Where `x <- ...` performs a monadic operation.
Expects `Result` to be a type defined with `type` and a function `Result/bind` to be defined.
Expects `Result` to be a type defined with `type` or `object` and the function `Result/bind` to be defined.
The monadic bind function should be of type `(Result a) -> (a -> Result b) -> Result b`, like this:
```python
def Result/bind(res, nxt):
match res:
case Result/Ok:
nxt = undefer(nxt)
return nxt(res.value)
case Result/Err:
return res
```
However, the second argument, `nxt`, is actually a deferred call to the continuation, passing any free variables as arguments.
Therefore, all `bind` functions must call the builtin function `undefer` before using the value of `nxt`, as in the example above.
This is necessary to ensure that the continuation in recursive monadic functions stays lazy and doesn't expand infinitely.
This is an example of a recursive function that would loop if passing the variable `a` to the recursive call `Result/foo(a, b)` was not deferred:
```python
def Result/foo(x, y):
with Result:
a <- Result/Ok(1)
if b:
b = Result/Err(x)
else:
b = Result/Ok(y)
b <- b
return Result/foo(a, b)
```
Other statements are allowed inside the `with` block and it can both return a value at the end and bind a variable, like branching statements do.
```python
@ -389,7 +407,7 @@ return y
```
The name `wrap` is bound inside a `with` block as a shorthand for `Type/wrap`,
the equivalent as a `pure` function in other functional languages:
and it calls the unit function of the monad, also called `pure` in some languages:
```python
def Result/wrap(x):
@ -968,8 +986,8 @@ match x {
### With block
```rust
Result/bind (Result/Ok val) f = (f val)
Result/bind err _ = err
Result/bind (Result/Ok val) nxt = (nxt val)
Result/bind err _nxt = err
div a b = switch b {
0: (Result/Err "Div by 0")
@ -991,6 +1009,23 @@ Main = with Result {
Receives a type defined with `type` and expects `Result/bind` to be defined as a monadic bind function.
It should be of type `(Result a) -> (a -> Result b) -> Result b`, like in the example above.
However, the second argument, `nxt`, is actually a deferred call to the continuation, passing any free variables as arguments.
Therefore, all `bind` functions must call the builtin function `undefer` before using the value of `nxt`, as in the example above.
This is necessary to ensure that the continuation in recursive monadic functions stays lazy and doesn't expand infinitely.
This is an example of a recursive function that would loop if passing the variable `a` to the recursive call `Result/foo(a, b)` was not deferred:
```python
Result/foo x y = with Result {
ask a = (Result/Ok 1)
ask b = if b {
(Result/Err x)
} else {
(Result/Ok y)
}
(Result/foo a b)
}
```
Inside a `with` block, you can use `ask`, to access the continuation value of the monadic operation.
```rust

View File

@ -73,6 +73,7 @@ def IO/wrap(x):
def IO/bind(a, b):
match a:
case IO/Done:
b = undefer(b)
return b(a.expr)
case IO/Call:
return IO/Call(IO/MAGIC, a.func, a.argm, lambda x: IO/bind(a.cont(x), b))
@ -89,3 +90,15 @@ print text = (IO/Call IO/MAGIC "PUT_TEXT" text @x (IO/Done IO/MAGIC x))
get_time = (IO/Call IO/MAGIC "GET_TIME" * @x (IO/Done IO/MAGIC x))
sleep hi_lo = (IO/Call IO/MAGIC "PUT_TIME" hi_lo @x (IO/Done IO/MAGIC x))
# Lazy thunks
# We can defer the evaluation of a function by wrapping it in a thunk
# Ex: @x (x @arg1 @arg2 @arg3 (f arg1 arg2 arg3) arg1 arg2 arg3)
# This is only evaluated when we call it with 'undefer' (undefer my_thunk)
# We can build a defered call directly or by by using defer and defer_arg
# The example above can be written as:
# (defer_arg (defer_arg (defer_arg (defer @arg1 @arg2 @arg3 (f arg1 arg2 arg3)) arg1) arg2) arg3)
defer val = @x (x val)
defer_arg defered arg = @x (defered x arg)
undefer defered = (defered @x x)

View File

@ -2,7 +2,6 @@ use crate::{
diagnostics::{Diagnostics, DiagnosticsConfig},
maybe_grow, multi_iterator, ENTRY_POINT,
};
// use hvmc::ast::get_typ;
use indexmap::{IndexMap, IndexSet};
use interner::global::{GlobalPool, GlobalString};
use itertools::Itertools;

View File

@ -1,20 +1,20 @@
use crate::{
diagnostics::Diagnostics,
fun::{Ctx, Name, Term},
fun::{Ctx, Name, Pattern, Term},
maybe_grow,
};
use std::collections::HashSet;
impl Ctx<'_> {
/// Converts `ask` terms inside `do` blocks into calls to a monadic bind operation.
pub fn desugar_do_blocks(&mut self) -> Result<(), Diagnostics> {
pub fn desugar_with_blocks(&mut self) -> Result<(), Diagnostics> {
self.info.start_pass();
let def_names = self.book.defs.keys().cloned().collect::<HashSet<_>>();
for def in self.book.defs.values_mut() {
for rule in def.rules.iter_mut() {
if let Err(e) = rule.body.desugar_do_blocks(None, &def_names) {
if let Err(e) = rule.body.desugar_with_blocks(None, &def_names) {
self.info.add_rule_error(e, def.name.clone());
}
}
@ -25,14 +25,14 @@ impl Ctx<'_> {
}
impl Term {
pub fn desugar_do_blocks(
pub fn desugar_with_blocks(
&mut self,
cur_block: Option<&Name>,
def_names: &HashSet<Name>,
) -> Result<(), String> {
maybe_grow(|| {
if let Term::With { typ, bod } = self {
bod.desugar_do_blocks(Some(typ), def_names)?;
bod.desugar_with_blocks(Some(typ), def_names)?;
let wrap_ref = Term::r#ref(&format!("{typ}/wrap"));
// let wrap_ref = if def_names.contains(&wrap_nam) {
// Term::r#ref(&wrap_nam)
@ -47,8 +47,9 @@ impl Term {
let bind_nam = Name::new(format!("{typ}/bind"));
if def_names.contains(&bind_nam) {
// TODO: come up with a strategy for forwarding free vars to prevent infinite recursion.
let nxt = Term::lam(*pat.clone(), std::mem::take(nxt));
let nxt = nxt.defer();
*self = Term::call(Term::Ref { nam: bind_nam }, [*val.clone(), nxt]);
} else {
return Err(format!("Could not find definition {bind_nam} for type {typ}."));
@ -59,10 +60,25 @@ impl Term {
}
for children in self.children_mut() {
children.desugar_do_blocks(cur_block, def_names)?;
children.desugar_with_blocks(cur_block, def_names)?;
}
Ok(())
})
}
/// Converts a term with free vars `(f x1 .. xn)` into a deferred
/// call that passes those vars to the term.
///
/// Ex: `(f x1 .. xn)` becomes `@x (x @x1 .. @xn (f x1 .. x2) x1 .. x2)`.
///
/// The user must call this lazy thunk by calling the builtin
/// `undefer` function, or by applying `@x x` to the term.
fn defer(self) -> Term {
let free_vars = self.free_vars().into_keys().collect::<Vec<_>>();
let term = Term::rfold_lams(self, free_vars.iter().cloned().map(Some));
let term = Term::call(Term::Var { nam: Name::new("%x") }, [term]);
let term = Term::call(term, free_vars.iter().cloned().map(|nam| Term::Var { nam }));
Term::lam(Pattern::Var(Some(Name::new("%x"))), term)
}
}

View File

@ -2,11 +2,11 @@ pub mod apply_args;
pub mod definition_merge;
pub mod definition_pruning;
pub mod desugar_bend;
pub mod desugar_do_blocks;
pub mod desugar_fold;
pub mod desugar_match_defs;
pub mod desugar_open;
pub mod desugar_use;
pub mod desugar_with_blocks;
pub mod encode_adts;
pub mod encode_match_terms;
pub mod expand_generated;

View File

@ -108,7 +108,7 @@ pub fn desugar_book(
ctx.desugar_bend()?;
ctx.desugar_fold()?;
ctx.desugar_do_blocks()?;
ctx.desugar_with_blocks()?;
ctx.check_unbound_vars()?;
@ -305,10 +305,10 @@ impl OptLevel {
#[derive(Clone, Debug)]
pub struct CompileOpts {
/// Enables [hvmc::transform::eta_reduce].
/// Enables [hvm::eta_reduce].
pub eta: bool,
/// Enables [fun::transform::definition_pruning] and [hvmc_net::prune].
/// Enables [fun::transform::definition_pruning] and [hvm::prune].
pub prune: bool,
/// Enables [fun::transform::linearize_matches].
@ -320,7 +320,7 @@ pub struct CompileOpts {
/// Enables [fun::transform::definition_merge]
pub merge: bool,
/// Enables [hvmc::transform::inline].
/// Enables [hvm::inline].
pub inline: bool,
/// Enables [hvm::check_net_size].

View File

@ -49,7 +49,7 @@ enum Mode {
RunC(RunArgs),
/// Compiles the program and runs it with the Cuda HVM implementation.
RunCu(RunArgs),
/// Compiles the program to hvmc and prints to stdout.
/// Compiles the program to hvm and prints to stdout.
GenHvm(GenArgs),
/// Compiles the program to standalone C and prints to stdout.
GenC(GenArgs),

View File

@ -235,7 +235,7 @@ fn simplify_matches() {
ctx.fix_match_terms()?;
ctx.desugar_bend()?;
ctx.desugar_fold()?;
ctx.desugar_do_blocks()?;
ctx.desugar_with_blocks()?;
ctx.check_unbound_vars()?;
ctx.book.make_var_names_unique();
ctx.book.linearize_match_binds();
@ -284,7 +284,7 @@ fn encode_pattern_match() {
ctx.fix_match_terms()?;
ctx.desugar_bend()?;
ctx.desugar_fold()?;
ctx.desugar_do_blocks()?;
ctx.desugar_with_blocks()?;
ctx.check_unbound_vars()?;
ctx.book.make_var_names_unique();
ctx.book.linearize_match_binds();

View File

@ -1,6 +1,6 @@
type Result = (Ok val) | (Err val)
Result/bind (Result/Ok val) f = (f val)
Result/bind (Result/Ok val) f = ((undefer f) val)
Result/bind err _ = err
safe_div a b = switch b {

View File

@ -2,7 +2,7 @@
type Result = (Ok val) | (Err val)
Result/bind r nxt = match r {
Result/Ok: (nxt r.val)
Result/Ok: ((undefer nxt) r.val)
Result/Err: r
}

View File

@ -1,6 +1,6 @@
type Result = (Ok val) | (Err val)
Result/bind (Result/Ok val) f = (f val)
Result/bind (Result/Ok val) f = ((undefer f) val)
Result/bind err _ = err
Bar x = (Result/Err 0)

View File

@ -0,0 +1,15 @@
# This will only work if we make the call to `(Result/foo a b)` lazy (by converting it to a combinator).
type Result = (Ok val) | (Err val)
Result/bind = @val @nxt match val {
Result/Ok: ((undefer nxt) val.val)
Result/Err: (Result/Err val.val)
}
Result/foo x y =
with Result {
ask a = (Result/Ok x)
ask b = switch y { 0: (Result/Err a); _: (Result/Ok y-1) }
(Result/foo a b)
}
main = (Result/foo 1 2)

View File

@ -2,13 +2,15 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/bind_syntax.bend
---
(undefer) = λa (a λb b)
(Result/bind) = λa λb (a Result/bind__C2 b)
(safe_div) = λa λb (switch b { 0: λ* (Result/Err (String/Cons 68 (String/Cons 105 (String/Cons 118 (String/Cons 32 (String/Cons 98 (String/Cons 121 (String/Cons 32 (String/Cons 48 String/Nil))))))))); _: safe_div__C0; } a)
(safe_rem) = λa λb (switch b { 0: λ* (Result/Err (String/Cons 77 (String/Cons 111 (String/Cons 100 (String/Cons 32 (String/Cons 98 (String/Cons 121 (String/Cons 32 (String/Cons 48 String/Nil))))))))); _: safe_rem__C0; } a)
(Main) = (Result/bind Main__C1 Main__C0)
(Main) = (Result/bind Main__C3 Main__C2)
(String/Nil) = λa (a String/Nil/tag)
@ -26,11 +28,15 @@ input_file: tests/golden_tests/desugar_file/bind_syntax.bend
(Result/Err/tag) = 1
(Main__C0) = λa (Result/bind (safe_rem a 0) λb b)
(Main__C0) = λa (a λb b)
(Main__C1) = (safe_div 3 2)
(Main__C1) = λa (Result/bind (safe_rem a 0) Main__C0)
(Result/bind__C0) = λa λb (b a)
(Main__C2) = λa (a Main__C1)
(Main__C3) = (safe_div 3 2)
(Result/bind__C0) = λa λb (undefer b a)
(Result/bind__C1) = λ* λa λ* (Result/Err a)

View File

@ -2,30 +2,8 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/run_file/recursive_bind.bend
---
Errors:
The following functions contain recursive cycles incompatible with HVM's strict evaluation:
* Foo -> Foo
NumScott:
λa (a Result/Err/tag 0)
The greedy eager evaluation of HVM may cause infinite loops.
Refactor these functions to use lazy references instead of direct function calls.
A reference is strict when it's being called ('(Foo x)') or when it's used non-linearly ('let x = Foo; (x x)').
It is lazy when it's an argument ('(x Foo)') or when it's used linearly ('let x = Foo; (x 0)').
Try one of these strategies:
- Use pattern matching with 'match', 'fold', and 'bend' to automatically lift expressions to lazy references.
- Replace direct calls with combinators. For example, change:
'Foo = λa λb (b (λc (Foo a c)) a)'
to:
'Foo = λa λb (b (λc λa (Foo a c)) (λa a) a)'
which is lifted to:
'Foo = λa λb (b Foo__C1 Foo__C2 a)'
- Replace non-linear 'let' expressions with 'use' expressions. For example, change:
'Foo = λf let x = Foo; (f x x)'
to:
'Foo = λf use x = Foo; (f x x)'
which inlines to:
'Foo = λf (f Foo Foo)'
- If disabled, re-enable the default 'float-combinators' and 'linearize-matches' compiler options.
For more information, visit: https://github.com/HigherOrderCO/Bend/blob/main/docs/lazy-definitions.md.
To disable this check, use the "-Arecursion-cycle" compiler option.
Scott:
λ* λa (a 0)

View File

@ -0,0 +1,9 @@
---
source: tests/golden_tests.rs
input_file: tests/golden_tests/run_file/strict_monad_fn.bend
---
NumScott:
λa (a Result/Err/tag 1)
Scott:
λ* λa (a 1)