diff --git a/build.sbt b/build.sbt index c7ff7a6f04..4da6a90b94 100644 --- a/build.sbt +++ b/build.sbt @@ -1,14 +1,14 @@ import java.io.File -import sbt.Keys.scalacOptions - -import scala.sys.process._ import org.enso.build.BenchTasks._ import org.enso.build.WithDebugCommand +import sbt.Keys.scalacOptions import sbt.addCompilerPlugin import sbtassembly.AssemblyPlugin.defaultUniversalScript import sbtcrossproject.CrossPlugin.autoImport.{crossProject, CrossType} +import scala.sys.process._ + ////////////////////////////// //// Global Configuration //// ////////////////////////////// @@ -446,16 +446,16 @@ lazy val polyglot_api = project lazy val language_server = (project in file("engine/language-server")) .settings( libraryDependencies ++= akka ++ circe ++ Seq( - "ch.qos.logback" % "logback-classic" % "1.2.3", + "ch.qos.logback" % "logback-classic" % "1.2.3", "io.circe" %% "circe-generic-extras" % "0.12.2", - "io.circe" %% "circe-literal" % circeVersion, - "org.bouncycastle" % "bcpkix-jdk15on" % "1.64", - "dev.zio" %% "zio" % "1.0.0-RC18-2", + "io.circe" %% "circe-literal" % circeVersion, + "org.bouncycastle" % "bcpkix-jdk15on" % "1.64", + "dev.zio" %% "zio" % "1.0.0-RC18-2", akkaTestkit % Test, - "commons-io" % "commons-io" % "2.6", - "org.scalatest" %% "scalatest" % "3.2.0-M2" % Test, - "org.scalacheck" %% "scalacheck" % "1.14.0" % Test, - "org.graalvm.sdk" % "polyglot-tck" % graalVersion % "provided" + "commons-io" % "commons-io" % "2.6", + "org.scalatest" %% "scalatest" % "3.2.0-M2" % Test, + "org.scalacheck" %% "scalacheck" % "1.14.0" % Test, + "org.graalvm.sdk" % "polyglot-tck" % graalVersion % "provided" ), testOptions in Test += Tests .Argument(TestFrameworks.ScalaCheck, "-minSuccessfulTests", "1000") diff --git a/doc/design/syntax/syntax.md b/doc/design/syntax/syntax.md index 3ee1343277..a9210d193f 100644 --- a/doc/design/syntax/syntax.md +++ b/doc/design/syntax/syntax.md @@ -641,7 +641,8 @@ binds the function name. This means that: ### Methods Enso makes a distinction between functions and methods. In Enso, a method is a function where the first argument (known as the `this` argument) is associated -with a given atom. +with a given atom. Methods are dispatched dynamically based on the type of the +`this` argument, while functions are not. Methods can be defined in Enso in two ways: @@ -669,6 +670,18 @@ Methods can be defined in Enso in two ways: ... ``` +3. **As a Function with an Explicit `this` Argument:** A function defined with + the type of the `this` argument specified to be a type. + + ```ruby + floor (this : Number) = case this of + Integer -> ... + ``` + +If the user does not explicitly specify the `this` argument by name when +defining a method (e.g. they use the `Type.name` syntax), it is implicitly added +to the start of the argument list. + #### This vs. Self Though it varies greatly between programming languages, we have chosen `this` to be the name of the 'current type' rather than `self`. This is a purely aesthetic @@ -687,17 +700,10 @@ when calling it. To that end, Enso supports what is known as Uniform Call Syntax - This is a needless constraint as both notations have their advantages. - Enso has two notations, but one unified semantics. -The rules for the uniform syntax call translation in Enso are as follows. +The rules for the uniform syntax call translation in Enso are as follows: -1. For an expression `t.fn`, this is equivalent to `fn (this = t)`. -2. The `this` argument may occur at any position in the function. - -> The actionables for this section are: -> -> - Clarify exactly how this should work, and which argument should be -> translated. -> - We do not _currently_ implement the above-listed transformation, so we need -> to solidify these rules. +1. For an expression `t.fn `, this is equivalent to `fn t `. +2. For an expression `fn t `, this is equivalent to `t.fn `. ### Code Blocks Top-level blocks in the language are evaluated immediately. This means that the diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/ClosureRootNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/ClosureRootNode.java index 6000b0933c..ca2474c538 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/ClosureRootNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/ClosureRootNode.java @@ -70,15 +70,4 @@ public class ClosureRootNode extends EnsoRootNode { state = FrameUtil.getObjectSafe(frame, this.getStateFrameSlot()); return new Stateful(state, result); } - - /** - * Sets whether the node is tail-recursive. - * - * @param isTail whether or not the node is tail-recursive. - */ - @Override - public void setTail(boolean isTail) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - body.setTail(isTail); - } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/EnsoRootNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/EnsoRootNode.java index dd2c357ca4..dddce8bf99 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/EnsoRootNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/EnsoRootNode.java @@ -83,13 +83,6 @@ public abstract class EnsoRootNode extends RootNode { return this.name; } - /** - * Sets whether the node is tail-recursive. - * - * @param isTail whether or not the node is tail-recursive - */ - public abstract void setTail(boolean isTail); - /** * Gets the frame slot containing the program state. * diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/BlockNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/BlockNode.java index afffb77b0d..9c0dc030df 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/BlockNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/BlockNode.java @@ -30,16 +30,6 @@ public class BlockNode extends ExpressionNode { return new BlockNode(expressions, returnExpr); } - /** - * Sets whether or not the function is tail-recursive. - * - * @param isTail whether or not the function is tail-recursive. - */ - @Override - public void setTail(boolean isTail) { - returnExpr.setTail(isTail); - } - /** * Executes the body of the function. * diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/CreateFunctionNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/CreateFunctionNode.java index 6da0dab973..f13664492f 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/CreateFunctionNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/callable/function/CreateFunctionNode.java @@ -37,16 +37,6 @@ public class CreateFunctionNode extends ExpressionNode { return new CreateFunctionNode(callTarget, args); } - /** - * Sets the tail-recursiveness of the function. - * - * @param isTail whether or not the function is tail-recursive - */ - @Override - public void setTail(boolean isTail) { - ((ClosureRootNode) callTarget.getRootNode()).setTail(isTail); - } - /** * Generates the provided function definition in the given stack {@code frame}. * diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/ConstructorCaseNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/ConstructorCaseNode.java index 26fb328da2..c3ea845dbe 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/ConstructorCaseNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/ConstructorCaseNode.java @@ -38,16 +38,6 @@ public class ConstructorCaseNode extends CaseNode { return new ConstructorCaseNode(matcher, branch); } - /** - * Sets whether or not the case expression is tail recursive. - * - * @param isTail whether or not the case expression is tail-recursive - */ - @Override - public void setTail(boolean isTail) { - branch.setTail(isTail); - } - /** * Handles the atom scrutinee case. * diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/MatchNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/MatchNode.java index a15c1656bc..34da6aed7f 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/MatchNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/controlflow/MatchNode.java @@ -43,20 +43,6 @@ public abstract class MatchNode extends ExpressionNode { return MatchNodeGen.create(cases, fallback, scrutinee); } - /** - * Sets whether or not the pattern match is tail-recursive. - * - * @param isTail whether or not the expression is tail-recursive - */ - @Override - @ExplodeLoop - public void setTail(boolean isTail) { - for (CaseNode caseNode : cases) { - caseNode.setTail(isTail); - } - fallback.setTail(isTail); - } - @Specialization Object doError(VirtualFrame frame, RuntimeError error) { return error; diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/debug/EvalNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/debug/EvalNode.java index 21a5c908bf..af2aeb5f67 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/debug/EvalNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/debug/EvalNode.java @@ -5,6 +5,7 @@ import com.oracle.truffle.api.Truffle; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.nodes.NodeInfo; +import org.enso.compiler.InlineContext; import org.enso.interpreter.Constants; import org.enso.interpreter.Language; import org.enso.interpreter.node.BaseNode; @@ -16,6 +17,7 @@ import org.enso.interpreter.runtime.callable.argument.Thunk; import org.enso.interpreter.runtime.scope.LocalScope; import org.enso.interpreter.runtime.scope.ModuleScope; import org.enso.interpreter.runtime.state.Stateful; +import scala.Some; /** Node running Enso expressions passed to it as strings. */ @NodeInfo(shortName = "Eval", description = "Evaluates code passed to it as string") @@ -57,11 +59,12 @@ public abstract class EvalNode extends BaseNode { RootCallTarget parseExpression(LocalScope scope, ModuleScope moduleScope, String expression) { LocalScope localScope = scope.createChild(); Language language = lookupLanguageReference(Language.class).get(); + InlineContext inlineContext = InlineContext.fromJava(localScope, moduleScope, isTail()); ExpressionNode expr = lookupContextReference(Language.class) .get() .compiler() - .runInline(expression, localScope, moduleScope) + .runInline(expression, inlineContext) .getOrElse(null); if (expr == null) { throw new RuntimeException("Invalid code passed to `eval`: " + expression); @@ -78,7 +81,6 @@ public abstract class EvalNode extends BaseNode { expr, null, ""); - framedNode.setTail(isTail()); return Truffle.getRuntime().createCallTarget(framedNode); } diff --git a/engine/runtime/src/main/scala/org/enso/compiler/Compiler.scala b/engine/runtime/src/main/scala/org/enso/compiler/Compiler.scala index 351b7ebc17..2d55c0e1ff 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/Compiler.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/Compiler.scala @@ -7,9 +7,18 @@ import com.oracle.truffle.api.source.Source import org.enso.compiler.codegen.{AstToIR, IRToTruffle} import org.enso.compiler.core.IR import org.enso.compiler.core.IR.{Expression, Module} +import org.enso.compiler.exception.CompilerError import org.enso.compiler.pass.IRPass -import org.enso.compiler.pass.analyse.{AliasAnalysis, ApplicationSaturation} -import org.enso.compiler.pass.desugar.{LiftSpecialOperators, OperatorToFunction} +import org.enso.compiler.pass.analyse.{ + AliasAnalysis, + ApplicationSaturation, + TailCall +} +import org.enso.compiler.pass.desugar.{ + GenerateMethodBodies, + LiftSpecialOperators, + OperatorToFunction +} import org.enso.interpreter.Language import org.enso.interpreter.node.{ExpressionNode => RuntimeExpression} import org.enso.interpreter.runtime.Context @@ -40,10 +49,12 @@ class Compiler( * they nevertheless exist. */ val compilerPhaseOrdering: List[IRPass] = List( + GenerateMethodBodies, LiftSpecialOperators, OperatorToFunction, AliasAnalysis, - ApplicationSaturation() + ApplicationSaturation(), + TailCall ) /** @@ -105,14 +116,13 @@ class Compiler( * Processes the source in the context of given local and module scopes. * * @param srcString string representing the expression to process - * @param localScope local scope to process the source in - * @param moduleScope module scope to process the source in + * @param inlineContext a context object that contains the information needed + * for inline evaluation * @return an expression node representing the parsed and analyzed source */ def runInline( srcString: String, - localScope: LocalScope, - moduleScope: ModuleScope + inlineContext: InlineContext ): Option[RuntimeExpression] = { val source = Source .newBuilder( @@ -124,17 +134,8 @@ class Compiler( val parsed: AST = parse(source) generateIRInline(parsed).flatMap { ir => - Some({ - val compilerOutput = - runCompilerPhasesInline(ir, localScope, moduleScope) - - truffleCodegenInline( - compilerOutput, - source, - moduleScope, - localScope - ) - }) + val compilerOutput = runCompilerPhasesInline(ir, inlineContext) + Some(truffleCodegenInline(compilerOutput, source, inlineContext)) } } @@ -147,7 +148,7 @@ class Compiler( * @param qualifiedName the qualified name of the module * @return the scope containing all definitions in the requested module */ - def requestProcess(qualifiedName: String): ModuleScope = { + def processImport(qualifiedName: String): ModuleScope = { val module = topScope.getModule(qualifiedName) if (module.isPresent) { module.get().getScope(context) @@ -205,16 +206,17 @@ class Compiler( /** Runs the various compiler passes in an inline context. * * @param ir the compiler intermediate representation to transform + * @param inlineContext a context object that contains the information needed + * for inline evaluation * @return the output result of the */ def runCompilerPhasesInline( ir: IR.Expression, - localScope: LocalScope, - moduleScope: ModuleScope + inlineContext: InlineContext ): IR.Expression = { compilerPhaseOrdering.foldLeft(ir)( (intermediateIR, pass) => - pass.runExpression(intermediateIR, Some(localScope), Some(moduleScope)) + pass.runExpression(intermediateIR, inlineContext) ) } @@ -236,18 +238,27 @@ class Compiler( * * @param ir the prorgam to translate * @param source the source code of the program represented by `ir` - * @param moduleScope the module scope in which the code is to be generated - * @param localScope the local scope in which the inline code is to be - * located + * @param inlineContext a context object that contains the information needed + * for inline evaluation * @return the runtime representation of the program represented by `ir` */ def truffleCodegenInline( ir: IR.Expression, source: Source, - moduleScope: ModuleScope, - localScope: LocalScope + inlineContext: InlineContext ): RuntimeExpression = { - new IRToTruffle(this.language, source, moduleScope) - .runInline(ir, localScope, "") + new IRToTruffle( + this.language, + source, + inlineContext.moduleScope.getOrElse( + throw new CompilerError( + "Cannot perform inline codegen with a missing module scope." + ) + ) + ).runInline( + ir, + inlineContext.localScope.getOrElse(LocalScope.root), + "" + ) } } diff --git a/engine/runtime/src/main/scala/org/enso/compiler/InlineContext.scala b/engine/runtime/src/main/scala/org/enso/compiler/InlineContext.scala new file mode 100644 index 0000000000..7aedac3378 --- /dev/null +++ b/engine/runtime/src/main/scala/org/enso/compiler/InlineContext.scala @@ -0,0 +1,41 @@ +package org.enso.compiler + +import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope} + +/** A type containing the information about the execution context for an inline + * expression. + * + * @param localScope the local scope in which the expression is being executed + * @param moduleScope the module scope in which the expression is being + * executed + * @param isInTailPosition whether or not the inline expression occurs in tail + * position ([[None]] indicates no information) + */ +case class InlineContext( + localScope: Option[LocalScope] = None, + moduleScope: Option[ModuleScope] = None, + isInTailPosition: Option[Boolean] = None +) +object InlineContext { + + /** Implements a null-safe conversion from nullable objects to Scala's option + * internally. + * + * @param localScope the local scope instance + * @param moduleScope the module scope instance + * @param isInTailPosition whether or not the inline expression occurs in a + * tail position + * @return the [[InlineContext]] instance corresponding to the arguments + */ + def fromJava( + localScope: LocalScope, + moduleScope: ModuleScope, + isInTailPosition: Boolean + ): InlineContext = { + InlineContext( + Option(localScope), + Option(moduleScope), + Option(isInTailPosition) + ) + } +} diff --git a/engine/runtime/src/main/scala/org/enso/compiler/codegen/AstToIR.scala b/engine/runtime/src/main/scala/org/enso/compiler/codegen/AstToIR.scala index c097096452..b3b4e53e84 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/codegen/AstToIR.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/codegen/AstToIR.scala @@ -140,17 +140,11 @@ object AstToIR { (Constants.Names.CURRENT_MODULE, None) } - val nameStr = name match { case AST.Ident.Var.any(name) => name } - val defExpression = translateExpression(definition) - val defExpr: Function.Lambda = defExpression match { - case fun: Function.Lambda => fun - case expr => - Function.Lambda(List(), expr, expr.location) - } + val nameStr = name match { case AST.Ident.Var.any(name) => name } Module.Scope.Definition.Method( Name.Literal(path, pathLoc), Name.Literal(nameStr.name, nameStr.location), - defExpr, + translateExpression(definition), inputAST.location ) case _ => @@ -180,10 +174,14 @@ object AstToIR { case AstView.Assignment(name, expr) => translateBinding(inputAST.location, name, expr) case AstView.MethodCall(target, name, args) => + val (validArguments, hasDefaultsSuspended) = + calculateDefaultsSuspension(args) + + // Note [Uniform Call Syntax Translation] Application.Prefix( translateExpression(name), - (target :: args).map(translateCallArgument), - false, + (target :: validArguments).map(translateCallArgument), + hasDefaultsSuspended = hasDefaultsSuspended, inputAST.location ) case AstView.CaseExpression(scrutinee, branches) => @@ -226,6 +224,16 @@ object AstToIR { } } + /* Note [Uniform Call Syntax Translation] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * As the uniform call syntax must work for both methods and functions, the + * conversion can't take advantage of any by-name application semantics at the + * current time. + * + * This means that it is a purely _positional_ conversion on the first + * argument and cannot be performed any other way. + */ + /** Translates a program literal from its [[AST]] representation into * [[Core]]. * @@ -289,7 +297,7 @@ object AstToIR { DefinitionArgument.Specified( Name.Literal(name.name, name.location), Some(translateExpression(value)), - true, + suspended = true, arg.location ) case AstView.LazyArgument(arg) => @@ -331,6 +339,28 @@ object AstToIR { CallArgument.Specified(None, translateExpression(arg), arg.location) } + /** Calculates whether a set of arguments has its defaults suspended, and + * processes the argument list to remove that operator. + * + * @param args the list of arguments + * @return the list of arguments with the suspension operator removed, and + * whether or not the defaults are suspended + */ + def calculateDefaultsSuspension(args: List[AST]): (List[AST], Boolean) = { + val validArguments = args.filter { + case AstView.SuspendDefaultsOperator(_) => false + case _ => true + } + + val suspendPositions = args.zipWithIndex.collect { + case (AstView.SuspendDefaultsOperator(_), ix) => ix + } + + val hasDefaultsSuspended = suspendPositions.contains(args.length - 1) + + (validArguments, hasDefaultsSuspended) + } + /** Translates an arbitrary expression that takes the form of a syntactic * application from its [[AST]] representation into [[Core]]. * @@ -348,16 +378,8 @@ object AstToIR { case AstView.ForcedTerm(term) => Application.Force(translateExpression(term), callable.location) case AstView.Application(name, args) => - val validArguments = args.filter { - case AstView.SuspendDefaultsOperator(_) => false - case _ => true - } - - val suspendPositions = args.zipWithIndex.collect { - case (AstView.SuspendDefaultsOperator(_), ix) => ix - } - - val hasDefaultsSuspended = suspendPositions.contains(args.length - 1) + val (validArguments, hasDefaultsSuspended) = + calculateDefaultsSuspension(args) Application.Prefix( translateExpression(name), diff --git a/engine/runtime/src/main/scala/org/enso/compiler/codegen/AstView.scala b/engine/runtime/src/main/scala/org/enso/compiler/codegen/AstView.scala index cd2ca91fc7..2333efc18e 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/codegen/AstView.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/codegen/AstView.scala @@ -355,6 +355,11 @@ object AstView { def unapply(ast: AST): Option[(AST, AST.Ident, List[AST])] = ast match { case OperatorDot(target, Application(ConsOrVar(ident), args)) => Some((target, ident, args)) + case AST.App.Section.Left( + MethodCall(target, ident, List()), + susp @ SuspendDefaultsOperator(_) + ) => + Some((target, ident, List(susp))) case OperatorDot(target, ConsOrVar(ident)) => Some((target, ident, List())) case _ => None diff --git a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IRToTruffle.scala b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IRToTruffle.scala index b78b25f184..e8aa975521 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/codegen/IRToTruffle.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/codegen/IRToTruffle.scala @@ -4,7 +4,11 @@ import com.oracle.truffle.api.Truffle import com.oracle.truffle.api.source.{Source, SourceSection} import org.enso.compiler.core.IR import org.enso.compiler.exception.{CompilerError, UnhandledEntity} -import org.enso.compiler.pass.analyse.{AliasAnalysis, ApplicationSaturation} +import org.enso.compiler.pass.analyse.{ + AliasAnalysis, + ApplicationSaturation, + TailCall +} import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Scope => AliasScope} import org.enso.compiler.pass.analyse.AliasAnalysis.{Graph => AliasGraph} import org.enso.interpreter.node.callable.argument.ReadArgumentNode @@ -127,7 +131,7 @@ class IRToTruffle( // Register the imports in scope imports.foreach( - i => this.moduleScope.addImport(context.compiler.requestProcess(i.name)) + i => this.moduleScope.addImport(context.compiler.processImport(i.name)) ) // Register the atoms and their constructors in scope @@ -138,7 +142,7 @@ class IRToTruffle( atomConstructors .zip(atomDefs) .foreach { - case (atomCons, atomDefn) => { + case (atomCons, atomDefn) => val scopeInfo = atomDefn .getMetadata[AliasAnalysis.Info.Scope.Root] .getOrElse( @@ -157,7 +161,6 @@ class IRToTruffle( } atomCons.initializeFields(argDefs: _*) - } } // Register the method definitions in scope @@ -166,7 +169,7 @@ class IRToTruffle( val scopeInfo = methodDef .getMetadata[AliasAnalysis.Info.Scope.Root] .getOrElse( - throw new CompilerError(("Missing scope information for method.")) + throw new CompilerError("Missing scope information for method.") ) val typeName = @@ -182,6 +185,12 @@ class IRToTruffle( scopeInfo.graph.rootScope ) + val methodFunIsTail = methodDef.body + .getMetadata[TailCall.Metadata] + .getOrElse( + throw new CompilerError("Method body missing tail call info.") + ) + val funNode = methodDef.body match { case fn: IR.Function => expressionProcessor.processFunctionBody( @@ -190,14 +199,12 @@ class IRToTruffle( fn.location ) case _ => - expressionProcessor.processFunctionBody( - List(), - methodDef.body, - methodDef.body.location + throw new CompilerError( + "Method bodies must be functions at the point of codegen." ) } - funNode.markTail() + funNode.setTail(methodFunIsTail) val function = new RuntimeFunction( funNode.getCallTarget, @@ -266,7 +273,7 @@ class IRToTruffle( val scopeName: String ) { - private var currentVarName = "anonymous"; + private var currentVarName = "anonymous" // === Construction ======================================================= @@ -299,25 +306,36 @@ class IRToTruffle( * @param ir the IR to generate code for * @return a truffle expression that represents the same program as `ir` */ - def run(ir: IR): RuntimeExpression = ir match { - case block: IR.Expression.Block => processBlock(block) - case literal: IR.Literal => processLiteral(literal) - case app: IR.Application => processApplication(app) - case name: IR.Name => processName(name) - case function: IR.Function => processFunction(function) - case binding: IR.Expression.Binding => processBinding(binding) - case caseExpr: IR.Case => processCase(caseExpr) - case comment: IR.Comment => processComment(comment) - case err: IR.Error => - throw new CompilerError( - s"No errors should remain by the point of truffle codegen, but " + - s"found $err." + def run(ir: IR): RuntimeExpression = { + val tailMeta = ir + .getMetadata[TailCall.Metadata] + .getOrElse( + throw new CompilerError(s"Missing tail call metadata for $ir") ) - case IR.Foreign.Definition(_, _, _, _) => - throw new CompilerError( - s"Foreign expressions not yet implemented: $ir." - ) - case _ => throw new UnhandledEntity(ir, "run") + + val runtimeExpression = ir match { + case block: IR.Expression.Block => processBlock(block) + case literal: IR.Literal => processLiteral(literal) + case app: IR.Application => processApplication(app) + case name: IR.Name => processName(name) + case function: IR.Function => processFunction(function) + case binding: IR.Expression.Binding => processBinding(binding) + case caseExpr: IR.Case => processCase(caseExpr) + case comment: IR.Comment => processComment(comment) + case err: IR.Error => + throw new CompilerError( + s"No errors should remain by the point of truffle codegen, but " + + s"found $err." + ) + case IR.Foreign.Definition(_, _, _, _) => + throw new CompilerError( + s"Foreign expressions not yet implemented: $ir." + ) + case _ => throw new UnhandledEntity(ir, "run") + } + + runtimeExpression.setTail(tailMeta) + runtimeExpression } /** Executes the expression processor on a piece of code that has been @@ -328,7 +346,6 @@ class IRToTruffle( */ def runInline(ir: IR.Expression): RuntimeExpression = { val expression = run(ir) - expression.markNotTail() expression } @@ -393,9 +410,21 @@ class IRToTruffle( val cases = branches .map( - branch => - ConstructorCaseNode + branch => { + val caseIsTail = branch + .getMetadata[TailCall.Metadata] + .getOrElse( + throw new CompilerError( + "Case branch missing tail position information." + ) + ) + + val caseNode = ConstructorCaseNode .build(this.run(branch.pattern), this.run(branch.expression)) + caseNode.setTail(caseIsTail) + + caseNode + } ) .toArray[CaseNode] @@ -466,7 +495,6 @@ class IRToTruffle( function.body, function.location ) - fn.setTail(function.canBeTCO) fn } @@ -485,7 +513,7 @@ class IRToTruffle( throw new CompilerError("No occurence on variable usage.") ) - val slot = scope.getFramePointer(useInfo.id) + val slot = scope.getFramePointer(useInfo.id) val atomCons = moduleScope.getConstructor(nameStr).toScala if (nameStr == Constants.Names.CURRENT_MODULE) { ConstructorNode.build(moduleScope.getAssociatedType) @@ -567,6 +595,14 @@ class IRToTruffle( } else seenArgNames.add(argName) } + val bodyIsTail = body + .getMetadata[TailCall.Metadata] + .getOrElse( + throw new CompilerError( + "Function body missing tail call information." + ) + ) + val bodyExpr = this.run(body) val fnBodyNode = BlockNode.build(argExpressions.toArray, bodyExpr) @@ -582,6 +618,8 @@ class IRToTruffle( val expr = CreateFunctionNode.build(callTarget, argDefinitions) + fnBodyNode.setTail(bodyIsTail) + setLocation(expr, location) } @@ -696,7 +734,16 @@ class IRToTruffle( val childScope = scope.createChild(scopeInfo.scope) val argumentExpression = new ExpressionProcessor(childScope, scopeName).run(value) - argumentExpression.markTail() + + val argExpressionIsTail = value + .getMetadata[TailCall.Metadata] + .getOrElse( + throw new CompilerError( + "Argument with missing tail call information." + ) + ) + + argumentExpression.setTail(argExpressionIsTail) val displayName = s"call_argument<${name.getOrElse(String.valueOf(position))}>" diff --git a/engine/runtime/src/main/scala/org/enso/compiler/core/IR.scala b/engine/runtime/src/main/scala/org/enso/compiler/core/IR.scala index ef9b7ebe0a..1b0cd37d31 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/core/IR.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/core/IR.scala @@ -2,7 +2,7 @@ package org.enso.compiler.core import org.enso.compiler.core.IR.Expression import org.enso.syntax.text.ast.Doc -import org.enso.syntax.text.{AST, Location} +import org.enso.syntax.text.{AST, Debug, Location} import scala.collection.immutable.{Set => ISet} import scala.reflect.ClassTag @@ -51,6 +51,12 @@ sealed trait IR { * @return `this`, potentially having had its children transformed by `fn` */ def mapExpressions(fn: Expression => Expression): IR + + /** Pretty prints the IR. + * + * @return a pretty-printed representation of the IR + */ + def pretty: String = Debug.pretty(this.toString) } object IR { @@ -185,6 +191,8 @@ object IR { * @param location the source location that the node corresponds to * @param passData the pass metadata associated with this node */ + // TODO [AA] Separate Method into Method.Binding and Method.Explicit to + // account for syntax sugar later. sealed case class Method( typeName: IR.Name, methodName: IR.Name, @@ -850,7 +858,7 @@ object IR { right: Expression, override val location: Option[Location], override val passData: ISet[Metadata] = ISet() - ) extends Application + ) extends Operator with IRKind.Sugar { override def addMetadata(newData: Metadata): Binary = { copy(passData = this.passData + newData) @@ -1035,7 +1043,7 @@ object IR { code: String, override val location: Option[Location], override val passData: ISet[Metadata] = ISet() - ) extends Expression + ) extends Foreign with IRKind.Primitive { override def addMetadata(newData: Metadata): Definition = { copy(passData = this.passData + newData) diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/IRPass.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/IRPass.scala index 344d642898..978a89b7bc 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/IRPass.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/IRPass.scala @@ -1,7 +1,7 @@ package org.enso.compiler.pass +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR -import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope} /** A representation of a compiler pass that runs on the [[IR]] type. */ trait IRPass { @@ -22,14 +22,13 @@ trait IRPass { * or annotated version of `ir` in an inline context. * * @param ir the Enso IR to process - * @param localScope the local scope in which the expression is executed - * @param moduleScope the module scope in which the expression is executed + * @param inlineContext a context object that contains the information needed + * for inline evaluation * @return `ir`, possibly having made transformations or annotations to that * IR. */ def runExpression( ir: IR.Expression, - localScope: Option[LocalScope] = None, - moduleScope: Option[ModuleScope] = None + inlineContext: InlineContext ): IR.Expression } diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala index db84f424fc..f53018c183 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/AliasAnalysis.scala @@ -1,10 +1,10 @@ package org.enso.compiler.pass.analyse +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR import org.enso.compiler.exception.CompilerError import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Occurrence, Scope} -import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope} import org.enso.syntax.text.Debug import scala.reflect.ClassTag @@ -52,17 +52,16 @@ case object AliasAnalysis extends IRPass { * provided scope. * * @param ir the Enso IR to process - * @param localScope the local scope in which the expression is executed - * @param moduleScope the module scope in which the expression is executed + * @param inlineContext a context object that contains the information needed + * for inline evaluation * @return `ir`, possibly having made transformations or annotations to that * IR. */ override def runExpression( ir: IR.Expression, - localScope: Option[LocalScope] = None, - moduleScope: Option[ModuleScope] = None + inlineContext: InlineContext ): IR.Expression = - localScope + inlineContext.localScope .map { localScope => val scope = localScope.scope val graph = localScope.aliasingGraph @@ -90,31 +89,23 @@ case object AliasAnalysis extends IRPass { ir match { case m @ IR.Module.Scope.Definition.Method(_, _, body, _, _) => - val bodyWithThisArg = body match { - case lam @ IR.Function.Lambda(args, _, _, _, _) => - lam.copy( - arguments = IR.DefinitionArgument.Specified( - IR.Name.This(None), - None, - suspended = false, - None - ) :: args - ) + body match { + case _: IR.Function => + m.copy( + body = analyseExpression( + body, + topLevelGraph, + topLevelGraph.rootScope, + lambdaReuseScope = true, + blockReuseScope = true + ) + ) + .addMetadata(Info.Scope.Root(topLevelGraph)) case _ => throw new CompilerError( - "The body of a method should always be a lambda by." + "The body of a method should always be a function." ) } - m.copy( - body = analyseExpression( - bodyWithThisArg, - topLevelGraph, - topLevelGraph.rootScope, - lambdaReuseScope = true, - blockReuseScope = true - ) - ) - .addMetadata(Info.Scope.Root(topLevelGraph)) case a @ IR.Module.Scope.Definition.Atom(_, args, _, _) => a.copy( arguments = diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/ApplicationSaturation.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/ApplicationSaturation.scala index 4e07b08ce9..1009a8ab38 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/ApplicationSaturation.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/ApplicationSaturation.scala @@ -1,26 +1,22 @@ package org.enso.compiler.pass.analyse +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR import org.enso.compiler.exception.CompilerError import org.enso.compiler.pass.IRPass -import org.enso.compiler.pass.analyse.ApplicationSaturation.{CallSaturation, Default, FunctionSpec, PassConfiguration} +import org.enso.compiler.pass.analyse.ApplicationSaturation.{ + CallSaturation, + Default, + FunctionSpec, + PassConfiguration +} import org.enso.interpreter.node.{ExpressionNode => RuntimeExpression} import org.enso.interpreter.runtime.callable.argument.CallArgument -import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope} - -import scala.annotation.unused /** This optimisation pass recognises fully-saturated applications of known * functions and writes analysis data that allows optimisation of them to * specific nodes at codegen time. * - * PLEASE NOTE: This implementation is _incomplete_ as the analysis it performs - * is _unconditional_ at this stage. This means that, until we have alias - * analysis information, - * - * PLEASE NOTE: This implementation is _incomplete_ as the analysis it performs - * only operates for functions where the arguments are applied positionally. - * * @param knownFunctions a mapping from known function names to information * about that function that can be used for optimisation */ @@ -39,7 +35,9 @@ case class ApplicationSaturation( * IR. */ override def runModule(ir: IR.Module): IR.Module = - ir.transformExpressions({ case x => runExpression(x) }) + ir.transformExpressions({ + case x => runExpression(x, new InlineContext) + }) /** Executes the analysis pass, marking functions with information about their * argument saturation. @@ -50,8 +48,7 @@ case class ApplicationSaturation( */ override def runExpression( ir: IR.Expression, - @unused localScope: Option[LocalScope] = None, - @unused moduleScope: Option[ModuleScope] = None + inlineContext: InlineContext ): IR.Expression = { ir.transformExpressions { case func @ IR.Application.Prefix(fn, args, _, _, meta) => @@ -82,7 +79,8 @@ case class ApplicationSaturation( func.copy( arguments = args.map( _.mapExpressions( - (ir: IR.Expression) => runExpression(ir) + (ir: IR.Expression) => + runExpression(ir, inlineContext) ) ), passData = meta + saturationInfo @@ -92,7 +90,8 @@ case class ApplicationSaturation( func.copy( arguments = args.map( _.mapExpressions( - (ir: IR.Expression) => runExpression(ir) + (ir: IR.Expression) => + runExpression(ir, inlineContext) ) ), passData = meta + CallSaturation.Over(args.length - arity) @@ -101,7 +100,8 @@ case class ApplicationSaturation( func.copy( arguments = args.map( _.mapExpressions( - (ir: IR.Expression) => runExpression(ir) + (ir: IR.Expression) => + runExpression(ir, inlineContext) ) ), passData = meta + CallSaturation.Partial( @@ -112,22 +112,26 @@ case class ApplicationSaturation( case None => func.copy( arguments = args.map( - _.mapExpressions((ir: IR.Expression) => runExpression(ir)) + _.mapExpressions( + (ir: IR.Expression) => runExpression(ir, inlineContext) + ) ), passData = meta + CallSaturation.Unknown() ) } } else { func.copy( - function = runExpression(fn), - arguments = args.map(_.mapExpressions(runExpression(_))), + function = runExpression(fn, inlineContext), + arguments = + args.map(_.mapExpressions(runExpression(_, inlineContext))), passData = meta + CallSaturation.Unknown() ) } case _ => func.copy( - function = runExpression(fn), - arguments = args.map(_.mapExpressions(runExpression(_))), + function = runExpression(fn, inlineContext), + arguments = + args.map(_.mapExpressions(runExpression(_, inlineContext))), passData = meta + CallSaturation.Unknown() ) } diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala new file mode 100644 index 0000000000..e9f71162e0 --- /dev/null +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/analyse/TailCall.scala @@ -0,0 +1,380 @@ +package org.enso.compiler.pass.analyse + +import org.enso.compiler.InlineContext +import org.enso.compiler.core.IR +import org.enso.compiler.exception.CompilerError +import org.enso.compiler.pass.IRPass + +/** This pass performs tail call analysis on the Enso IR. + * + * It is responsible for marking every single expression with whether it is in + * tail position or not. This allows the code generator to correctly create the + * Truffle nodes. + */ +case object TailCall extends IRPass { + + /** The annotation metadata type associated with IR nodes by this pass. */ + override type Metadata = TailPosition + + /** Analyses tail call state for expressions in a module. + * + * @param ir the Enso IR to process + * @return `ir`, possibly having made transformations or annotations to that + * IR. + */ + override def runModule(ir: IR.Module): IR.Module = { + ir.copy(bindings = ir.bindings.map(analyseModuleBinding)) + } + + /** Analyses tail call state for an arbitrary expression. + * + * @param ir the Enso IR to process + * @param inlineContext a context object that contains the information needed + * for inline evaluation + * @return `ir`, possibly having made transformations or annotations to that + * IR. + */ + override def runExpression( + ir: IR.Expression, + inlineContext: InlineContext + ): IR.Expression = + analyseExpression( + ir, + inlineContext.isInTailPosition.getOrElse( + throw new CompilerError( + "Information about the tail position for an inline expression " + + "must be known by the point of tail call analysis." + ) + ) + ) + + /** Performs tail call analysis on a top-level definition in a module. + * + * @param definition the top-level definition to analyse + * @return `definition`, annotated with tail call information + */ + def analyseModuleBinding( + definition: IR.Module.Scope.Definition + ): IR.Module.Scope.Definition = { + definition match { + case method @ IR.Module.Scope.Definition.Method(_, _, body, _, _) => + method + .copy( + body = analyseExpression(body, isInTailPosition = true) + ) + .addMetadata(TailPosition.Tail) + case atom @ IR.Module.Scope.Definition.Atom(_, args, _, _) => + atom + .copy( + arguments = args.map(analyseDefArgument) + ) + .addMetadata(TailPosition.Tail) + } + } + + /** Performs tail call analysis on an arbitrary expression. + * + * @param expression the expression to analyse + * @param isInTailPosition whether or not the expression is occurring in tail + * position + * @return `expression`, annotated with tail position metadata + */ + def analyseExpression( + expression: IR.Expression, + isInTailPosition: Boolean + ): IR.Expression = { + expression match { + case function: IR.Function => analyseFunction(function, isInTailPosition) + case caseExpr: IR.Case => analyseCase(caseExpr, isInTailPosition) + case typ: IR.Type => analyseType(typ, isInTailPosition) + case app: IR.Application => analyseApplication(app, isInTailPosition) + case name: IR.Name => analyseName(name, isInTailPosition) + case foreign: IR.Foreign => foreign.addMetadata(TailPosition.NotTail) + case literal: IR.Literal => analyseLiteral(literal, isInTailPosition) + case comment: IR.Comment => analyseComment(comment, isInTailPosition) + case block @ IR.Expression.Block(expressions, returnValue, _, _, _) => + block + .copy( + expressions = + expressions.map(analyseExpression(_, isInTailPosition = false)), + returnValue = analyseExpression(returnValue, isInTailPosition) + ) + .addMetadata(TailPosition.fromBool(isInTailPosition)) + case binding @ IR.Expression.Binding(_, expression, _, _) => + binding + .copy( + expression = analyseExpression(expression, isInTailPosition) + ) + .addMetadata(TailPosition.fromBool(isInTailPosition)) + case err: IR.Error => err + } + } + + /** Performs tail call analysis on an occurrence of a name. + * + * @param name the name to analyse + * @param isInTailPosition whether the name occurs in tail position or not + * @return `name`, annotated with tail position metadata + */ + def analyseName(name: IR.Name, isInTailPosition: Boolean): IR.Name = { + name.addMetadata(TailPosition.fromBool(isInTailPosition)) + } + + /** Performs tail call analysis on a comment occurrence. + * + * @param comment the comment to analyse + * @param isInTailPosition whether the comment occurs in tail position or not + * @return `comment`, annotated with tail position metadata + */ + def analyseComment( + comment: IR.Comment, + isInTailPosition: Boolean + ): IR.Comment = { + comment match { + case doc @ IR.Comment.Documentation(expr, _, _, _) => + doc + .copy(commented = analyseExpression(expr, isInTailPosition)) + .addMetadata(TailPosition.fromBool(isInTailPosition)) + } + } + + /** Performs tail call analysis on a literal. + * + * @param literal the literal to analyse + * @param isInTailPosition whether or not the literal occurs in tail position + * or not + * @return `literal`, annotated with tail position metdata + */ + def analyseLiteral( + literal: IR.Literal, + isInTailPosition: Boolean + ): IR.Literal = { + literal.addMetadata(TailPosition.fromBool(isInTailPosition)) + } + + /** Performs tail call analysis on an application. + * + * @param application the application to analyse + * @param isInTailPosition whether or not the application is occurring in + * tail position + * @return `application`, annotated with tail position metadata + */ + def analyseApplication( + application: IR.Application, + isInTailPosition: Boolean + ): IR.Application = { + application match { + case app @ IR.Application.Prefix(fn, args, _, _, _) => + app + .copy( + function = analyseExpression(fn, isInTailPosition = false), + arguments = args.map(analyseCallArg) + ) + .addMetadata(TailPosition.fromBool(isInTailPosition)) + case force @ IR.Application.Force(target, _, _) => + force + .copy( + target = analyseExpression(target, isInTailPosition) + ) + .addMetadata(TailPosition.fromBool(isInTailPosition)) + case _: IR.Application.Operator => + throw new CompilerError("Unexpected binary operator.") + } + } + + /** Performs tail call analysis on a call site argument. + * + * @param argument the argument to analyse + * @return `argument`, annotated with tail position metadata + */ + def analyseCallArg(argument: IR.CallArgument): IR.CallArgument = { + argument match { + case arg @ IR.CallArgument.Specified(_, expr, _, _) => + arg + .copy( + // Note [Call Argument Tail Position] + value = analyseExpression(expr, isInTailPosition = true) + ) + .addMetadata(TailPosition.Tail) + } + } + + /* Note [Call Argument Tail Position] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * In order to efficiently deal with Enso's ability to suspend function + * arguments, we behave as if all arguments to a function are passed as + * thunks. This means that the _function_ becomes responsible for deciding + * when to evaluate its arguments. + * + * Conceptually, this results in a desugaring as follows: + * + * ``` + * foo a b c + * ``` + * + * Becomes: + * + * ``` + * foo ({} -> a) ({} -> b) ({} -> c) + * ``` + * + * Quite obviously, the arguments `a`, `b` and `c` are in tail position in + * these closures, and hence should be marked as tail. + */ + + /** Performs tail call analysis on an expression involving type operators. + * + * @param value the type operator expression + * @param isInTailPosition whether or not the type operator occurs in a tail + * call position + * @return `value`, annotated with tail position metadata + */ + def analyseType(value: IR.Type, isInTailPosition: Boolean): IR.Type = { + value + .mapExpressions(analyseExpression(_, isInTailPosition = false)) + .addMetadata(TailPosition.fromBool(isInTailPosition)) + } + + /** Performs tail call analysis on a case expression. + * + * @param caseExpr the case expression to analyse + * @param isInTailPosition whether or not the case expression occurs in a tail + * call position + * @return `caseExpr`, annotated with tail position metadata + */ + def analyseCase(caseExpr: IR.Case, isInTailPosition: Boolean): IR.Case = { + caseExpr match { + case caseExpr @ IR.Case.Expr(scrutinee, branches, fallback, _, _) => + caseExpr + .copy( + scrutinee = analyseExpression(scrutinee, isInTailPosition = false), + // Note [Analysing Branches in Case Expressions] + branches = branches.map(analyseCaseBranch(_, isInTailPosition)), + fallback = fallback.map(analyseExpression(_, isInTailPosition)) + ) + .addMetadata(TailPosition.fromBool(isInTailPosition)) + case _: IR.Case.Branch => + throw new CompilerError("Unexpected case branch.") + } + } + + /* Note [Analysing Branches in Case Expressions] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * When performing tail call analysis on a case expression it is very + * important to recognise that the branches of a case expression should all + * have the same tail call state. The branches should only be marked as being + * in tail position when the case expression _itself_ is in tail position. + * + * As only one branch is ever executed, it is hence safe to mark _all_ + * branches as being in tail position if the case expression is. + */ + + /** Performs tail call analysis on a case branch. + * + * @param branch the branch to analyse + * @param isInTailPosition whether or not the branch occurs in a tail call + * position + * @return `branch`, annotated with tail position metadata + */ + def analyseCaseBranch( + branch: IR.Case.Branch, + isInTailPosition: Boolean + ): IR.Case.Branch = { + branch + .copy( + pattern = analyseExpression(branch.pattern, isInTailPosition = false), + expression = analyseExpression( + branch.expression, + isInTailPosition + ) + ) + .addMetadata(TailPosition.fromBool(isInTailPosition)) + } + + /** Performs tail call analysis on a function definition. + * + * @param function the function to analyse + * @param isInTailPosition whether or not the function definition occurs in a + * tail position + * @return `function`, annotated with tail position metadata + */ + def analyseFunction( + function: IR.Function, + isInTailPosition: Boolean + ): IR.Function = { + val canBeTCO = function.canBeTCO + val markAsTail = (!canBeTCO && isInTailPosition) || canBeTCO + + val resultFunction = function match { + case lambda @ IR.Function.Lambda(args, body, _, _, _) => + lambda.copy( + arguments = args.map(analyseDefArgument), + body = analyseExpression(body, isInTailPosition = markAsTail) + ) + } + + resultFunction.addMetadata( + TailPosition.fromBool(isInTailPosition) + ) + } + + /** Performs tail call analysis on a function definition argument. + * + * @param arg the argument definition to analyse + * @return `arg`, annotated with tail position metadata + */ + def analyseDefArgument(arg: IR.DefinitionArgument): IR.DefinitionArgument = { + arg match { + case arg @ IR.DefinitionArgument.Specified(_, default, _, _, _) => + arg + .copy( + defaultValue = default.map( + x => + analyseExpression(x, isInTailPosition = false) + .addMetadata(TailPosition.NotTail) + ) + ) + .addMetadata(TailPosition.NotTail) + case err: IR.Error.Redefined.Argument => err + } + } + + /** Expresses the tail call state of an IR Node. */ + sealed trait TailPosition extends IR.Metadata { + + /** A boolean representation of the expression's tail state. */ + def isTail: Boolean + } + object TailPosition { + + /** The expression is in a tail position and can be tail call optimised. */ + final case object Tail extends TailPosition { + override def isTail: Boolean = true + } + + /** The expression is not in a tail position and cannot be tail call + * optimised. + */ + final case object NotTail extends TailPosition { + override def isTail: Boolean = false + } + + /** Implicitly converts a boolean to a [[TailPosition]] value. + * + * @param isTail the boolean + * @return the tail position value corresponding to `bool` + */ + implicit def fromBool(isTail: Boolean): TailPosition = { + if (isTail) TailPosition.Tail else TailPosition.NotTail + } + + /** Implicitly converts the tail position data into a boolean. + * + * @param tailPosition the tail position value + * @return the boolean value corresponding to `tailPosition` + */ + implicit def toBool(tailPosition: TailPosition): Boolean = { + tailPosition.isTail + } + } +} diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/GenerateMethodBodies.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/GenerateMethodBodies.scala new file mode 100644 index 0000000000..a196345629 --- /dev/null +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/GenerateMethodBodies.scala @@ -0,0 +1,111 @@ +package org.enso.compiler.pass.desugar + +import org.enso.compiler.InlineContext +import org.enso.compiler.core.IR +import org.enso.compiler.pass.IRPass + +/** This pass is responsible for ensuring that method bodies are in the correct + * format. + * + * The correct format as far as the rest of the compiler pipeline is concerned + * is as follows: + * + * - The body is a function (lambda) + * - The body has `this` at the start of its argument list. + */ +case object GenerateMethodBodies extends IRPass { + + /** This is a desugaring pass and performs no analysis */ + override type Metadata = IR.Metadata.Empty + + /** Generates and consolidates method bodies. + * + * @param ir the Enso IR to process + * @return `ir`, possibly having made transformations or annotations to that + * IR. + */ + override def runModule(ir: IR.Module): IR.Module = { + ir.copy( + bindings = ir.bindings.map { + case m: IR.Module.Scope.Definition.Method => processMethodDef(m) + case x => x + } + ) + } + + /** Processes a method definition, ensuring that it's in the correct format. + * + * @param ir the method definition to process + * @return `ir` potentially with alterations to ensure that it's in the + * correct format + */ + def processMethodDef( + ir: IR.Module.Scope.Definition.Method + ): IR.Module.Scope.Definition.Method = { + ir.copy( + body = ir.body match { + case fun: IR.Function => processBodyFunction(fun) + case expression => processBodyExpression(expression) + } + ) + } + + /** Processes the method body if it's a function. + * + * This is solely responsible for prepending the `this` argument to the list + * of arguments. + * + * @param fun the body function + * @return the body function with the `this` argument + */ + def processBodyFunction(fun: IR.Function): IR.Function = { + fun match { + case lam @ IR.Function.Lambda(args, _, _, _, _) => + lam.copy( + arguments = genThisArgument :: args + ) + } + } + + /** Processes the method body if it's an expression. + * + * @param expr the body expression + * @return `expr` converted to a function taking the `this` argument + */ + def processBodyExpression(expr: IR.Expression): IR.Expression = { + IR.Function.Lambda( + arguments = List(genThisArgument), + body = expr, + location = expr.location + ) + } + + /** Generates a definition of the `this` argument for method definitions. + * + * @return the `this` argument + */ + def genThisArgument: IR.DefinitionArgument.Specified = { + IR.DefinitionArgument.Specified( + IR.Name.This(None), + None, + suspended = false, + None + ) + } + + /** Executes the pass on an expression. + * + * It is a identity operation on expressions as method definitions are not + * expressions. + * + * @param ir the Enso IR to process + * @param inlineContext a context object that contains the information needed + * for inline evaluation + * @return `ir`, possibly having made transformations or annotations to that + * IR. + */ + override def runExpression( + ir: IR.Expression, + inlineContext: InlineContext + ): IR.Expression = ir +} diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/LiftSpecialOperators.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/LiftSpecialOperators.scala index a640e5b334..dc596f5b6d 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/LiftSpecialOperators.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/LiftSpecialOperators.scala @@ -1,10 +1,8 @@ package org.enso.compiler.pass.desugar +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR import org.enso.compiler.pass.IRPass -import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope} - -import scala.annotation.unused /** This pass lifts any special operators (ones reserved by the language * implementation) into their own special IR constructs. @@ -21,45 +19,89 @@ case object LiftSpecialOperators extends IRPass { * IR. */ override def runModule(ir: IR.Module): IR.Module = - ir.transformExpressions({ case x => runExpression(x) }) + ir.transformExpressions({ + case x => runExpression(x, new InlineContext) + }) /** Executes the lifting pass in an inline context. * * @param ir the Enso IR to process + * @param inlineContext a context object that contains the information needed + * for inline evaluation * @return `ir`, possibly having made transformations or annotations to that * IR. */ override def runExpression( ir: IR.Expression, - @unused localScope: Option[LocalScope] = None, - @unused moduleScope: Option[ModuleScope] = None + inlineContext: InlineContext ): IR.Expression = ir.transformExpressions({ case IR.Application.Operator.Binary(l, op, r, loc, meta) => op.name match { case IR.Type.Ascription.name => - IR.Type.Ascription(runExpression(l), runExpression(r), loc, meta) + IR.Type.Ascription( + runExpression(l, inlineContext), + runExpression(r, inlineContext), + loc, + meta + ) case IR.Type.Set.Subsumption.name => IR.Type.Set - .Subsumption(runExpression(l), runExpression(r), loc, meta) + .Subsumption( + runExpression(l, inlineContext), + runExpression(r, inlineContext), + loc, + meta + ) case IR.Type.Set.Equality.name => IR.Type.Set - .Equality(runExpression(l), runExpression(r), loc, meta) + .Equality( + runExpression(l, inlineContext), + runExpression(r, inlineContext), + loc, + meta + ) case IR.Type.Set.Concat.name => IR.Type.Set - .Concat(runExpression(l), runExpression(r), loc, meta) + .Concat( + runExpression(l, inlineContext), + runExpression(r, inlineContext), + loc, + meta + ) case IR.Type.Set.Union.name => IR.Type.Set - .Union(runExpression(l), runExpression(r), loc, meta) + .Union( + runExpression(l, inlineContext), + runExpression(r, inlineContext), + loc, + meta + ) case IR.Type.Set.Intersection.name => IR.Type.Set - .Intersection(runExpression(l), runExpression(r), loc, meta) + .Intersection( + runExpression(l, inlineContext), + runExpression(r, inlineContext), + loc, + meta + ) case IR.Type.Set.Subtraction.name => IR.Type.Set - .Subtraction(runExpression(l), runExpression(r), loc, meta) + .Subtraction( + runExpression(l, inlineContext), + runExpression(r, inlineContext), + loc, + meta + ) case _ => IR.Application.Operator - .Binary(runExpression(l), op, runExpression(r), loc, meta) + .Binary( + runExpression(l, inlineContext), + op, + runExpression(r, inlineContext), + loc, + meta + ) } }) diff --git a/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/OperatorToFunction.scala b/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/OperatorToFunction.scala index 89ca82a2e8..3bb1a41013 100644 --- a/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/OperatorToFunction.scala +++ b/engine/runtime/src/main/scala/org/enso/compiler/pass/desugar/OperatorToFunction.scala @@ -1,8 +1,8 @@ package org.enso.compiler.pass.desugar +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR import org.enso.compiler.pass.IRPass -import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope} /** This pass converts usages of operators to calls to standard functions. */ case object OperatorToFunction extends IRPass { @@ -17,26 +17,31 @@ case object OperatorToFunction extends IRPass { * IR. */ override def runModule(ir: IR.Module): IR.Module = - ir.transformExpressions({ case x => runExpression(x) }) + ir.transformExpressions({ + case x => runExpression(x, new InlineContext) + }) /** Executes the conversion pass in an inline context. * * @param ir the Enso IR to process + * @param inlineContext a context object that contains the information needed + * for inline evaluation * @return `ir`, possibly having made transformations or annotations to that * IR. */ override def runExpression( ir: IR.Expression, - localScope: Option[LocalScope] = None, - moduleScope: Option[ModuleScope] = None + inlineContext: InlineContext ): IR.Expression = ir.transformExpressions { case IR.Application.Operator.Binary(l, op, r, loc, passData) => IR.Application.Prefix( op, List( - IR.CallArgument.Specified(None, runExpression(l), l.location), - IR.CallArgument.Specified(None, runExpression(r), r.location) + IR.CallArgument + .Specified(None, runExpression(l, inlineContext), l.location), + IR.CallArgument + .Specified(None, runExpression(r, inlineContext), r.location) ), hasDefaultsSuspended = false, loc, diff --git a/engine/runtime/src/test/scala/org/enso/compiler/test/CompilerTest.scala b/engine/runtime/src/test/scala/org/enso/compiler/test/CompilerTest.scala index a1844dc8cb..1940ad186a 100644 --- a/engine/runtime/src/test/scala/org/enso/compiler/test/CompilerTest.scala +++ b/engine/runtime/src/test/scala/org/enso/compiler/test/CompilerTest.scala @@ -1,9 +1,9 @@ package org.enso.compiler.test +import org.enso.compiler.InlineContext import org.enso.compiler.codegen.AstToIR import org.enso.compiler.core.IR import org.enso.compiler.pass.IRPass -import org.enso.interpreter.runtime.scope.LocalScope import org.enso.syntax.text.{AST, Parser} import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpecLike @@ -93,12 +93,12 @@ trait CompilerRunner { */ def runPasses( passes: List[IRPass], - localScope: Option[LocalScope] = None + inlineContext: InlineContext ): IR = ir match { case expr: IR.Expression => passes.foldLeft(expr)( (intermediate, pass) => - pass.runExpression(intermediate, localScope = localScope) + pass.runExpression(intermediate, inlineContext) ) case mod: IR.Module => passes.foldLeft(mod)( diff --git a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/AliasAnalysisTest.scala b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/AliasAnalysisTest.scala index e82ecc67d5..1f6f9178b7 100644 --- a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/AliasAnalysisTest.scala +++ b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/AliasAnalysisTest.scala @@ -1,12 +1,17 @@ package org.enso.compiler.test.pass.analyse +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR import org.enso.compiler.core.IR.Module.Scope.Definition.{Atom, Method} import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.analyse.AliasAnalysis -import org.enso.compiler.pass.analyse.AliasAnalysis.{Graph, Info} import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Link, Occurrence} -import org.enso.compiler.pass.desugar.{LiftSpecialOperators, OperatorToFunction} +import org.enso.compiler.pass.analyse.AliasAnalysis.{Graph, Info} +import org.enso.compiler.pass.desugar.{ + GenerateMethodBodies, + LiftSpecialOperators, + OperatorToFunction +} import org.enso.compiler.test.CompilerTest class AliasAnalysisTest extends CompilerTest { @@ -19,6 +24,7 @@ class AliasAnalysisTest extends CompilerTest { */ implicit class Preprocess(source: String) { val precursorPasses: List[IRPass] = List( + GenerateMethodBodies, LiftSpecialOperators, OperatorToFunction ) @@ -28,7 +34,9 @@ class AliasAnalysisTest extends CompilerTest { * @return IR appropriate for testing the alias analysis pass as a module */ def preprocessModule: IR.Module = { - source.toIrModule.runPasses(precursorPasses).asInstanceOf[IR.Module] + source.toIrModule + .runPasses(precursorPasses, InlineContext()) + .asInstanceOf[IR.Module] } /** Translates the source code into appropriate IR for testing this pass @@ -36,9 +44,12 @@ class AliasAnalysisTest extends CompilerTest { * @return IR appropriate for testing the alias analysis pass as an * expression */ - def preprocessExpression: Option[IR.Expression] = { + def preprocessExpression( + inlineContext: InlineContext + ): Option[IR.Expression] = { source.toIrExpression.map( - _.runPasses(precursorPasses).asInstanceOf[IR.Expression] + _.runPasses(precursorPasses, inlineContext) + .asInstanceOf[IR.Expression] ) } } @@ -68,8 +79,8 @@ class AliasAnalysisTest extends CompilerTest { * * @return [[ir]], with attached aliasing information */ - def analyse: IR.Expression = { - AliasAnalysis.runExpression(ir) + def analyse(inlineContext: InlineContext): IR.Expression = { + AliasAnalysis.runExpression(ir, inlineContext) } } diff --git a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/ApplicationSaturationTest.scala b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/ApplicationSaturationTest.scala index 56903091bb..03f78edc04 100644 --- a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/ApplicationSaturationTest.scala +++ b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/ApplicationSaturationTest.scala @@ -1,13 +1,14 @@ package org.enso.compiler.test.pass.analyse +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR import org.enso.compiler.core.IR.Metadata -import org.enso.compiler.pass.analyse.{AliasAnalysis, ApplicationSaturation} import org.enso.compiler.pass.analyse.ApplicationSaturation.{ CallSaturation, FunctionSpec, PassConfiguration } +import org.enso.compiler.pass.analyse.{AliasAnalysis, ApplicationSaturation} import org.enso.compiler.pass.desugar.{LiftSpecialOperators, OperatorToFunction} import org.enso.compiler.test.CompilerTest import org.enso.interpreter.node.ExpressionNode @@ -58,39 +59,54 @@ class ApplicationSaturationTest extends CompilerTest { val localScope = Some(LocalScope.root) + val ctx = new InlineContext(localScope = localScope) + // === The Tests ============================================================ "Known applications" should { - val plusFn = IR.Application.Prefix( - IR.Name.Literal("+", None), - genNArgs(2), - hasDefaultsSuspended = false, - None - ).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] + val plusFn = IR.Application + .Prefix( + IR.Name.Literal("+", None), + genNArgs(2), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] - val bazFn = IR.Application.Prefix( - IR.Name.Literal("baz", None), - genNArgs(2), - hasDefaultsSuspended = false, - None - ).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] + val bazFn = IR.Application + .Prefix( + IR.Name.Literal("baz", None), + genNArgs(2), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] - val fooFn = IR.Application.Prefix( - IR.Name.Literal("foo", None), - genNArgs(5), - hasDefaultsSuspended = false, - None - ).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] + val fooFn = IR.Application + .Prefix( + IR.Name.Literal("foo", None), + genNArgs(5), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] - val fooFnByName = IR.Application.Prefix( - IR.Name.Literal("foo", None), - genNArgs(4, positional = false), - hasDefaultsSuspended = false, - None - ).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] + val fooFnByName = IR.Application + .Prefix( + IR.Name.Literal("foo", None), + genNArgs(4, positional = false), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] "be tagged with full saturation where possible" in { - val resultIR = ApplicationSaturation(knownFunctions).runExpression(plusFn) + val resultIR = + ApplicationSaturation(knownFunctions).runExpression(plusFn, ctx) resultIR.getMetadata[CallSaturation].foreach { case _: CallSaturation.Exact => succeed @@ -99,14 +115,16 @@ class ApplicationSaturationTest extends CompilerTest { } "be tagged with partial saturation where possible" in { - val resultIR = ApplicationSaturation(knownFunctions).runExpression(bazFn) + val resultIR = + ApplicationSaturation(knownFunctions).runExpression(bazFn, ctx) val expected = Some(CallSaturation.Partial(1)) resultIR.getMetadata[CallSaturation] shouldEqual expected } "be tagged with over saturation where possible" in { - val resultIR = ApplicationSaturation(knownFunctions).runExpression(fooFn) + val resultIR = + ApplicationSaturation(knownFunctions).runExpression(fooFn, ctx) val expected = Some(CallSaturation.Over(1)) resultIR.getMetadata[CallSaturation] shouldEqual expected @@ -114,7 +132,7 @@ class ApplicationSaturationTest extends CompilerTest { "be tagged with by name if applied by name" in { val resultIR = - ApplicationSaturation(knownFunctions).runExpression(fooFnByName) + ApplicationSaturation(knownFunctions).runExpression(fooFnByName, ctx) val expected = Some(CallSaturation.ExactButByName()) resultIR.getMetadata[CallSaturation] shouldEqual expected @@ -122,16 +140,19 @@ class ApplicationSaturationTest extends CompilerTest { } "Unknown applications" should { - val unknownFn = IR.Application.Prefix( - IR.Name.Literal("unknown", None), - genNArgs(10), - hasDefaultsSuspended = false, - None - ).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] + val unknownFn = IR.Application + .Prefix( + IR.Name.Literal("unknown", None), + genNArgs(10), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] "be tagged with unknown saturation" in { val resultIR = - ApplicationSaturation(knownFunctions).runExpression(unknownFn) + ApplicationSaturation(knownFunctions).runExpression(unknownFn, ctx) val expected = Some(CallSaturation.Unknown()) resultIR.getMetadata[CallSaturation] shouldEqual expected @@ -140,26 +161,35 @@ class ApplicationSaturationTest extends CompilerTest { "Known applications containing known applications" should { val empty = IR.Empty(None) - val knownPlus = IR.Application.Prefix( - IR.Name.Literal("+", None), - genNArgs(2), - hasDefaultsSuspended = false, - None - ).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] + val knownPlus = IR.Application + .Prefix( + IR.Name.Literal("+", None), + genNArgs(2), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] - val undersaturatedPlus = IR.Application.Prefix( - IR.Name.Literal("+", None), - genNArgs(1), - hasDefaultsSuspended = false, - None - ).runPasses(passes,localScope).asInstanceOf[IR.Application.Prefix] + val undersaturatedPlus = IR.Application + .Prefix( + IR.Name.Literal("+", None), + genNArgs(1), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] - val oversaturatedPlus = IR.Application.Prefix( - IR.Name.Literal("+", None), - genNArgs(3), - hasDefaultsSuspended = false, - None - ).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] + val oversaturatedPlus = IR.Application + .Prefix( + IR.Name.Literal("+", None), + genNArgs(3), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] implicit class InnerMeta(ir: IR.Expression) { def getInnerMetadata[T <: Metadata: ClassTag]: Option[T] = { @@ -173,21 +203,25 @@ class ApplicationSaturationTest extends CompilerTest { } def outerPlus(argExpr: IR.Expression): IR.Application.Prefix = { - IR.Application.Prefix( - IR.Name.Literal("+", None), - List( - IR.CallArgument.Specified(None, argExpr, None), - IR.CallArgument.Specified(None, empty, None) - ), - hasDefaultsSuspended = false, - None - ).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] + IR.Application + .Prefix( + IR.Name.Literal("+", None), + List( + IR.CallArgument.Specified(None, argExpr, None), + IR.CallArgument.Specified(None, empty, None) + ), + hasDefaultsSuspended = false, + None + ) + .runPasses(passes, ctx) + .asInstanceOf[IR.Application.Prefix] } "have fully saturated applications tagged correctly" in { val result = ApplicationSaturation(knownFunctions).runExpression( - outerPlus(knownPlus) + outerPlus(knownPlus), + ctx ) // The outer should be reported as fully saturated @@ -206,7 +240,8 @@ class ApplicationSaturationTest extends CompilerTest { "have non-fully saturated applications tagged correctly" in { val result = ApplicationSaturation(knownFunctions).runExpression( - outerPlus(undersaturatedPlus) + outerPlus(undersaturatedPlus), + ctx ) val expectedInnerMeta = CallSaturation.Partial(1) @@ -225,7 +260,8 @@ class ApplicationSaturationTest extends CompilerTest { "have a mixture of application saturations tagged correctly" in { val result = ApplicationSaturation(knownFunctions).runExpression( - outerPlus(oversaturatedPlus) + outerPlus(oversaturatedPlus), + ctx ) val expectedInnerMeta = CallSaturation.Over(1) @@ -252,11 +288,11 @@ class ApplicationSaturationTest extends CompilerTest { |""".stripMargin.toIR val inputIR = rawIR - .runPasses(passes, localScope = localScope) + .runPasses(passes, ctx) .asInstanceOf[IR.Expression] val result = ApplicationSaturation(knownFunctions) - .runExpression(inputIR, localScope = localScope) + .runExpression(inputIR, ctx) .asInstanceOf[IR.Expression.Binding] "be tagged as unknown even if their name is known" in { diff --git a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala new file mode 100644 index 0000000000..f927873cfb --- /dev/null +++ b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/analyse/TailCallTest.scala @@ -0,0 +1,335 @@ +package org.enso.compiler.test.pass.analyse + +import org.enso.compiler.InlineContext +import org.enso.compiler.core.IR +import org.enso.compiler.core.IR.Module.Scope.Definition.Method +import org.enso.compiler.exception.CompilerError +import org.enso.compiler.pass.IRPass +import org.enso.compiler.pass.analyse.TailCall.TailPosition +import org.enso.compiler.pass.analyse.{ + AliasAnalysis, + ApplicationSaturation, + TailCall +} +import org.enso.compiler.pass.desugar.{ + GenerateMethodBodies, + LiftSpecialOperators, + OperatorToFunction +} +import org.enso.compiler.test.CompilerTest +import org.enso.interpreter.runtime.scope.LocalScope + +class TailCallTest extends CompilerTest { + + // === Test Setup =========================================================== + + val tailCtx = InlineContext( + localScope = Some(LocalScope.root), + isInTailPosition = Some(true) + ) + + val noTailCtx = InlineContext( + localScope = Some(LocalScope.root), + isInTailPosition = Some(false) + ) + + val precursorPasses: List[IRPass] = List( + GenerateMethodBodies, + LiftSpecialOperators, + OperatorToFunction, + AliasAnalysis, + ApplicationSaturation() + ) + + /** Adds an extension method to preprocess source code as an Enso module. + * + * @param code the source code to preprocess + */ + implicit class PreprocessModule(code: String) { + + /** Preprocesses the provided source code into an [[IR.Module]]. + * + * @return the IR representation of [[code]] + */ + def runTCAModule: IR.Module = { + val preprocessed = code.toIrModule + .runPasses(precursorPasses, tailCtx) + .asInstanceOf[IR.Module] + + TailCall.runModule(preprocessed) + } + } + + /** Adds an extension method to preprocess source code as an Enso expression. + * + * @param code the source code to preprocess + */ + implicit class PreprocessExpression(code: String) { + + /** Preprocesses the provided source code into an [[IR.Expression]]. + * + * @return the IR representation of [[code]] + */ + def runTCAExpression(context: InlineContext): IR.Expression = { + val preprocessed = code.toIrExpression + .getOrElse( + throw new CompilerError("Code was not a valid expression.") + ) + .runPasses(precursorPasses, context) + .asInstanceOf[IR.Expression] + + TailCall.runExpression(preprocessed, context) + } + } + + // === The Tests ============================================================ + + "Tail call analysis on modules" should { + val ir = + """ + |Foo.bar = a b c -> + | d = a + b + | + | case c of + | Baz a b -> a * b * d + | _ -> d + | + |type MyAtom a b c + |""".stripMargin.runTCAModule + + "mark methods as tail" in { + ir.bindings.head + .getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail) + } + + "mark atoms as tail" in { + ir.bindings(1) + .getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail) + } + } + + "Tail call analysis on expressions" should { + val code = + """ + |x y z -> x y z + |""".stripMargin + + "mark the expression as tail if the context requires it" in { + val ir = code.runTCAExpression(tailCtx) + + ir.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail) + } + + "not mark the expression as tail if the context doesn't require it" in { + val ir = code.runTCAExpression(noTailCtx) + + ir.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.NotTail) + } + } + + "Tail call analysis on functions" should { + val ir = + """ + |a b c -> + | d = a + b + | e = a * c + | d + e + |""".stripMargin + .runTCAExpression(tailCtx) + .asInstanceOf[IR.Function.Lambda] + + val fnBody = ir.body.asInstanceOf[IR.Expression.Block] + + "mark the last expression of the function as tail" in { + fnBody.returnValue.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.Tail + ) + } + + "mark the other expressions in the function as not tail" in { + fnBody.expressions.foreach( + expr => + expr.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.NotTail + ) + ) + } + } + + "Tail call analysis on case expressions" should { + "not mark any portion of the branch functions as tail by default" in { + val ir = + """ + |Foo.bar = a -> + | x = case a of + | Lambda fn arg -> fn arg + | + | x + |""".stripMargin.runTCAModule + + val caseExpr = ir.bindings.head + .asInstanceOf[Method] + .body + .asInstanceOf[IR.Function.Lambda] + .body + .asInstanceOf[IR.Expression.Block] + .expressions + .head + .asInstanceOf[IR.Expression.Binding] + .expression + .asInstanceOf[IR.Case.Expr] + + caseExpr.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.NotTail + ) + caseExpr.branches.foreach(branch => { + val branchExpression = + branch.expression.asInstanceOf[IR.Function.Lambda] + + branchExpression.getMetadata[TailPosition] shouldEqual Some( + TailPosition.NotTail + ) + branchExpression.body.getMetadata[TailPosition] shouldEqual Some( + TailPosition.NotTail + ) + }) + } + + "only mark the branches as tail if the expression is in tail position" in { + val ir = + """ + |Foo.bar = a -> + | case a of + | Lambda fn arg -> fn arg + |""".stripMargin.runTCAModule + + val caseExpr = ir.bindings.head + .asInstanceOf[Method] + .body + .asInstanceOf[IR.Function.Lambda] + .body + .asInstanceOf[IR.Expression.Block] + .returnValue + .asInstanceOf[IR.Case.Expr] + + caseExpr.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.Tail + ) + caseExpr.branches.foreach(branch => { + val branchExpression = + branch.expression.asInstanceOf[IR.Function.Lambda] + + branchExpression.getMetadata[TailPosition] shouldEqual Some( + TailPosition.Tail + ) + branchExpression.body.getMetadata[TailPosition] shouldEqual Some( + TailPosition.Tail + ) + }) + } + } + + "Tail call analysis on function calls" should { + val tailCall = + """ + |Foo.bar = + | IO.println "AAAAA" + |""".stripMargin.runTCAModule.bindings.head.asInstanceOf[Method] + val tailCallBody = tailCall.body + .asInstanceOf[IR.Function.Lambda] + .body + .asInstanceOf[IR.Expression.Block] + + val nonTailCall = + """ + |Foo.bar = + | a = b c d + | a + |""".stripMargin.runTCAModule.bindings.head.asInstanceOf[Method] + val nonTailCallBody = nonTailCall.body + .asInstanceOf[IR.Function.Lambda] + .body + .asInstanceOf[IR.Expression.Block] + + "mark the arguments as tail" in { + nonTailCallBody.expressions.head + .asInstanceOf[IR.Expression.Binding] + .expression + .asInstanceOf[IR.Application.Prefix] + .arguments + .foreach( + arg => + arg.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.Tail + ) + ) + + tailCallBody.returnValue + .asInstanceOf[IR.Application.Prefix] + .arguments + .foreach( + arg => + arg.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.Tail + ) + ) + } + + "mark the function call as tail if it is in a tail position" in { + tailCallBody.returnValue.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.Tail + ) + } + + "mark the function call as not tail if it is in a tail position" in { + nonTailCallBody.expressions.head + .asInstanceOf[IR.Expression.Binding] + .expression + .getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.NotTail) + } + } + + "Tail call analysis on blocks" should { + val ir = + """ + |Foo.bar = a b c -> + | d = a + b + | mul = a b -> a * b + | mul c d + |""".stripMargin.runTCAModule.bindings.head.asInstanceOf[Method] + + val block = ir.body + .asInstanceOf[IR.Function.Lambda] + .body + .asInstanceOf[IR.Expression.Block] + + "mark the bodies of bound functions as tail properly" in { + block + .expressions(1) + .asInstanceOf[IR.Expression.Binding] + .expression + .asInstanceOf[IR.Function.Lambda] + .body + .getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail) + } + + "mark the block expressions as not tail" in { + block.expressions.foreach( + expr => + expr.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.NotTail + ) + ) + } + + "mark the final expression of the block as tail" in { + block.returnValue.getMetadata[TailCall.Metadata] shouldEqual Some( + TailPosition.Tail + ) + } + + "mark the block as tail if it is in a tail position" in { + block.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail) + } + } +} diff --git a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/GenerateMethodBodiesTest.scala b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/GenerateMethodBodiesTest.scala new file mode 100644 index 0000000000..157a7abce2 --- /dev/null +++ b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/GenerateMethodBodiesTest.scala @@ -0,0 +1,77 @@ +package org.enso.compiler.test.pass.desugar + +import org.enso.compiler.core.IR +import org.enso.compiler.core.IR.Module.Scope.Definition.Method +import org.enso.compiler.pass.desugar.GenerateMethodBodies +import org.enso.compiler.test.CompilerTest + +class GenerateMethodBodiesTest extends CompilerTest { + + // === The Tests ============================================================ + + "Methods with functions as bodies" should { + val ir = + """ + |Unit.method = a b c -> a + b + c + |""".stripMargin.toIrModule + val irMethod = ir.bindings.head.asInstanceOf[Method] + + val irResult = GenerateMethodBodies.runModule(ir) + val irResultMethod = irResult.bindings.head.asInstanceOf[Method] + + "have the `this` argument prepended to the argument list" in { + val resultArgs = + irResultMethod.body.asInstanceOf[IR.Function.Lambda].arguments + + resultArgs.head + .asInstanceOf[IR.DefinitionArgument.Specified] + .name shouldEqual IR.Name.This(None) + + resultArgs.tail shouldEqual irMethod.body + .asInstanceOf[IR.Function.Lambda] + .arguments + } + + "have the body of the function remain untouched" in { + val inputBody = irMethod.body.asInstanceOf[IR.Function.Lambda].body + val resultBody = irResultMethod.body.asInstanceOf[IR.Function.Lambda].body + + inputBody shouldEqual resultBody + } + } + + "Methods with expressions as bodies" should { + val ir = + """ + |Unit.method = 1 + |""".stripMargin.toIrModule + val irMethod = ir.bindings.head.asInstanceOf[Method] + + val irResult = GenerateMethodBodies.runModule(ir) + val irResultMethod = irResult.bindings.head.asInstanceOf[Method] + + "have the expression converted into a function" in { + irResultMethod.body shouldBe an[IR.Function.Lambda] + } + + "have the resultant function take the `this` argument" in { + val bodyArgs = + irResultMethod.body.asInstanceOf[IR.Function.Lambda].arguments + + bodyArgs.length shouldEqual 1 + bodyArgs.head + .asInstanceOf[IR.DefinitionArgument.Specified] + .name shouldEqual IR.Name.This(None) + } + + "have the body of the function be equivalent to the expression" in { + irResultMethod.body + .asInstanceOf[IR.Function.Lambda] + .body shouldEqual irMethod.body + } + + "have the body function's location equivalent to the original body" in { + irMethod.body.location shouldEqual irResultMethod.body.location + } + } +} diff --git a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/LiftSpecialOperatorsTest.scala b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/LiftSpecialOperatorsTest.scala index de947a33f9..c77181adcd 100644 --- a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/LiftSpecialOperatorsTest.scala +++ b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/LiftSpecialOperatorsTest.scala @@ -1,5 +1,6 @@ package org.enso.compiler.test.pass.desugar +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR import org.enso.compiler.pass.desugar.LiftSpecialOperators import org.enso.compiler.test.CompilerTest @@ -9,6 +10,8 @@ class LiftSpecialOperatorsTest extends CompilerTest { // === Utilities ============================================================ + val ctx = new InlineContext + /** Tests whether a given operator is lifted correctly into the corresponding * special construct. * @@ -42,7 +45,7 @@ class LiftSpecialOperatorsTest extends CompilerTest { "be lifted by the pass in an inline context" in { LiftSpecialOperators - .runExpression(expressionIR) shouldEqual outputExpressionIR + .runExpression(expressionIR, ctx) shouldEqual outputExpressionIR } "be lifted by the pass in a module context" in { @@ -62,7 +65,7 @@ class LiftSpecialOperatorsTest extends CompilerTest { ) LiftSpecialOperators - .runExpression(recursiveIR) shouldEqual recursiveIROutput + .runExpression(recursiveIR, ctx) shouldEqual recursiveIROutput } } diff --git a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/OperatorToFunctionTest.scala b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/OperatorToFunctionTest.scala index ea2db625a4..f2ba4a76bb 100644 --- a/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/OperatorToFunctionTest.scala +++ b/engine/runtime/src/test/scala/org/enso/compiler/test/pass/desugar/OperatorToFunctionTest.scala @@ -1,5 +1,6 @@ package org.enso.compiler.test.pass.desugar +import org.enso.compiler.InlineContext import org.enso.compiler.core.IR import org.enso.compiler.pass.desugar.OperatorToFunction import org.enso.compiler.test.CompilerTest @@ -9,6 +10,8 @@ class OperatorToFunctionTest extends CompilerTest { // === Utilities ============================================================ + val ctx = new InlineContext + /** Generates an operator and its corresponding function. * * @param name @@ -47,7 +50,7 @@ class OperatorToFunctionTest extends CompilerTest { val (operator, operatorFn) = genOprAndFn(opName, left, right) "be translated to functions" in { - OperatorToFunction.runExpression(operator) shouldEqual operatorFn + OperatorToFunction.runExpression(operator, ctx) shouldEqual operatorFn } "be translated in module contexts" in { @@ -70,7 +73,7 @@ class OperatorToFunctionTest extends CompilerTest { None ) - OperatorToFunction.runExpression(recursiveIR) shouldEqual recursiveIRResult + OperatorToFunction.runExpression(recursiveIR, ctx) shouldEqual recursiveIRResult } } } diff --git a/engine/runtime/src/test/scala/org/enso/interpreter/test/semantic/CurryingTest.scala b/engine/runtime/src/test/scala/org/enso/interpreter/test/semantic/CurryingTest.scala index e6276a506a..9b13d6e4b1 100644 --- a/engine/runtime/src/test/scala/org/enso/interpreter/test/semantic/CurryingTest.scala +++ b/engine/runtime/src/test/scala/org/enso/interpreter/test/semantic/CurryingTest.scala @@ -57,4 +57,20 @@ class CurryingTest extends InterpreterTest { eval(code) shouldEqual 32 } + + "Method call syntax" should "allow default arguments to be suspended" in { + val code = + """ + |Unit.fn = w x (y = 10) (z = 20) -> w + x + y + z + | + |main = + | fn1 = Unit.fn ... + | fn2 = fn1 1 2 ... + | fn3 = fn2 3 ... + | + | fn3.call + |""".stripMargin + + eval(code) shouldEqual 26 + } }