Extract remaining analyses from codegen (#616)

This commit is contained in:
Ara Adkins 2020-03-24 10:28:03 +00:00 committed by GitHub
parent 6f8d3b73bb
commit 2c1d967dd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1423 additions and 330 deletions

View File

@ -1,14 +1,14 @@
import java.io.File import java.io.File
import sbt.Keys.scalacOptions
import scala.sys.process._
import org.enso.build.BenchTasks._ import org.enso.build.BenchTasks._
import org.enso.build.WithDebugCommand import org.enso.build.WithDebugCommand
import sbt.Keys.scalacOptions
import sbt.addCompilerPlugin import sbt.addCompilerPlugin
import sbtassembly.AssemblyPlugin.defaultUniversalScript import sbtassembly.AssemblyPlugin.defaultUniversalScript
import sbtcrossproject.CrossPlugin.autoImport.{crossProject, CrossType} import sbtcrossproject.CrossPlugin.autoImport.{crossProject, CrossType}
import scala.sys.process._
////////////////////////////// //////////////////////////////
//// Global Configuration //// //// Global Configuration ////
////////////////////////////// //////////////////////////////
@ -446,16 +446,16 @@ lazy val polyglot_api = project
lazy val language_server = (project in file("engine/language-server")) lazy val language_server = (project in file("engine/language-server"))
.settings( .settings(
libraryDependencies ++= akka ++ circe ++ Seq( libraryDependencies ++= akka ++ circe ++ Seq(
"ch.qos.logback" % "logback-classic" % "1.2.3", "ch.qos.logback" % "logback-classic" % "1.2.3",
"io.circe" %% "circe-generic-extras" % "0.12.2", "io.circe" %% "circe-generic-extras" % "0.12.2",
"io.circe" %% "circe-literal" % circeVersion, "io.circe" %% "circe-literal" % circeVersion,
"org.bouncycastle" % "bcpkix-jdk15on" % "1.64", "org.bouncycastle" % "bcpkix-jdk15on" % "1.64",
"dev.zio" %% "zio" % "1.0.0-RC18-2", "dev.zio" %% "zio" % "1.0.0-RC18-2",
akkaTestkit % Test, akkaTestkit % Test,
"commons-io" % "commons-io" % "2.6", "commons-io" % "commons-io" % "2.6",
"org.scalatest" %% "scalatest" % "3.2.0-M2" % Test, "org.scalatest" %% "scalatest" % "3.2.0-M2" % Test,
"org.scalacheck" %% "scalacheck" % "1.14.0" % Test, "org.scalacheck" %% "scalacheck" % "1.14.0" % Test,
"org.graalvm.sdk" % "polyglot-tck" % graalVersion % "provided" "org.graalvm.sdk" % "polyglot-tck" % graalVersion % "provided"
), ),
testOptions in Test += Tests testOptions in Test += Tests
.Argument(TestFrameworks.ScalaCheck, "-minSuccessfulTests", "1000") .Argument(TestFrameworks.ScalaCheck, "-minSuccessfulTests", "1000")

View File

@ -641,7 +641,8 @@ binds the function name. This means that:
### Methods ### Methods
Enso makes a distinction between functions and methods. In Enso, a method is a Enso makes a distinction between functions and methods. In Enso, a method is a
function where the first argument (known as the `this` argument) is associated function where the first argument (known as the `this` argument) is associated
with a given atom. with a given atom. Methods are dispatched dynamically based on the type of the
`this` argument, while functions are not.
Methods can be defined in Enso in two ways: Methods can be defined in Enso in two ways:
@ -669,6 +670,18 @@ Methods can be defined in Enso in two ways:
... ...
``` ```
3. **As a Function with an Explicit `this` Argument:** A function defined with
the type of the `this` argument specified to be a type.
```ruby
floor (this : Number) = case this of
Integer -> ...
```
If the user does not explicitly specify the `this` argument by name when
defining a method (e.g. they use the `Type.name` syntax), it is implicitly added
to the start of the argument list.
#### This vs. Self #### This vs. Self
Though it varies greatly between programming languages, we have chosen `this` to Though it varies greatly between programming languages, we have chosen `this` to
be the name of the 'current type' rather than `self`. This is a purely aesthetic be the name of the 'current type' rather than `self`. This is a purely aesthetic
@ -687,17 +700,10 @@ when calling it. To that end, Enso supports what is known as Uniform Call Syntax
- This is a needless constraint as both notations have their advantages. - This is a needless constraint as both notations have their advantages.
- Enso has two notations, but one unified semantics. - Enso has two notations, but one unified semantics.
The rules for the uniform syntax call translation in Enso are as follows. The rules for the uniform syntax call translation in Enso are as follows:
1. For an expression `t.fn`, this is equivalent to `fn (this = t)`. 1. For an expression `t.fn <args>`, this is equivalent to `fn t <args>`.
2. The `this` argument may occur at any position in the function. 2. For an expression `fn t <args>`, this is equivalent to `t.fn <args>`.
> The actionables for this section are:
>
> - Clarify exactly how this should work, and which argument should be
> translated.
> - We do not _currently_ implement the above-listed transformation, so we need
> to solidify these rules.
### Code Blocks ### Code Blocks
Top-level blocks in the language are evaluated immediately. This means that the Top-level blocks in the language are evaluated immediately. This means that the

View File

@ -70,15 +70,4 @@ public class ClosureRootNode extends EnsoRootNode {
state = FrameUtil.getObjectSafe(frame, this.getStateFrameSlot()); state = FrameUtil.getObjectSafe(frame, this.getStateFrameSlot());
return new Stateful(state, result); return new Stateful(state, result);
} }
/**
* Sets whether the node is tail-recursive.
*
* @param isTail whether or not the node is tail-recursive.
*/
@Override
public void setTail(boolean isTail) {
CompilerDirectives.transferToInterpreterAndInvalidate();
body.setTail(isTail);
}
} }

View File

