diff --git a/rust/ares/src/interpreter.rs b/rust/ares/src/interpreter.rs index 87e69c2..67dd162 100644 --- a/rust/ares/src/interpreter.rs +++ b/rust/ares/src/interpreter.rs @@ -9,6 +9,7 @@ use crate::jets::warm::Warm; use crate::jets::JetErr; use crate::mem::unifying_equality; use crate::mem::NockStack; +use crate::mem::Preserve; use crate::newt::Newt; use crate::noun; use crate::noun::{Atom, Cell, IndirectAtom, Noun, Slots, D, T}; @@ -282,6 +283,26 @@ impl Context { self.cold = saved.cold; self.warm = saved.warm; } + + /** + * For jets that need a stack frame internally. + * + * This ensures that the frame is cleaned up even if the closure short-circuites to an error + * result using e.g. the ? syntax. We need this method separately from with_frame to allow the + * jet to use the entire context without the borrow checker complaining about the mutable + * references. + */ + pub unsafe fn with_stack_frame(&mut self, slots: usize, f: F) -> O + where + F: FnOnce(&mut Context) -> O, + O: Preserve, + { + self.stack.frame_push(slots); + let mut ret = f(self); + ret.preserve(&mut self.stack); + self.stack.frame_pop(); + ret + } } #[derive(Clone, Copy, Debug)] @@ -292,6 +313,26 @@ pub enum Error { NonDeterministic(Noun), // trace } +impl Preserve for Error { + unsafe fn preserve(&mut self, stack: &mut NockStack) { + match self { + Error::ScryBlocked(ref mut path) => path.preserve(stack), + Error::ScryCrashed(ref mut trace) => trace.preserve(stack), + Error::Deterministic(ref mut trace) => trace.preserve(stack), + Error::NonDeterministic(ref mut trace) => trace.preserve(stack), + } + } + + unsafe fn assert_in_stack(&self, stack: &NockStack) { + match self { + Error::ScryBlocked(ref path) => path.assert_in_stack(stack), + Error::ScryCrashed(ref trace) => trace.assert_in_stack(stack), + Error::Deterministic(ref trace) => trace.assert_in_stack(stack), + Error::NonDeterministic(ref trace) => trace.assert_in_stack(stack), + } + } +} + impl From for Error { fn from(_: noun::Error) -> Self { Error::Deterministic(D(0)) diff --git a/rust/ares/src/jets.rs b/rust/ares/src/jets.rs index 9f3b4b2..5bff6b1 100644 --- a/rust/ares/src/jets.rs +++ b/rust/ares/src/jets.rs @@ -34,7 +34,7 @@ use crate::jets::sort::*; use crate::jets::tree::*; use crate::jets::warm::Warm; -use crate::mem::NockStack; +use crate::mem::{NockStack, Preserve}; use crate::newt::Newt; use crate::noun::{self, Noun, Slots, D}; use ares_macros::tas; @@ -55,6 +55,22 @@ pub enum JetErr { Fail(Error), // Error; do not retry } +impl Preserve for JetErr { + unsafe fn preserve(&mut self, stack: &mut NockStack) { + match self { + JetErr::Punt => {} + JetErr::Fail(ref mut err) => err.preserve(stack), + } + } + + unsafe fn assert_in_stack(&self, stack: &NockStack) { + match self { + JetErr::Punt => {} + JetErr::Fail(ref err) => err.assert_in_stack(stack), + } + } +} + impl From for JetErr { fn from(err: Error) -> Self { Self::Fail(err) diff --git a/rust/ares/src/mem.rs b/rust/ares/src/mem.rs index f53b63b..9ba1997 100644 --- a/rust/ares/src/mem.rs +++ b/rust/ares/src/mem.rs @@ -663,10 +663,12 @@ impl NockStack { */ pub unsafe fn with_frame(&mut self, num_locals: usize, f: F) -> O where - F: FnOnce() -> O, + F: FnOnce(&mut NockStack) -> O, + O: Preserve, { self.frame_push(num_locals); - let ret = f(); + let mut ret = f(self); + ret.preserve(self); self.frame_pop(); ret } @@ -1133,3 +1135,19 @@ impl Stack for NockStack { self.layout_alloc(layout) } } + +impl Preserve for Result { + unsafe fn preserve(&mut self, stack: &mut NockStack) { + match self.as_mut() { + Ok(t_ref) => t_ref.preserve(stack), + Err(e_ref) => e_ref.preserve(stack), + } + } + + unsafe fn assert_in_stack(&self, stack: &NockStack) { + match self.as_ref() { + Ok(t_ref) => t_ref.assert_in_stack(stack), + Err(e_ref) => e_ref.assert_in_stack(stack), + } + } +}