mirror of
https://github.com/enso-org/enso.git
synced 2024-11-23 08:08:34 +03:00
Implement suspension of default arguments (#170)
This commit is contained in:
parent
f8dea12e44
commit
ed8223c57c
@ -12,6 +12,7 @@ import org.enso.interpreter.runtime.scope.LocalScope;
|
||||
* runtime nodes used by the interpreter to guide function evaluation.
|
||||
*/
|
||||
public class CallArgFactory implements AstCallArgVisitor<CallArgument> {
|
||||
|
||||
private final LocalScope scope;
|
||||
private final Language language;
|
||||
private final String scopeName;
|
||||
|
@ -4,19 +4,42 @@ import com.oracle.truffle.api.RootCallTarget;
|
||||
import com.oracle.truffle.api.Truffle;
|
||||
import com.oracle.truffle.api.frame.FrameSlot;
|
||||
import com.oracle.truffle.api.nodes.RootNode;
|
||||
import org.enso.interpreter.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import org.enso.interpreter.AstArgDefinition;
|
||||
import org.enso.interpreter.AstCallArg;
|
||||
import org.enso.interpreter.AstCase;
|
||||
import org.enso.interpreter.AstCaseFunction;
|
||||
import org.enso.interpreter.AstExpression;
|
||||
import org.enso.interpreter.AstExpressionVisitor;
|
||||
import org.enso.interpreter.Language;
|
||||
import org.enso.interpreter.node.EnsoRootNode;
|
||||
import org.enso.interpreter.node.ExpressionNode;
|
||||
import org.enso.interpreter.node.callable.InvokeCallableNodeGen;
|
||||
import org.enso.interpreter.node.callable.argument.ReadArgumentNode;
|
||||
import org.enso.interpreter.node.callable.function.CreateFunctionNode;
|
||||
import org.enso.interpreter.node.callable.function.FunctionBodyNode;
|
||||
import org.enso.interpreter.node.controlflow.*;
|
||||
import org.enso.interpreter.node.controlflow.CaseNode;
|
||||
import org.enso.interpreter.node.controlflow.ConstructorCaseNode;
|
||||
import org.enso.interpreter.node.controlflow.DefaultFallbackNode;
|
||||
import org.enso.interpreter.node.controlflow.FallbackNode;
|
||||
import org.enso.interpreter.node.controlflow.IfZeroNode;
|
||||
import org.enso.interpreter.node.controlflow.MatchNode;
|
||||
import org.enso.interpreter.node.expression.builtin.PrintNode;
|
||||
import org.enso.interpreter.node.expression.constant.ConstructorNode;
|
||||
import org.enso.interpreter.node.expression.constant.DynamicSymbolNode;
|
||||
import org.enso.interpreter.node.expression.literal.IntegerLiteralNode;
|
||||
import org.enso.interpreter.node.expression.operator.*;
|
||||
import org.enso.interpreter.node.expression.operator.AddOperatorNodeGen;
|
||||
import org.enso.interpreter.node.expression.operator.DivideOperatorNodeGen;
|
||||
import org.enso.interpreter.node.expression.operator.ModOperatorNodeGen;
|
||||
import org.enso.interpreter.node.expression.operator.MultiplyOperatorNodeGen;
|
||||
import org.enso.interpreter.node.expression.operator.SubtractOperatorNodeGen;
|
||||
import org.enso.interpreter.node.scope.AssignmentNode;
|
||||
import org.enso.interpreter.node.scope.AssignmentNodeGen;
|
||||
import org.enso.interpreter.node.scope.ReadLocalTargetNodeGen;
|
||||
@ -27,16 +50,12 @@ import org.enso.interpreter.runtime.error.DuplicateArgumentNameException;
|
||||
import org.enso.interpreter.runtime.scope.ModuleScope;
|
||||
import org.enso.interpreter.runtime.scope.LocalScope;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* An {@code ExpressionFactory} is responsible for converting the majority of Enso's parsed AST into
|
||||
* nodes evaluated by the interpreter at runtime.
|
||||
*/
|
||||
public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
|
||||
|
||||
private final LocalScope scope;
|
||||
private final Language language;
|
||||
private final String scopeName;
|
||||
@ -128,11 +147,21 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
|
||||
String operator, AstExpression leftAst, AstExpression rightAst) {
|
||||
ExpressionNode left = leftAst.visit(this);
|
||||
ExpressionNode right = rightAst.visit(this);
|
||||
if (operator.equals("+")) return AddOperatorNodeGen.create(left, right);
|
||||
if (operator.equals("-")) return SubtractOperatorNodeGen.create(left, right);
|
||||
if (operator.equals("*")) return MultiplyOperatorNodeGen.create(left, right);
|
||||
if (operator.equals("/")) return DivideOperatorNodeGen.create(left, right);
|
||||
if (operator.equals("%")) return ModOperatorNodeGen.create(left, right);
|
||||
if (operator.equals("+")) {
|
||||
return AddOperatorNodeGen.create(left, right);
|
||||
}
|
||||
if (operator.equals("-")) {
|
||||
return SubtractOperatorNodeGen.create(left, right);
|
||||
}
|
||||
if (operator.equals("*")) {
|
||||
return MultiplyOperatorNodeGen.create(left, right);
|
||||
}
|
||||
if (operator.equals("/")) {
|
||||
return DivideOperatorNodeGen.create(left, right);
|
||||
}
|
||||
if (operator.equals("%")) {
|
||||
return ModOperatorNodeGen.create(left, right);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -300,7 +329,7 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
|
||||
*/
|
||||
@Override
|
||||
public ExpressionNode visitFunctionApplication(
|
||||
AstExpression function, List<AstCallArg> arguments) {
|
||||
AstExpression function, List<AstCallArg> arguments, boolean hasDefaultsSuspended) {
|
||||
CallArgFactory argFactory = new CallArgFactory(scope, language, scopeName, moduleScope);
|
||||
|
||||
List<CallArgument> callArgs = new ArrayList<>();
|
||||
@ -310,7 +339,7 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
|
||||
}
|
||||
|
||||
return InvokeCallableNodeGen.create(
|
||||
callArgs.stream().toArray(CallArgument[]::new), function.visit(this));
|
||||
callArgs.toArray(new CallArgument[0]), hasDefaultsSuspended, function.visit(this));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -8,6 +8,7 @@ import com.oracle.truffle.api.frame.VirtualFrame;
|
||||
import com.oracle.truffle.api.nodes.ExplodeLoop;
|
||||
import com.oracle.truffle.api.nodes.NodeInfo;
|
||||
import com.oracle.truffle.api.profiles.ConditionProfile;
|
||||
import java.util.Arrays;
|
||||
import org.enso.interpreter.Constants;
|
||||
import org.enso.interpreter.node.ExpressionNode;
|
||||
import org.enso.interpreter.node.callable.argument.sorter.ArgumentSorterNode;
|
||||
@ -22,8 +23,6 @@ import org.enso.interpreter.runtime.error.MethodDoesNotExistException;
|
||||
import org.enso.interpreter.runtime.error.NotInvokableException;
|
||||
import org.enso.interpreter.runtime.type.TypesGen;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* This node is responsible for organising callable calls so that they are ready to be made.
|
||||
*
|
||||
@ -33,15 +32,17 @@ import java.util.Arrays;
|
||||
@NodeInfo(shortName = "@", description = "Executes function")
|
||||
@NodeChild(value = "callable", type = ExpressionNode.class)
|
||||
public abstract class InvokeCallableNode extends ExpressionNode {
|
||||
@Children
|
||||
private @CompilationFinal(dimensions = 1) ExpressionNode[] argExpressions;
|
||||
|
||||
private final boolean canApplyThis;
|
||||
private final int thisArgumentPosition;
|
||||
@Children
|
||||
@CompilationFinal(dimensions = 1)
|
||||
private ExpressionNode[] argExpressions;
|
||||
|
||||
@Child private ArgumentSorterNode argumentSorter;
|
||||
@Child private MethodResolverNode methodResolverNode;
|
||||
|
||||
private final boolean canApplyThis;
|
||||
private final int thisArgumentPosition;
|
||||
|
||||
private final ConditionProfile methodCalledOnNonAtom = ConditionProfile.createCountingProfile();
|
||||
|
||||
/**
|
||||
@ -49,7 +50,7 @@ public abstract class InvokeCallableNode extends ExpressionNode {
|
||||
*
|
||||
* @param callArguments information on the arguments being passed to the {@link Function}
|
||||
*/
|
||||
public InvokeCallableNode(CallArgument[] callArguments) {
|
||||
public InvokeCallableNode(CallArgument[] callArguments, boolean hasDefaultsSuspended) {
|
||||
this.argExpressions =
|
||||
Arrays.stream(callArguments)
|
||||
.map(CallArgument::getExpression)
|
||||
@ -57,9 +58,9 @@ public abstract class InvokeCallableNode extends ExpressionNode {
|
||||
|
||||
CallArgumentInfo[] argSchema =
|
||||
Arrays.stream(callArguments).map(CallArgumentInfo::new).toArray(CallArgumentInfo[]::new);
|
||||
|
||||
boolean appliesThis = false;
|
||||
int idx = 0;
|
||||
|
||||
for (; idx < argSchema.length; idx++) {
|
||||
CallArgumentInfo arg = argSchema[idx];
|
||||
if (arg.isPositional()
|
||||
@ -68,15 +69,17 @@ public abstract class InvokeCallableNode extends ExpressionNode {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
this.canApplyThis = appliesThis;
|
||||
this.thisArgumentPosition = idx;
|
||||
|
||||
this.argumentSorter = ArgumentSorterNodeGen.create(argSchema);
|
||||
this.argumentSorter = ArgumentSorterNodeGen.create(argSchema, hasDefaultsSuspended);
|
||||
this.methodResolverNode = MethodResolverNodeGen.create();
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks whether the {@code argumentSorter} child is tail–recursive.
|
||||
*
|
||||
* @param isTail whether or not the node is tail-recursive.
|
||||
*/
|
||||
@Override
|
||||
|
@ -16,18 +16,20 @@ import org.enso.interpreter.runtime.callable.function.Function;
|
||||
*/
|
||||
@NodeInfo(shortName = "ArgumentSorter")
|
||||
public abstract class ArgumentSorterNode extends BaseNode {
|
||||
|
||||
private @CompilationFinal(dimensions = 1) CallArgumentInfo[] schema;
|
||||
private final boolean hasDefaultsSuspended;
|
||||
|
||||
/**
|
||||
* Creates a node that performs the argument organisation for the provided schema.
|
||||
*
|
||||
* @param schema information about the call arguments in positional order
|
||||
*/
|
||||
public ArgumentSorterNode(CallArgumentInfo[] schema) {
|
||||
public ArgumentSorterNode(CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
|
||||
this.schema = schema;
|
||||
this.hasDefaultsSuspended = hasDefaultsSuspended;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generates the argument mapping where it has already been computed and executes the function.
|
||||
*
|
||||
@ -49,7 +51,8 @@ public abstract class ArgumentSorterNode extends BaseNode {
|
||||
public Object invokeCached(
|
||||
Function function,
|
||||
Object[] arguments,
|
||||
@Cached("create(function, getSchema())") CachedArgumentSorterNode mappingNode,
|
||||
@Cached("create(function, getSchema(), hasDefaultsSuspended())")
|
||||
CachedArgumentSorterNode mappingNode,
|
||||
@Cached CallOptimiserNode optimiser) {
|
||||
Object[] mappedArguments = mappingNode.execute(function, arguments);
|
||||
if (mappingNode.appliesFully()) {
|
||||
@ -84,4 +87,13 @@ public abstract class ArgumentSorterNode extends BaseNode {
|
||||
CallArgumentInfo[] getSchema() {
|
||||
return schema;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks whether the function whose arguments are being sorted has suspended defaults arguments.
|
||||
*
|
||||
* @return {@code true} if it has suspended defaults, otherwise {@code false}
|
||||
*/
|
||||
boolean hasDefaultsSuspended() {
|
||||
return this.hasDefaultsSuspended;
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import org.enso.interpreter.runtime.callable.function.Function;
|
||||
*/
|
||||
@NodeInfo(shortName = "CachedArgumentSorter")
|
||||
public class CachedArgumentSorterNode extends BaseNode {
|
||||
|
||||
private final Function originalFunction;
|
||||
private final @CompilationFinal(dimensions = 1) int[] mapping;
|
||||
private final ArgumentSchema postApplicationSchema;
|
||||
@ -22,9 +23,12 @@ public class CachedArgumentSorterNode extends BaseNode {
|
||||
* Creates a node that generates and then caches the argument mapping.
|
||||
*
|
||||
* @param function the function to sort arguments for
|
||||
* @param schema information on the calling arguments
|
||||
* @param schema information on the calling argument
|
||||
* @param hasDefaultsSuspended whether or not the function to which these arguments are applied
|
||||
* has its defaults suspended.
|
||||
*/
|
||||
public CachedArgumentSorterNode(Function function, CallArgumentInfo[] schema) {
|
||||
public CachedArgumentSorterNode(
|
||||
Function function, CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
|
||||
this.originalFunction = function;
|
||||
CallArgumentInfo.ArgumentMapping mapping =
|
||||
CallArgumentInfo.ArgumentMapping.generate(function.getSchema(), schema);
|
||||
@ -33,7 +37,10 @@ public class CachedArgumentSorterNode extends BaseNode {
|
||||
|
||||
boolean fullApplication = true;
|
||||
for (int i = 0; i < postApplicationSchema.getArgumentsCount(); i++) {
|
||||
if (!(postApplicationSchema.hasDefaultAt(i) || postApplicationSchema.hasPreAppliedAt(i))) {
|
||||
boolean hasValidDefault = postApplicationSchema.hasDefaultAt(i) && !hasDefaultsSuspended;
|
||||
boolean hasPreappliedArg = postApplicationSchema.hasPreAppliedAt(i);
|
||||
|
||||
if (!(hasValidDefault || hasPreappliedArg)) {
|
||||
fullApplication = false;
|
||||
break;
|
||||
}
|
||||
@ -48,8 +55,9 @@ public class CachedArgumentSorterNode extends BaseNode {
|
||||
* @param schema information on the calling arguments
|
||||
* @return a sorter node for the arguments in {@code schema} being passed to {@code callable}
|
||||
*/
|
||||
public static CachedArgumentSorterNode create(Function function, CallArgumentInfo[] schema) {
|
||||
return new CachedArgumentSorterNode(function, schema);
|
||||
public static CachedArgumentSorterNode create(
|
||||
Function function, CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
|
||||
return new CachedArgumentSorterNode(function, schema, hasDefaultsSuspended);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -7,6 +7,12 @@ import com.oracle.truffle.api.TruffleLanguage;
|
||||
import com.oracle.truffle.api.TruffleLanguage.Env;
|
||||
import com.oracle.truffle.api.frame.FrameDescriptor;
|
||||
import com.oracle.truffle.api.source.Source;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import org.enso.interpreter.AstGlobalScope;
|
||||
import org.enso.interpreter.Constants;
|
||||
import org.enso.interpreter.EnsoParser;
|
||||
@ -20,12 +26,6 @@ import org.enso.interpreter.util.ScalaConversions;
|
||||
import org.enso.pkg.Package;
|
||||
import org.enso.pkg.SourceFile;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* The language context is the internal state of the language that is associated with each thread in
|
||||
* a running Enso program.
|
||||
|
@ -28,7 +28,8 @@ trait AstExpressionVisitor[+T] {
|
||||
|
||||
def visitFunctionApplication(
|
||||
function: AstExpression,
|
||||
arguments: java.util.List[AstCallArg]
|
||||
arguments: java.util.List[AstCallArg],
|
||||
defaultsSuspended: Boolean
|
||||
): T
|
||||
|
||||
def visitIf(
|
||||
@ -156,10 +157,13 @@ case class AstVariable(name: String) extends AstExpression {
|
||||
visitor.visitVariable(name)
|
||||
}
|
||||
|
||||
case class AstApply(fun: AstExpression, args: List[AstCallArg])
|
||||
case class AstApply(
|
||||
fun: AstExpression,
|
||||
args: List[AstCallArg],
|
||||
hasDefaultsSuspended: Boolean)
|
||||
extends AstExpression {
|
||||
override def visit[T](visitor: AstExpressionVisitor[T]): T =
|
||||
visitor.visitFunctionApplication(fun, args.asJava)
|
||||
visitor.visitFunctionApplication(fun, args.asJava, hasDefaultsSuspended)
|
||||
}
|
||||
|
||||
case class AstFunction(
|
||||
@ -269,7 +273,7 @@ class EnsoParserInternal extends JavaTokenParsers {
|
||||
def variable: Parser[AstVariable] = ident ^^ AstVariable
|
||||
|
||||
def operand: Parser[AstExpression] =
|
||||
long | foreign | variable | "(" ~> expression <~ ")" | call
|
||||
long | foreign | variable | "(" ~> expression <~ ")" | functionCall
|
||||
|
||||
def arith: Parser[AstExpression] =
|
||||
operand ~ ((("+" | "-" | "*" | "/" | "%") ~ operand) ?) ^^ {
|
||||
@ -280,9 +284,17 @@ class EnsoParserInternal extends JavaTokenParsers {
|
||||
def expression: Parser[AstExpression] =
|
||||
ifZero | matchClause | arith | function
|
||||
|
||||
def call: Parser[AstApply] = "@" ~> expression ~ (argList ?) ^^ {
|
||||
case expr ~ args => AstApply(expr, args.getOrElse(Nil))
|
||||
}
|
||||
def functionCall: Parser[AstApply] =
|
||||
"@" ~> expression ~ (argList ?) ~ defaultSuspend ^^ {
|
||||
case expr ~ args ~ hasDefaultsSuspended =>
|
||||
AstApply(expr, args.getOrElse(Nil), hasDefaultsSuspended)
|
||||
}
|
||||
|
||||
def defaultSuspend: Parser[Boolean] =
|
||||
("..." ?) ^^ ({
|
||||
case Some(_) => true
|
||||
case None => false
|
||||
})
|
||||
|
||||
def assignment: Parser[AstAssignment] = ident ~ ("=" ~> expression) ^^ {
|
||||
case v ~ exp => AstAssignment(v, exp)
|
||||
|
@ -14,4 +14,34 @@ class CurryingTest extends LanguageTest {
|
||||
|""".stripMargin
|
||||
eval(code) shouldEqual 11
|
||||
}
|
||||
|
||||
"Functions" should "allow default arguments to be suspended" in {
|
||||
val code =
|
||||
"""
|
||||
|@{
|
||||
| fn = { |w, x, y = 10, z = 20| (w + x) + (y + z) };
|
||||
|
|
||||
| fn1 = @fn ...;
|
||||
| fn2 = @fn1 [1, 2] ...;
|
||||
| fn3 = @fn2 [3] ...;
|
||||
|
|
||||
| @fn3
|
||||
|}
|
||||
|""".stripMargin
|
||||
|
||||
eval(code) shouldEqual 26
|
||||
}
|
||||
|
||||
"Functions" should "allow defaults to be suspended in application chains" in {
|
||||
val code =
|
||||
"""
|
||||
|@{
|
||||
| fn = { |w, x, y = 10, z = 20| (w + x) + (y + z) };
|
||||
|
|
||||
| @(@fn [3, 6] ...) [3]
|
||||
|}
|
||||
|""".stripMargin
|
||||
|
||||
eval(code) shouldEqual 32
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user