@ -83,13 +83,6 @@ public abstract class EnsoRootNode extends RootNode {
return this.name; return this.name;
} }
/**
* Sets whether the node is tail-recursive.
*
* @param isTail whether or not the node is tail-recursive
*/
public abstract void setTail(boolean isTail);
/** /**
* Gets the frame slot containing the program state. * Gets the frame slot containing the program state.
* *

View File

@ -30,16 +30,6 @@ public class BlockNode extends ExpressionNode {
return new BlockNode(expressions, returnExpr); return new BlockNode(expressions, returnExpr);
} }
/**
* Sets whether or not the function is tail-recursive.
*
* @param isTail whether or not the function is tail-recursive.
*/
@Override
public void setTail(boolean isTail) {
returnExpr.setTail(isTail);
}
/** /**
* Executes the body of the function. * Executes the body of the function.
* *

View File

@ -37,16 +37,6 @@ public class CreateFunctionNode extends ExpressionNode {
return new CreateFunctionNode(callTarget, args); return new CreateFunctionNode(callTarget, args);
} }
/**
* Sets the tail-recursiveness of the function.
*
* @param isTail whether or not the function is tail-recursive
*/
@Override
public void setTail(boolean isTail) {
((ClosureRootNode) callTarget.getRootNode()).setTail(isTail);
}
/** /**
* Generates the provided function definition in the given stack {@code frame}. * Generates the provided function definition in the given stack {@code frame}.
* *

View File

@ -38,16 +38,6 @@ public class ConstructorCaseNode extends CaseNode {
return new ConstructorCaseNode(matcher, branch); return new ConstructorCaseNode(matcher, branch);
} }
/**
* Sets whether or not the case expression is tail recursive.
*
* @param isTail whether or not the case expression is tail-recursive
*/
@Override
public void setTail(boolean isTail) {
branch.setTail(isTail);
}
/** /**
* Handles the atom scrutinee case. * Handles the atom scrutinee case.
* *

View File

@ -43,20 +43,6 @@ public abstract class MatchNode extends ExpressionNode {
return MatchNodeGen.create(cases, fallback, scrutinee); return MatchNodeGen.create(cases, fallback, scrutinee);
} }
/**
* Sets whether or not the pattern match is tail-recursive.
*
* @param isTail whether or not the expression is tail-recursive
*/
@Override
@ExplodeLoop
public void setTail(boolean isTail) {
for (CaseNode caseNode : cases) {
caseNode.setTail(isTail);
}
fallback.setTail(isTail);
}
@Specialization @Specialization
Object doError(VirtualFrame frame, RuntimeError error) { Object doError(VirtualFrame frame, RuntimeError error) {
return error; return error;

View File

@ -5,6 +5,7 @@ import com.oracle.truffle.api.Truffle;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.nodes.NodeInfo; import com.oracle.truffle.api.nodes.NodeInfo;
import org.enso.compiler.InlineContext;
import org.enso.interpreter.Constants; import org.enso.interpreter.Constants;
import org.enso.interpreter.Language; import org.enso.interpreter.Language;
import org.enso.interpreter.node.BaseNode; import org.enso.interpreter.node.BaseNode;
@ -16,6 +17,7 @@ import org.enso.interpreter.runtime.callable.argument.Thunk;
import org.enso.interpreter.runtime.scope.LocalScope; import org.enso.interpreter.runtime.scope.LocalScope;
import org.enso.interpreter.runtime.scope.ModuleScope; import org.enso.interpreter.runtime.scope.ModuleScope;
import org.enso.interpreter.runtime.state.Stateful; import org.enso.interpreter.runtime.state.Stateful;
import scala.Some;
/** Node running Enso expressions passed to it as strings. */ /** Node running Enso expressions passed to it as strings. */
@NodeInfo(shortName = "Eval", description = "Evaluates code passed to it as string") @NodeInfo(shortName = "Eval", description = "Evaluates code passed to it as string")
@ -57,11 +59,12 @@ public abstract class EvalNode extends BaseNode {
RootCallTarget parseExpression(LocalScope scope, ModuleScope moduleScope, String expression) { RootCallTarget parseExpression(LocalScope scope, ModuleScope moduleScope, String expression) {
LocalScope localScope = scope.createChild(); LocalScope localScope = scope.createChild();
Language language = lookupLanguageReference(Language.class).get(); Language language = lookupLanguageReference(Language.class).get();
InlineContext inlineContext = InlineContext.fromJava(localScope, moduleScope, isTail());
ExpressionNode expr = ExpressionNode expr =
lookupContextReference(Language.class) lookupContextReference(Language.class)
.get() .get()
.compiler() .compiler()
.runInline(expression, localScope, moduleScope) .runInline(expression, inlineContext)
.getOrElse(null); .getOrElse(null);
if (expr == null) { if (expr == null) {
throw new RuntimeException("Invalid code passed to `eval`: " + expression); throw new RuntimeException("Invalid code passed to `eval`: " + expression);
@ -78,7 +81,6 @@ public abstract class EvalNode extends BaseNode {
expr, expr,
null, null,
"<dynamic_eval>"); "<dynamic_eval>");
framedNode.setTail(isTail());
return Truffle.getRuntime().createCallTarget(framedNode); return Truffle.getRuntime().createCallTarget(framedNode);
} }

View File

@ -7,9 +7,18 @@ import com.oracle.truffle.api.source.Source
import org.enso.compiler.codegen.{AstToIR, IRToTruffle} import org.enso.compiler.codegen.{AstToIR, IRToTruffle}
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.core.IR.{Expression, Module} import org.enso.compiler.core.IR.{Expression, Module}
import org.enso.compiler.exception.CompilerError
import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.IRPass
import org.enso.compiler.pass.analyse.{AliasAnalysis, ApplicationSaturation} import org.enso.compiler.pass.analyse.{
import org.enso.compiler.pass.desugar.{LiftSpecialOperators, OperatorToFunction} AliasAnalysis,
ApplicationSaturation,
TailCall
}
import org.enso.compiler.pass.desugar.{
GenerateMethodBodies,
LiftSpecialOperators,
OperatorToFunction
}
import org.enso.interpreter.Language import org.enso.interpreter.Language
import org.enso.interpreter.node.{ExpressionNode => RuntimeExpression} import org.enso.interpreter.node.{ExpressionNode => RuntimeExpression}
import org.enso.interpreter.runtime.Context import org.enso.interpreter.runtime.Context
@ -40,10 +49,12 @@ class Compiler(
* they nevertheless exist. * they nevertheless exist.
*/ */
val compilerPhaseOrdering: List[IRPass] = List( val compilerPhaseOrdering: List[IRPass] = List(
GenerateMethodBodies,
LiftSpecialOperators, LiftSpecialOperators,
OperatorToFunction, OperatorToFunction,
AliasAnalysis, AliasAnalysis,
ApplicationSaturation() ApplicationSaturation(),
TailCall
) )
/** /**
@ -105,14 +116,13 @@ class Compiler(
* Processes the source in the context of given local and module scopes. * Processes the source in the context of given local and module scopes.
* *
* @param srcString string representing the expression to process * @param srcString string representing the expression to process
* @param localScope local scope to process the source in * @param inlineContext a context object that contains the information needed
* @param moduleScope module scope to process the source in * for inline evaluation
* @return an expression node representing the parsed and analyzed source * @return an expression node representing the parsed and analyzed source
*/ */
def runInline( def runInline(
srcString: String, srcString: String,
localScope: LocalScope, inlineContext: InlineContext
moduleScope: ModuleScope
): Option[RuntimeExpression] = { ): Option[RuntimeExpression] = {
val source = Source val source = Source
.newBuilder( .newBuilder(
@ -124,17 +134,8 @@ class Compiler(
val parsed: AST = parse(source) val parsed: AST = parse(source)
generateIRInline(parsed).flatMap { ir => generateIRInline(parsed).flatMap { ir =>
Some({ val compilerOutput = runCompilerPhasesInline(ir, inlineContext)
val compilerOutput = Some(truffleCodegenInline(compilerOutput, source, inlineContext))
runCompilerPhasesInline(ir, localScope, moduleScope)
truffleCodegenInline(
compilerOutput,
source,
moduleScope,
localScope
)
})
} }
} }
@ -147,7 +148,7 @@ class Compiler(
* @param qualifiedName the qualified name of the module * @param qualifiedName the qualified name of the module
* @return the scope containing all definitions in the requested module * @return the scope containing all definitions in the requested module
*/ */
def requestProcess(qualifiedName: String): ModuleScope = { def processImport(qualifiedName: String): ModuleScope = {
val module = topScope.getModule(qualifiedName) val module = topScope.getModule(qualifiedName)
if (module.isPresent) { if (module.isPresent) {
module.get().getScope(context) module.get().getScope(context)
@ -205,16 +206,17 @@ class Compiler(
/** Runs the various compiler passes in an inline context. /** Runs the various compiler passes in an inline context.
* *
* @param ir the compiler intermediate representation to transform * @param ir the compiler intermediate representation to transform
* @param inlineContext a context object that contains the information needed
* for inline evaluation
* @return the output result of the * @return the output result of the
*/ */
def runCompilerPhasesInline( def runCompilerPhasesInline(
ir: IR.Expression, ir: IR.Expression,
localScope: LocalScope, inlineContext: InlineContext
moduleScope: ModuleScope
): IR.Expression = { ): IR.Expression = {
compilerPhaseOrdering.foldLeft(ir)( compilerPhaseOrdering.foldLeft(ir)(
(intermediateIR, pass) => (intermediateIR, pass) =>
pass.runExpression(intermediateIR, Some(localScope), Some(moduleScope)) pass.runExpression(intermediateIR, inlineContext)
) )
} }
@ -236,18 +238,27 @@ class Compiler(
* *
* @param ir the prorgam to translate * @param ir the prorgam to translate
* @param source the source code of the program represented by `ir` * @param source the source code of the program represented by `ir`
* @param moduleScope the module scope in which the code is to be generated * @param inlineContext a context object that contains the information needed
* @param localScope the local scope in which the inline code is to be * for inline evaluation
* located
* @return the runtime representation of the program represented by `ir` * @return the runtime representation of the program represented by `ir`
*/ */
def truffleCodegenInline( def truffleCodegenInline(
ir: IR.Expression, ir: IR.Expression,
source: Source, source: Source,
moduleScope: ModuleScope, inlineContext: InlineContext
localScope: LocalScope
): RuntimeExpression = { ): RuntimeExpression = {
new IRToTruffle(this.language, source, moduleScope) new IRToTruffle(
.runInline(ir, localScope, "<inline_source>") this.language,
source,
inlineContext.moduleScope.getOrElse(
throw new CompilerError(
"Cannot perform inline codegen with a missing module scope."
)
)
).runInline(
ir,
inlineContext.localScope.getOrElse(LocalScope.root),
"<inline_source>"
)
} }
} }

View File

@ -0,0 +1,41 @@
package org.enso.compiler
import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope}
/** A type containing the information about the execution context for an inline
* expression.
*
* @param localScope the local scope in which the expression is being executed
* @param moduleScope the module scope in which the expression is being
* executed
* @param isInTailPosition whether or not the inline expression occurs in tail
* position ([[None]] indicates no information)
*/
case class InlineContext(
localScope: Option[LocalScope] = None,
moduleScope: Option[ModuleScope] = None,
isInTailPosition: Option[Boolean] = None
)
object InlineContext {
/** Implements a null-safe conversion from nullable objects to Scala's option
* internally.
*
* @param localScope the local scope instance
* @param moduleScope the module scope instance
* @param isInTailPosition whether or not the inline expression occurs in a
* tail position
* @return the [[InlineContext]] instance corresponding to the arguments
*/
def fromJava(
localScope: LocalScope,
moduleScope: ModuleScope,
isInTailPosition: Boolean
): InlineContext = {
InlineContext(
Option(localScope),
Option(moduleScope),
Option(isInTailPosition)
)
}
}

View File

@ -140,17 +140,11 @@ object AstToIR {
(Constants.Names.CURRENT_MODULE, None) (Constants.Names.CURRENT_MODULE, None)
} }
val nameStr = name match { case AST.Ident.Var.any(name) => name } val nameStr = name match { case AST.Ident.Var.any(name) => name }
val defExpression = translateExpression(definition)
val defExpr: Function.Lambda = defExpression match {
case fun: Function.Lambda => fun
case expr =>
Function.Lambda(List(), expr, expr.location)
}
Module.Scope.Definition.Method( Module.Scope.Definition.Method(
Name.Literal(path, pathLoc), Name.Literal(path, pathLoc),
Name.Literal(nameStr.name, nameStr.location), Name.Literal(nameStr.name, nameStr.location),
defExpr, translateExpression(definition),
inputAST.location inputAST.location
) )
case _ => case _ =>
@ -180,10 +174,14 @@ object AstToIR {
case AstView.Assignment(name, expr) => case AstView.Assignment(name, expr) =>
translateBinding(inputAST.location, name, expr) translateBinding(inputAST.location, name, expr)
case AstView.MethodCall(target, name, args) => case AstView.MethodCall(target, name, args) =>
val (validArguments, hasDefaultsSuspended) =
calculateDefaultsSuspension(args)
// Note [Uniform Call Syntax Translation]
Application.Prefix( Application.Prefix(
translateExpression(name), translateExpression(name),
(target :: args).map(translateCallArgument), (target :: validArguments).map(translateCallArgument),
false, hasDefaultsSuspended = hasDefaultsSuspended,
inputAST.location inputAST.location
) )
case AstView.CaseExpression(scrutinee, branches) => case AstView.CaseExpression(scrutinee, branches) =>
@ -226,6 +224,16 @@ object AstToIR {
} }
} }
/* Note [Uniform Call Syntax Translation]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* As the uniform call syntax must work for both methods and functions, the
* conversion can't take advantage of any by-name application semantics at the
* current time.
*
* This means that it is a purely _positional_ conversion on the first
* argument and cannot be performed any other way.
*/
/** Translates a program literal from its [[AST]] representation into /** Translates a program literal from its [[AST]] representation into
* [[Core]]. * [[Core]].
* *
@ -289,7 +297,7 @@ object AstToIR {
DefinitionArgument.Specified( DefinitionArgument.Specified(
Name.Literal(name.name, name.location), Name.Literal(name.name, name.location),
Some(translateExpression(value)), Some(translateExpression(value)),
true, suspended = true,
arg.location arg.location
) )
case AstView.LazyArgument(arg) => case AstView.LazyArgument(arg) =>
@ -331,6 +339,28 @@ object AstToIR {
CallArgument.Specified(None, translateExpression(arg), arg.location) CallArgument.Specified(None, translateExpression(arg), arg.location)
} }
/** Calculates whether a set of arguments has its defaults suspended, and
* processes the argument list to remove that operator.
*
* @param args the list of arguments
* @return the list of arguments with the suspension operator removed, and
* whether or not the defaults are suspended
*/
def calculateDefaultsSuspension(args: List[AST]): (List[AST], Boolean) = {
val validArguments = args.filter {
case AstView.SuspendDefaultsOperator(_) => false
case _ => true
}
val suspendPositions = args.zipWithIndex.collect {
case (AstView.SuspendDefaultsOperator(_), ix) => ix
}
val hasDefaultsSuspended = suspendPositions.contains(args.length - 1)
(validArguments, hasDefaultsSuspended)
}
/** Translates an arbitrary expression that takes the form of a syntactic /** Translates an arbitrary expression that takes the form of a syntactic
* application from its [[AST]] representation into [[Core]]. * application from its [[AST]] representation into [[Core]].
* *
@ -348,16 +378,8 @@ object AstToIR {
case AstView.ForcedTerm(term) => case AstView.ForcedTerm(term) =>
Application.Force(translateExpression(term), callable.location) Application.Force(translateExpression(term), callable.location)
case AstView.Application(name, args) => case AstView.Application(name, args) =>
val validArguments = args.filter { val (validArguments, hasDefaultsSuspended) =
case AstView.SuspendDefaultsOperator(_) => false calculateDefaultsSuspension(args)
case _ => true
}
val suspendPositions = args.zipWithIndex.collect {
case (AstView.SuspendDefaultsOperator(_), ix) => ix
}
val hasDefaultsSuspended = suspendPositions.contains(args.length - 1)
Application.Prefix( Application.Prefix(
translateExpression(name), translateExpression(name),

View File

@ -355,6 +355,11 @@ object AstView {
def unapply(ast: AST): Option[(AST, AST.Ident, List[AST])] = ast match { def unapply(ast: AST): Option[(AST, AST.Ident, List[AST])] = ast match {
case OperatorDot(target, Application(ConsOrVar(ident), args)) => case OperatorDot(target, Application(ConsOrVar(ident), args)) =>
Some((target, ident, args)) Some((target, ident, args))
case AST.App.Section.Left(
MethodCall(target, ident, List()),
susp @ SuspendDefaultsOperator(_)
) =>
Some((target, ident, List(susp)))
case OperatorDot(target, ConsOrVar(ident)) => case OperatorDot(target, ConsOrVar(ident)) =>
Some((target, ident, List())) Some((target, ident, List()))
case _ => None case _ => None

View File

@ -4,7 +4,11 @@ import com.oracle.truffle.api.Truffle
import com.oracle.truffle.api.source.{Source, SourceSection} import com.oracle.truffle.api.source.{Source, SourceSection}
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.exception.{CompilerError, UnhandledEntity} import org.enso.compiler.exception.{CompilerError, UnhandledEntity}
import org.enso.compiler.pass.analyse.{AliasAnalysis, ApplicationSaturation} import org.enso.compiler.pass.analyse.{
AliasAnalysis,
ApplicationSaturation,
TailCall
}
import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Scope => AliasScope} import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Scope => AliasScope}
import org.enso.compiler.pass.analyse.AliasAnalysis.{Graph => AliasGraph} import org.enso.compiler.pass.analyse.AliasAnalysis.{Graph => AliasGraph}
import org.enso.interpreter.node.callable.argument.ReadArgumentNode import org.enso.interpreter.node.callable.argument.ReadArgumentNode
@ -127,7 +131,7 @@ class IRToTruffle(
// Register the imports in scope // Register the imports in scope
imports.foreach( imports.foreach(
i => this.moduleScope.addImport(context.compiler.requestProcess(i.name)) i => this.moduleScope.addImport(context.compiler.processImport(i.name))
) )
// Register the atoms and their constructors in scope // Register the atoms and their constructors in scope
@ -138,7 +142,7 @@ class IRToTruffle(
atomConstructors atomConstructors
.zip(atomDefs) .zip(atomDefs)
.foreach { .foreach {
case (atomCons, atomDefn) => { case (atomCons, atomDefn) =>
val scopeInfo = atomDefn val scopeInfo = atomDefn
.getMetadata[AliasAnalysis.Info.Scope.Root] .getMetadata[AliasAnalysis.Info.Scope.Root]
.getOrElse( .getOrElse(
@ -157,7 +161,6 @@ class IRToTruffle(
} }
atomCons.initializeFields(argDefs: _*) atomCons.initializeFields(argDefs: _*)
}
} }
// Register the method definitions in scope // Register the method definitions in scope
@ -166,7 +169,7 @@ class IRToTruffle(
val scopeInfo = methodDef val scopeInfo = methodDef
.getMetadata[AliasAnalysis.Info.Scope.Root] .getMetadata[AliasAnalysis.Info.Scope.Root]
.getOrElse( .getOrElse(
throw new CompilerError(("Missing scope information for method.")) throw new CompilerError("Missing scope information for method.")
) )
val typeName = val typeName =
@ -182,6 +185,12 @@ class IRToTruffle(
scopeInfo.graph.rootScope scopeInfo.graph.rootScope
) )
val methodFunIsTail = methodDef.body
.getMetadata[TailCall.Metadata]
.getOrElse(
throw new CompilerError("Method body missing tail call info.")
)
val funNode = methodDef.body match { val funNode = methodDef.body match {
case fn: IR.Function => case fn: IR.Function =>
expressionProcessor.processFunctionBody( expressionProcessor.processFunctionBody(
@ -190,14 +199,12 @@ class IRToTruffle(
fn.location fn.location
) )
case _ => case _ =>
expressionProcessor.processFunctionBody( throw new CompilerError(
List(), "Method bodies must be functions at the point of codegen."
methodDef.body,
methodDef.body.location
) )
} }
funNode.markTail() funNode.setTail(methodFunIsTail)
val function = new RuntimeFunction( val function = new RuntimeFunction(
funNode.getCallTarget, funNode.getCallTarget,
@ -266,7 +273,7 @@ class IRToTruffle(
val scopeName: String val scopeName: String
) { ) {
private var currentVarName = "anonymous"; private var currentVarName = "anonymous"
// === Construction ======================================================= // === Construction =======================================================
@ -299,25 +306,36 @@ class IRToTruffle(
* @param ir the IR to generate code for * @param ir the IR to generate code for
* @return a truffle expression that represents the same program as `ir` * @return a truffle expression that represents the same program as `ir`
*/ */
def run(ir: IR): RuntimeExpression = ir match { def run(ir: IR): RuntimeExpression = {
case block: IR.Expression.Block => processBlock(block) val tailMeta = ir
case literal: IR.Literal => processLiteral(literal) .getMetadata[TailCall.Metadata]
case app: IR.Application => processApplication(app) .getOrElse(
case name: IR.Name => processName(name) throw new CompilerError(s"Missing tail call metadata for $ir")
case function: IR.Function => processFunction(function)
case binding: IR.Expression.Binding => processBinding(binding)
case caseExpr: IR.Case => processCase(caseExpr)
case comment: IR.Comment => processComment(comment)
case err: IR.Error =>
throw new CompilerError(
s"No errors should remain by the point of truffle codegen, but " +
s"found $err."
) )
case IR.Foreign.Definition(_, _, _, _) =>
throw new CompilerError( val runtimeExpression = ir match {
s"Foreign expressions not yet implemented: $ir." case block: IR.Expression.Block => processBlock(block)
) case literal: IR.Literal => processLiteral(literal)
case _ => throw new UnhandledEntity(ir, "run") case app: IR.Application => processApplication(app)
case name: IR.Name => processName(name)
case function: IR.Function => processFunction(function)
case binding: IR.Expression.Binding => processBinding(binding)
case caseExpr: IR.Case => processCase(caseExpr)
case comment: IR.Comment => processComment(comment)
case err: IR.Error =>
throw new CompilerError(
s"No errors should remain by the point of truffle codegen, but " +
s"found $err."
)
case IR.Foreign.Definition(_, _, _, _) =>
throw new CompilerError(
s"Foreign expressions not yet implemented: $ir."
)
case _ => throw new UnhandledEntity(ir, "run")
}
runtimeExpression.setTail(tailMeta)
runtimeExpression
} }
/** Executes the expression processor on a piece of code that has been /** Executes the expression processor on a piece of code that has been
@ -328,7 +346,6 @@ class IRToTruffle(
*/ */
def runInline(ir: IR.Expression): RuntimeExpression = { def runInline(ir: IR.Expression): RuntimeExpression = {
val expression = run(ir) val expression = run(ir)
expression.markNotTail()
expression expression
} }
@ -393,9 +410,21 @@ class IRToTruffle(
val cases = branches val cases = branches
.map( .map(
branch => branch => {
ConstructorCaseNode val caseIsTail = branch
.getMetadata[TailCall.Metadata]
.getOrElse(
throw new CompilerError(
"Case branch missing tail position information."
)
)
val caseNode = ConstructorCaseNode
.build(this.run(branch.pattern), this.run(branch.expression)) .build(this.run(branch.pattern), this.run(branch.expression))
caseNode.setTail(caseIsTail)
caseNode
}
) )
.toArray[CaseNode] .toArray[CaseNode]
@ -466,7 +495,6 @@ class IRToTruffle(
function.body, function.body,
function.location function.location
) )
fn.setTail(function.canBeTCO)
fn fn
} }
@ -485,7 +513,7 @@ class IRToTruffle(
throw new CompilerError("No occurence on variable usage.") throw new CompilerError("No occurence on variable usage.")
) )
val slot = scope.getFramePointer(useInfo.id) val slot = scope.getFramePointer(useInfo.id)
val atomCons = moduleScope.getConstructor(nameStr).toScala val atomCons = moduleScope.getConstructor(nameStr).toScala
if (nameStr == Constants.Names.CURRENT_MODULE) { if (nameStr == Constants.Names.CURRENT_MODULE) {
ConstructorNode.build(moduleScope.getAssociatedType) ConstructorNode.build(moduleScope.getAssociatedType)
@ -567,6 +595,14 @@ class IRToTruffle(
} else seenArgNames.add(argName) } else seenArgNames.add(argName)
} }
val bodyIsTail = body
.getMetadata[TailCall.Metadata]
.getOrElse(
throw new CompilerError(
"Function body missing tail call information."
)
)
val bodyExpr = this.run(body) val bodyExpr = this.run(body)
val fnBodyNode = BlockNode.build(argExpressions.toArray, bodyExpr) val fnBodyNode = BlockNode.build(argExpressions.toArray, bodyExpr)
@ -582,6 +618,8 @@ class IRToTruffle(
val expr = CreateFunctionNode.build(callTarget, argDefinitions) val expr = CreateFunctionNode.build(callTarget, argDefinitions)
fnBodyNode.setTail(bodyIsTail)
setLocation(expr, location) setLocation(expr, location)
} }
@ -696,7 +734,16 @@ class IRToTruffle(
val childScope = scope.createChild(scopeInfo.scope) val childScope = scope.createChild(scopeInfo.scope)
val argumentExpression = val argumentExpression =
new ExpressionProcessor(childScope, scopeName).run(value) new ExpressionProcessor(childScope, scopeName).run(value)
argumentExpression.markTail()
val argExpressionIsTail = value
.getMetadata[TailCall.Metadata]
.getOrElse(
throw new CompilerError(
"Argument with missing tail call information."
)
)
argumentExpression.setTail(argExpressionIsTail)
val displayName = val displayName =
s"call_argument<${name.getOrElse(String.valueOf(position))}>" s"call_argument<${name.getOrElse(String.valueOf(position))}>"

View File

@ -2,7 +2,7 @@ package org.enso.compiler.core
import org.enso.compiler.core.IR.Expression import org.enso.compiler.core.IR.Expression
import org.enso.syntax.text.ast.Doc import org.enso.syntax.text.ast.Doc
import org.enso.syntax.text.{AST, Location} import org.enso.syntax.text.{AST, Debug, Location}
import scala.collection.immutable.{Set => ISet} import scala.collection.immutable.{Set => ISet}
import scala.reflect.ClassTag import scala.reflect.ClassTag
@ -51,6 +51,12 @@ sealed trait IR {
* @return `this`, potentially having had its children transformed by `fn` * @return `this`, potentially having had its children transformed by `fn`
*/ */
def mapExpressions(fn: Expression => Expression): IR def mapExpressions(fn: Expression => Expression): IR
/** Pretty prints the IR.
*
* @return a pretty-printed representation of the IR
*/
def pretty: String = Debug.pretty(this.toString)
} }
object IR { object IR {
@ -185,6 +191,8 @@ object IR {
* @param location the source location that the node corresponds to * @param location the source location that the node corresponds to
* @param passData the pass metadata associated with this node * @param passData the pass metadata associated with this node
*/ */
// TODO [AA] Separate Method into Method.Binding and Method.Explicit to
// account for syntax sugar later.
sealed case class Method( sealed case class Method(
typeName: IR.Name, typeName: IR.Name,
methodName: IR.Name, methodName: IR.Name,
@ -850,7 +858,7 @@ object IR {
right: Expression, right: Expression,
override val location: Option[Location], override val location: Option[Location],
override val passData: ISet[Metadata] = ISet() override val passData: ISet[Metadata] = ISet()
) extends Application ) extends Operator
with IRKind.Sugar { with IRKind.Sugar {
override def addMetadata(newData: Metadata): Binary = { override def addMetadata(newData: Metadata): Binary = {
copy(passData = this.passData + newData) copy(passData = this.passData + newData)
@ -1035,7 +1043,7 @@ object IR {
code: String, code: String,
override val location: Option[Location], override val location: Option[Location],
override val passData: ISet[Metadata] = ISet() override val passData: ISet[Metadata] = ISet()
) extends Expression ) extends Foreign
with IRKind.Primitive { with IRKind.Primitive {
override def addMetadata(newData: Metadata): Definition = { override def addMetadata(newData: Metadata): Definition = {
copy(passData = this.passData + newData) copy(passData = this.passData + newData)

View File

@ -1,7 +1,7 @@
package org.enso.compiler.pass package org.enso.compiler.pass
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope}
/** A representation of a compiler pass that runs on the [[IR]] type. */ /** A representation of a compiler pass that runs on the [[IR]] type. */
trait IRPass { trait IRPass {
@ -22,14 +22,13 @@ trait IRPass {
* or annotated version of `ir` in an inline context. * or annotated version of `ir` in an inline context.
* *
* @param ir the Enso IR to process * @param ir the Enso IR to process
* @param localScope the local scope in which the expression is executed * @param inlineContext a context object that contains the information needed
* @param moduleScope the module scope in which the expression is executed * for inline evaluation
* @return `ir`, possibly having made transformations or annotations to that * @return `ir`, possibly having made transformations or annotations to that
* IR. * IR.
*/ */
def runExpression( def runExpression(
ir: IR.Expression, ir: IR.Expression,
localScope: Option[LocalScope] = None, inlineContext: InlineContext
moduleScope: Option[ModuleScope] = None
): IR.Expression ): IR.Expression
} }

View File

@ -1,10 +1,10 @@
package org.enso.compiler.pass.analyse package org.enso.compiler.pass.analyse
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.exception.CompilerError import org.enso.compiler.exception.CompilerError
import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.IRPass
import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Occurrence, Scope} import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Occurrence, Scope}
import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope}
import org.enso.syntax.text.Debug import org.enso.syntax.text.Debug
import scala.reflect.ClassTag import scala.reflect.ClassTag
@ -52,17 +52,16 @@ case object AliasAnalysis extends IRPass {
* provided scope. * provided scope.
* *
* @param ir the Enso IR to process * @param ir the Enso IR to process
* @param localScope the local scope in which the expression is executed * @param inlineContext a context object that contains the information needed
* @param moduleScope the module scope in which the expression is executed * for inline evaluation
* @return `ir`, possibly having made transformations or annotations to that * @return `ir`, possibly having made transformations or annotations to that
* IR. * IR.
*/ */
override def runExpression( override def runExpression(
ir: IR.Expression, ir: IR.Expression,
localScope: Option[LocalScope] = None, inlineContext: InlineContext
moduleScope: Option[ModuleScope] = None
): IR.Expression = ): IR.Expression =
localScope inlineContext.localScope
.map { localScope => .map { localScope =>
val scope = localScope.scope val scope = localScope.scope
val graph = localScope.aliasingGraph val graph = localScope.aliasingGraph
@ -90,31 +89,23 @@ case object AliasAnalysis extends IRPass {
ir match { ir match {
case m @ IR.Module.Scope.Definition.Method(_, _, body, _, _) => case m @ IR.Module.Scope.Definition.Method(_, _, body, _, _) =>
val bodyWithThisArg = body match { body match {
case lam @ IR.Function.Lambda(args, _, _, _, _) => case _: IR.Function =>
lam.copy( m.copy(
arguments = IR.DefinitionArgument.Specified( body = analyseExpression(
IR.Name.This(None), body,
None, topLevelGraph,
suspended = false, topLevelGraph.rootScope,
None lambdaReuseScope = true,
) :: args blockReuseScope = true
) )
)
.addMetadata(Info.Scope.Root(topLevelGraph))
case _ => case _ =>
throw new CompilerError( throw new CompilerError(
"The body of a method should always be a lambda by." "The body of a method should always be a function."
) )
} }
m.copy(
body = analyseExpression(
bodyWithThisArg,
topLevelGraph,
topLevelGraph.rootScope,
lambdaReuseScope = true,
blockReuseScope = true
)
)
.addMetadata(Info.Scope.Root(topLevelGraph))
case a @ IR.Module.Scope.Definition.Atom(_, args, _, _) => case a @ IR.Module.Scope.Definition.Atom(_, args, _, _) =>
a.copy( a.copy(
arguments = arguments =

View File

@ -1,26 +1,22 @@
package org.enso.compiler.pass.analyse package org.enso.compiler.pass.analyse
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.exception.CompilerError import org.enso.compiler.exception.CompilerError
import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.IRPass
import org.enso.compiler.pass.analyse.ApplicationSaturation.{CallSaturation, Default, FunctionSpec, PassConfiguration} import org.enso.compiler.pass.analyse.ApplicationSaturation.{
CallSaturation,
Default,
FunctionSpec,
PassConfiguration
}
import org.enso.interpreter.node.{ExpressionNode => RuntimeExpression} import org.enso.interpreter.node.{ExpressionNode => RuntimeExpression}
import org.enso.interpreter.runtime.callable.argument.CallArgument import org.enso.interpreter.runtime.callable.argument.CallArgument
import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope}
import scala.annotation.unused
/** This optimisation pass recognises fully-saturated applications of known /** This optimisation pass recognises fully-saturated applications of known
* functions and writes analysis data that allows optimisation of them to * functions and writes analysis data that allows optimisation of them to
* specific nodes at codegen time. * specific nodes at codegen time.
* *
* PLEASE NOTE: This implementation is _incomplete_ as the analysis it performs
* is _unconditional_ at this stage. This means that, until we have alias
* analysis information,
*
* PLEASE NOTE: This implementation is _incomplete_ as the analysis it performs
* only operates for functions where the arguments are applied positionally.
*
* @param knownFunctions a mapping from known function names to information * @param knownFunctions a mapping from known function names to information
* about that function that can be used for optimisation * about that function that can be used for optimisation
*/ */
@ -39,7 +35,9 @@ case class ApplicationSaturation(
* IR. * IR.
*/ */
override def runModule(ir: IR.Module): IR.Module = override def runModule(ir: IR.Module): IR.Module =
ir.transformExpressions({ case x => runExpression(x) }) ir.transformExpressions({
case x => runExpression(x, new InlineContext)
})
/** Executes the analysis pass, marking functions with information about their /** Executes the analysis pass, marking functions with information about their
* argument saturation. * argument saturation.
@ -50,8 +48,7 @@ case class ApplicationSaturation(
*/ */
override def runExpression( override def runExpression(
ir: IR.Expression, ir: IR.Expression,
@unused localScope: Option[LocalScope] = None, inlineContext: InlineContext
@unused moduleScope: Option[ModuleScope] = None
): IR.Expression = { ): IR.Expression = {
ir.transformExpressions { ir.transformExpressions {
case func @ IR.Application.Prefix(fn, args, _, _, meta) => case func @ IR.Application.Prefix(fn, args, _, _, meta) =>
@ -82,7 +79,8 @@ case class ApplicationSaturation(
func.copy( func.copy(
arguments = args.map( arguments = args.map(
_.mapExpressions( _.mapExpressions(
(ir: IR.Expression) => runExpression(ir) (ir: IR.Expression) =>
runExpression(ir, inlineContext)
) )
), ),
passData = meta + saturationInfo passData = meta + saturationInfo
@ -92,7 +90,8 @@ case class ApplicationSaturation(
func.copy( func.copy(
arguments = args.map( arguments = args.map(
_.mapExpressions( _.mapExpressions(
(ir: IR.Expression) => runExpression(ir) (ir: IR.Expression) =>
runExpression(ir, inlineContext)
) )
), ),
passData = meta + CallSaturation.Over(args.length - arity) passData = meta + CallSaturation.Over(args.length - arity)
@ -101,7 +100,8 @@ case class ApplicationSaturation(
func.copy( func.copy(
arguments = args.map( arguments = args.map(
_.mapExpressions( _.mapExpressions(
(ir: IR.Expression) => runExpression(ir) (ir: IR.Expression) =>
runExpression(ir, inlineContext)
) )
), ),
passData = meta + CallSaturation.Partial( passData = meta + CallSaturation.Partial(
@ -112,22 +112,26 @@ case class ApplicationSaturation(
case None => case None =>
func.copy( func.copy(
arguments = args.map( arguments = args.map(
_.mapExpressions((ir: IR.Expression) => runExpression(ir)) _.mapExpressions(
(ir: IR.Expression) => runExpression(ir, inlineContext)
)
), ),
passData = meta + CallSaturation.Unknown() passData = meta + CallSaturation.Unknown()
) )
} }
} else { } else {
func.copy( func.copy(
function = runExpression(fn), function = runExpression(fn, inlineContext),
arguments = args.map(_.mapExpressions(runExpression(_))), arguments =
args.map(_.mapExpressions(runExpression(_, inlineContext))),
passData = meta + CallSaturation.Unknown() passData = meta + CallSaturation.Unknown()
) )
} }
case _ => case _ =>
func.copy( func.copy(
function = runExpression(fn), function = runExpression(fn, inlineContext),
arguments = args.map(_.mapExpressions(runExpression(_))), arguments =
args.map(_.mapExpressions(runExpression(_, inlineContext))),
passData = meta + CallSaturation.Unknown() passData = meta + CallSaturation.Unknown()
) )
} }

View File

@ -0,0 +1,380 @@
package org.enso.compiler.pass.analyse
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR
import org.enso.compiler.exception.CompilerError
import org.enso.compiler.pass.IRPass
/** This pass performs tail call analysis on the Enso IR.
*
* It is responsible for marking every single expression with whether it is in
* tail position or not. This allows the code generator to correctly create the
* Truffle nodes.
*/
case object TailCall extends IRPass {
/** The annotation metadata type associated with IR nodes by this pass. */
override type Metadata = TailPosition
/** Analyses tail call state for expressions in a module.
*
* @param ir the Enso IR to process
* @return `ir`, possibly having made transformations or annotations to that
* IR.
*/
override def runModule(ir: IR.Module): IR.Module = {
ir.copy(bindings = ir.bindings.map(analyseModuleBinding))
}
/** Analyses tail call state for an arbitrary expression.
*
* @param ir the Enso IR to process
* @param inlineContext a context object that contains the information needed
* for inline evaluation
* @return `ir`, possibly having made transformations or annotations to that
* IR.
*/
override def runExpression(
ir: IR.Expression,
inlineContext: InlineContext
): IR.Expression =
analyseExpression(
ir,
inlineContext.isInTailPosition.getOrElse(
throw new CompilerError(
"Information about the tail position for an inline expression " +
"must be known by the point of tail call analysis."
)
)
)
/** Performs tail call analysis on a top-level definition in a module.
*
* @param definition the top-level definition to analyse
* @return `definition`, annotated with tail call information
*/
def analyseModuleBinding(
definition: IR.Module.Scope.Definition
): IR.Module.Scope.Definition = {
definition match {
case method @ IR.Module.Scope.Definition.Method(_, _, body, _, _) =>
method
.copy(
body = analyseExpression(body, isInTailPosition = true)
)
.addMetadata(TailPosition.Tail)
case atom @ IR.Module.Scope.Definition.Atom(_, args, _, _) =>
atom
.copy(
arguments = args.map(analyseDefArgument)
)
.addMetadata(TailPosition.Tail)
}
}
/** Performs tail call analysis on an arbitrary expression.
*
* @param expression the expression to analyse
* @param isInTailPosition whether or not the expression is occurring in tail
* position
* @return `expression`, annotated with tail position metadata
*/
def analyseExpression(
expression: IR.Expression,
isInTailPosition: Boolean
): IR.Expression = {
expression match {
case function: IR.Function => analyseFunction(function, isInTailPosition)
case caseExpr: IR.Case => analyseCase(caseExpr, isInTailPosition)
case typ: IR.Type => analyseType(typ, isInTailPosition)
case app: IR.Application => analyseApplication(app, isInTailPosition)
case name: IR.Name => analyseName(name, isInTailPosition)
case foreign: IR.Foreign => foreign.addMetadata(TailPosition.NotTail)
case literal: IR.Literal => analyseLiteral(literal, isInTailPosition)
case comment: IR.Comment => analyseComment(comment, isInTailPosition)
case block @ IR.Expression.Block(expressions, returnValue, _, _, _) =>
block
.copy(
expressions =
expressions.map(analyseExpression(_, isInTailPosition = false)),
returnValue = analyseExpression(returnValue, isInTailPosition)
)
.addMetadata(TailPosition.fromBool(isInTailPosition))
case binding @ IR.Expression.Binding(_, expression, _, _) =>
binding
.copy(
expression = analyseExpression(expression, isInTailPosition)
)
.addMetadata(TailPosition.fromBool(isInTailPosition))
case err: IR.Error => err
}
}
/** Performs tail call analysis on an occurrence of a name.
*
* @param name the name to analyse
* @param isInTailPosition whether the name occurs in tail position or not
* @return `name`, annotated with tail position metadata
*/
def analyseName(name: IR.Name, isInTailPosition: Boolean): IR.Name = {
name.addMetadata(TailPosition.fromBool(isInTailPosition))
}
/** Performs tail call analysis on a comment occurrence.
*
* @param comment the comment to analyse
* @param isInTailPosition whether the comment occurs in tail position or not
* @return `comment`, annotated with tail position metadata
*/
def analyseComment(
comment: IR.Comment,
isInTailPosition: Boolean
): IR.Comment = {
comment match {
case doc @ IR.Comment.Documentation(expr, _, _, _) =>
doc
.copy(commented = analyseExpression(expr, isInTailPosition))
.addMetadata(TailPosition.fromBool(isInTailPosition))
}
}
/** Performs tail call analysis on a literal.
*
* @param literal the literal to analyse
* @param isInTailPosition whether or not the literal occurs in tail position
* or not
* @return `literal`, annotated with tail position metdata
*/
def analyseLiteral(
literal: IR.Literal,
isInTailPosition: Boolean
): IR.Literal = {
literal.addMetadata(TailPosition.fromBool(isInTailPosition))
}
/** Performs tail call analysis on an application.
*
* @param application the application to analyse
* @param isInTailPosition whether or not the application is occurring in
* tail position
* @return `application`, annotated with tail position metadata
*/
def analyseApplication(
application: IR.Application,
isInTailPosition: Boolean
): IR.Application = {
application match {
case app @ IR.Application.Prefix(fn, args, _, _, _) =>
app
.copy(
function = analyseExpression(fn, isInTailPosition = false),
arguments = args.map(analyseCallArg)
)
.addMetadata(TailPosition.fromBool(isInTailPosition))
case force @ IR.Application.Force(target, _, _) =>
force
.copy(
target = analyseExpression(target, isInTailPosition)
)
.addMetadata(TailPosition.fromBool(isInTailPosition))
case _: IR.Application.Operator =>
throw new CompilerError("Unexpected binary operator.")
}
}
/** Performs tail call analysis on a call site argument.
*
* @param argument the argument to analyse
* @return `argument`, annotated with tail position metadata
*/
def analyseCallArg(argument: IR.CallArgument): IR.CallArgument = {
argument match {
case arg @ IR.CallArgument.Specified(_, expr, _, _) =>
arg
.copy(
// Note [Call Argument Tail Position]
value = analyseExpression(expr, isInTailPosition = true)
)
.addMetadata(TailPosition.Tail)
}
}
/* Note [Call Argument Tail Position]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* In order to efficiently deal with Enso's ability to suspend function
* arguments, we behave as if all arguments to a function are passed as
* thunks. This means that the _function_ becomes responsible for deciding
* when to evaluate its arguments.
*
* Conceptually, this results in a desugaring as follows:
*
* ```
* foo a b c
* ```
*
* Becomes:
*
* ```
* foo ({} -> a) ({} -> b) ({} -> c)
* ```
*
* Quite obviously, the arguments `a`, `b` and `c` are in tail position in
* these closures, and hence should be marked as tail.
*/
/** Performs tail call analysis on an expression involving type operators.
*
* @param value the type operator expression
* @param isInTailPosition whether or not the type operator occurs in a tail
* call position
* @return `value`, annotated with tail position metadata
*/
def analyseType(value: IR.Type, isInTailPosition: Boolean): IR.Type = {
value
.mapExpressions(analyseExpression(_, isInTailPosition = false))
.addMetadata(TailPosition.fromBool(isInTailPosition))
}
/** Performs tail call analysis on a case expression.
*
* @param caseExpr the case expression to analyse
* @param isInTailPosition whether or not the case expression occurs in a tail
* call position
* @return `caseExpr`, annotated with tail position metadata
*/
def analyseCase(caseExpr: IR.Case, isInTailPosition: Boolean): IR.Case = {
caseExpr match {
case caseExpr @ IR.Case.Expr(scrutinee, branches, fallback, _, _) =>
caseExpr
.copy(
scrutinee = analyseExpression(scrutinee, isInTailPosition = false),
// Note [Analysing Branches in Case Expressions]
branches = branches.map(analyseCaseBranch(_, isInTailPosition)),
fallback = fallback.map(analyseExpression(_, isInTailPosition))
)
.addMetadata(TailPosition.fromBool(isInTailPosition))
case _: IR.Case.Branch =>
throw new CompilerError("Unexpected case branch.")
}
}
/* Note [Analysing Branches in Case Expressions]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* When performing tail call analysis on a case expression it is very
* important to recognise that the branches of a case expression should all
* have the same tail call state. The branches should only be marked as being
* in tail position when the case expression _itself_ is in tail position.
*
* As only one branch is ever executed, it is hence safe to mark _all_
* branches as being in tail position if the case expression is.
*/
/** Performs tail call analysis on a case branch.
*
* @param branch the branch to analyse
* @param isInTailPosition whether or not the branch occurs in a tail call
* position
* @return `branch`, annotated with tail position metadata
*/
def analyseCaseBranch(
branch: IR.Case.Branch,
isInTailPosition: Boolean
): IR.Case.Branch = {
branch
.copy(
pattern = analyseExpression(branch.pattern, isInTailPosition = false),
expression = analyseExpression(
branch.expression,
isInTailPosition
)
)
.addMetadata(TailPosition.fromBool(isInTailPosition))
}
/** Performs tail call analysis on a function definition.
*
* @param function the function to analyse
* @param isInTailPosition whether or not the function definition occurs in a
* tail position
* @return `function`, annotated with tail position metadata
*/
def analyseFunction(
function: IR.Function,
isInTailPosition: Boolean
): IR.Function = {
val canBeTCO = function.canBeTCO
val markAsTail = (!canBeTCO && isInTailPosition) || canBeTCO
val resultFunction = function match {
case lambda @ IR.Function.Lambda(args, body, _, _, _) =>
lambda.copy(
arguments = args.map(analyseDefArgument),
body = analyseExpression(body, isInTailPosition = markAsTail)
)
}
resultFunction.addMetadata(
TailPosition.fromBool(isInTailPosition)
)
}
/** Performs tail call analysis on a function definition argument.
*
* @param arg the argument definition to analyse
* @return `arg`, annotated with tail position metadata
*/
def analyseDefArgument(arg: IR.DefinitionArgument): IR.DefinitionArgument = {
arg match {
case arg @ IR.DefinitionArgument.Specified(_, default, _, _, _) =>
arg
.copy(
defaultValue = default.map(
x =>
analyseExpression(x, isInTailPosition = false)
.addMetadata(TailPosition.NotTail)
)
)
.addMetadata(TailPosition.NotTail)
case err: IR.Error.Redefined.Argument => err
}
}
/** Expresses the tail call state of an IR Node. */
sealed trait TailPosition extends IR.Metadata {
/** A boolean representation of the expression's tail state. */
def isTail: Boolean
}
object TailPosition {
/** The expression is in a tail position and can be tail call optimised. */
final case object Tail extends TailPosition {
override def isTail: Boolean = true
}
/** The expression is not in a tail position and cannot be tail call
* optimised.
*/
final case object NotTail extends TailPosition {
override def isTail: Boolean = false
}
/** Implicitly converts a boolean to a [[TailPosition]] value.
*
* @param isTail the boolean
* @return the tail position value corresponding to `bool`
*/
implicit def fromBool(isTail: Boolean): TailPosition = {
if (isTail) TailPosition.Tail else TailPosition.NotTail
}
/** Implicitly converts the tail position data into a boolean.
*
* @param tailPosition the tail position value
* @return the boolean value corresponding to `tailPosition`
*/
implicit def toBool(tailPosition: TailPosition): Boolean = {
tailPosition.isTail
}
}
}

View File

@ -0,0 +1,111 @@
package org.enso.compiler.pass.desugar
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR
import org.enso.compiler.pass.IRPass
/** This pass is responsible for ensuring that method bodies are in the correct
* format.
*
* The correct format as far as the rest of the compiler pipeline is concerned
* is as follows:
*
* - The body is a function (lambda)
* - The body has `this` at the start of its argument list.
*/
case object GenerateMethodBodies extends IRPass {
/** This is a desugaring pass and performs no analysis */
override type Metadata = IR.Metadata.Empty
/** Generates and consolidates method bodies.
*
* @param ir the Enso IR to process
* @return `ir`, possibly having made transformations or annotations to that
* IR.
*/
override def runModule(ir: IR.Module): IR.Module = {
ir.copy(
bindings = ir.bindings.map {
case m: IR.Module.Scope.Definition.Method => processMethodDef(m)
case x => x
}
)
}
/** Processes a method definition, ensuring that it's in the correct format.
*
* @param ir the method definition to process
* @return `ir` potentially with alterations to ensure that it's in the
* correct format
*/
def processMethodDef(
ir: IR.Module.Scope.Definition.Method
): IR.Module.Scope.Definition.Method = {
ir.copy(
body = ir.body match {
case fun: IR.Function => processBodyFunction(fun)
case expression => processBodyExpression(expression)
}
)
}
/** Processes the method body if it's a function.
*
* This is solely responsible for prepending the `this` argument to the list
* of arguments.
*
* @param fun the body function
* @return the body function with the `this` argument
*/
def processBodyFunction(fun: IR.Function): IR.Function = {
fun match {
case lam @ IR.Function.Lambda(args, _, _, _, _) =>
lam.copy(
arguments = genThisArgument :: args
)
}
}
/** Processes the method body if it's an expression.
*
* @param expr the body expression
* @return `expr` converted to a function taking the `this` argument
*/
def processBodyExpression(expr: IR.Expression): IR.Expression = {
IR.Function.Lambda(
arguments = List(genThisArgument),
body = expr,
location = expr.location
)
}
/** Generates a definition of the `this` argument for method definitions.
*
* @return the `this` argument
*/
def genThisArgument: IR.DefinitionArgument.Specified = {
IR.DefinitionArgument.Specified(
IR.Name.This(None),
None,
suspended = false,
None
)
}
/** Executes the pass on an expression.
*
* It is a identity operation on expressions as method definitions are not
* expressions.
*
* @param ir the Enso IR to process
* @param inlineContext a context object that contains the information needed
* for inline evaluation
* @return `ir`, possibly having made transformations or annotations to that
* IR.
*/
override def runExpression(
ir: IR.Expression,
inlineContext: InlineContext
): IR.Expression = ir
}

View File

@ -1,10 +1,8 @@
package org.enso.compiler.pass.desugar package org.enso.compiler.pass.desugar
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.IRPass
import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope}
import scala.annotation.unused
/** This pass lifts any special operators (ones reserved by the language /** This pass lifts any special operators (ones reserved by the language
* implementation) into their own special IR constructs. * implementation) into their own special IR constructs.
@ -21,45 +19,89 @@ case object LiftSpecialOperators extends IRPass {
* IR. * IR.
*/ */
override def runModule(ir: IR.Module): IR.Module = override def runModule(ir: IR.Module): IR.Module =
ir.transformExpressions({ case x => runExpression(x) }) ir.transformExpressions({
case x => runExpression(x, new InlineContext)
})
/** Executes the lifting pass in an inline context. /** Executes the lifting pass in an inline context.
* *
* @param ir the Enso IR to process * @param ir the Enso IR to process
* @param inlineContext a context object that contains the information needed
* for inline evaluation
* @return `ir`, possibly having made transformations or annotations to that * @return `ir`, possibly having made transformations or annotations to that
* IR. * IR.
*/ */
override def runExpression( override def runExpression(
ir: IR.Expression, ir: IR.Expression,
@unused localScope: Option[LocalScope] = None, inlineContext: InlineContext
@unused moduleScope: Option[ModuleScope] = None
): IR.Expression = ): IR.Expression =
ir.transformExpressions({ ir.transformExpressions({
case IR.Application.Operator.Binary(l, op, r, loc, meta) => case IR.Application.Operator.Binary(l, op, r, loc, meta) =>
op.name match { op.name match {
case IR.Type.Ascription.name => case IR.Type.Ascription.name =>
IR.Type.Ascription(runExpression(l), runExpression(r), loc, meta) IR.Type.Ascription(
runExpression(l, inlineContext),
runExpression(r, inlineContext),
loc,
meta
)
case IR.Type.Set.Subsumption.name => case IR.Type.Set.Subsumption.name =>
IR.Type.Set IR.Type.Set
.Subsumption(runExpression(l), runExpression(r), loc, meta) .Subsumption(
runExpression(l, inlineContext),
runExpression(r, inlineContext),
loc,
meta
)
case IR.Type.Set.Equality.name => case IR.Type.Set.Equality.name =>
IR.Type.Set IR.Type.Set
.Equality(runExpression(l), runExpression(r), loc, meta) .Equality(
runExpression(l, inlineContext),
runExpression(r, inlineContext),
loc,
meta
)
case IR.Type.Set.Concat.name => case IR.Type.Set.Concat.name =>
IR.Type.Set IR.Type.Set
.Concat(runExpression(l), runExpression(r), loc, meta) .Concat(
runExpression(l, inlineContext),
runExpression(r, inlineContext),
loc,
meta
)
case IR.Type.Set.Union.name => case IR.Type.Set.Union.name =>
IR.Type.Set IR.Type.Set
.Union(runExpression(l), runExpression(r), loc, meta) .Union(
runExpression(l, inlineContext),
runExpression(r, inlineContext),
loc,
meta
)
case IR.Type.Set.Intersection.name => case IR.Type.Set.Intersection.name =>
IR.Type.Set IR.Type.Set
.Intersection(runExpression(l), runExpression(r), loc, meta) .Intersection(
runExpression(l, inlineContext),
runExpression(r, inlineContext),
loc,
meta
)
case IR.Type.Set.Subtraction.name => case IR.Type.Set.Subtraction.name =>
IR.Type.Set IR.Type.Set
.Subtraction(runExpression(l), runExpression(r), loc, meta) .Subtraction(
runExpression(l, inlineContext),
runExpression(r, inlineContext),
loc,
meta
)
case _ => case _ =>
IR.Application.Operator IR.Application.Operator
.Binary(runExpression(l), op, runExpression(r), loc, meta) .Binary(
runExpression(l, inlineContext),
op,
runExpression(r, inlineContext),
loc,
meta
)
} }
}) })

View File

@ -1,8 +1,8 @@
package org.enso.compiler.pass.desugar package org.enso.compiler.pass.desugar
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.IRPass
import org.enso.interpreter.runtime.scope.{LocalScope, ModuleScope}
/** This pass converts usages of operators to calls to standard functions. */ /** This pass converts usages of operators to calls to standard functions. */
case object OperatorToFunction extends IRPass { case object OperatorToFunction extends IRPass {
@ -17,26 +17,31 @@ case object OperatorToFunction extends IRPass {
* IR. * IR.
*/ */
override def runModule(ir: IR.Module): IR.Module = override def runModule(ir: IR.Module): IR.Module =
ir.transformExpressions({ case x => runExpression(x) }) ir.transformExpressions({
case x => runExpression(x, new InlineContext)
})
/** Executes the conversion pass in an inline context. /** Executes the conversion pass in an inline context.
* *
* @param ir the Enso IR to process * @param ir the Enso IR to process
* @param inlineContext a context object that contains the information needed
* for inline evaluation
* @return `ir`, possibly having made transformations or annotations to that * @return `ir`, possibly having made transformations or annotations to that
* IR. * IR.
*/ */
override def runExpression( override def runExpression(
ir: IR.Expression, ir: IR.Expression,
localScope: Option[LocalScope] = None, inlineContext: InlineContext
moduleScope: Option[ModuleScope] = None
): IR.Expression = ): IR.Expression =
ir.transformExpressions { ir.transformExpressions {
case IR.Application.Operator.Binary(l, op, r, loc, passData) => case IR.Application.Operator.Binary(l, op, r, loc, passData) =>
IR.Application.Prefix( IR.Application.Prefix(
op, op,
List( List(
IR.CallArgument.Specified(None, runExpression(l), l.location), IR.CallArgument
IR.CallArgument.Specified(None, runExpression(r), r.location) .Specified(None, runExpression(l, inlineContext), l.location),
IR.CallArgument
.Specified(None, runExpression(r, inlineContext), r.location)
), ),
hasDefaultsSuspended = false, hasDefaultsSuspended = false,
loc, loc,

View File

@ -1,9 +1,9 @@
package org.enso.compiler.test package org.enso.compiler.test
import org.enso.compiler.InlineContext
import org.enso.compiler.codegen.AstToIR import org.enso.compiler.codegen.AstToIR
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.IRPass
import org.enso.interpreter.runtime.scope.LocalScope
import org.enso.syntax.text.{AST, Parser} import org.enso.syntax.text.{AST, Parser}
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpecLike import org.scalatest.wordspec.AnyWordSpecLike
@ -93,12 +93,12 @@ trait CompilerRunner {
*/ */
def runPasses( def runPasses(
passes: List[IRPass], passes: List[IRPass],
localScope: Option[LocalScope] = None inlineContext: InlineContext
): IR = ir match { ): IR = ir match {
case expr: IR.Expression => case expr: IR.Expression =>
passes.foldLeft(expr)( passes.foldLeft(expr)(
(intermediate, pass) => (intermediate, pass) =>
pass.runExpression(intermediate, localScope = localScope) pass.runExpression(intermediate, inlineContext)
) )
case mod: IR.Module => case mod: IR.Module =>
passes.foldLeft(mod)( passes.foldLeft(mod)(

View File

@ -1,12 +1,17 @@
package org.enso.compiler.test.pass.analyse package org.enso.compiler.test.pass.analyse
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.core.IR.Module.Scope.Definition.{Atom, Method} import org.enso.compiler.core.IR.Module.Scope.Definition.{Atom, Method}
import org.enso.compiler.pass.IRPass import org.enso.compiler.pass.IRPass
import org.enso.compiler.pass.analyse.AliasAnalysis import org.enso.compiler.pass.analyse.AliasAnalysis
import org.enso.compiler.pass.analyse.AliasAnalysis.{Graph, Info}
import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Link, Occurrence} import org.enso.compiler.pass.analyse.AliasAnalysis.Graph.{Link, Occurrence}
import org.enso.compiler.pass.desugar.{LiftSpecialOperators, OperatorToFunction} import org.enso.compiler.pass.analyse.AliasAnalysis.{Graph, Info}
import org.enso.compiler.pass.desugar.{
GenerateMethodBodies,
LiftSpecialOperators,
OperatorToFunction
}
import org.enso.compiler.test.CompilerTest import org.enso.compiler.test.CompilerTest
class AliasAnalysisTest extends CompilerTest { class AliasAnalysisTest extends CompilerTest {
@ -19,6 +24,7 @@ class AliasAnalysisTest extends CompilerTest {
*/ */
implicit class Preprocess(source: String) { implicit class Preprocess(source: String) {
val precursorPasses: List[IRPass] = List( val precursorPasses: List[IRPass] = List(
GenerateMethodBodies,
LiftSpecialOperators, LiftSpecialOperators,
OperatorToFunction OperatorToFunction
) )
@ -28,7 +34,9 @@ class AliasAnalysisTest extends CompilerTest {
* @return IR appropriate for testing the alias analysis pass as a module * @return IR appropriate for testing the alias analysis pass as a module
*/ */
def preprocessModule: IR.Module = { def preprocessModule: IR.Module = {
source.toIrModule.runPasses(precursorPasses).asInstanceOf[IR.Module] source.toIrModule
.runPasses(precursorPasses, InlineContext())
.asInstanceOf[IR.Module]
} }
/** Translates the source code into appropriate IR for testing this pass /** Translates the source code into appropriate IR for testing this pass
@ -36,9 +44,12 @@ class AliasAnalysisTest extends CompilerTest {
* @return IR appropriate for testing the alias analysis pass as an * @return IR appropriate for testing the alias analysis pass as an
* expression * expression
*/ */
def preprocessExpression: Option[IR.Expression] = { def preprocessExpression(
inlineContext: InlineContext
): Option[IR.Expression] = {
source.toIrExpression.map( source.toIrExpression.map(
_.runPasses(precursorPasses).asInstanceOf[IR.Expression] _.runPasses(precursorPasses, inlineContext)
.asInstanceOf[IR.Expression]
) )
} }
} }
@ -68,8 +79,8 @@ class AliasAnalysisTest extends CompilerTest {
* *
* @return [[ir]], with attached aliasing information * @return [[ir]], with attached aliasing information
*/ */
def analyse: IR.Expression = { def analyse(inlineContext: InlineContext): IR.Expression = {
AliasAnalysis.runExpression(ir) AliasAnalysis.runExpression(ir, inlineContext)
} }
} }

View File

@ -1,13 +1,14 @@
package org.enso.compiler.test.pass.analyse package org.enso.compiler.test.pass.analyse
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.core.IR.Metadata import org.enso.compiler.core.IR.Metadata
import org.enso.compiler.pass.analyse.{AliasAnalysis, ApplicationSaturation}
import org.enso.compiler.pass.analyse.ApplicationSaturation.{ import org.enso.compiler.pass.analyse.ApplicationSaturation.{
CallSaturation, CallSaturation,
FunctionSpec, FunctionSpec,
PassConfiguration PassConfiguration
} }
import org.enso.compiler.pass.analyse.{AliasAnalysis, ApplicationSaturation}
import org.enso.compiler.pass.desugar.{LiftSpecialOperators, OperatorToFunction} import org.enso.compiler.pass.desugar.{LiftSpecialOperators, OperatorToFunction}
import org.enso.compiler.test.CompilerTest import org.enso.compiler.test.CompilerTest
import org.enso.interpreter.node.ExpressionNode import org.enso.interpreter.node.ExpressionNode
@ -58,39 +59,54 @@ class ApplicationSaturationTest extends CompilerTest {
val localScope = Some(LocalScope.root) val localScope = Some(LocalScope.root)
val ctx = new InlineContext(localScope = localScope)
// === The Tests ============================================================ // === The Tests ============================================================
"Known applications" should { "Known applications" should {
val plusFn = IR.Application.Prefix( val plusFn = IR.Application
IR.Name.Literal("+", None), .Prefix(
genNArgs(2), IR.Name.Literal("+", None),
hasDefaultsSuspended = false, genNArgs(2),
None hasDefaultsSuspended = false,
).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
val bazFn = IR.Application.Prefix( val bazFn = IR.Application
IR.Name.Literal("baz", None), .Prefix(
genNArgs(2), IR.Name.Literal("baz", None),
hasDefaultsSuspended = false, genNArgs(2),
None hasDefaultsSuspended = false,
).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
val fooFn = IR.Application.Prefix( val fooFn = IR.Application
IR.Name.Literal("foo", None), .Prefix(
genNArgs(5), IR.Name.Literal("foo", None),
hasDefaultsSuspended = false, genNArgs(5),
None hasDefaultsSuspended = false,
).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
val fooFnByName = IR.Application.Prefix( val fooFnByName = IR.Application
IR.Name.Literal("foo", None), .Prefix(
genNArgs(4, positional = false), IR.Name.Literal("foo", None),
hasDefaultsSuspended = false, genNArgs(4, positional = false),
None hasDefaultsSuspended = false,
).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
"be tagged with full saturation where possible" in { "be tagged with full saturation where possible" in {
val resultIR = ApplicationSaturation(knownFunctions).runExpression(plusFn) val resultIR =
ApplicationSaturation(knownFunctions).runExpression(plusFn, ctx)
resultIR.getMetadata[CallSaturation].foreach { resultIR.getMetadata[CallSaturation].foreach {
case _: CallSaturation.Exact => succeed case _: CallSaturation.Exact => succeed
@ -99,14 +115,16 @@ class ApplicationSaturationTest extends CompilerTest {
} }
"be tagged with partial saturation where possible" in { "be tagged with partial saturation where possible" in {
val resultIR = ApplicationSaturation(knownFunctions).runExpression(bazFn) val resultIR =
ApplicationSaturation(knownFunctions).runExpression(bazFn, ctx)
val expected = Some(CallSaturation.Partial(1)) val expected = Some(CallSaturation.Partial(1))
resultIR.getMetadata[CallSaturation] shouldEqual expected resultIR.getMetadata[CallSaturation] shouldEqual expected
} }
"be tagged with over saturation where possible" in { "be tagged with over saturation where possible" in {
val resultIR = ApplicationSaturation(knownFunctions).runExpression(fooFn) val resultIR =
ApplicationSaturation(knownFunctions).runExpression(fooFn, ctx)
val expected = Some(CallSaturation.Over(1)) val expected = Some(CallSaturation.Over(1))
resultIR.getMetadata[CallSaturation] shouldEqual expected resultIR.getMetadata[CallSaturation] shouldEqual expected
@ -114,7 +132,7 @@ class ApplicationSaturationTest extends CompilerTest {
"be tagged with by name if applied by name" in { "be tagged with by name if applied by name" in {
val resultIR = val resultIR =
ApplicationSaturation(knownFunctions).runExpression(fooFnByName) ApplicationSaturation(knownFunctions).runExpression(fooFnByName, ctx)
val expected = Some(CallSaturation.ExactButByName()) val expected = Some(CallSaturation.ExactButByName())
resultIR.getMetadata[CallSaturation] shouldEqual expected resultIR.getMetadata[CallSaturation] shouldEqual expected
@ -122,16 +140,19 @@ class ApplicationSaturationTest extends CompilerTest {
} }
"Unknown applications" should { "Unknown applications" should {
val unknownFn = IR.Application.Prefix( val unknownFn = IR.Application
IR.Name.Literal("unknown", None), .Prefix(
genNArgs(10), IR.Name.Literal("unknown", None),
hasDefaultsSuspended = false, genNArgs(10),
None hasDefaultsSuspended = false,
).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
"be tagged with unknown saturation" in { "be tagged with unknown saturation" in {
val resultIR = val resultIR =
ApplicationSaturation(knownFunctions).runExpression(unknownFn) ApplicationSaturation(knownFunctions).runExpression(unknownFn, ctx)
val expected = Some(CallSaturation.Unknown()) val expected = Some(CallSaturation.Unknown())
resultIR.getMetadata[CallSaturation] shouldEqual expected resultIR.getMetadata[CallSaturation] shouldEqual expected
@ -140,26 +161,35 @@ class ApplicationSaturationTest extends CompilerTest {
"Known applications containing known applications" should { "Known applications containing known applications" should {
val empty = IR.Empty(None) val empty = IR.Empty(None)
val knownPlus = IR.Application.Prefix( val knownPlus = IR.Application
IR.Name.Literal("+", None), .Prefix(
genNArgs(2), IR.Name.Literal("+", None),
hasDefaultsSuspended = false, genNArgs(2),
None hasDefaultsSuspended = false,
).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
val undersaturatedPlus = IR.Application.Prefix( val undersaturatedPlus = IR.Application
IR.Name.Literal("+", None), .Prefix(
genNArgs(1), IR.Name.Literal("+", None),
hasDefaultsSuspended = false, genNArgs(1),
None hasDefaultsSuspended = false,
).runPasses(passes,localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
val oversaturatedPlus = IR.Application.Prefix( val oversaturatedPlus = IR.Application
IR.Name.Literal("+", None), .Prefix(
genNArgs(3), IR.Name.Literal("+", None),
hasDefaultsSuspended = false, genNArgs(3),
None hasDefaultsSuspended = false,
).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
implicit class InnerMeta(ir: IR.Expression) { implicit class InnerMeta(ir: IR.Expression) {
def getInnerMetadata[T <: Metadata: ClassTag]: Option[T] = { def getInnerMetadata[T <: Metadata: ClassTag]: Option[T] = {
@ -173,21 +203,25 @@ class ApplicationSaturationTest extends CompilerTest {
} }
def outerPlus(argExpr: IR.Expression): IR.Application.Prefix = { def outerPlus(argExpr: IR.Expression): IR.Application.Prefix = {
IR.Application.Prefix( IR.Application
IR.Name.Literal("+", None), .Prefix(
List( IR.Name.Literal("+", None),
IR.CallArgument.Specified(None, argExpr, None), List(
IR.CallArgument.Specified(None, empty, None) IR.CallArgument.Specified(None, argExpr, None),
), IR.CallArgument.Specified(None, empty, None)
hasDefaultsSuspended = false, ),
None hasDefaultsSuspended = false,
).runPasses(passes, localScope).asInstanceOf[IR.Application.Prefix] None
)
.runPasses(passes, ctx)
.asInstanceOf[IR.Application.Prefix]
} }
"have fully saturated applications tagged correctly" in { "have fully saturated applications tagged correctly" in {
val result = val result =
ApplicationSaturation(knownFunctions).runExpression( ApplicationSaturation(knownFunctions).runExpression(
outerPlus(knownPlus) outerPlus(knownPlus),
ctx
) )
// The outer should be reported as fully saturated // The outer should be reported as fully saturated
@ -206,7 +240,8 @@ class ApplicationSaturationTest extends CompilerTest {
"have non-fully saturated applications tagged correctly" in { "have non-fully saturated applications tagged correctly" in {
val result = val result =
ApplicationSaturation(knownFunctions).runExpression( ApplicationSaturation(knownFunctions).runExpression(
outerPlus(undersaturatedPlus) outerPlus(undersaturatedPlus),
ctx
) )
val expectedInnerMeta = CallSaturation.Partial(1) val expectedInnerMeta = CallSaturation.Partial(1)
@ -225,7 +260,8 @@ class ApplicationSaturationTest extends CompilerTest {
"have a mixture of application saturations tagged correctly" in { "have a mixture of application saturations tagged correctly" in {
val result = val result =
ApplicationSaturation(knownFunctions).runExpression( ApplicationSaturation(knownFunctions).runExpression(
outerPlus(oversaturatedPlus) outerPlus(oversaturatedPlus),
ctx
) )
val expectedInnerMeta = CallSaturation.Over(1) val expectedInnerMeta = CallSaturation.Over(1)
@ -252,11 +288,11 @@ class ApplicationSaturationTest extends CompilerTest {
|""".stripMargin.toIR |""".stripMargin.toIR
val inputIR = rawIR val inputIR = rawIR
.runPasses(passes, localScope = localScope) .runPasses(passes, ctx)
.asInstanceOf[IR.Expression] .asInstanceOf[IR.Expression]
val result = ApplicationSaturation(knownFunctions) val result = ApplicationSaturation(knownFunctions)
.runExpression(inputIR, localScope = localScope) .runExpression(inputIR, ctx)
.asInstanceOf[IR.Expression.Binding] .asInstanceOf[IR.Expression.Binding]
"be tagged as unknown even if their name is known" in { "be tagged as unknown even if their name is known" in {

View File

@ -0,0 +1,335 @@
package org.enso.compiler.test.pass.analyse
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR
import org.enso.compiler.core.IR.Module.Scope.Definition.Method
import org.enso.compiler.exception.CompilerError
import org.enso.compiler.pass.IRPass
import org.enso.compiler.pass.analyse.TailCall.TailPosition
import org.enso.compiler.pass.analyse.{
AliasAnalysis,
ApplicationSaturation,
TailCall
}
import org.enso.compiler.pass.desugar.{
GenerateMethodBodies,
LiftSpecialOperators,
OperatorToFunction
}
import org.enso.compiler.test.CompilerTest
import org.enso.interpreter.runtime.scope.LocalScope
class TailCallTest extends CompilerTest {
// === Test Setup ===========================================================
val tailCtx = InlineContext(
localScope = Some(LocalScope.root),
isInTailPosition = Some(true)
)
val noTailCtx = InlineContext(
localScope = Some(LocalScope.root),
isInTailPosition = Some(false)
)
val precursorPasses: List[IRPass] = List(
GenerateMethodBodies,
LiftSpecialOperators,
OperatorToFunction,
AliasAnalysis,
ApplicationSaturation()
)
/** Adds an extension method to preprocess source code as an Enso module.
*
* @param code the source code to preprocess
*/
implicit class PreprocessModule(code: String) {
/** Preprocesses the provided source code into an [[IR.Module]].
*
* @return the IR representation of [[code]]
*/
def runTCAModule: IR.Module = {
val preprocessed = code.toIrModule
.runPasses(precursorPasses, tailCtx)
.asInstanceOf[IR.Module]
TailCall.runModule(preprocessed)
}
}
/** Adds an extension method to preprocess source code as an Enso expression.
*
* @param code the source code to preprocess
*/
implicit class PreprocessExpression(code: String) {
/** Preprocesses the provided source code into an [[IR.Expression]].
*
* @return the IR representation of [[code]]
*/
def runTCAExpression(context: InlineContext): IR.Expression = {
val preprocessed = code.toIrExpression
.getOrElse(
throw new CompilerError("Code was not a valid expression.")
)
.runPasses(precursorPasses, context)
.asInstanceOf[IR.Expression]
TailCall.runExpression(preprocessed, context)
}
}
// === The Tests ============================================================
"Tail call analysis on modules" should {
val ir =
"""
|Foo.bar = a b c ->
| d = a + b
|
| case c of
| Baz a b -> a * b * d
| _ -> d
|
|type MyAtom a b c
|""".stripMargin.runTCAModule
"mark methods as tail" in {
ir.bindings.head
.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail)
}
"mark atoms as tail" in {
ir.bindings(1)
.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail)
}
}
"Tail call analysis on expressions" should {
val code =
"""
|x y z -> x y z
|""".stripMargin
"mark the expression as tail if the context requires it" in {
val ir = code.runTCAExpression(tailCtx)
ir.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail)
}
"not mark the expression as tail if the context doesn't require it" in {
val ir = code.runTCAExpression(noTailCtx)
ir.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.NotTail)
}
}
"Tail call analysis on functions" should {
val ir =
"""
|a b c ->
| d = a + b
| e = a * c
| d + e
|""".stripMargin
.runTCAExpression(tailCtx)
.asInstanceOf[IR.Function.Lambda]
val fnBody = ir.body.asInstanceOf[IR.Expression.Block]
"mark the last expression of the function as tail" in {
fnBody.returnValue.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.Tail
)
}
"mark the other expressions in the function as not tail" in {
fnBody.expressions.foreach(
expr =>
expr.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.NotTail
)
)
}
}
"Tail call analysis on case expressions" should {
"not mark any portion of the branch functions as tail by default" in {
val ir =
"""
|Foo.bar = a ->
| x = case a of
| Lambda fn arg -> fn arg
|
| x
|""".stripMargin.runTCAModule
val caseExpr = ir.bindings.head
.asInstanceOf[Method]
.body
.asInstanceOf[IR.Function.Lambda]
.body
.asInstanceOf[IR.Expression.Block]
.expressions
.head
.asInstanceOf[IR.Expression.Binding]
.expression
.asInstanceOf[IR.Case.Expr]
caseExpr.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.NotTail
)
caseExpr.branches.foreach(branch => {
val branchExpression =
branch.expression.asInstanceOf[IR.Function.Lambda]
branchExpression.getMetadata[TailPosition] shouldEqual Some(
TailPosition.NotTail
)
branchExpression.body.getMetadata[TailPosition] shouldEqual Some(
TailPosition.NotTail
)
})
}
"only mark the branches as tail if the expression is in tail position" in {
val ir =
"""
|Foo.bar = a ->
| case a of
| Lambda fn arg -> fn arg
|""".stripMargin.runTCAModule
val caseExpr = ir.bindings.head
.asInstanceOf[Method]
.body
.asInstanceOf[IR.Function.Lambda]
.body
.asInstanceOf[IR.Expression.Block]
.returnValue
.asInstanceOf[IR.Case.Expr]
caseExpr.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.Tail
)
caseExpr.branches.foreach(branch => {
val branchExpression =
branch.expression.asInstanceOf[IR.Function.Lambda]
branchExpression.getMetadata[TailPosition] shouldEqual Some(
TailPosition.Tail
)
branchExpression.body.getMetadata[TailPosition] shouldEqual Some(
TailPosition.Tail
)
})
}
}
"Tail call analysis on function calls" should {
val tailCall =
"""
|Foo.bar =
| IO.println "AAAAA"
|""".stripMargin.runTCAModule.bindings.head.asInstanceOf[Method]
val tailCallBody = tailCall.body
.asInstanceOf[IR.Function.Lambda]
.body
.asInstanceOf[IR.Expression.Block]
val nonTailCall =
"""
|Foo.bar =
| a = b c d
| a
|""".stripMargin.runTCAModule.bindings.head.asInstanceOf[Method]
val nonTailCallBody = nonTailCall.body
.asInstanceOf[IR.Function.Lambda]
.body
.asInstanceOf[IR.Expression.Block]
"mark the arguments as tail" in {
nonTailCallBody.expressions.head
.asInstanceOf[IR.Expression.Binding]
.expression
.asInstanceOf[IR.Application.Prefix]
.arguments
.foreach(
arg =>
arg.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.Tail
)
)
tailCallBody.returnValue
.asInstanceOf[IR.Application.Prefix]
.arguments
.foreach(
arg =>
arg.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.Tail
)
)
}
"mark the function call as tail if it is in a tail position" in {
tailCallBody.returnValue.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.Tail
)
}
"mark the function call as not tail if it is in a tail position" in {
nonTailCallBody.expressions.head
.asInstanceOf[IR.Expression.Binding]
.expression
.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.NotTail)
}
}
"Tail call analysis on blocks" should {
val ir =
"""
|Foo.bar = a b c ->
| d = a + b
| mul = a b -> a * b
| mul c d
|""".stripMargin.runTCAModule.bindings.head.asInstanceOf[Method]
val block = ir.body
.asInstanceOf[IR.Function.Lambda]
.body
.asInstanceOf[IR.Expression.Block]
"mark the bodies of bound functions as tail properly" in {
block
.expressions(1)
.asInstanceOf[IR.Expression.Binding]
.expression
.asInstanceOf[IR.Function.Lambda]
.body
.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail)
}
"mark the block expressions as not tail" in {
block.expressions.foreach(
expr =>
expr.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.NotTail
)
)
}
"mark the final expression of the block as tail" in {
block.returnValue.getMetadata[TailCall.Metadata] shouldEqual Some(
TailPosition.Tail
)
}
"mark the block as tail if it is in a tail position" in {
block.getMetadata[TailCall.Metadata] shouldEqual Some(TailPosition.Tail)
}
}
}

View File

@ -0,0 +1,77 @@
package org.enso.compiler.test.pass.desugar
import org.enso.compiler.core.IR
import org.enso.compiler.core.IR.Module.Scope.Definition.Method
import org.enso.compiler.pass.desugar.GenerateMethodBodies
import org.enso.compiler.test.CompilerTest
class GenerateMethodBodiesTest extends CompilerTest {
// === The Tests ============================================================
"Methods with functions as bodies" should {
val ir =
"""
|Unit.method = a b c -> a + b + c
|""".stripMargin.toIrModule
val irMethod = ir.bindings.head.asInstanceOf[Method]
val irResult = GenerateMethodBodies.runModule(ir)
val irResultMethod = irResult.bindings.head.asInstanceOf[Method]
"have the `this` argument prepended to the argument list" in {
val resultArgs =
irResultMethod.body.asInstanceOf[IR.Function.Lambda].arguments
resultArgs.head
.asInstanceOf[IR.DefinitionArgument.Specified]
.name shouldEqual IR.Name.This(None)
resultArgs.tail shouldEqual irMethod.body
.asInstanceOf[IR.Function.Lambda]
.arguments
}
"have the body of the function remain untouched" in {
val inputBody = irMethod.body.asInstanceOf[IR.Function.Lambda].body
val resultBody = irResultMethod.body.asInstanceOf[IR.Function.Lambda].body
inputBody shouldEqual resultBody
}
}
"Methods with expressions as bodies" should {
val ir =
"""
|Unit.method = 1
|""".stripMargin.toIrModule
val irMethod = ir.bindings.head.asInstanceOf[Method]
val irResult = GenerateMethodBodies.runModule(ir)
val irResultMethod = irResult.bindings.head.asInstanceOf[Method]
"have the expression converted into a function" in {
irResultMethod.body shouldBe an[IR.Function.Lambda]
}
"have the resultant function take the `this` argument" in {
val bodyArgs =
irResultMethod.body.asInstanceOf[IR.Function.Lambda].arguments
bodyArgs.length shouldEqual 1
bodyArgs.head
.asInstanceOf[IR.DefinitionArgument.Specified]
.name shouldEqual IR.Name.This(None)
}
"have the body of the function be equivalent to the expression" in {
irResultMethod.body
.asInstanceOf[IR.Function.Lambda]
.body shouldEqual irMethod.body
}
"have the body function's location equivalent to the original body" in {
irMethod.body.location shouldEqual irResultMethod.body.location
}
}
}

View File

@ -1,5 +1,6 @@
package org.enso.compiler.test.pass.desugar package org.enso.compiler.test.pass.desugar
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.pass.desugar.LiftSpecialOperators import org.enso.compiler.pass.desugar.LiftSpecialOperators
import org.enso.compiler.test.CompilerTest import org.enso.compiler.test.CompilerTest
@ -9,6 +10,8 @@ class LiftSpecialOperatorsTest extends CompilerTest {
// === Utilities ============================================================ // === Utilities ============================================================
val ctx = new InlineContext
/** Tests whether a given operator is lifted correctly into the corresponding /** Tests whether a given operator is lifted correctly into the corresponding
* special construct. * special construct.
* *
@ -42,7 +45,7 @@ class LiftSpecialOperatorsTest extends CompilerTest {
"be lifted by the pass in an inline context" in { "be lifted by the pass in an inline context" in {
LiftSpecialOperators LiftSpecialOperators
.runExpression(expressionIR) shouldEqual outputExpressionIR .runExpression(expressionIR, ctx) shouldEqual outputExpressionIR
} }
"be lifted by the pass in a module context" in { "be lifted by the pass in a module context" in {
@ -62,7 +65,7 @@ class LiftSpecialOperatorsTest extends CompilerTest {
) )
LiftSpecialOperators LiftSpecialOperators
.runExpression(recursiveIR) shouldEqual recursiveIROutput .runExpression(recursiveIR, ctx) shouldEqual recursiveIROutput
} }
} }

View File

@ -1,5 +1,6 @@
package org.enso.compiler.test.pass.desugar package org.enso.compiler.test.pass.desugar
import org.enso.compiler.InlineContext
import org.enso.compiler.core.IR import org.enso.compiler.core.IR
import org.enso.compiler.pass.desugar.OperatorToFunction import org.enso.compiler.pass.desugar.OperatorToFunction
import org.enso.compiler.test.CompilerTest import org.enso.compiler.test.CompilerTest
@ -9,6 +10,8 @@ class OperatorToFunctionTest extends CompilerTest {
// === Utilities ============================================================ // === Utilities ============================================================
val ctx = new InlineContext
/** Generates an operator and its corresponding function. /** Generates an operator and its corresponding function.
* *
* @param name * @param name
@ -47,7 +50,7 @@ class OperatorToFunctionTest extends CompilerTest {
val (operator, operatorFn) = genOprAndFn(opName, left, right) val (operator, operatorFn) = genOprAndFn(opName, left, right)
"be translated to functions" in { "be translated to functions" in {
OperatorToFunction.runExpression(operator) shouldEqual operatorFn OperatorToFunction.runExpression(operator, ctx) shouldEqual operatorFn
} }
"be translated in module contexts" in { "be translated in module contexts" in {
@ -70,7 +73,7 @@ class OperatorToFunctionTest extends CompilerTest {
None None
) )
OperatorToFunction.runExpression(recursiveIR) shouldEqual recursiveIRResult OperatorToFunction.runExpression(recursiveIR, ctx) shouldEqual recursiveIRResult
} }
} }
} }

View File

@ -57,4 +57,20 @@ class CurryingTest extends InterpreterTest {
eval(code) shouldEqual 32 eval(code) shouldEqual 32
} }
"Method call syntax" should "allow default arguments to be suspended" in {
val code =
"""
|Unit.fn = w x (y = 10) (z = 20) -> w + x + y + z
|
|main =
| fn1 = Unit.fn ...
| fn2 = fn1 1 2 ...
| fn3 = fn2 3 ...
|
| fn3.call
|""".stripMargin
eval(code) shouldEqual 26
}
} }