diff --git a/Cargo.toml b/Cargo.toml index a474068a..f3e339a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ license = "Apache-2.0" version = "0.2.28" edition = "2021" rust-version = "1.74" -exclude = ["tests/snapshots/"] +exclude = ["tests/"] [lib] name = "bend" diff --git a/GUIDE.md b/GUIDE.md index d7d31b66..4eac58ae 100644 --- a/GUIDE.md +++ b/GUIDE.md @@ -641,7 +641,7 @@ A Parallel Bitonic Sort The bitonic sort is a popular algorithm that sorts a set of numbers by moving them through a "circuit" (sorting network) and swapping as they pass through: -![bsort](https://upload.wikimedia.org/wikipedia/commons/thumb/b/bd/BitonicSort1.svg/1686px-BitonicSort1.svg.png) +![bitonic-sort](https://upload.wikimedia.org/wikipedia/commons/thumb/b/bd/BitonicSort1.svg/1686px-BitonicSort1.svg.png) In CUDA, this can be implemented by using mutable arrays and synchronization primitives. This is well known. What is less known is that it can also be diff --git a/cspell.json b/cspell.json index 72fa0a9d..05e82cca 100644 --- a/cspell.json +++ b/cspell.json @@ -4,14 +4,12 @@ "words": [ "anni", "annihilations", - "argn", "arities", "arity", "arrayvec", "behaviour", "bitand", "Bitonic", - "bsort", "builtins", "callcc", "chumsky", @@ -23,14 +21,12 @@ "concat", "ctrs", "cuda", - "Dall", "datatypes", "Deque", "destructures", "desugared", "desugars", "devs", - "dref", "dups", "effectful", "elif", @@ -39,8 +35,6 @@ "hasher", "hexdigit", "hvm's", - "hvmc", - "iexp", "indexmap", "inet", "inets", @@ -60,7 +54,6 @@ "linearization", "linearizes", "linearizing", - "lnet", "lnil", "lpthread", "mant", @@ -75,9 +68,6 @@ "nums", "OOM's", "oper", - "opre", - "oprune", - "oref", "parallelizable", "peekable", "postcondition", @@ -118,12 +108,15 @@ "TSPL", "tunr", "unbounds", + "undefer", "vectorize", "vectorizes", "walkdir", - "wopts" ], - "files": ["**/*.rs", "**/*.md"], + "files": [ + "**/*.rs", + "**/*.md" + ], "ignoreRegExpList": [ "HexValues", "/λ/g", diff --git a/docs/syntax.md b/docs/syntax.md index 7f469d83..5967a00e 100644 --- a/docs/syntax.md +++ b/docs/syntax.md @@ -362,22 +362,40 @@ with Result: return x ``` -A monadic with block. +A monadic `with` block. Where `x <- ...` performs a monadic operation. -Expects `Result` to be a type defined with `type` and a function `Result/bind` to be defined. +Expects `Result` to be a type defined with `type` or `object` and the function `Result/bind` to be defined. The monadic bind function should be of type `(Result a) -> (a -> Result b) -> Result b`, like this: ```python def Result/bind(res, nxt): match res: case Result/Ok: + nxt = undefer(nxt) return nxt(res.value) case Result/Err: return res ``` +However, the second argument, `nxt`, is actually a deferred call to the continuation, passing any free variables as arguments. +Therefore, all `bind` functions must call the builtin function `undefer` before using the value of `nxt`, as in the example above. +This is necessary to ensure that the continuation in recursive monadic functions stays lazy and doesn't expand infinitely. + +This is an example of a recursive function that would loop if passing the variable `a` to the recursive call `Result/foo(a, b)` was not deferred: +```python +def Result/foo(x, y): + with Result: + a <- Result/Ok(1) + if b: + b = Result/Err(x) + else: + b = Result/Ok(y) + b <- b + return Result/foo(a, b) +``` + Other statements are allowed inside the `with` block and it can both return a value at the end and bind a variable, like branching statements do. ```python @@ -389,7 +407,7 @@ return y ``` The name `wrap` is bound inside a `with` block as a shorthand for `Type/wrap`, -the equivalent as a `pure` function in other functional languages: +and it calls the unit function of the monad, also called `pure` in some languages: ```python def Result/wrap(x): @@ -968,8 +986,8 @@ match x { ### With block ```rust -Result/bind (Result/Ok val) f = (f val) -Result/bind err _ = err +Result/bind (Result/Ok val) nxt = (nxt val) +Result/bind err _nxt = err div a b = switch b { 0: (Result/Err "Div by 0") @@ -991,6 +1009,23 @@ Main = with Result { Receives a type defined with `type` and expects `Result/bind` to be defined as a monadic bind function. It should be of type `(Result a) -> (a -> Result b) -> Result b`, like in the example above. +However, the second argument, `nxt`, is actually a deferred call to the continuation, passing any free variables as arguments. +Therefore, all `bind` functions must call the builtin function `undefer` before using the value of `nxt`, as in the example above. +This is necessary to ensure that the continuation in recursive monadic functions stays lazy and doesn't expand infinitely. + +This is an example of a recursive function that would loop if passing the variable `a` to the recursive call `Result/foo(a, b)` was not deferred: +```python +Result/foo x y = with Result { + ask a = (Result/Ok 1) + ask b = if b { + (Result/Err x) + } else { + (Result/Ok y) + } + (Result/foo a b) +} +``` + Inside a `with` block, you can use `ask`, to access the continuation value of the monadic operation. ```rust diff --git a/src/fun/builtins.bend b/src/fun/builtins.bend index f9ec3d4f..d198887e 100644 --- a/src/fun/builtins.bend +++ b/src/fun/builtins.bend @@ -73,6 +73,7 @@ def IO/wrap(x): def IO/bind(a, b): match a: case IO/Done: + b = undefer(b) return b(a.expr) case IO/Call: return IO/Call(IO/MAGIC, a.func, a.argm, lambda x: IO/bind(a.cont(x), b)) @@ -89,3 +90,15 @@ print text = (IO/Call IO/MAGIC "PUT_TEXT" text @x (IO/Done IO/MAGIC x)) get_time = (IO/Call IO/MAGIC "GET_TIME" * @x (IO/Done IO/MAGIC x)) sleep hi_lo = (IO/Call IO/MAGIC "PUT_TIME" hi_lo @x (IO/Done IO/MAGIC x)) + + +# Lazy thunks +# We can defer the evaluation of a function by wrapping it in a thunk +# Ex: @x (x @arg1 @arg2 @arg3 (f arg1 arg2 arg3) arg1 arg2 arg3) +# This is only evaluated when we call it with 'undefer' (undefer my_thunk) +# We can build a defered call directly or by by using defer and defer_arg +# The example above can be written as: +# (defer_arg (defer_arg (defer_arg (defer @arg1 @arg2 @arg3 (f arg1 arg2 arg3)) arg1) arg2) arg3) +defer val = @x (x val) +defer_arg defered arg = @x (defered x arg) +undefer defered = (defered @x x) \ No newline at end of file diff --git a/src/fun/mod.rs b/src/fun/mod.rs index 4a1ad00b..93259597 100644 --- a/src/fun/mod.rs +++ b/src/fun/mod.rs @@ -2,7 +2,6 @@ use crate::{ diagnostics::{Diagnostics, DiagnosticsConfig}, maybe_grow, multi_iterator, ENTRY_POINT, }; -// use hvmc::ast::get_typ; use indexmap::{IndexMap, IndexSet}; use interner::global::{GlobalPool, GlobalString}; use itertools::Itertools; diff --git a/src/fun/transform/desugar_do_blocks.rs b/src/fun/transform/desugar_with_blocks.rs similarity index 61% rename from src/fun/transform/desugar_do_blocks.rs rename to src/fun/transform/desugar_with_blocks.rs index fedc1b2f..a290ff59 100644 --- a/src/fun/transform/desugar_do_blocks.rs +++ b/src/fun/transform/desugar_with_blocks.rs @@ -1,20 +1,20 @@ use crate::{ diagnostics::Diagnostics, - fun::{Ctx, Name, Term}, + fun::{Ctx, Name, Pattern, Term}, maybe_grow, }; use std::collections::HashSet; impl Ctx<'_> { /// Converts `ask` terms inside `do` blocks into calls to a monadic bind operation. - pub fn desugar_do_blocks(&mut self) -> Result<(), Diagnostics> { + pub fn desugar_with_blocks(&mut self) -> Result<(), Diagnostics> { self.info.start_pass(); let def_names = self.book.defs.keys().cloned().collect::>(); for def in self.book.defs.values_mut() { for rule in def.rules.iter_mut() { - if let Err(e) = rule.body.desugar_do_blocks(None, &def_names) { + if let Err(e) = rule.body.desugar_with_blocks(None, &def_names) { self.info.add_rule_error(e, def.name.clone()); } } @@ -25,14 +25,14 @@ impl Ctx<'_> { } impl Term { - pub fn desugar_do_blocks( + pub fn desugar_with_blocks( &mut self, cur_block: Option<&Name>, def_names: &HashSet, ) -> Result<(), String> { maybe_grow(|| { if let Term::With { typ, bod } = self { - bod.desugar_do_blocks(Some(typ), def_names)?; + bod.desugar_with_blocks(Some(typ), def_names)?; let wrap_ref = Term::r#ref(&format!("{typ}/wrap")); // let wrap_ref = if def_names.contains(&wrap_nam) { // Term::r#ref(&wrap_nam) @@ -47,8 +47,9 @@ impl Term { let bind_nam = Name::new(format!("{typ}/bind")); if def_names.contains(&bind_nam) { - // TODO: come up with a strategy for forwarding free vars to prevent infinite recursion. let nxt = Term::lam(*pat.clone(), std::mem::take(nxt)); + let nxt = nxt.defer(); + *self = Term::call(Term::Ref { nam: bind_nam }, [*val.clone(), nxt]); } else { return Err(format!("Could not find definition {bind_nam} for type {typ}.")); @@ -59,10 +60,25 @@ impl Term { } for children in self.children_mut() { - children.desugar_do_blocks(cur_block, def_names)?; + children.desugar_with_blocks(cur_block, def_names)?; } Ok(()) }) } + + /// Converts a term with free vars `(f x1 .. xn)` into a deferred + /// call that passes those vars to the term. + /// + /// Ex: `(f x1 .. xn)` becomes `@x (x @x1 .. @xn (f x1 .. x2) x1 .. x2)`. + /// + /// The user must call this lazy thunk by calling the builtin + /// `undefer` function, or by applying `@x x` to the term. + fn defer(self) -> Term { + let free_vars = self.free_vars().into_keys().collect::>(); + let term = Term::rfold_lams(self, free_vars.iter().cloned().map(Some)); + let term = Term::call(Term::Var { nam: Name::new("%x") }, [term]); + let term = Term::call(term, free_vars.iter().cloned().map(|nam| Term::Var { nam })); + Term::lam(Pattern::Var(Some(Name::new("%x"))), term) + } } diff --git a/src/fun/transform/mod.rs b/src/fun/transform/mod.rs index 90456baf..e4d4ced2 100644 --- a/src/fun/transform/mod.rs +++ b/src/fun/transform/mod.rs @@ -2,11 +2,11 @@ pub mod apply_args; pub mod definition_merge; pub mod definition_pruning; pub mod desugar_bend; -pub mod desugar_do_blocks; pub mod desugar_fold; pub mod desugar_match_defs; pub mod desugar_open; pub mod desugar_use; +pub mod desugar_with_blocks; pub mod encode_adts; pub mod encode_match_terms; pub mod expand_generated; diff --git a/src/lib.rs b/src/lib.rs index 4a5b91a0..42c15705 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,7 +108,7 @@ pub fn desugar_book( ctx.desugar_bend()?; ctx.desugar_fold()?; - ctx.desugar_do_blocks()?; + ctx.desugar_with_blocks()?; ctx.check_unbound_vars()?; @@ -305,10 +305,10 @@ impl OptLevel { #[derive(Clone, Debug)] pub struct CompileOpts { - /// Enables [hvmc::transform::eta_reduce]. + /// Enables [hvm::eta_reduce]. pub eta: bool, - /// Enables [fun::transform::definition_pruning] and [hvmc_net::prune]. + /// Enables [fun::transform::definition_pruning] and [hvm::prune]. pub prune: bool, /// Enables [fun::transform::linearize_matches]. @@ -320,7 +320,7 @@ pub struct CompileOpts { /// Enables [fun::transform::definition_merge] pub merge: bool, - /// Enables [hvmc::transform::inline]. + /// Enables [hvm::inline]. pub inline: bool, /// Enables [hvm::check_net_size]. diff --git a/src/main.rs b/src/main.rs index 4f27c9f0..35d61b64 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,7 +49,7 @@ enum Mode { RunC(RunArgs), /// Compiles the program and runs it with the Cuda HVM implementation. RunCu(RunArgs), - /// Compiles the program to hvmc and prints to stdout. + /// Compiles the program to hvm and prints to stdout. GenHvm(GenArgs), /// Compiles the program to standalone C and prints to stdout. GenC(GenArgs), diff --git a/tests/golden_tests.rs b/tests/golden_tests.rs index a5394502..ef313780 100644 --- a/tests/golden_tests.rs +++ b/tests/golden_tests.rs @@ -235,7 +235,7 @@ fn simplify_matches() { ctx.fix_match_terms()?; ctx.desugar_bend()?; ctx.desugar_fold()?; - ctx.desugar_do_blocks()?; + ctx.desugar_with_blocks()?; ctx.check_unbound_vars()?; ctx.book.make_var_names_unique(); ctx.book.linearize_match_binds(); @@ -284,7 +284,7 @@ fn encode_pattern_match() { ctx.fix_match_terms()?; ctx.desugar_bend()?; ctx.desugar_fold()?; - ctx.desugar_do_blocks()?; + ctx.desugar_with_blocks()?; ctx.check_unbound_vars()?; ctx.book.make_var_names_unique(); ctx.book.linearize_match_binds(); diff --git a/tests/golden_tests/desugar_file/bind_syntax.bend b/tests/golden_tests/desugar_file/bind_syntax.bend index 885a5367..ee706c23 100644 --- a/tests/golden_tests/desugar_file/bind_syntax.bend +++ b/tests/golden_tests/desugar_file/bind_syntax.bend @@ -1,6 +1,6 @@ type Result = (Ok val) | (Err val) -Result/bind (Result/Ok val) f = (f val) +Result/bind (Result/Ok val) f = ((undefer f) val) Result/bind err _ = err safe_div a b = switch b { diff --git a/tests/golden_tests/run_file/do_block_mixed.bend b/tests/golden_tests/run_file/do_block_mixed.bend index 1dd9cbfc..20948481 100644 --- a/tests/golden_tests/run_file/do_block_mixed.bend +++ b/tests/golden_tests/run_file/do_block_mixed.bend @@ -2,7 +2,7 @@ type Result = (Ok val) | (Err val) Result/bind r nxt = match r { - Result/Ok: (nxt r.val) + Result/Ok: ((undefer nxt) r.val) Result/Err: r } diff --git a/tests/golden_tests/run_file/recursive_bind.bend b/tests/golden_tests/run_file/recursive_bind.bend index 18e6c7e0..c753ff58 100644 --- a/tests/golden_tests/run_file/recursive_bind.bend +++ b/tests/golden_tests/run_file/recursive_bind.bend @@ -1,6 +1,6 @@ type Result = (Ok val) | (Err val) -Result/bind (Result/Ok val) f = (f val) +Result/bind (Result/Ok val) f = ((undefer f) val) Result/bind err _ = err Bar x = (Result/Err 0) diff --git a/tests/golden_tests/run_file/strict_monad_fn.bend b/tests/golden_tests/run_file/strict_monad_fn.bend new file mode 100644 index 00000000..f63c7564 --- /dev/null +++ b/tests/golden_tests/run_file/strict_monad_fn.bend @@ -0,0 +1,15 @@ +# This will only work if we make the call to `(Result/foo a b)` lazy (by converting it to a combinator). +type Result = (Ok val) | (Err val) + +Result/bind = @val @nxt match val { + Result/Ok: ((undefer nxt) val.val) + Result/Err: (Result/Err val.val) +} +Result/foo x y = + with Result { + ask a = (Result/Ok x) + ask b = switch y { 0: (Result/Err a); _: (Result/Ok y-1) } + (Result/foo a b) + } + +main = (Result/foo 1 2) \ No newline at end of file diff --git a/tests/snapshots/desugar_file__bind_syntax.bend.snap b/tests/snapshots/desugar_file__bind_syntax.bend.snap index 8071bb82..083a10aa 100644 --- a/tests/snapshots/desugar_file__bind_syntax.bend.snap +++ b/tests/snapshots/desugar_file__bind_syntax.bend.snap @@ -2,13 +2,15 @@ source: tests/golden_tests.rs input_file: tests/golden_tests/desugar_file/bind_syntax.bend --- +(undefer) = λa (a λb b) + (Result/bind) = λa λb (a Result/bind__C2 b) (safe_div) = λa λb (switch b { 0: λ* (Result/Err (String/Cons 68 (String/Cons 105 (String/Cons 118 (String/Cons 32 (String/Cons 98 (String/Cons 121 (String/Cons 32 (String/Cons 48 String/Nil))))))))); _: safe_div__C0; } a) (safe_rem) = λa λb (switch b { 0: λ* (Result/Err (String/Cons 77 (String/Cons 111 (String/Cons 100 (String/Cons 32 (String/Cons 98 (String/Cons 121 (String/Cons 32 (String/Cons 48 String/Nil))))))))); _: safe_rem__C0; } a) -(Main) = (Result/bind Main__C1 Main__C0) +(Main) = (Result/bind Main__C3 Main__C2) (String/Nil) = λa (a String/Nil/tag) @@ -26,11 +28,15 @@ input_file: tests/golden_tests/desugar_file/bind_syntax.bend (Result/Err/tag) = 1 -(Main__C0) = λa (Result/bind (safe_rem a 0) λb b) +(Main__C0) = λa (a λb b) -(Main__C1) = (safe_div 3 2) +(Main__C1) = λa (Result/bind (safe_rem a 0) Main__C0) -(Result/bind__C0) = λa λb (b a) +(Main__C2) = λa (a Main__C1) + +(Main__C3) = (safe_div 3 2) + +(Result/bind__C0) = λa λb (undefer b a) (Result/bind__C1) = λ* λa λ* (Result/Err a) diff --git a/tests/snapshots/run_file__recursive_bind.bend.snap b/tests/snapshots/run_file__recursive_bind.bend.snap index 3ea02af3..747c8c20 100644 --- a/tests/snapshots/run_file__recursive_bind.bend.snap +++ b/tests/snapshots/run_file__recursive_bind.bend.snap @@ -2,30 +2,8 @@ source: tests/golden_tests.rs input_file: tests/golden_tests/run_file/recursive_bind.bend --- -Errors: -The following functions contain recursive cycles incompatible with HVM's strict evaluation: - * Foo -> Foo +NumScott: +λa (a Result/Err/tag 0) -The greedy eager evaluation of HVM may cause infinite loops. -Refactor these functions to use lazy references instead of direct function calls. -A reference is strict when it's being called ('(Foo x)') or when it's used non-linearly ('let x = Foo; (x x)'). -It is lazy when it's an argument ('(x Foo)') or when it's used linearly ('let x = Foo; (x 0)'). - -Try one of these strategies: -- Use pattern matching with 'match', 'fold', and 'bend' to automatically lift expressions to lazy references. -- Replace direct calls with combinators. For example, change: - 'Foo = λa λb (b (λc (Foo a c)) a)' - to: - 'Foo = λa λb (b (λc λa (Foo a c)) (λa a) a)' - which is lifted to: - 'Foo = λa λb (b Foo__C1 Foo__C2 a)' -- Replace non-linear 'let' expressions with 'use' expressions. For example, change: - 'Foo = λf let x = Foo; (f x x)' - to: - 'Foo = λf use x = Foo; (f x x)' - which inlines to: - 'Foo = λf (f Foo Foo)' -- If disabled, re-enable the default 'float-combinators' and 'linearize-matches' compiler options. - -For more information, visit: https://github.com/HigherOrderCO/Bend/blob/main/docs/lazy-definitions.md. -To disable this check, use the "-Arecursion-cycle" compiler option. +Scott: +λ* λa (a 0) diff --git a/tests/snapshots/run_file__strict_monad_fn.bend.snap b/tests/snapshots/run_file__strict_monad_fn.bend.snap new file mode 100644 index 00000000..c5e0607d --- /dev/null +++ b/tests/snapshots/run_file__strict_monad_fn.bend.snap @@ -0,0 +1,9 @@ +--- +source: tests/golden_tests.rs +input_file: tests/golden_tests/run_file/strict_monad_fn.bend +--- +NumScott: +λa (a Result/Err/tag 1) + +Scott: +λ* λa (a 1)