Implement and benchmark ArrowOperationPlus node (#10150)

Prototype of #10056 showing `+` operation implemented in the _Arrow language_.
This commit is contained in:
Jaroslav Tulach 2024-06-11 14:50:59 +02:00 committed by GitHub
parent 19c50ceff9
commit aaaebcabf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1017 additions and 379 deletions

View File

@ -11,36 +11,42 @@ public final class ArrowParser {
public record Result(PhysicalLayout physicalLayout, LogicalLayout logicalLayout, Mode mode) {}
public static Result parse(Source source) {
String src = source.getCharacters().toString();
Matcher m = NEW_ARRAY_CONSTR.matcher(src);
String src = source.getCharacters().toString().replace('\n', ' ').trim();
Matcher m = PATTERN.matcher(src);
if (m.find()) {
try {
var layout = LogicalLayout.valueOf(m.group(1));
return new Result(PhysicalLayout.Primitive, layout, Mode.Allocate);
var layout = LogicalLayout.valueOf(m.group(2));
var mode = Mode.parse(m.group(1));
if (layout != null && mode != null) {
return new Result(PhysicalLayout.Primitive, layout, mode);
}
} catch (IllegalArgumentException iae) {
// propagate warning
return null;
}
}
m = CAST_PATTERN.matcher(src);
if (m.find()) {
try {
var layout = LogicalLayout.valueOf(m.group(1));
return new Result(PhysicalLayout.Primitive, layout, Mode.Cast);
} catch (IllegalArgumentException iae) {
// propagate warning
return null;
}
}
return null;
}
private static final Pattern NEW_ARRAY_CONSTR = Pattern.compile("^new\\[(.+)\\]$");
private static final Pattern CAST_PATTERN = Pattern.compile("^cast\\[(.+)\\]$");
private static final Pattern PATTERN = Pattern.compile("^([a-z\\+]+)\\[(.+)\\]$");
public enum Mode {
Allocate,
Cast
Allocate("new"),
Cast("cast"),
Plus("+");
private final String op;
private Mode(String text) {
this.op = text;
}
static Mode parse(String operation) {
for (var m : values()) {
if (m.op.equals(operation)) {
return m;
}
}
return null;
}
}
}

View File

@ -1,16 +0,0 @@
package org.enso.interpreter.arrow.node;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.arrow.LogicalLayout;
import org.enso.interpreter.arrow.runtime.ArrowCastToFixedSizeArrayFactory;
public class ArrowCastFixedSizeNode extends Node {
static ArrowCastFixedSizeNode create() {
return new ArrowCastFixedSizeNode();
}
public Object execute(LogicalLayout layoutType) {
return new ArrowCastToFixedSizeArrayFactory(layoutType);
}
}

View File

@ -6,13 +6,13 @@ import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.arrow.ArrowLanguage;
import org.enso.interpreter.arrow.ArrowParser;
import org.enso.interpreter.arrow.runtime.ArrowCastToFixedSizeArrayFactory;
import org.enso.interpreter.arrow.runtime.ArrowFixedSizeArrayFactory;
import org.enso.interpreter.arrow.runtime.ArrowOperationPlus;
public class ArrowEvalNode extends RootNode {
private final ArrowParser.Result code;
@Child private ArrowFixedSizeNode fixedPhysicalLayout = ArrowFixedSizeNode.create();
@Child private ArrowCastFixedSizeNode castToFixedPhysicalLayout = ArrowCastFixedSizeNode.create();
public static ArrowEvalNode create(ArrowLanguage language, ArrowParser.Result code) {
return new ArrowEvalNode(language, code);
}
@ -25,8 +25,10 @@ public class ArrowEvalNode extends RootNode {
public Object execute(VirtualFrame frame) {
return switch (code.physicalLayout()) {
case Primitive -> switch (code.mode()) {
case Allocate -> fixedPhysicalLayout.execute(code.logicalLayout());
case Cast -> castToFixedPhysicalLayout.execute(code.logicalLayout());
case Allocate -> new ArrowFixedSizeArrayFactory(code.logicalLayout());
case Cast -> new ArrowCastToFixedSizeArrayFactory(code.logicalLayout());
case Plus -> new ArrowOperationPlus(code.logicalLayout());
default -> throw CompilerDirectives.shouldNotReachHere("unsupported mode");
};
default -> throw CompilerDirectives.shouldNotReachHere("unsupported physical layout");
};

View File

@ -1,16 +0,0 @@
package org.enso.interpreter.arrow.node;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.arrow.LogicalLayout;
import org.enso.interpreter.arrow.runtime.ArrowFixedSizeArrayFactory;
public class ArrowFixedSizeNode extends Node {
static ArrowFixedSizeNode create() {
return new ArrowFixedSizeNode();
}
public Object execute(LogicalLayout layoutType) {
return new ArrowFixedSizeArrayFactory(layoutType);
}
}

View File

@ -1,13 +1,23 @@
package org.enso.interpreter.arrow.runtime;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Bind;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.dsl.NeverDefault;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.StopIterationException;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.InlinedExactClassProfile;
import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import org.enso.interpreter.arrow.LogicalLayout;
@ExportLibrary(InteropLibrary.class)
@ -27,7 +37,24 @@ public final class ArrowFixedArrayInt implements TruffleObject {
}
@ExportMessage
public boolean hasArrayElements() {
boolean hasArrayElements() {
return true;
}
@ExportMessage
Object getIterator(
@Cached(value = "this.getUnit()", allowUncached = true) LogicalLayout cachedUnit)
throws UnsupportedMessageException {
if (cachedUnit == LogicalLayout.Int64) {
var dataIt = new LongIterator(buffer.getDataBuffer(), cachedUnit.sizeInBytes());
var nullIt = new NullIterator(dataIt, buffer.getBitmapBuffer());
return nullIt;
}
return new GenericIterator(this);
}
@ExportMessage
boolean hasIterator() {
return true;
}
@ -65,13 +92,18 @@ public final class ArrowFixedArrayInt implements TruffleObject {
}
@Specialization(guards = "receiver.getUnit() == Int64")
public static Object doLong(ArrowFixedArrayInt receiver, long index)
public static Object doLong(
ArrowFixedArrayInt receiver,
long index,
@Bind("$node") Node node,
@CachedLibrary("receiver") InteropLibrary iop,
@Cached InlinedExactClassProfile bufferClazz)
throws UnsupportedMessageException, InvalidArrayIndexException {
var at = adjustedIndex(receiver.buffer, receiver.unit, receiver.size, index);
var at = adjustedIndex(receiver.buffer, LogicalLayout.Int64, receiver.size, index);
if (receiver.buffer.isNull((int) index)) {
return NullValue.get();
}
return receiver.buffer.getLong(at);
return receiver.buffer.getLong(at, iop, bufferClazz);
}
}
@ -96,4 +128,133 @@ public final class ArrowFixedArrayInt implements TruffleObject {
private static int typeAdjustedIndex(long index, SizeInBytes unit) {
return Math.toIntExact(index * unit.sizeInBytes());
}
@ExportLibrary(InteropLibrary.class)
static final class LongIterator implements TruffleObject {
private int at;
private final ByteBuffer buffer;
@NeverDefault final int step;
LongIterator(ByteBuffer buffer, int step) {
assert step != 0;
this.buffer = buffer;
this.step = step;
}
@ExportMessage
Object getIteratorNextElement(
@Bind("$node") Node node,
@Cached("this.step") int step,
@Cached InlinedExactClassProfile bufferTypeProfile)
throws StopIterationException {
var buf = bufferTypeProfile.profile(node, buffer);
try {
var res = buf.getLong(at);
at += step;
return res;
} catch (BufferOverflowException ex) {
CompilerDirectives.transferToInterpreter();
throw StopIterationException.create();
}
}
@ExportMessage
boolean isIterator() {
return true;
}
@ExportMessage
boolean hasIteratorNextElement() throws UnsupportedMessageException {
return at < buffer.limit();
}
}
@ExportLibrary(value = InteropLibrary.class)
static final class NullIterator implements TruffleObject {
private final TruffleObject it;
private final ByteBuffer buffer;
private byte byteMask;
private byte byteValue;
NullIterator(TruffleObject delegate, ByteBuffer buffer) {
this.it = delegate;
this.buffer = buffer;
}
final TruffleObject it() {
return it;
}
@ExportMessage(limit = "3")
Object getIteratorNextElement(
@Bind("$node") Node node,
@CachedLibrary("this.it()") InteropLibrary iopIt,
@Cached InlinedExactClassProfile bufferTypeProfile)
throws StopIterationException, UnsupportedMessageException {
var element = iopIt.getIteratorNextElement(it);
if (buffer != null) {
var buf = bufferTypeProfile.profile(node, buffer);
if (byteMask == 0) {
// (byte) (0x01 << 8) ==> 0
byteValue = buf.get();
byteMask = 0x01;
}
var include = byteValue & byteMask;
byteMask = (byte) (byteMask << 1);
if (include == 0) {
return NullValue.get();
}
}
return element;
}
@ExportMessage
boolean isIterator() {
return true;
}
@ExportMessage(limit = "3")
boolean hasIteratorNextElement(@CachedLibrary("this.it()") InteropLibrary iopIt)
throws UnsupportedMessageException {
return iopIt.hasIteratorNextElement(it);
}
}
@ExportLibrary(InteropLibrary.class)
static final class GenericIterator implements TruffleObject {
private int at;
private final TruffleObject array;
GenericIterator(TruffleObject array) {
assert InteropLibrary.getUncached().hasArrayElements(array);
this.array = array;
}
TruffleObject array() {
return array;
}
@ExportMessage(limit = "3")
Object getIteratorNextElement(@CachedLibrary("this.array()") InteropLibrary iop)
throws StopIterationException {
try {
var res = iop.readArrayElement(array, at);
at++;
return res;
} catch (UnsupportedMessageException | InvalidArrayIndexException ex) {
throw StopIterationException.create();
}
}
@ExportMessage
boolean isIterator() {
return true;
}
@ExportMessage(limit = "3")
boolean hasIteratorNextElement(@CachedLibrary("this.array()") InteropLibrary iop)
throws UnsupportedMessageException {
return at < iop.getArraySize(array);
}
}
}

View File

@ -1,22 +1,27 @@
package org.enso.interpreter.arrow.runtime;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Cached.Shared;
import com.oracle.truffle.api.dsl.GenerateInline;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.interop.UnknownIdentifierException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.arrow.LogicalLayout;
@ExportLibrary(InteropLibrary.class)
public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
private final ByteBufferDirect buffer;
private final LogicalLayout unit;
private final int size;
private int index;
private boolean sealed;
private ByteBufferDirect buffer;
private static final String APPEND_OP = "append";
private static final String BUILD_OP = "build";
@ -25,8 +30,6 @@ public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
this.size = size;
this.unit = unit;
this.buffer = ByteBufferDirect.forSize(size, unit);
this.index = 0;
this.sealed = false;
}
public LogicalLayout getUnit() {
@ -34,7 +37,7 @@ public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
}
public boolean isSealed() {
return sealed;
return buffer == null;
}
public ByteBufferDirect getBuffer() {
@ -53,7 +56,7 @@ public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
@ExportMessage
public boolean isMemberInvocable(String member) {
return switch (member) {
case APPEND_OP -> !this.sealed;
case APPEND_OP -> buffer != null;
case BUILD_OP -> true;
default -> false;
};
@ -65,33 +68,83 @@ public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
}
@ExportMessage
Object invokeMember(
String name,
Object[] args,
@Cached(value = "buildWriterOrNull(name)", neverDefault = true)
WriteToBuilderNode writeToBuilderNode)
Object invokeMember(String name, Object[] args, @Cached AppendNode append)
throws UnsupportedMessageException, UnknownIdentifierException, UnsupportedTypeException {
switch (name) {
case BUILD_OP:
sealed = true;
return switch (unit) {
case Date32, Date64 -> new ArrowFixedArrayDate(buffer, size, unit);
case Int8, Int16, Int32, Int64 -> new ArrowFixedArrayInt(buffer, size, unit);
};
case APPEND_OP:
if (sealed) {
throw UnsupportedMessageException.create();
}
var current = index;
writeToBuilderNode.executeWrite(this, current, args[0]);
index += 1;
return NullValue.get();
default:
throw UnknownIdentifierException.create(name);
return switch (name) {
case BUILD_OP -> build();
case APPEND_OP -> {
append.executeAppend(this, args[0]);
yield NullValue.get();
}
default -> throw UnknownIdentifierException.create(name);
};
}
private final TruffleObject build() throws UnsupportedMessageException {
var b = buffer;
if (b == null) {
throw UnsupportedMessageException.create();
}
buffer = null;
return switch (unit) {
case Date32, Date64 -> new ArrowFixedArrayDate(b, size, unit);
case Int8, Int16, Int32, Int64 -> new ArrowFixedArrayInt(b, size, unit);
};
}
@GenerateUncached
@GenerateInline(false)
abstract static class AppendNode extends Node {
abstract void executeAppend(ArrowFixedSizeArrayBuilder builder, Object value)
throws UnsupportedTypeException, UnsupportedMessageException;
@Specialization(
limit = "3",
guards = {"builder.getUnit() == cachedUnit"})
static void writeToBuffer(
ArrowFixedSizeArrayBuilder builder,
Object value,
@Cached(value = "builder.getUnit()", allowUncached = true) LogicalLayout cachedUnit,
@Shared("put") @Cached ByteBufferDirect.PutNode put,
@Shared("value") @Cached ValueToNumberNode valueNode,
@Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop)
throws UnsupportedTypeException, UnsupportedMessageException {
if (iop.isNull(value)) {
put.putNull(builder.buffer, cachedUnit);
return;
}
var number = valueNode.executeAdjust(cachedUnit, value);
switch (number) {
case Byte b -> put.put(builder.buffer, b);
case Short s -> put.putShort(builder.buffer, s);
case Integer i -> put.putInt(builder.buffer, i);
case Long l -> put.putLong(builder.buffer, l);
default -> throw CompilerDirectives.shouldNotReachHere();
}
}
@Specialization(replaces = "writeToBuffer")
static void writeToBufferUncached(
ArrowFixedSizeArrayBuilder builder,
Object value,
@Shared("put") @Cached ByteBufferDirect.PutNode put,
@Shared("value") @Cached ValueToNumberNode valueNode,
@Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop)
throws UnsupportedTypeException, UnsupportedMessageException {
writeToBuffer(builder, value, builder.getUnit(), put, valueNode, iop);
}
}
static WriteToBuilderNode buildWriterOrNull(String op) {
return APPEND_OP.equals(op) ? WriteToBuilderNode.build() : WriteToBuilderNodeGen.getUncached();
@GenerateUncached
@GenerateInline(false)
abstract static class BuildNode extends Node {
abstract TruffleObject executeBuild(ArrowFixedSizeArrayBuilder builder)
throws UnsupportedMessageException;
@Specialization
static TruffleObject buildIt(ArrowFixedSizeArrayBuilder builder)
throws UnsupportedMessageException {
return builder.build();
}
}
}

View File

@ -2,8 +2,8 @@ package org.enso.interpreter.arrow.runtime;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.dsl.GenerateInline;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.TruffleObject;
@ -11,10 +11,11 @@ import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.arrow.LogicalLayout;
@ExportLibrary(InteropLibrary.class)
public class ArrowFixedSizeArrayFactory implements TruffleObject {
public final class ArrowFixedSizeArrayFactory implements TruffleObject {
private final LogicalLayout logicalLayout;
@ -23,7 +24,7 @@ public class ArrowFixedSizeArrayFactory implements TruffleObject {
}
@ExportMessage
public boolean isInstantiable() {
boolean isInstantiable() {
return true;
}
@ -32,79 +33,36 @@ public class ArrowFixedSizeArrayFactory implements TruffleObject {
}
@ExportMessage
@ImportStatic(LogicalLayout.class)
static class Instantiate {
@Specialization(guards = "receiver.getLayout() == Date32")
static Object doDate32(
ArrowFixedSizeArrayFactory receiver,
Object[] args,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedMessageException {
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
}
ArrowFixedSizeArrayBuilder instantiate(
Object[] args,
@Cached InstantiateNode instantiate,
@CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedMessageException {
var size = arraySize(args, iop);
return instantiate.allocateBuilder(logicalLayout, size);
}
@Specialization(guards = "receiver.getLayout() == Date64")
static Object doDate64(
ArrowFixedSizeArrayFactory receiver,
Object[] args,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedMessageException {
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
private static int arraySize(Object[] args, InteropLibrary interop)
throws UnsupportedMessageException {
if (args.length != 1 || !interop.isNumber(args[0]) || !interop.fitsInInt(args[0])) {
throw UnsupportedMessageException.create();
}
return interop.asInt(args[0]);
}
@Specialization(guards = "receiver.getLayout() == Int8")
static Object doInt8(
ArrowFixedSizeArrayFactory receiver,
Object[] args,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedMessageException {
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
}
@GenerateUncached
@GenerateInline(false)
abstract static class InstantiateNode extends Node {
abstract ArrowFixedSizeArrayBuilder executeNew(LogicalLayout logicalLayout, long size);
@Specialization(guards = "receiver.getLayout() == Int16")
static Object doInt16(
ArrowFixedSizeArrayFactory receiver,
Object[] args,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedMessageException {
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
}
@Specialization(guards = "receiver.getLayout() == Int32")
static Object doInt32(
ArrowFixedSizeArrayFactory receiver,
Object[] args,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedMessageException {
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
}
@Specialization(guards = "receiver.getLayout() == Int64")
static Object doInt64(
ArrowFixedSizeArrayFactory receiver,
Object[] args,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedMessageException {
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
}
@CompilerDirectives.TruffleBoundary
private static int arraySize(Object[] args, InteropLibrary interop)
throws UnsupportedMessageException {
if (args.length != 1 || !interop.isNumber(args[0]) || !interop.fitsInInt(args[0])) {
throw UnsupportedMessageException.create();
@Specialization
final ArrowFixedSizeArrayBuilder allocateBuilder(LogicalLayout logicalLayout, long size) {
try {
return new ArrowFixedSizeArrayBuilder(Math.toIntExact(size), logicalLayout);
} catch (ArithmeticException ex) {
CompilerDirectives.transferToInterpreter();
throw ex;
}
return interop.asInt(args[0]);
}
@Fallback
static Object doOther(ArrowFixedSizeArrayFactory receiver, Object[] args) {
throw CompilerDirectives.shouldNotReachHere(unknownLayoutMessage(receiver.getLayout()));
}
@CompilerDirectives.TruffleBoundary
private static String unknownLayoutMessage(SizeInBytes layout) {
return "unknown layout: " + layout.toString();
}
}
}

View File

@ -0,0 +1,97 @@
package org.enso.interpreter.arrow.runtime;
import com.oracle.truffle.api.dsl.Bind;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.interop.ArityException;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.StopIterationException;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.library.ExportLibrary;
import com.oracle.truffle.api.library.ExportMessage;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.InlinedExactClassProfile;
import org.enso.interpreter.arrow.LogicalLayout;
@ExportLibrary(InteropLibrary.class)
public final class ArrowOperationPlus implements TruffleObject {
private final LogicalLayout layout;
public ArrowOperationPlus(LogicalLayout layout) {
this.layout = layout;
}
final LogicalLayout layout() {
return layout;
}
@ExportMessage
boolean isExecutable() {
return true;
}
static Object args(Object[] args, int index) throws ArityException {
if (args.length != 2) {
throw ArityException.create(2, 2, args.length);
}
return args[index];
}
static Object it(Object[] args, InteropLibrary iop, int index)
throws ArityException, UnsupportedMessageException {
if (args.length != 2) {
throw ArityException.create(2, 2, args.length);
}
return iop.getIterator(args[index]);
}
ScalarOperationNode createScalarOp(boolean cached) {
return cached ? OperationPlus.create() : OperationPlus.getUncached();
}
@ExportMessage(limit = "3")
Object execute(
Object[] args,
@Bind("$node") Node node,
@Cached(value = "this.layout()", allowUncached = true) LogicalLayout cachedLayout,
@Cached ArrowFixedSizeArrayFactory.InstantiateNode factory,
@CachedLibrary("args(args, 0)") InteropLibrary iopArray0,
@CachedLibrary("args(args, 1)") InteropLibrary iopArray1,
@CachedLibrary("it(args, iopArray0, 0)") InteropLibrary iopIt0,
@CachedLibrary("it(args, iopArray1, 1)") InteropLibrary iopIt1,
@CachedLibrary(limit = "3") InteropLibrary iopElem,
@Cached(value = "this.createScalarOp(true)", uncached = "this.createScalarOp(false)")
ScalarOperationNode opNode,
@Cached ArrowFixedSizeArrayBuilder.AppendNode append,
@Cached ArrowFixedSizeArrayBuilder.BuildNode build,
@Cached InlinedExactClassProfile typeOfBuf0,
@Cached InlinedExactClassProfile typeOfBuf1)
throws ArityException, UnsupportedTypeException, UnsupportedMessageException {
var arr0 = args[0];
var arr1 = args[1];
if (!iopArray0.hasArrayElements(arr0) || !iopArray1.hasArrayElements(arr1)) {
throw UnsupportedTypeException.create(args);
}
var len = iopArray0.getArraySize(arr0);
if (len != iopArray1.getArraySize(arr1)) {
throw UnsupportedTypeException.create(args, "Arrays must have the same length");
}
var it0 = iopArray0.getIterator(arr0);
var it1 = iopArray1.getIterator(arr1);
var builder = factory.allocateBuilder(cachedLayout, len);
for (long i = 0; i < len; i++) {
try {
var elem0 = iopIt0.getIteratorNextElement(it0);
var elem1 = iopIt1.getIteratorNextElement(it1);
var res = opNode.executeOp(elem0, elem1);
append.executeAppend(builder, res);
} catch (StopIterationException ex) {
throw UnsupportedTypeException.create(new Object[] {it0, it1});
}
}
return build.executeBuild(builder);
}
}

View File

@ -1,14 +1,25 @@
package org.enso.interpreter.arrow.runtime;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Bind;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.GenerateInline;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.NeverDefault;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.InlinedExactClassProfile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.enso.interpreter.arrow.LogicalLayout;
import org.enso.interpreter.arrow.runtime.ByteBufferDirect.DataBufferNode;
import org.enso.interpreter.arrow.util.MemoryUtil;
final class ByteBufferDirect implements AutoCloseable {
private final ByteBuffer allocated;
private final ByteBuffer dataBuffer;
private final ByteBuffer bitmapBuffer;
private ByteBuffer bitmapBuffer;
/**
* Creates a fresh buffer with an empty non-null bitmap..
@ -22,10 +33,7 @@ final class ByteBufferDirect implements AutoCloseable {
this.allocated = buffer;
this.dataBuffer = buffer.slice(0, padded.getDataBufferSizeInBytes());
this.bitmapBuffer = buffer.slice(dataBuffer.capacity(), padded.getValidityBitmapSizeInBytes());
for (int i = 0; i < bitmapBuffer.capacity(); i++) {
bitmapBuffer.put(i, (byte) 0);
}
this.bitmapBuffer = null;
}
/**
@ -53,8 +61,13 @@ final class ByteBufferDirect implements AutoCloseable {
this.dataBuffer = dataBuffer;
this.bitmapBuffer = allocated.slice(dataBuffer.capacity(), bitmapSizeInBytes);
for (int i = 0; i < bitmapBuffer.capacity(); i++) {
bitmapBuffer.put(i, (byte) 255);
bitmapBuffer.put(i, (byte) 0xff);
}
bitmapBuffer.rewind();
}
static ByteBufferDirect forBuffer(ByteBuffer buf) {
return new ByteBufferDirect(buf, buf, null);
}
/**
@ -93,119 +106,200 @@ final class ByteBufferDirect implements AutoCloseable {
return new ByteBufferDirect(allocated, dataBuffer, bitmapBuffer);
}
public void put(byte b) throws UnsupportedMessageException {
setValidityBitmap(0, 1);
dataBuffer.put(b);
@CompilerDirectives.TruffleBoundary
ByteBuffer initializeBitmapBuffer() {
assert bitmapBuffer == null;
bitmapBuffer =
allocated.slice(dataBuffer.capacity(), allocated.capacity() - dataBuffer.capacity());
for (var i = 0; i < bitmapBuffer.capacity(); i++) {
bitmapBuffer.put(i, (byte) 0xff);
}
return bitmapBuffer;
}
final ByteBuffer getDataBuffer() {
return dataBuffer;
}
final ByteBuffer getBitmapBuffer() {
return bitmapBuffer;
}
@GenerateInline(false)
@GenerateUncached
abstract static class DataBufferNode extends Node {
static DataBufferNode create() {
return ByteBufferDirectFactory.DataBufferNodeGen.create();
}
static DataBufferNode getUncached() {
return ByteBufferDirectFactory.DataBufferNodeGen.getUncached();
}
abstract ByteBuffer executeDataBuffer(ByteBufferDirect direct);
@Specialization
static ByteBuffer profiledDataBuffer(
ByteBufferDirect direct,
@Bind("$node") Node node,
@Cached InlinedExactClassProfile bufferClazz) {
return bufferClazz.profile(node, direct.dataBuffer);
}
}
@GenerateInline(false)
@GenerateUncached
abstract static class BitmapBufferNode extends Node {
static BitmapBufferNode create() {
return ByteBufferDirectFactory.BitmapBufferNodeGen.create();
}
static BitmapBufferNode getUncached() {
return ByteBufferDirectFactory.BitmapBufferNodeGen.getUncached();
}
abstract ByteBuffer executeBitmapBuffer(ByteBufferDirect direct, boolean forceCreation);
@Specialization
static ByteBuffer profiledBitmapBuffer(
ByteBufferDirect direct,
boolean forceCreation,
@Bind("$node") Node node,
@Cached InlinedExactClassProfile bufferClazz) {
if (direct.bitmapBuffer == null) {
if (forceCreation) {
direct.bitmapBuffer = direct.initializeBitmapBuffer();
} else {
return null;
}
}
return bufferClazz.profile(node, direct.bitmapBuffer);
}
}
static final class PutNode extends Node {
private static final PutNode UNCACHED =
new PutNode(DataBufferNode.getUncached(), BitmapBufferNode.getUncached());
private @Child DataBufferNode dataBuffer;
private @Child BitmapBufferNode bitmapBuffer;
private PutNode(DataBufferNode dbn, BitmapBufferNode bbn) {
this.dataBuffer = dbn;
this.bitmapBuffer = bbn;
}
@NeverDefault
static PutNode create() {
return new PutNode(DataBufferNode.create(), BitmapBufferNode.create());
}
@NeverDefault
static PutNode getUncached() {
return UNCACHED;
}
final void put(ByteBufferDirect direct, byte b) {
var db = dataBuffer.executeDataBuffer(direct);
addValidityBitmap(direct, db.position(), 1);
db.put(b);
}
final void putNull(ByteBufferDirect direct, LogicalLayout unit) {
var db = dataBuffer.executeDataBuffer(direct);
var index = db.position() / unit.sizeInBytes();
var bb = bitmapBuffer.executeBitmapBuffer(direct, true);
var bufferIndex = index >> 3;
var slot = bb.get(bufferIndex);
var byteIndex = index & BYTE_MASK;
var mask = ~(1 << byteIndex);
bb.put(bufferIndex, (byte) (slot & mask));
db.position(db.position() + unit.sizeInBytes());
}
final void putShort(ByteBufferDirect direct, short value) {
var db = dataBuffer.executeDataBuffer(direct);
addValidityBitmap(direct, db.position(), 2);
db.putShort(value);
}
final void putInt(ByteBufferDirect direct, int value) {
var db = dataBuffer.executeDataBuffer(direct);
addValidityBitmap(direct, db.position(), 4);
db.putInt(value);
}
final void putLong(ByteBufferDirect direct, long value) throws UnsupportedMessageException {
var db = dataBuffer.executeDataBuffer(direct);
addValidityBitmap(direct, db.position(), 8);
db.putLong(value);
}
private void addValidityBitmap(ByteBufferDirect direct, int pos, int size) {
var bb = bitmapBuffer.executeBitmapBuffer(direct, false);
if (bb == null) {
return;
}
var index = pos / size;
var bufferIndex = index >> 3;
var slot = bb.get(bufferIndex);
var byteIndex = index & BYTE_MASK;
var mask = 1 << byteIndex;
var updated = (slot | mask);
bb.put(bufferIndex, (byte) (updated));
}
}
public byte get(int index) throws UnsupportedMessageException {
return dataBuffer.get(index);
}
public void put(int index, byte b) throws UnsupportedMessageException {
setValidityBitmap(index, 1);
dataBuffer.put(index, b);
}
public void putShort(short value) throws UnsupportedMessageException {
setValidityBitmap(0, 2);
dataBuffer.putShort(value);
}
public short getShort(int index) throws UnsupportedMessageException {
return dataBuffer.getShort(index);
}
public void putShort(int index, short value) throws UnsupportedMessageException {
setValidityBitmap(index, 2);
dataBuffer.putShort(index, value);
}
public void putInt(int value) throws UnsupportedMessageException {
setValidityBitmap(0, 4);
dataBuffer.putInt(value);
}
public int getInt(int index) throws UnsupportedMessageException {
return dataBuffer.getInt(index);
}
public void putInt(int index, int value) {
setValidityBitmap(index, 4);
dataBuffer.putInt(index, value);
}
public void putLong(long value) throws UnsupportedMessageException {
setValidityBitmap(0, 8);
dataBuffer.putLong(value);
}
public long getLong(int index) throws UnsupportedMessageException {
return dataBuffer.getLong(index);
}
public void putLong(int index, long value) {
setValidityBitmap(index, 8);
dataBuffer.putLong(index, value);
}
public void putFloat(float value) throws UnsupportedMessageException {
setValidityBitmap(0, 4);
dataBuffer.putFloat(value);
}
public float getFloat(int index) throws UnsupportedMessageException {
return dataBuffer.getFloat(index);
}
public void putFloat(int index, float value) throws UnsupportedMessageException {
setValidityBitmap(index, 4);
dataBuffer.putFloat(index, value);
}
public void putDouble(double value) throws UnsupportedMessageException {
setValidityBitmap(0, 8);
dataBuffer.putDouble(value);
}
public double getDouble(int index) throws UnsupportedMessageException {
return dataBuffer.getDouble(index);
}
public void putDouble(int index, double value) throws UnsupportedMessageException {
setValidityBitmap(index, 8);
dataBuffer.putDouble(index, value);
public long getLong(int index, Node node, InlinedExactClassProfile profile)
throws UnsupportedMessageException {
var buf = profile.profile(node, dataBuffer);
return buf.getLong(index);
}
public int capacity() throws UnsupportedMessageException {
return dataBuffer.capacity();
}
boolean hasNulls() {
return bitmapBuffer != null;
}
public boolean isNull(int index) {
if (bitmapBuffer == null) {
return false;
}
return checkForNull(index);
}
private boolean checkForNull(int index) {
var bufferIndex = index >> 3;
var slot = bitmapBuffer.get(bufferIndex);
var byteIndex = index & ~(1 << 3);
var byteIndex = index & BYTE_MASK;
var mask = 1 << byteIndex;
return (slot & mask) == 0;
}
public void setNull(int index) {
var bufferIndex = index >> 3;
var slot = bitmapBuffer.get(bufferIndex);
var byteIndex = index & ~(1 << 3);
var mask = ~(1 << byteIndex);
bitmapBuffer.put(bufferIndex, (byte) (slot & mask));
}
private void setValidityBitmap(int index0, int unitSize) {
var index = index0 / unitSize;
var bufferIndex = index >> 3;
var slot = bitmapBuffer.get(bufferIndex);
var byteIndex = index & ~(1 << 3);
var mask = 1 << byteIndex;
var updated = (slot | mask);
bitmapBuffer.put(bufferIndex, (byte) (updated));
}
private static final int BYTE_MASK = ~(~(1 << 3) + 1); // 7
@Override
public void close() throws Exception {

View File

@ -0,0 +1,69 @@
package org.enso.interpreter.arrow.runtime;
import com.oracle.truffle.api.dsl.Cached.Shared;
import com.oracle.truffle.api.dsl.GenerateInline;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.library.CachedLibrary;
@GenerateUncached
@GenerateInline(false)
abstract class OperationPlus extends ScalarOperationNode {
@Override
abstract Object executeOp(Object a, Object b) throws UnsupportedMessageException;
static OperationPlus create() {
return OperationPlusNodeGen.create();
}
static OperationPlus getUncached() {
return OperationPlusNodeGen.getUncached();
}
@Specialization(rewriteOn = ArithmeticException.class)
long doLongs(long a, long b) {
return Math.addExact(a, b);
}
@Specialization(replaces = "doLongs")
Object doLongsWithOverflowCheck(long a, long b) {
long res = a + b;
long check1 = a ^ res;
long check2 = b ^ res;
long checkBoth = check1 & check2;
if (checkBoth < 0) {
return NullValue.get();
}
return res;
}
@Specialization(
guards = {"iop.fitsInLong(a)", "iop.fitsInLong(b)"},
rewriteOn = ArithmeticException.class)
Object doFitInLong(
Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop)
throws UnsupportedMessageException {
var la = iop.asLong(a);
var lb = iop.asLong(b);
return doLongs(la, lb);
}
@Specialization(
guards = {"iop.fitsInLong(a)", "iop.fitsInLong(b)"},
replaces = "doFitInLong")
Object doFitInLongWithOverflowCheck(
Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop)
throws UnsupportedMessageException {
var la = iop.asLong(a);
var lb = iop.asLong(b);
return doLongsWithOverflowCheck(la, lb);
}
@Specialization(guards = {"iop.isNull(a) || iop.isNull(b)"})
NullValue nothing(
Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop) {
return NullValue.get();
}
}

View File

@ -19,16 +19,16 @@
package org.enso.interpreter.arrow.runtime;
class RoundingUtil {
final class RoundingUtil {
/** The mask for rounding an integer to a multiple of 8. (i.e. clear the lowest 3 bits) */
static int ROUND_8_MASK_INT = 0xFFFFFFF8;
static final int ROUND_8_MASK_INT = 0xFFFFFFF8;
/** The mask for rounding a long integer to a multiple of 8. (i.e. clear the lowest 3 bits) */
static long ROUND_8_MASK_LONG = 0xFFFFFFFFFFFFFFF8L;
static final long ROUND_8_MASK_LONG = 0xFFFFFFFFFFFFFFF8L;
/** The number of bits to shift for dividing by 8. */
static int DIVIDE_BY_8_SHIFT_BITS = 3;
static final int DIVIDE_BY_8_SHIFT_BITS = 3;
private RoundingUtil() {}
@ -91,13 +91,15 @@ class RoundingUtil {
return (int) (dataBufferSize + validityBitmapSize);
}
private long validityBitmapSize;
private long dataBufferSize;
private final long validityBitmapSize;
private final long dataBufferSize;
private PaddedSize(int valueCount, SizeInBytes unit) {
this.valueCount = valueCount;
this.unit = unit;
computeBufferSize(valueCount, unit);
var pair = computeBufferSize(valueCount, unit);
this.validityBitmapSize = pair[0];
this.dataBufferSize = pair[1];
}
private long defaultRoundedSize(long val) {
@ -127,7 +129,7 @@ class RoundingUtil {
return defaultRoundedSize(bufferSize);
}
private void computeBufferSize(int valueCount, SizeInBytes unit) {
private long[] computeBufferSize(int valueCount, SizeInBytes unit) {
var typeWidth = unit.sizeInBytes();
long bufferSize = computeCombinedBufferSize(valueCount, typeWidth);
assert bufferSize <= Long.MAX_VALUE;
@ -149,8 +151,7 @@ class RoundingUtil {
--actualCount;
} while (true);
}
this.validityBitmapSize = validityBufferSize;
this.dataBufferSize = dataBufferSize;
return new long[] {validityBufferSize, dataBufferSize};
}
}
}

View File

@ -0,0 +1,10 @@
package org.enso.interpreter.arrow.runtime;
import com.oracle.truffle.api.dsl.GenerateInline;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.nodes.Node;
@GenerateInline(false)
abstract class ScalarOperationNode extends Node {
abstract Object executeOp(Object a, Object b) throws UnsupportedMessageException;
}

View File

@ -1,7 +1,13 @@
package org.enso.interpreter.arrow.runtime;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.*;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.GenerateInline;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.dsl.NeverDefault;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
@ -17,58 +23,56 @@ import org.enso.interpreter.arrow.LogicalLayout;
@ImportStatic(LogicalLayout.class)
@GenerateUncached
@GenerateInline(value = false)
abstract class WriteToBuilderNode extends Node {
public abstract void executeWrite(ArrowFixedSizeArrayBuilder receiver, long index, Object value)
throws UnsupportedTypeException;
abstract class ValueToNumberNode extends Node {
/**
* Converts {@code value} to a suitable representation to be stored in an appropriate
* DirectBuffer.
*
* @param unit type of layout
* @param value a value to convert
* @return byte, short, int or long
* @throws UnsupportedTypeException if the conversion isn't possible
*/
abstract Number executeAdjust(LogicalLayout unit, Object value) throws UnsupportedTypeException;
@NeverDefault
static WriteToBuilderNode build() {
return WriteToBuilderNodeGen.create();
static ValueToNumberNode build() {
return ValueToNumberNodeGen.create();
}
@Specialization(guards = "receiver.getUnit() == Date32")
void doWriteDay(
ArrowFixedSizeArrayBuilder receiver,
long index,
@NeverDefault
static ValueToNumberNode getUncached() {
return ValueToNumberNodeGen.getUncached();
}
@Specialization(guards = "unit == Date32")
Integer doDay(
LogicalLayout unit,
Object value,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedTypeException {
validAccess(receiver, index);
if (iop.isNull(value)) {
receiver.getBuffer().setNull((int) index);
return;
}
if (!iop.isDate(value)) {
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date");
}
var at = ArrowFixedArrayDate.typeAdjustedIndex(index, 4);
long time;
try {
time = iop.asDate(value).toEpochDay();
} catch (UnsupportedMessageException e) {
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date");
}
receiver.getBuffer().putInt(at, Math.toIntExact(time));
return Math.toIntExact(time);
}
@Specialization(guards = {"receiver.getUnit() == Date64"})
void doWriteMilliseconds(
ArrowFixedSizeArrayBuilder receiver,
long index,
@Specialization(guards = {"unit == Date64"})
Long doMilliseconds(
LogicalLayout unit,
Object value,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedTypeException {
validAccess(receiver, index);
if (iop.isNull(value)) {
receiver.getBuffer().setNull((int) index);
return;
}
if (!iop.isDate(value) || !iop.isTime(value)) {
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date and a time");
}
var at = ArrowFixedArrayDate.typeAdjustedIndex(index, 8);
if (iop.isTimeZone(value)) {
Instant zoneDateTimeInstant;
try {
@ -84,7 +88,7 @@ abstract class WriteToBuilderNode extends Node {
var secondsPlusNano =
zoneDateTimeInstant.getEpochSecond() * ArrowFixedArrayDate.NANO_DIV
+ zoneDateTimeInstant.getNano();
receiver.getBuffer().putLong(at, secondsPlusNano);
return secondsPlusNano;
} else {
Instant dateTime;
try {
@ -94,7 +98,7 @@ abstract class WriteToBuilderNode extends Node {
}
var secondsPlusNano =
dateTime.getEpochSecond() * ArrowFixedArrayDate.NANO_DIV + dateTime.getNano();
receiver.getBuffer().putLong(at, secondsPlusNano);
return secondsPlusNano;
}
}
@ -109,109 +113,75 @@ abstract class WriteToBuilderNode extends Node {
return date.atTime(time).toInstant(offset);
}
@Specialization(guards = "receiver.getUnit() == Int8")
void doWriteByte(
ArrowFixedSizeArrayBuilder receiver,
long index,
@Specialization(guards = "unit == Int8")
Byte doByte(
LogicalLayout unit,
Object value,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedTypeException {
validAccess(receiver, index);
if (iop.isNull(value)) {
receiver.getBuffer().setNull((int) index);
return;
}
if (!iop.fitsInByte(value)) {
throw UnsupportedTypeException.create(new Object[] {value}, "value does not fit a byte");
}
try {
receiver.getBuffer().put(typeAdjustedIndex(index, receiver.getUnit()), (iop.asByte(value)));
return iop.asByte(value);
} catch (UnsupportedMessageException e) {
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a byte");
}
}
@Specialization(guards = "receiver.getUnit() == Int16")
void doWriteShort(
ArrowFixedSizeArrayBuilder receiver,
long index,
@Specialization(guards = "unit == Int16")
Short doShort(
LogicalLayout unit,
Object value,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedTypeException {
validAccess(receiver, index);
if (iop.isNull(value)) {
receiver.getBuffer().setNull((int) index);
return;
}
if (!iop.fitsInShort(value)) {
throw UnsupportedTypeException.create(
new Object[] {value}, "value does not fit a 2 byte short");
}
try {
receiver
.getBuffer()
.putShort(typeAdjustedIndex(index, receiver.getUnit()), (iop.asShort(value)));
return iop.asShort(value);
} catch (UnsupportedMessageException e) {
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a short");
}
}
@Specialization(guards = "receiver.getUnit() == Int32")
void doWriteInt(
ArrowFixedSizeArrayBuilder receiver,
long index,
@Specialization(guards = "unit == Int32")
Integer doInt(
LogicalLayout unit,
int value,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedTypeException {
validAccess(receiver, index);
if (iop.isNull(value)) {
receiver.getBuffer().setNull((int) index);
return;
}
if (!iop.fitsInInt(value)) {
throw UnsupportedTypeException.create(
new Object[] {value}, "value does not fit a 4 byte int");
}
try {
receiver.getBuffer().putInt(typeAdjustedIndex(index, receiver.getUnit()), (iop.asInt(value)));
return iop.asInt(value);
} catch (UnsupportedMessageException e) {
throw UnsupportedTypeException.create(new Object[] {value}, "value is not an int");
}
}
@Specialization(guards = "receiver.getUnit() == Int64")
public static void doWriteLong(
ArrowFixedSizeArrayBuilder receiver,
long index,
long value,
@Specialization(guards = "unit == Int64")
static Long doLong(
LogicalLayout unit,
Object value,
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
throws UnsupportedTypeException {
validAccess(receiver, index);
if (iop.isNull(value)) {
receiver.getBuffer().setNull((int) index);
return;
if (!iop.fitsInLong(value)) {
throw UnsupportedTypeException.create(
new Object[] {value}, "value does not fit a 8 byte int");
}
try {
return iop.asLong(value);
} catch (UnsupportedMessageException e) {
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a long");
}
receiver.getBuffer().putLong(typeAdjustedIndex(index, receiver.getUnit()), value);
}
@Fallback
void doWriteOther(ArrowFixedSizeArrayBuilder receiver, long index, Object value)
throws UnsupportedTypeException {
throw UnsupportedTypeException.create(new Object[] {index, value}, "unknown type of receiver");
}
private static void validAccess(ArrowFixedSizeArrayBuilder receiver, long index)
throws UnsupportedTypeException {
if (receiver.isSealed()) {
throw UnsupportedTypeException.create(
new Object[] {receiver}, "receiver is not an unsealed buffer");
}
if (index >= receiver.getSize() || index < 0) {
throw UnsupportedTypeException.create(new Object[] {index}, "index is out of range");
}
}
private static int typeAdjustedIndex(long index, SizeInBytes unit) {
return ArrowFixedArrayDate.typeAdjustedIndex(index, unit.sizeInBytes());
Number doOther(LogicalLayout unit, Object value) throws UnsupportedTypeException {
throw UnsupportedTypeException.create(new Object[] {unit, value}, "unknown type");
}
}

View File

@ -0,0 +1,142 @@
package org.enso.interpreter.arrow;
import static org.junit.Assert.*;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.io.IOAccess;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class AddArrowTest {
private static Context ctx;
@BeforeClass
public static void initEnsoContext() {
ctx =
Context.newBuilder()
.allowExperimentalOptions(true)
.allowIO(IOAccess.ALL)
.out(System.out)
.err(System.err)
.allowAllAccess(true)
.build();
}
@AfterClass
public static void closeEnsoContext() throws Exception {
if (ctx != null) {
ctx.close();
}
}
@Test
public void addTwoInt8ArrowArrays() {
var arrow = ctx.getEngine().getLanguages().get("arrow");
assertNotNull("Arrow is available", arrow);
var int8Constr = ctx.eval("arrow", "new[Int8]");
assertNotNull(int8Constr);
var arrLength = 10;
var builder1 = int8Constr.newInstance(arrLength);
var builder2 = int8Constr.newInstance(arrLength);
for (var i = 0; i < arrLength; i++) {
var ni = arrLength - i - 1;
var v = i * i;
builder1.invokeMember("append", i, (byte) v);
builder2.invokeMember("append", ni, (byte) v);
}
var arr1 = builder1.invokeMember("build");
assertEquals("Right size of arr1", arrLength, arr1.getArraySize());
var arr2 = builder2.invokeMember("build");
assertEquals("Right size of arr2", arrLength, arr2.getArraySize());
var int8Plus = ctx.eval("arrow", "+[Int8]");
var resultArr = int8Plus.execute(arr1, arr2);
assertTrue("Result is an array", resultArr.hasArrayElements());
assertEquals("Right size", arrLength, resultArr.getArraySize());
for (var i = 0; i < arrLength; i++) {
var ni = arrLength - i - 1;
var v1 = resultArr.getArrayElement(i).asLong();
var v2 = resultArr.getArrayElement(ni).asLong();
assertEquals("Values at " + i + " and " + ni + " are the same", v1, v2);
assertTrue("Values are always bigger than zero: " + v1, v1 > 0);
}
}
@Test
public void addTwoInt64ArrowArraysWithNulls() {
var arrow = ctx.getEngine().getLanguages().get("arrow");
assertNotNull("Arrow is available", arrow);
var constr = ctx.eval("arrow", "new[Int64]");
assertNotNull(constr);
var arrLength = 10;
var builder1 = constr.newInstance(arrLength);
for (int i = 0; i < arrLength; i++) {
if (i % 7 < 2) {
builder1.invokeMember("append", Long.MAX_VALUE);
} else {
builder1.invokeMember("append", i);
}
}
var builder2 = constr.newInstance(arrLength);
for (var i = 0; i < arrLength; i++) {
builder2.invokeMember("append", 10 + i);
}
var arr1 = builder1.invokeMember("build");
assertEquals("Right size of arr1", arrLength, arr1.getArraySize());
var addArr = builder2.invokeMember("build");
assertEquals("Right size of arr2", arrLength, addArr.getArraySize());
var plus = ctx.eval("arrow", "+[Int64]");
var res1 = plus.execute(arr1, addArr);
assertTrue("Result is an array", res1.hasArrayElements());
assertEquals("Right size", arrLength, res1.getArraySize());
assertTrue("is null", res1.getArrayElement(0).isNull());
assertTrue("is null", res1.getArrayElement(1).isNull());
assertTrue("is null", res1.getArrayElement(7).isNull());
assertTrue("is null", res1.getArrayElement(8).isNull());
var countNulls = 0;
for (var i = 0; i < arrLength; i++) {
var v = res1.getArrayElement(i);
if (v.isNull()) {
countNulls++;
} else {
assertEquals(i * 2 + 10, v.asLong());
}
}
assertEquals("Four nulls", 4, countNulls);
var res2 = plus.execute(res1, addArr);
assertTrue("Result is an array", res2.hasArrayElements());
assertEquals("Right size", arrLength, res2.getArraySize());
assertTrue("is null", res2.getArrayElement(0).isNull());
assertTrue("is null", res2.getArrayElement(1).isNull());
assertTrue("is null", res2.getArrayElement(7).isNull());
assertTrue("is null", res2.getArrayElement(8).isNull());
var countNullsAgain = 0;
for (var i = 0; i < arrLength; i++) {
var v = res2.getArrayElement(i);
if (v.isNull()) {
countNullsAgain++;
} else {
assertEquals(i * 3 + 20, v.asLong());
}
}
assertEquals("Four nulls", 4, countNullsAgain);
}
}

View File

@ -19,6 +19,7 @@ import org.apache.arrow.vector.BaseFixedWidthVector;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.IntVector;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.PolyglotException;
import org.graalvm.polyglot.Value;
import org.graalvm.polyglot.io.IOAccess;
import org.junit.AfterClass;
@ -84,7 +85,7 @@ public class VerifyArrowTest {
date32ArrayBuilder.invokeMember("build");
assertFalse(date32ArrayBuilder.canInvokeMember("append"));
assertThrows(
UnsupportedOperationException.class,
PolyglotException.class,
() -> finalDate32ArrayBuilder.invokeMember("append", startDateTime));
assertFalse(date32Array.canInvokeMember("append"));
}
@ -153,6 +154,33 @@ public class VerifyArrowTest {
assertEquals((byte) 5, v.asByte());
}
@Test
public void arrowInt64() {
var arrow = ctx.getEngine().getLanguages().get("arrow");
assertNotNull("Arrow is available", arrow);
var constr = ctx.eval("arrow", "new[Int64]");
assertNotNull(constr);
var arrLength = 48;
Value builder = constr.newInstance(arrLength);
for (var i = 0; i < arrLength; i++) {
builder.invokeMember("append", i);
}
var arr = builder.invokeMember("build");
assertEquals(arrLength, arr.getArraySize());
for (var i = 0; i < arrLength; i++) {
var ith = arr.getArrayElement(i);
assertEquals("Checking value at " + i, i, ith.asLong());
}
var plus = ctx.eval("arrow", "+[Int64]");
var doubled = plus.execute(arr, arr);
for (var i = 0; i < arrLength; i++) {
var ith = doubled.getArrayElement(i);
assertEquals("Checking double value at " + i, 2 * i, ith.asInt());
}
}
@Test
public void castInt() {
var typeLength = LogicalLayout.Int32;

View File

@ -9,7 +9,6 @@ import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.runtime.callable.argument.CallArgument;
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
import org.enso.interpreter.runtime.callable.function.Function;
import org.enso.interpreter.runtime.state.State;
/**
* This node is responsible for organising callable calls so that they are ready to be made.
@ -92,10 +91,10 @@ public class ApplicationNode extends ExpressionNode {
*/
@Override
public Object executeGeneric(VirtualFrame frame) {
State state = Function.ArgumentsHelper.getState(frame.getArguments());
Object[] evaluatedArguments = evaluateArguments(frame);
return this.invokeCallableNode.execute(
this.callable.executeGeneric(frame), frame, state, evaluatedArguments);
var state = Function.ArgumentsHelper.getState(frame.getArguments());
var evaluatedArguments = evaluateArguments(frame);
var self = this.callable.executeGeneric(frame);
return this.invokeCallableNode.execute(self, frame, state, evaluatedArguments);
}
/**

View File

@ -1,6 +1,7 @@
package org.enso.interpreter.node.callable;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Bind;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Cached.Shared;
import com.oracle.truffle.api.dsl.Fallback;
@ -12,6 +13,7 @@ import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.InlinedBranchProfile;
import com.oracle.truffle.api.source.SourceSection;
import java.util.UUID;
import java.util.concurrent.locks.Lock;
@ -337,32 +339,37 @@ public abstract class InvokeCallableNode extends BaseNode {
"!types.hasSpecialDispatch(self)",
"iop.isExecutable(self)",
})
Object doPolyglot(
static Object doPolyglot(
Object self,
VirtualFrame frame,
State state,
Object[] arguments,
@Bind("$node") Node node,
@CachedLibrary(limit = "3") InteropLibrary iop,
@Shared("warnings") @CachedLibrary(limit = "3") WarningsLibrary warnings,
@CachedLibrary(limit = "3") TypesLibrary types,
@Cached ThunkExecutorNode thunkNode) {
var errors = EnsoContext.get(this).getBuiltins().error();
@Cached ThunkExecutorNode thunkNode,
@Cached InlinedBranchProfile errorNeedsToBeReported) {
var errors = EnsoContext.get(node).getBuiltins().error();
try {
for (int i = 0; i < arguments.length; i++) {
arguments[i] = thunkNode.executeThunk(frame, arguments[i], state, TailStatus.NOT_TAIL);
}
return iop.execute(self, arguments);
} catch (UnsupportedTypeException ex) {
errorNeedsToBeReported.enter(node);
var err = errors.makeUnsupportedArgumentsError(ex.getSuppliedValues(), ex.getMessage());
throw new PanicException(err, this);
throw new PanicException(err, node);
} catch (ArityException ex) {
errorNeedsToBeReported.enter(node);
var err =
errors.makeArityError(
ex.getExpectedMinArity(), ex.getExpectedMaxArity(), arguments.length);
throw new PanicException(err, this);
throw new PanicException(err, node);
} catch (UnsupportedMessageException ex) {
errorNeedsToBeReported.enter(node);
var err = errors.makeNotInvokable(self);
throw new PanicException(err, this);
throw new PanicException(err, node);
}
}

View File

@ -8,45 +8,118 @@ polyglot java import java.lang.Long as Java_Long
options = Bench.options . set_warmup (Bench.phase_conf 3 5) . set_measure (Bench.phase_conf 3 5)
create_table : Table
create_table num_rows =
create_vectors num_rows =
x = Vector.new num_rows i->
i+1
y = Vector.new num_rows i->
if i % 10 < 2 then Java_Long.MAX_VALUE else i+1
u = Vector.new num_rows i->
10 + (i % 100)
z = Vector.new num_rows i->
if i % 10 < 2 then Nothing else i+1
t = Table.new [["X", x], ["Y", y], ["U", u]]
[x, y, u, z]
assert condition =
if condition.not then Panic.throw "Assertion failed"
create_table : Table
create_table num_rows =
v = create_vectors num_rows
x = v.at 0
y = v.at 1
u = v.at 2
z = v.at 3
assert ((t.at "X" . value_type) == Value_Type.Integer)
assert ((t.at "Y" . value_type) == Value_Type.Integer)
assert ((t.at "U" . value_type) == Value_Type.Integer)
t = Table.new [["X", x], ["Y", y], ["U", u], ["Z", z]]
Runtime.assert ((t.at "X" . value_type) == Value_Type.Integer)
Runtime.assert ((t.at "Y" . value_type) == Value_Type.Integer)
Runtime.assert ((t.at "U" . value_type) == Value_Type.Integer)
Runtime.assert ((t.at "Z" . value_type) == Value_Type.Integer)
t
create_arrow_columns num_rows =
column_to_arrow v:Vector -> Array =
builder = int64_new.new v.length
v.map e-> builder.append e
builder.build
v = create_vectors num_rows
x = column_to_arrow (v.at 0)
y = column_to_arrow (v.at 1)
u = column_to_arrow (v.at 2)
z = column_to_arrow (v.at 3)
[int64_plus, x, y, u, z]
foreign arrow int64_new = """
new[Int64]
foreign arrow int64_plus = """
+[Int64]
type Data
Value ~table
private Value ~table ~arrow
create num_rows = Data.Value (create_table num_rows)
arrow_plus self = self.arrow.at 0
arrow_x self = self.arrow.at 1
arrow_y self = self.arrow.at 2
arrow_u self = self.arrow.at 3
arrow_z self = self.arrow.at 4
create num_rows = Data.Value (create_table num_rows) (create_arrow_columns num_rows)
collect_benches = Bench.build builder->
column_arithmetic_plus_fitting d =
(d.table.at "X") + (d.table.at "U")
column_arithmetic_plus_overflowing d =
(d.table.at "Y") + (d.table.at "U")
column_arithmetic_plus_nothing d =
(d.table.at "Z") + (d.table.at "U")
column_arithmetic_multiply_fitting d =
(d.table.at "X") * (d.table.at "U")
column_arithmetic_multiply_overflowing d =
(d.table.at "Y") * (d.table.at "U")
arrow_arithmetic_plus_fitting d =
d.arrow_plus d.arrow_x d.arrow_u
arrow_arithmetic_plus_overflowing d =
d.arrow_plus d.arrow_y d.arrow_u
arrow_arithmetic_plus_nothing d =
d.arrow_plus d.arrow_y d.arrow_u
num_rows = 1000000
data = Data.create num_rows
Runtime.assert ((column_arithmetic_plus_fitting data . to_vector) == (arrow_arithmetic_plus_fitting data)) "Column and arrow correctness check one"
Runtime.assert ((column_arithmetic_plus_overflowing data . to_vector) == (arrow_arithmetic_plus_overflowing data)) "Column and arrow correctness check two"
Runtime.assert ((column_arithmetic_plus_nothing data . to_vector) == (arrow_arithmetic_plus_nothing data)) "Column and arrow correctness check three"
builder.group ("Column_Arithmetic_" + num_rows.to_text) options group_builder->
group_builder.specify "Plus_Fitting" <|
(data.table.at "X") + (data.table.at "U")
column_arithmetic_plus_fitting data
group_builder.specify "Plus_Overflowing" <|
(data.table.at "Y") + (data.table.at "U")
column_arithmetic_plus_overflowing data
group_builder.specify "Plus_Nothing" <|
column_arithmetic_plus_nothing data
group_builder.specify "Multiply_Fitting" <|
(data.table.at "X") * (data.table.at "U")
column_arithmetic_multiply_fitting data
group_builder.specify "Multiply_Overflowing" <|
(data.table.at "Y") * (data.table.at "U")
column_arithmetic_multiply_overflowing data
builder.group ("Arrow_Arithmetic_" + num_rows.to_text) options group_builder->
group_builder.specify "Plus_Fitting" <|
arrow_arithmetic_plus_fitting data
group_builder.specify "Plus_Overflowing" <|
arrow_arithmetic_plus_overflowing data
group_builder.specify "Plus_Nothing" <|
arrow_arithmetic_plus_nothing data
main = collect_benches . run_main