diff --git a/engine/runtime/src/bench/scala/org/enso/interpreter/bench/fixtures/semantic/RecursionFixtures.scala b/engine/runtime/src/bench/scala/org/enso/interpreter/bench/fixtures/semantic/RecursionFixtures.scala index 054a3533cf..cab0badefa 100644 --- a/engine/runtime/src/bench/scala/org/enso/interpreter/bench/fixtures/semantic/RecursionFixtures.scala +++ b/engine/runtime/src/bench/scala/org/enso/interpreter/bench/fixtures/semantic/RecursionFixtures.scala @@ -50,14 +50,13 @@ class RecursionFixtures extends DefaultInterpreterRunner { val sumStateTCOCode = """ - |main = sumTo -> - | stateSum = n -> - | acc = State.get - | State.put (acc + n) - | if n == 0 then State.get else stateSum (n - 1) + |stateSum = n -> + | acc = State.get Number + | State.put Number (acc + n) + | if n == 0 then State.get Number else here.stateSum (n - 1) | - | State.put 0 - | res = stateSum sumTo + |main = sumTo -> + | res = State.run Number 0 (here.stateSum sumTo) | res |""".stripMargin val sumStateTCO = getMain(sumStateTCOCode) @@ -75,18 +74,17 @@ class RecursionFixtures extends DefaultInterpreterRunner { val nestedThunkSumCode = """ + |doNTimes = n -> ~block -> + | block + | if n == 1 then Unit else here.doNTimes n-1 block + | |main = n -> - | doNTimes = n -> ~block -> - | block - | if n == 1 then Unit else doNTimes n-1 block - | | block = - | x = State.get - | State.put x+1 + | x = State.get Number + | State.put Number x+1 | - | State.put 0 - | doNTimes n block - | State.get + | res = State.run Number 0 (here.doNTimes n block) + | res |""".stripMargin val nestedThunkSum = getMain(nestedThunkSumCode) } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/CaseNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/CaseNode.java index 15e0913014..8fcc4cfb3b 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/CaseNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/CaseNode.java @@ -50,7 +50,7 @@ public abstract class CaseNode extends ExpressionNode { /** * Forwards an error in the case's scrutinee. * - * It is important that this is the first specialization. + *

It is important that this is the first specialization. * * @param frame the stack frame in which to execute * @param error the error being matched against @@ -81,7 +81,8 @@ public abstract class CaseNode extends ExpressionNode { } CompilerDirectives.transferToInterpreter(); throw new PanicException( - ctx.get().getBuiltins().inexhaustivePatternMatchError().newInstance(object), this); + ctx.get().getBuiltins().error().inexhaustivePatternMatchError().newInstance(object), + this); } catch (BranchSelectedException e) { // Note [Branch Selection Control Flow] frame.setObject(getStateFrameSlot(), e.getResult().getState()); diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/GetStateNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/GetStateNode.java index 9b40fd4971..fdfbca0a79 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/GetStateNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/GetStateNode.java @@ -1,15 +1,71 @@ package org.enso.interpreter.node.expression.builtin.state; +import com.oracle.truffle.api.TruffleLanguage; +import com.oracle.truffle.api.dsl.*; import com.oracle.truffle.api.nodes.Node; +import org.enso.interpreter.Language; import org.enso.interpreter.dsl.BuiltinMethod; import org.enso.interpreter.dsl.MonadicState; +import org.enso.interpreter.runtime.Context; +import org.enso.interpreter.runtime.state.data.EmptyMap; +import org.enso.interpreter.runtime.state.data.SingletonMap; +import org.enso.interpreter.runtime.state.data.SmallMap; +import org.enso.interpreter.runtime.error.PanicException; @BuiltinMethod( type = "State", name = "get", description = "Returns the current value of monadic state.") -public class GetStateNode extends Node { - Object execute(@MonadicState Object state, Object _this) { - return state; +@ImportStatic(SmallMap.class) +@ReportPolymorphism +public abstract class GetStateNode extends Node { + static GetStateNode build() { + return GetStateNodeGen.create(); + } + + abstract Object execute(@MonadicState Object state, Object _this, Object key); + + @Specialization(guards = {"state.getKey() == key"}) + Object doSingleton(SingletonMap state, Object _this, Object key) { + return state.getValue(); + } + + @Specialization( + guards = {"state.getKeys() == cachedKeys", "key == cachedKey", "idx != NOT_FOUND"}) + Object doMultiCached( + SmallMap state, + Object _this, + Object key, + @Cached("key") Object cachedKey, + @Cached(value = "state.getKeys()", dimensions = 1) Object[] cachedKeys, + @Cached("state.indexOf(key)") int idx) { + return state.getValues()[idx]; + } + + @Specialization + Object doMultiUncached( + SmallMap state, + Object _this, + Object key, + @CachedContext(Language.class) TruffleLanguage.ContextReference ctxRef) { + int idx = state.indexOf(key); + if (idx == SmallMap.NOT_FOUND) { + throw new PanicException( + ctxRef.get().getBuiltins().error().unitializedState().newInstance(key), this); + } else { + return state.getValues()[idx]; + } + } + + @Specialization + Object doEmpty( + EmptyMap state, Object _this, Object key, @CachedContext(Language.class) Context ctx) { + throw new PanicException(ctx.getBuiltins().error().unitializedState().newInstance(key), this); + } + + @Specialization + Object doSingletonError( + SingletonMap state, Object _this, Object key, @CachedContext(Language.class) Context ctx) { + throw new PanicException(ctx.getBuiltins().error().unitializedState().newInstance(key), this); } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/PutStateNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/PutStateNode.java index 7e4133e1a4..e138b1871d 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/PutStateNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/PutStateNode.java @@ -1,13 +1,72 @@ package org.enso.interpreter.node.expression.builtin.state; +import com.oracle.truffle.api.TruffleLanguage; +import com.oracle.truffle.api.dsl.*; import com.oracle.truffle.api.nodes.Node; +import org.enso.interpreter.Language; import org.enso.interpreter.dsl.BuiltinMethod; import org.enso.interpreter.dsl.MonadicState; +import org.enso.interpreter.runtime.Context; +import org.enso.interpreter.runtime.state.data.SingletonMap; +import org.enso.interpreter.runtime.state.data.SmallMap; +import org.enso.interpreter.runtime.error.PanicException; import org.enso.interpreter.runtime.state.Stateful; @BuiltinMethod(type = "State", name = "put", description = "Updates the value of monadic state.") -public class PutStateNode extends Node { - Stateful execute(@MonadicState Object state, Object _this, Object new_state) { - return new Stateful(new_state, state); +@ImportStatic(SmallMap.class) +@ReportPolymorphism +public abstract class PutStateNode extends Node { + static PutStateNode build() { + return PutStateNodeGen.create(); + } + + abstract Stateful execute(@MonadicState Object state, Object _this, Object key, Object new_state); + + @Specialization(guards = "state.getKey() == key") + Stateful doExistingSingleton(SingletonMap state, Object _this, Object key, Object new_state) { + return new Stateful(new SingletonMap(key, new_state), new_state); + } + + @Specialization( + guards = {"state.getKeys() == cachedKeys", "index != NOT_FOUND", "key == cachedKey"}) + Stateful doExistingMultiCached( + SmallMap state, + Object _this, + Object key, + Object new_state, + @Cached("key") Object cachedKey, + @Cached(value = "state.getKeys()", dimensions = 1) Object[] cachedKeys, + @Cached("state.indexOf(key)") int index) { + Object[] newVals = new Object[cachedKeys.length]; + System.arraycopy(state.getValues(), 0, newVals, 0, cachedKeys.length); + newVals[index] = new_state; + SmallMap newStateMap = new SmallMap(cachedKeys, newVals); + return new Stateful(newStateMap, new_state); + } + + @Specialization + Stateful doMultiUncached( + SmallMap state, + Object _this, + Object key, + Object new_state, + @CachedContext(Language.class) TruffleLanguage.ContextReference ctxRef) { + int index = state.indexOf(key); + if (index == SmallMap.NOT_FOUND) { + throw new PanicException( + ctxRef.get().getBuiltins().error().unitializedState().newInstance(key), this); + } else { + return doExistingMultiCached(state, _this, key, new_state, key, state.getKeys(), index); + } + } + + @Specialization + Stateful doError( + Object state, + Object _this, + Object key, + Object new_state, + @CachedContext(Language.class) Context ctx) { + throw new PanicException(ctx.getBuiltins().error().unitializedState().newInstance(key), this); } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/RunStateNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/RunStateNode.java index 451f1f99e8..d56e8324e8 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/RunStateNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/state/RunStateNode.java @@ -1,18 +1,163 @@ package org.enso.interpreter.node.expression.builtin.state; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.ImportStatic; +import com.oracle.truffle.api.dsl.ReportPolymorphism; +import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.nodes.Node; import org.enso.interpreter.dsl.BuiltinMethod; +import org.enso.interpreter.dsl.MonadicState; import org.enso.interpreter.node.callable.thunk.ThunkExecutorNode; import org.enso.interpreter.runtime.callable.argument.Thunk; +import org.enso.interpreter.runtime.state.data.EmptyMap; +import org.enso.interpreter.runtime.state.data.SingletonMap; +import org.enso.interpreter.runtime.state.data.SmallMap; +import org.enso.interpreter.runtime.state.Stateful; @BuiltinMethod( type = "State", name = "run", description = "Runs a stateful computation in a local state environment.") -public class RunStateNode extends Node { +@ReportPolymorphism +@ImportStatic(SmallMap.class) +public abstract class RunStateNode extends Node { + static RunStateNode build() { + return RunStateNodeGen.create(); + } + private @Child ThunkExecutorNode thunkExecutorNode = ThunkExecutorNode.build(); - Object execute(Object _this, Object local_state, Thunk computation) { - return thunkExecutorNode.executeThunk(computation, local_state, false).getValue(); + abstract Stateful execute( + @MonadicState Object state, Object _this, Object key, Object local_state, Thunk computation); + + @Specialization + Stateful doEmpty( + EmptyMap state, Object _this, Object key, Object local_state, Thunk computation) { + SingletonMap localStateMap = new SingletonMap(key, local_state); + Object result = thunkExecutorNode.executeThunk(computation, localStateMap, false).getValue(); + return new Stateful(state, result); + } + + @Specialization(guards = {"state.getKey() == key"}) + Stateful doSingletonSameKey( + SingletonMap state, Object _this, Object key, Object local_state, Thunk computation) { + SingletonMap localStateContainer = new SingletonMap(state.getKey(), local_state); + Stateful res = thunkExecutorNode.executeThunk(computation, localStateContainer, false); + return new Stateful(state, res.getValue()); + } + + @Specialization( + guards = { + "key == cachedNewKey", + "state.getKey() == cachedOldKey", + "cachedOldKey != cachedNewKey" + }) + Stateful doSingletonNewKeyCached( + SingletonMap state, + Object _this, + Object key, + Object local_state, + Thunk computation, + @Cached("key") Object cachedNewKey, + @Cached("state.getKey()") Object cachedOldKey, + @Cached(value = "buildSmallKeys(cachedNewKey, cachedOldKey)", dimensions = 1) + Object[] newKeys) { + SmallMap localStateMap = new SmallMap(newKeys, new Object[] {local_state, state.getValue()}); + Stateful res = thunkExecutorNode.executeThunk(computation, localStateMap, false); + Object newStateVal = ((SmallMap) res.getState()).getValues()[1]; + return new Stateful(new SingletonMap(cachedOldKey, newStateVal), res.getValue()); + } + + @Specialization + Stateful doSingletonNewKeyUncached( + SingletonMap state, Object _this, Object key, Object local_state, Thunk computation) { + return doSingletonNewKeyCached( + state, + _this, + key, + local_state, + computation, + key, + state.getKey(), + buildSmallKeys(key, state.getKey())); + } + + Object[] buildSmallKeys(Object k1, Object k2) { + return new Object[] {k1, k2}; + } + + @Specialization( + guards = {"key == cachedNewKey", "state.getKeys() == cachedOldKeys", "index == NOT_FOUND"}) + Stateful doMultiNewKeyCached( + SmallMap state, + Object _this, + Object key, + Object local_state, + Thunk computation, + @Cached("key") Object cachedNewKey, + @Cached(value = "state.getKeys()", dimensions = 1) Object[] cachedOldKeys, + @Cached("state.indexOf(key)") int index, + @Cached(value = "buildNewKeys(cachedNewKey, cachedOldKeys)", dimensions = 1) + Object[] newKeys) { + Object[] newValues = new Object[newKeys.length]; + System.arraycopy(state.getValues(), 0, newValues, 1, cachedOldKeys.length); + newValues[0] = local_state; + SmallMap localStateMap = new SmallMap(newKeys, newValues); + Stateful res = thunkExecutorNode.executeThunk(computation, localStateMap, false); + SmallMap resultStateMap = (SmallMap) res.getState(); + Object[] resultValues = new Object[cachedOldKeys.length]; + System.arraycopy(resultStateMap.getValues(), 1, resultValues, 0, cachedOldKeys.length); + return new Stateful(new SmallMap(cachedOldKeys, resultValues), res.getValue()); + } + + @Specialization( + guards = {"key == cachedNewKey", "state.getKeys() == cachedOldKeys", "index != NOT_FOUND"}) + Stateful doMultiExistingKeyCached( + SmallMap state, + Object _this, + Object key, + Object local_state, + Thunk computation, + @Cached("key") Object cachedNewKey, + @Cached(value = "state.getKeys()", dimensions = 1) Object[] cachedOldKeys, + @Cached("state.indexOf(key)") int index) { + Object[] newValues = new Object[cachedOldKeys.length]; + System.arraycopy(state.getValues(), 0, newValues, 0, cachedOldKeys.length); + newValues[index] = local_state; + SmallMap localStateMap = new SmallMap(cachedOldKeys, newValues); + Stateful res = thunkExecutorNode.executeThunk(computation, localStateMap, false); + SmallMap resultStateMap = (SmallMap) res.getState(); + Object[] resultValues = new Object[cachedOldKeys.length]; + System.arraycopy(resultStateMap.getValues(), 0, resultValues, 0, cachedOldKeys.length); + resultValues[index] = state.getValues()[index]; + return new Stateful(new SmallMap(cachedOldKeys, resultValues), res.getValue()); + } + + @Specialization + Stateful doMultiUncached( + SmallMap state, Object _this, Object key, Object local_state, Thunk computation) { + int idx = state.indexOf(key); + if (idx == SmallMap.NOT_FOUND) { + return doMultiNewKeyCached( + state, + _this, + key, + local_state, + computation, + key, + state.getKeys(), + idx, + buildNewKeys(key, state.getKeys())); + } else { + return doMultiExistingKeyCached( + state, _this, key, local_state, computation, key, state.getKeys(), idx); + } + } + + Object[] buildNewKeys(Object newKey, Object[] oldKeys) { + Object[] result = new Object[oldKeys.length + 1]; + System.arraycopy(oldKeys, 0, result, 1, oldKeys.length); + result[0] = newKey; + return result; } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Builtins.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Builtins.java index 51dcdea27c..c02f3cd258 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Builtins.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Builtins.java @@ -47,9 +47,7 @@ public class Builtins { private final AtomConstructor function; private final AtomConstructor text; private final AtomConstructor debug; - private final AtomConstructor syntaxError; - private final AtomConstructor compileError; - private final AtomConstructor inexhaustivePatternMatchError; + private final Error error; private final Bool bool; private final RootCallTarget interopDispatchRoot; @@ -70,21 +68,10 @@ public class Builtins { any = new AtomConstructor("Any", scope).initializeFields(); number = new AtomConstructor("Number", scope).initializeFields(); bool = new Bool(language, scope); + error = new Error(language, scope); function = new AtomConstructor("Function", scope).initializeFields(); text = new AtomConstructor("Text", scope).initializeFields(); debug = new AtomConstructor("Debug", scope).initializeFields(); - syntaxError = - new AtomConstructor("Syntax_Error", scope) - .initializeFields( - new ArgumentDefinition(0, "message", ArgumentDefinition.ExecutionMode.EXECUTE)); - compileError = - new AtomConstructor("Compile_Error", scope) - .initializeFields( - new ArgumentDefinition(0, "message", ArgumentDefinition.ExecutionMode.EXECUTE)); - inexhaustivePatternMatchError = - new AtomConstructor("Inexhaustive_Pattern_Match_Error", scope) - .initializeFields( - new ArgumentDefinition(0, "scrutinee", ArgumentDefinition.ExecutionMode.EXECUTE)); AtomConstructor nil = new AtomConstructor("Nil", scope).initializeFields(); AtomConstructor cons = @@ -118,9 +105,6 @@ public class Builtins { scope.registerConstructor(system); scope.registerConstructor(runtime); - scope.registerConstructor(syntaxError); - scope.registerConstructor(compileError); - scope.registerConstructor(java); scope.registerConstructor(thread); @@ -234,6 +218,11 @@ public class Builtins { return bool; } + /** @return the builtin Error types container. */ + public Error error() { + return error; + } + /** * Returns the {@code Any} atom constructor. * @@ -252,21 +241,6 @@ public class Builtins { return debug; } - /** @return the builtin {@code Syntax_Error} atom constructor. */ - public AtomConstructor syntaxError() { - return syntaxError; - } - - /** @return the builtin {@code Compile_Error} atom constructor. */ - public AtomConstructor compileError() { - return compileError; - } - - /** @return the builtin {@code Inexhaustive_Pattern_Match_Error} atom constructor. */ - public AtomConstructor inexhaustivePatternMatchError() { - return inexhaustivePatternMatchError; - } - /** * Returns the builtin module scope. * diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Error.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Error.java new file mode 100644 index 0000000000..d32f572f3c --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Error.java @@ -0,0 +1,66 @@ +package org.enso.interpreter.runtime.builtin; + +import org.enso.interpreter.Language; +import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition; +import org.enso.interpreter.runtime.callable.atom.AtomConstructor; +import org.enso.interpreter.runtime.scope.ModuleScope; + +/** + * Container for builtin Error types + */ +public class Error { + private final AtomConstructor syntaxError; + private final AtomConstructor compileError; + private final AtomConstructor inexhaustivePatternMatchError; + private final AtomConstructor unitializedState; + + /** + * Creates and registers the relevant constructors. + * + * @param language the current language instance. + * @param scope the scope to register constructors in. + */ + public Error(Language language, ModuleScope scope) { + syntaxError = + new AtomConstructor("Syntax_Error", scope) + .initializeFields( + new ArgumentDefinition(0, "message", ArgumentDefinition.ExecutionMode.EXECUTE)); + compileError = + new AtomConstructor("Compile_Error", scope) + .initializeFields( + new ArgumentDefinition(0, "message", ArgumentDefinition.ExecutionMode.EXECUTE)); + inexhaustivePatternMatchError = + new AtomConstructor("Inexhaustive_Pattern_Match_Error", scope) + .initializeFields( + new ArgumentDefinition(0, "scrutinee", ArgumentDefinition.ExecutionMode.EXECUTE)); + unitializedState = + new AtomConstructor("Uninitialized_State", scope) + .initializeFields( + new ArgumentDefinition(0, "key", ArgumentDefinition.ExecutionMode.EXECUTE)); + + scope.registerConstructor(syntaxError); + scope.registerConstructor(compileError); + scope.registerConstructor(inexhaustivePatternMatchError); + scope.registerConstructor(unitializedState); + } + + /** @return the builtin {@code Syntax_Error} atom constructor. */ + public AtomConstructor syntaxError() { + return syntaxError; + } + + /** @return the builtin {@code Compile_Error} atom constructor. */ + public AtomConstructor compileError() { + return compileError; + } + + /** @return the builtin {@code Inexhaustive_Pattern_Match_Error} atom constructor. */ + public AtomConstructor inexhaustivePatternMatchError() { + return inexhaustivePatternMatchError; + } + + /** @return the builtin {@code Uninitialized_State} atom constructor. */ + public AtomConstructor unitializedState() { + return unitializedState; + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/function/Function.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/function/Function.java index b46fb30e89..7b822cd3ff 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/function/Function.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/callable/function/Function.java @@ -25,6 +25,7 @@ import org.enso.interpreter.runtime.Context; import org.enso.interpreter.runtime.callable.CallerInfo; import org.enso.interpreter.runtime.callable.argument.ArgumentDefinition; import org.enso.interpreter.runtime.callable.argument.Thunk; +import org.enso.interpreter.runtime.state.data.EmptyMap; import org.enso.interpreter.runtime.data.Vector; import org.enso.interpreter.runtime.type.Types; import org.enso.polyglot.MethodNames; @@ -198,7 +199,7 @@ public final class Function implements TruffleObject { Object[] arguments, @Cached InteropApplicationNode interopApplicationNode, @CachedContext(Language.class) Context context) { - return interopApplicationNode.execute(function, context.getBuiltins().unit(), arguments); + return interopApplicationNode.execute(function, EmptyMap.create(), arguments); } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/EmptyMap.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/EmptyMap.java new file mode 100644 index 0000000000..2b388b5567 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/EmptyMap.java @@ -0,0 +1,15 @@ +package org.enso.interpreter.runtime.state.data; + +import com.oracle.truffle.api.interop.TruffleObject; + +/** A dummy type, denoting an empty map structure. */ +public final class EmptyMap implements TruffleObject { + private static final EmptyMap INSTANCE = new EmptyMap(); + + private EmptyMap() {} + + /** @return an instance of empty map. */ + public static EmptyMap create() { + return INSTANCE; + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/SingletonMap.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/SingletonMap.java new file mode 100644 index 0000000000..5a1505e657 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/SingletonMap.java @@ -0,0 +1,28 @@ +package org.enso.interpreter.runtime.state.data; + +/** An object representing a single key-value pairing. */ +public final class SingletonMap { + private final Object key; + private final Object value; + + /** + * Creates a new key-value pair. + * + * @param key the key of this pair + * @param value the value of this pair + */ + public SingletonMap(Object key, Object value) { + this.key = key; + this.value = value; + } + + /** @return the key of this pair */ + public Object getKey() { + return key; + } + + /** @return the value of this pair */ + public Object getValue() { + return value; + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/SmallMap.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/SmallMap.java new file mode 100644 index 0000000000..db89db0261 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/state/data/SmallMap.java @@ -0,0 +1,59 @@ +package org.enso.interpreter.runtime.state.data; + +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.interop.TruffleObject; + +/** + * Represents an arbitrary-size map-like structure. It is low-level and only works well for small + * numbers of keys. + */ +public final class SmallMap implements TruffleObject { + private final @CompilerDirectives.CompilationFinal(dimensions = 1) Object[] keys; + private final @CompilerDirectives.CompilationFinal(dimensions = 1) Object[] values; + private static final SmallMap EMPTY = new SmallMap(new Object[0], new Object[0]); + + public static final int NOT_FOUND = -1; + + /** @return an empty instance of this class */ + public static SmallMap empty() { + return EMPTY; + } + + /** + * Creates a map with given keys and values. + * + * @param keys the keys of this map. + * @param values the values of this map. Must have the same length as {@code keys}. + */ + public SmallMap(Object[] keys, Object[] values) { + this.keys = keys; + this.values = values; + } + + /** + * Returns the index of a given key in the keys array. Returns {@code NOT_FOUND} if the key is + * missing. + * + * @param key the key to lookup + * @return the key's index or {@code NOT_FOUND} + */ + @CompilerDirectives.TruffleBoundary + public int indexOf(Object key) { + for (int i = 0; i < keys.length; i++) { + if (key == keys[i]) { + return i; + } + } + return NOT_FOUND; + } + + /** @return the keys in this map. */ + public Object[] getKeys() { + return keys; + } + + /** @return the values in this map. */ + public Object[] getValues() { + return values; + } +} diff --git a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala index 610356e354..fe9bd6b032 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IrToTruffle.scala @@ -437,6 +437,7 @@ class IrToTruffle( setLocation( ErrorNode.build( context.getBuiltins + .error() .syntaxError() .newInstance( "Type operators are not currently supported at runtime." @@ -479,6 +480,7 @@ class IrToTruffle( val message = invalidBranches.map(_.message).mkString(", ") val error = context.getBuiltins + .error() .compileError() .newInstance(message) @@ -738,17 +740,17 @@ class IrToTruffle( case Error.InvalidIR(_, _, _) => throw new CompilerError("Unexpected Invalid IR during codegen.") case err: Error.Syntax => - context.getBuiltins.syntaxError().newInstance(err.message) + context.getBuiltins.error().syntaxError().newInstance(err.message) case err: Error.Redefined.Binding => - context.getBuiltins.compileError().newInstance(err.message) + context.getBuiltins.error().compileError().newInstance(err.message) case err: Error.Redefined.Method => - context.getBuiltins.compileError().newInstance(err.message) + context.getBuiltins.error().compileError().newInstance(err.message) case err: Error.Redefined.Atom => - context.getBuiltins.compileError().newInstance(err.message) + context.getBuiltins.error().compileError().newInstance(err.message) case err: Error.Redefined.ThisArg => - context.getBuiltins.compileError().newInstance(err.message) + context.getBuiltins.error().compileError().newInstance(err.message) case err: Error.Unexpected.TypeSignature => - context.getBuiltins.compileError().newInstance(err.message) + context.getBuiltins.error().compileError().newInstance(err.message) } setLocation(ErrorNode.build(payload), error.location) } @@ -899,6 +901,7 @@ class IrToTruffle( setLocation( ErrorNode.build( context.getBuiltins + .error() .syntaxError() .newInstance( "Typeset literals are not yet supported at runtime." diff --git a/engine/runtime/src/test/scala/org/enso/interpreter/test/instrument/ReplTest.scala b/engine/runtime/src/test/scala/org/enso/interpreter/test/instrument/ReplTest.scala index c5e7642d81..6d88aa0f08 100644 --- a/engine/runtime/src/test/scala/org/enso/interpreter/test/instrument/ReplTest.scala +++ b/engine/runtime/src/test/scala/org/enso/interpreter/test/instrument/ReplTest.scala @@ -93,14 +93,16 @@ class ReplTest extends InterpreterTest with BeforeAndAfter with EitherValues { "access and modify monadic state" in { val code = """ - |main = - | State.put 10 + |run = + | State.put Number 10 | Debug.breakpoint - | State.get + | State.get Number + | + |main = State.run Number 0 here.run |""".stripMargin setSessionManager { executor => - executor.evaluate("x = State.get") - executor.evaluate("State.put (x + 1)") + executor.evaluate("x = State.get Number") + executor.evaluate("State.put Number (x + 1)") executor.exit() } eval(code) shouldEqual 11 diff --git a/engine/runtime/src/test/scala/org/enso/interpreter/test/semantic/StateTest.scala b/engine/runtime/src/test/scala/org/enso/interpreter/test/semantic/StateTest.scala index f8d1e338d5..7aaa952183 100644 --- a/engine/runtime/src/test/scala/org/enso/interpreter/test/semantic/StateTest.scala +++ b/engine/runtime/src/test/scala/org/enso/interpreter/test/semantic/StateTest.scala @@ -12,11 +12,13 @@ class StateTest extends InterpreterTest { "be accessible from functions" in { val code = """ - |main = - | State.put 10 - | x = State.get - | State.put x+1 - | State.get + |stateful = + | State.put Number 10 + | x = State.get Number + | State.put Number x+1 + | State.get Number + | + |main = State.run Number 0 here.stateful |""".stripMargin eval(code) shouldEqual 11 @@ -25,81 +27,58 @@ class StateTest extends InterpreterTest { "be implicitly threaded through function executions" in { val code = """ - |Unit.incState = - | x = State.get - | State.put x+1 + |inc_state = + | x = State.get Number + | State.put Number x+1 | - |main = - | State.put 0 - | Unit.incState - | Unit.incState - | Unit.incState - | Unit.incState - | Unit.incState - | State.get + |run = + | here.inc_state + | here.inc_state + | here.inc_state + | here.inc_state + | here.inc_state + | State.get Number + | + |main = State.run Number 0 here.run |""".stripMargin eval(code) shouldEqual 5 } - "be localized with State.run" in { - val code = - """ - |main = - | State.put 20 - | myBlock = - | res = State.get - | State.put 0 - | res - | - | res2 = State.run 10 myBlock - | state = State.get - | res2 + state - |""".stripMargin - eval(code) shouldEqual 30 - } - "work well with recursive code" in { val code = """ |main = | stateSum = n -> - | acc = State.get - | State.put acc+n - | if n == 0 then State.get else stateSum n-1 + | acc = State.get Number + | State.put Number acc+n + | if n == 0 then State.get Number else stateSum n-1 | - | State.run 0 (stateSum 10) + | State.run Number 0 (stateSum 10) |""".stripMargin eval(code) shouldEqual 55 } - "be initialized to a Unit by default" in { - val code = - """ - |main = IO.println State.get - |""".stripMargin - eval(code) - consumeOut shouldEqual List("Unit") - } - "work with pattern matches" in { val code = """ - |main = + |run = | matcher = x -> case x of | Unit -> - | y = State.get - | State.put (y + 5) + | y = State.get Number + | State.put Number (y + 5) | Nil -> - | y = State.get - | State.put (y + 10) + | y = State.get Number + | State.put Number (y + 10) | - | State.put 1 + | State.put Number 1 | matcher Nil - | IO.println State.get + | IO.println (State.get Number) | matcher Unit - | IO.println State.get + | IO.println (State.get Number) | 0 + | + |main = State.run Number 0 here.run |""".stripMargin eval(code) consumeOut shouldEqual List("11", "16") @@ -108,16 +87,77 @@ class StateTest extends InterpreterTest { "undo changes on Panics" in { val code = """ - |main = - | panicker = - | State.put 400 - | Panic.throw Unit + |panicker = + | State.put Number 400 + | Panic.throw Unit | - | State.put 5 - | Panic.recover panicker - | State.get + |stater = + | State.put Number 5 + | Panic.recover here.panicker + | State.get Number + | + |main = State.run Number 0 here.stater |""".stripMargin eval(code) shouldEqual 5 } + + "localize properly with State.run when 1 key used" in { + val code = + """ + |inner = State.put Number 0 + | + |outer = + | State.put Number 1 + | State.run Number 2 here.inner + | State.get Number + | + |main = State.run Number 3 here.outer + |""".stripMargin + eval(code) shouldEqual 1 + } + + "localize properly with State.run when 2 states used" in { + val code = + """ + |type S1 + |type S2 + | + |inner = + | State.put S1 0 + | State.put S2 0 + | + |outer = + | State.put S1 1 + | State.run S2 2 here.inner + | State.get S1 + | + |main = State.run S1 3 here.outer + | + |""".stripMargin + eval(code) shouldEqual 0 + } + + "localize properly with State.run when multiple states used" in { + val code = + """ + |type S1 + |type S2 + |type S3 + | + |inner = + | State.put S1 0 + | State.put S2 0 + | + |outer = + | State.put S1 1 + | State.put S3 2 + | State.run S2 2 here.inner + | State.get S1 + State.get S2 + State.get S3 + | + |main = State.run S3 0 (State.run S2 5 (State.run S1 3 here.outer)) + | + |""".stripMargin + eval(code) shouldEqual 7 // S1 = 0, S2 = 5, S3 = 2 + } } } diff --git a/test/Benchmarks/package.yaml b/test/Benchmarks/package.yaml new file mode 100644 index 0000000000..9733feab55 --- /dev/null +++ b/test/Benchmarks/package.yaml @@ -0,0 +1,5 @@ +name: Benchmarks +version: 0.0.1 +license: MIT +author: enso-dev@enso.org +maintainer: enso-dev@enso.org diff --git a/test/Benchmarks/src/Main.enso b/test/Benchmarks/src/Main.enso new file mode 100644 index 0000000000..96c6e4bcc9 --- /dev/null +++ b/test/Benchmarks/src/Main.enso @@ -0,0 +1,40 @@ +import Base.Bench_Utils + +type Counter +type Sum + +sum_tco = sum_to -> + summator = acc -> current -> + if current == 0 then acc else summator acc+current current-1 + res = summator 0 sum_to + res + +sum_co_state_body = + n = State.get Counter + acc = State.get Sum + State.put Counter n-1 + State.put Sum acc+n + if n == 0 then acc else here.sum_co_state_body + +sum_co_state n = + res = State.run Counter n (State.run Sum 0 here.sum_co_state_body) + res + +sum_state_body n = + acc = State.get Number + State.put Number (acc + n) + if n == 0 then State.get Number else here.sum_state_body (n - 1) + +sum_state = sum_to -> + res = State.run Number 0 (here.sum_state_body sum_to) + res + +main = + hundred_mil = 100000000 + IO.println "Measuring SumTCO" + Bench_Utils.measure (_ -> here.sum_tco hundred_mil) "sum_tco" 100 20 + IO.println "Measuring State" + Bench_Utils.measure (_ -> here.sum_state hundred_mil) "sum_state" 100 20 + IO.println "Measuring Co-State" + Bench_Utils.measure (_ -> here.sum_co_state hundred_mil) "sum_co_state" 100 20 + IO.println "Bye."