mirror of
https://github.com/enso-org/enso.git
synced 2024-10-26 13:14:43 +03:00
Implement and benchmark ArrowOperationPlus
node (#10150)
Prototype of #10056 showing `+` operation implemented in the _Arrow language_.
This commit is contained in:
parent
19c50ceff9
commit
aaaebcabf8
@ -11,36 +11,42 @@ public final class ArrowParser {
|
||||
public record Result(PhysicalLayout physicalLayout, LogicalLayout logicalLayout, Mode mode) {}
|
||||
|
||||
public static Result parse(Source source) {
|
||||
String src = source.getCharacters().toString();
|
||||
Matcher m = NEW_ARRAY_CONSTR.matcher(src);
|
||||
String src = source.getCharacters().toString().replace('\n', ' ').trim();
|
||||
Matcher m = PATTERN.matcher(src);
|
||||
if (m.find()) {
|
||||
try {
|
||||
var layout = LogicalLayout.valueOf(m.group(1));
|
||||
return new Result(PhysicalLayout.Primitive, layout, Mode.Allocate);
|
||||
var layout = LogicalLayout.valueOf(m.group(2));
|
||||
var mode = Mode.parse(m.group(1));
|
||||
if (layout != null && mode != null) {
|
||||
return new Result(PhysicalLayout.Primitive, layout, mode);
|
||||
}
|
||||
} catch (IllegalArgumentException iae) {
|
||||
// propagate warning
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
m = CAST_PATTERN.matcher(src);
|
||||
if (m.find()) {
|
||||
try {
|
||||
var layout = LogicalLayout.valueOf(m.group(1));
|
||||
return new Result(PhysicalLayout.Primitive, layout, Mode.Cast);
|
||||
} catch (IllegalArgumentException iae) {
|
||||
// propagate warning
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static final Pattern NEW_ARRAY_CONSTR = Pattern.compile("^new\\[(.+)\\]$");
|
||||
private static final Pattern CAST_PATTERN = Pattern.compile("^cast\\[(.+)\\]$");
|
||||
private static final Pattern PATTERN = Pattern.compile("^([a-z\\+]+)\\[(.+)\\]$");
|
||||
|
||||
public enum Mode {
|
||||
Allocate,
|
||||
Cast
|
||||
Allocate("new"),
|
||||
Cast("cast"),
|
||||
Plus("+");
|
||||
|
||||
private final String op;
|
||||
|
||||
private Mode(String text) {
|
||||
this.op = text;
|
||||
}
|
||||
|
||||
static Mode parse(String operation) {
|
||||
for (var m : values()) {
|
||||
if (m.op.equals(operation)) {
|
||||
return m;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,16 +0,0 @@
|
||||
package org.enso.interpreter.arrow.node;
|
||||
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
import org.enso.interpreter.arrow.LogicalLayout;
|
||||
import org.enso.interpreter.arrow.runtime.ArrowCastToFixedSizeArrayFactory;
|
||||
|
||||
public class ArrowCastFixedSizeNode extends Node {
|
||||
|
||||
static ArrowCastFixedSizeNode create() {
|
||||
return new ArrowCastFixedSizeNode();
|
||||
}
|
||||
|
||||
public Object execute(LogicalLayout layoutType) {
|
||||
return new ArrowCastToFixedSizeArrayFactory(layoutType);
|
||||
}
|
||||
}
|
@ -6,13 +6,13 @@ import com.oracle.truffle.api.frame.VirtualFrame;
|
||||
import com.oracle.truffle.api.nodes.RootNode;
|
||||
import org.enso.interpreter.arrow.ArrowLanguage;
|
||||
import org.enso.interpreter.arrow.ArrowParser;
|
||||
import org.enso.interpreter.arrow.runtime.ArrowCastToFixedSizeArrayFactory;
|
||||
import org.enso.interpreter.arrow.runtime.ArrowFixedSizeArrayFactory;
|
||||
import org.enso.interpreter.arrow.runtime.ArrowOperationPlus;
|
||||
|
||||
public class ArrowEvalNode extends RootNode {
|
||||
private final ArrowParser.Result code;
|
||||
|
||||
@Child private ArrowFixedSizeNode fixedPhysicalLayout = ArrowFixedSizeNode.create();
|
||||
@Child private ArrowCastFixedSizeNode castToFixedPhysicalLayout = ArrowCastFixedSizeNode.create();
|
||||
|
||||
public static ArrowEvalNode create(ArrowLanguage language, ArrowParser.Result code) {
|
||||
return new ArrowEvalNode(language, code);
|
||||
}
|
||||
@ -25,8 +25,10 @@ public class ArrowEvalNode extends RootNode {
|
||||
public Object execute(VirtualFrame frame) {
|
||||
return switch (code.physicalLayout()) {
|
||||
case Primitive -> switch (code.mode()) {
|
||||
case Allocate -> fixedPhysicalLayout.execute(code.logicalLayout());
|
||||
case Cast -> castToFixedPhysicalLayout.execute(code.logicalLayout());
|
||||
case Allocate -> new ArrowFixedSizeArrayFactory(code.logicalLayout());
|
||||
case Cast -> new ArrowCastToFixedSizeArrayFactory(code.logicalLayout());
|
||||
case Plus -> new ArrowOperationPlus(code.logicalLayout());
|
||||
default -> throw CompilerDirectives.shouldNotReachHere("unsupported mode");
|
||||
};
|
||||
default -> throw CompilerDirectives.shouldNotReachHere("unsupported physical layout");
|
||||
};
|
||||
|
@ -1,16 +0,0 @@
|
||||
package org.enso.interpreter.arrow.node;
|
||||
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
import org.enso.interpreter.arrow.LogicalLayout;
|
||||
import org.enso.interpreter.arrow.runtime.ArrowFixedSizeArrayFactory;
|
||||
|
||||
public class ArrowFixedSizeNode extends Node {
|
||||
|
||||
static ArrowFixedSizeNode create() {
|
||||
return new ArrowFixedSizeNode();
|
||||
}
|
||||
|
||||
public Object execute(LogicalLayout layoutType) {
|
||||
return new ArrowFixedSizeArrayFactory(layoutType);
|
||||
}
|
||||
}
|
@ -1,13 +1,23 @@
|
||||
package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
import com.oracle.truffle.api.CompilerDirectives;
|
||||
import com.oracle.truffle.api.dsl.Bind;
|
||||
import com.oracle.truffle.api.dsl.Cached;
|
||||
import com.oracle.truffle.api.dsl.ImportStatic;
|
||||
import com.oracle.truffle.api.dsl.NeverDefault;
|
||||
import com.oracle.truffle.api.dsl.Specialization;
|
||||
import com.oracle.truffle.api.interop.InteropLibrary;
|
||||
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
|
||||
import com.oracle.truffle.api.interop.StopIterationException;
|
||||
import com.oracle.truffle.api.interop.TruffleObject;
|
||||
import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.library.CachedLibrary;
|
||||
import com.oracle.truffle.api.library.ExportLibrary;
|
||||
import com.oracle.truffle.api.library.ExportMessage;
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
import com.oracle.truffle.api.profiles.InlinedExactClassProfile;
|
||||
import java.nio.BufferOverflowException;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.enso.interpreter.arrow.LogicalLayout;
|
||||
|
||||
@ExportLibrary(InteropLibrary.class)
|
||||
@ -27,7 +37,24 @@ public final class ArrowFixedArrayInt implements TruffleObject {
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
public boolean hasArrayElements() {
|
||||
boolean hasArrayElements() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
Object getIterator(
|
||||
@Cached(value = "this.getUnit()", allowUncached = true) LogicalLayout cachedUnit)
|
||||
throws UnsupportedMessageException {
|
||||
if (cachedUnit == LogicalLayout.Int64) {
|
||||
var dataIt = new LongIterator(buffer.getDataBuffer(), cachedUnit.sizeInBytes());
|
||||
var nullIt = new NullIterator(dataIt, buffer.getBitmapBuffer());
|
||||
return nullIt;
|
||||
}
|
||||
return new GenericIterator(this);
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
boolean hasIterator() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -65,13 +92,18 @@ public final class ArrowFixedArrayInt implements TruffleObject {
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getUnit() == Int64")
|
||||
public static Object doLong(ArrowFixedArrayInt receiver, long index)
|
||||
public static Object doLong(
|
||||
ArrowFixedArrayInt receiver,
|
||||
long index,
|
||||
@Bind("$node") Node node,
|
||||
@CachedLibrary("receiver") InteropLibrary iop,
|
||||
@Cached InlinedExactClassProfile bufferClazz)
|
||||
throws UnsupportedMessageException, InvalidArrayIndexException {
|
||||
var at = adjustedIndex(receiver.buffer, receiver.unit, receiver.size, index);
|
||||
var at = adjustedIndex(receiver.buffer, LogicalLayout.Int64, receiver.size, index);
|
||||
if (receiver.buffer.isNull((int) index)) {
|
||||
return NullValue.get();
|
||||
}
|
||||
return receiver.buffer.getLong(at);
|
||||
return receiver.buffer.getLong(at, iop, bufferClazz);
|
||||
}
|
||||
}
|
||||
|
||||
@ -96,4 +128,133 @@ public final class ArrowFixedArrayInt implements TruffleObject {
|
||||
private static int typeAdjustedIndex(long index, SizeInBytes unit) {
|
||||
return Math.toIntExact(index * unit.sizeInBytes());
|
||||
}
|
||||
|
||||
@ExportLibrary(InteropLibrary.class)
|
||||
static final class LongIterator implements TruffleObject {
|
||||
private int at;
|
||||
private final ByteBuffer buffer;
|
||||
@NeverDefault final int step;
|
||||
|
||||
LongIterator(ByteBuffer buffer, int step) {
|
||||
assert step != 0;
|
||||
this.buffer = buffer;
|
||||
this.step = step;
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
Object getIteratorNextElement(
|
||||
@Bind("$node") Node node,
|
||||
@Cached("this.step") int step,
|
||||
@Cached InlinedExactClassProfile bufferTypeProfile)
|
||||
throws StopIterationException {
|
||||
var buf = bufferTypeProfile.profile(node, buffer);
|
||||
try {
|
||||
var res = buf.getLong(at);
|
||||
at += step;
|
||||
return res;
|
||||
} catch (BufferOverflowException ex) {
|
||||
CompilerDirectives.transferToInterpreter();
|
||||
throw StopIterationException.create();
|
||||
}
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
boolean isIterator() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
boolean hasIteratorNextElement() throws UnsupportedMessageException {
|
||||
return at < buffer.limit();
|
||||
}
|
||||
}
|
||||
|
||||
@ExportLibrary(value = InteropLibrary.class)
|
||||
static final class NullIterator implements TruffleObject {
|
||||
private final TruffleObject it;
|
||||
private final ByteBuffer buffer;
|
||||
private byte byteMask;
|
||||
private byte byteValue;
|
||||
|
||||
NullIterator(TruffleObject delegate, ByteBuffer buffer) {
|
||||
this.it = delegate;
|
||||
this.buffer = buffer;
|
||||
}
|
||||
|
||||
final TruffleObject it() {
|
||||
return it;
|
||||
}
|
||||
|
||||
@ExportMessage(limit = "3")
|
||||
Object getIteratorNextElement(
|
||||
@Bind("$node") Node node,
|
||||
@CachedLibrary("this.it()") InteropLibrary iopIt,
|
||||
@Cached InlinedExactClassProfile bufferTypeProfile)
|
||||
throws StopIterationException, UnsupportedMessageException {
|
||||
var element = iopIt.getIteratorNextElement(it);
|
||||
if (buffer != null) {
|
||||
var buf = bufferTypeProfile.profile(node, buffer);
|
||||
if (byteMask == 0) {
|
||||
// (byte) (0x01 << 8) ==> 0
|
||||
byteValue = buf.get();
|
||||
byteMask = 0x01;
|
||||
}
|
||||
var include = byteValue & byteMask;
|
||||
byteMask = (byte) (byteMask << 1);
|
||||
if (include == 0) {
|
||||
return NullValue.get();
|
||||
}
|
||||
}
|
||||
return element;
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
boolean isIterator() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ExportMessage(limit = "3")
|
||||
boolean hasIteratorNextElement(@CachedLibrary("this.it()") InteropLibrary iopIt)
|
||||
throws UnsupportedMessageException {
|
||||
return iopIt.hasIteratorNextElement(it);
|
||||
}
|
||||
}
|
||||
|
||||
@ExportLibrary(InteropLibrary.class)
|
||||
static final class GenericIterator implements TruffleObject {
|
||||
private int at;
|
||||
private final TruffleObject array;
|
||||
|
||||
GenericIterator(TruffleObject array) {
|
||||
assert InteropLibrary.getUncached().hasArrayElements(array);
|
||||
this.array = array;
|
||||
}
|
||||
|
||||
TruffleObject array() {
|
||||
return array;
|
||||
}
|
||||
|
||||
@ExportMessage(limit = "3")
|
||||
Object getIteratorNextElement(@CachedLibrary("this.array()") InteropLibrary iop)
|
||||
throws StopIterationException {
|
||||
try {
|
||||
var res = iop.readArrayElement(array, at);
|
||||
at++;
|
||||
return res;
|
||||
} catch (UnsupportedMessageException | InvalidArrayIndexException ex) {
|
||||
throw StopIterationException.create();
|
||||
}
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
boolean isIterator() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ExportMessage(limit = "3")
|
||||
boolean hasIteratorNextElement(@CachedLibrary("this.array()") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
return at < iop.getArraySize(array);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,22 +1,27 @@
|
||||
package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
import com.oracle.truffle.api.CompilerDirectives;
|
||||
import com.oracle.truffle.api.dsl.Cached;
|
||||
import com.oracle.truffle.api.dsl.Cached.Shared;
|
||||
import com.oracle.truffle.api.dsl.GenerateInline;
|
||||
import com.oracle.truffle.api.dsl.GenerateUncached;
|
||||
import com.oracle.truffle.api.dsl.Specialization;
|
||||
import com.oracle.truffle.api.interop.InteropLibrary;
|
||||
import com.oracle.truffle.api.interop.TruffleObject;
|
||||
import com.oracle.truffle.api.interop.UnknownIdentifierException;
|
||||
import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.interop.UnsupportedTypeException;
|
||||
import com.oracle.truffle.api.library.CachedLibrary;
|
||||
import com.oracle.truffle.api.library.ExportLibrary;
|
||||
import com.oracle.truffle.api.library.ExportMessage;
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
import org.enso.interpreter.arrow.LogicalLayout;
|
||||
|
||||
@ExportLibrary(InteropLibrary.class)
|
||||
public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
|
||||
private final ByteBufferDirect buffer;
|
||||
private final LogicalLayout unit;
|
||||
private final int size;
|
||||
private int index;
|
||||
private boolean sealed;
|
||||
private ByteBufferDirect buffer;
|
||||
|
||||
private static final String APPEND_OP = "append";
|
||||
private static final String BUILD_OP = "build";
|
||||
@ -25,8 +30,6 @@ public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
|
||||
this.size = size;
|
||||
this.unit = unit;
|
||||
this.buffer = ByteBufferDirect.forSize(size, unit);
|
||||
this.index = 0;
|
||||
this.sealed = false;
|
||||
}
|
||||
|
||||
public LogicalLayout getUnit() {
|
||||
@ -34,7 +37,7 @@ public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
|
||||
}
|
||||
|
||||
public boolean isSealed() {
|
||||
return sealed;
|
||||
return buffer == null;
|
||||
}
|
||||
|
||||
public ByteBufferDirect getBuffer() {
|
||||
@ -53,7 +56,7 @@ public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
|
||||
@ExportMessage
|
||||
public boolean isMemberInvocable(String member) {
|
||||
return switch (member) {
|
||||
case APPEND_OP -> !this.sealed;
|
||||
case APPEND_OP -> buffer != null;
|
||||
case BUILD_OP -> true;
|
||||
default -> false;
|
||||
};
|
||||
@ -65,33 +68,83 @@ public final class ArrowFixedSizeArrayBuilder implements TruffleObject {
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
Object invokeMember(
|
||||
String name,
|
||||
Object[] args,
|
||||
@Cached(value = "buildWriterOrNull(name)", neverDefault = true)
|
||||
WriteToBuilderNode writeToBuilderNode)
|
||||
Object invokeMember(String name, Object[] args, @Cached AppendNode append)
|
||||
throws UnsupportedMessageException, UnknownIdentifierException, UnsupportedTypeException {
|
||||
switch (name) {
|
||||
case BUILD_OP:
|
||||
sealed = true;
|
||||
return switch (unit) {
|
||||
case Date32, Date64 -> new ArrowFixedArrayDate(buffer, size, unit);
|
||||
case Int8, Int16, Int32, Int64 -> new ArrowFixedArrayInt(buffer, size, unit);
|
||||
};
|
||||
case APPEND_OP:
|
||||
if (sealed) {
|
||||
throw UnsupportedMessageException.create();
|
||||
}
|
||||
var current = index;
|
||||
writeToBuilderNode.executeWrite(this, current, args[0]);
|
||||
index += 1;
|
||||
return NullValue.get();
|
||||
default:
|
||||
throw UnknownIdentifierException.create(name);
|
||||
return switch (name) {
|
||||
case BUILD_OP -> build();
|
||||
case APPEND_OP -> {
|
||||
append.executeAppend(this, args[0]);
|
||||
yield NullValue.get();
|
||||
}
|
||||
default -> throw UnknownIdentifierException.create(name);
|
||||
};
|
||||
}
|
||||
|
||||
private final TruffleObject build() throws UnsupportedMessageException {
|
||||
var b = buffer;
|
||||
if (b == null) {
|
||||
throw UnsupportedMessageException.create();
|
||||
}
|
||||
buffer = null;
|
||||
return switch (unit) {
|
||||
case Date32, Date64 -> new ArrowFixedArrayDate(b, size, unit);
|
||||
case Int8, Int16, Int32, Int64 -> new ArrowFixedArrayInt(b, size, unit);
|
||||
};
|
||||
}
|
||||
|
||||
@GenerateUncached
|
||||
@GenerateInline(false)
|
||||
abstract static class AppendNode extends Node {
|
||||
abstract void executeAppend(ArrowFixedSizeArrayBuilder builder, Object value)
|
||||
throws UnsupportedTypeException, UnsupportedMessageException;
|
||||
|
||||
@Specialization(
|
||||
limit = "3",
|
||||
guards = {"builder.getUnit() == cachedUnit"})
|
||||
static void writeToBuffer(
|
||||
ArrowFixedSizeArrayBuilder builder,
|
||||
Object value,
|
||||
@Cached(value = "builder.getUnit()", allowUncached = true) LogicalLayout cachedUnit,
|
||||
@Shared("put") @Cached ByteBufferDirect.PutNode put,
|
||||
@Shared("value") @Cached ValueToNumberNode valueNode,
|
||||
@Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop)
|
||||
throws UnsupportedTypeException, UnsupportedMessageException {
|
||||
if (iop.isNull(value)) {
|
||||
put.putNull(builder.buffer, cachedUnit);
|
||||
return;
|
||||
}
|
||||
var number = valueNode.executeAdjust(cachedUnit, value);
|
||||
switch (number) {
|
||||
case Byte b -> put.put(builder.buffer, b);
|
||||
case Short s -> put.putShort(builder.buffer, s);
|
||||
case Integer i -> put.putInt(builder.buffer, i);
|
||||
case Long l -> put.putLong(builder.buffer, l);
|
||||
default -> throw CompilerDirectives.shouldNotReachHere();
|
||||
}
|
||||
}
|
||||
|
||||
@Specialization(replaces = "writeToBuffer")
|
||||
static void writeToBufferUncached(
|
||||
ArrowFixedSizeArrayBuilder builder,
|
||||
Object value,
|
||||
@Shared("put") @Cached ByteBufferDirect.PutNode put,
|
||||
@Shared("value") @Cached ValueToNumberNode valueNode,
|
||||
@Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop)
|
||||
throws UnsupportedTypeException, UnsupportedMessageException {
|
||||
writeToBuffer(builder, value, builder.getUnit(), put, valueNode, iop);
|
||||
}
|
||||
}
|
||||
|
||||
static WriteToBuilderNode buildWriterOrNull(String op) {
|
||||
return APPEND_OP.equals(op) ? WriteToBuilderNode.build() : WriteToBuilderNodeGen.getUncached();
|
||||
@GenerateUncached
|
||||
@GenerateInline(false)
|
||||
abstract static class BuildNode extends Node {
|
||||
abstract TruffleObject executeBuild(ArrowFixedSizeArrayBuilder builder)
|
||||
throws UnsupportedMessageException;
|
||||
|
||||
@Specialization
|
||||
static TruffleObject buildIt(ArrowFixedSizeArrayBuilder builder)
|
||||
throws UnsupportedMessageException {
|
||||
return builder.build();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,8 +2,8 @@ package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
import com.oracle.truffle.api.CompilerDirectives;
|
||||
import com.oracle.truffle.api.dsl.Cached;
|
||||
import com.oracle.truffle.api.dsl.Fallback;
|
||||
import com.oracle.truffle.api.dsl.ImportStatic;
|
||||
import com.oracle.truffle.api.dsl.GenerateInline;
|
||||
import com.oracle.truffle.api.dsl.GenerateUncached;
|
||||
import com.oracle.truffle.api.dsl.Specialization;
|
||||
import com.oracle.truffle.api.interop.InteropLibrary;
|
||||
import com.oracle.truffle.api.interop.TruffleObject;
|
||||
@ -11,10 +11,11 @@ import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.library.CachedLibrary;
|
||||
import com.oracle.truffle.api.library.ExportLibrary;
|
||||
import com.oracle.truffle.api.library.ExportMessage;
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
import org.enso.interpreter.arrow.LogicalLayout;
|
||||
|
||||
@ExportLibrary(InteropLibrary.class)
|
||||
public class ArrowFixedSizeArrayFactory implements TruffleObject {
|
||||
public final class ArrowFixedSizeArrayFactory implements TruffleObject {
|
||||
|
||||
private final LogicalLayout logicalLayout;
|
||||
|
||||
@ -23,7 +24,7 @@ public class ArrowFixedSizeArrayFactory implements TruffleObject {
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
public boolean isInstantiable() {
|
||||
boolean isInstantiable() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -32,79 +33,36 @@ public class ArrowFixedSizeArrayFactory implements TruffleObject {
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
@ImportStatic(LogicalLayout.class)
|
||||
static class Instantiate {
|
||||
@Specialization(guards = "receiver.getLayout() == Date32")
|
||||
static Object doDate32(
|
||||
ArrowFixedSizeArrayFactory receiver,
|
||||
Object[] args,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
|
||||
}
|
||||
ArrowFixedSizeArrayBuilder instantiate(
|
||||
Object[] args,
|
||||
@Cached InstantiateNode instantiate,
|
||||
@CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
var size = arraySize(args, iop);
|
||||
return instantiate.allocateBuilder(logicalLayout, size);
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getLayout() == Date64")
|
||||
static Object doDate64(
|
||||
ArrowFixedSizeArrayFactory receiver,
|
||||
Object[] args,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
|
||||
private static int arraySize(Object[] args, InteropLibrary interop)
|
||||
throws UnsupportedMessageException {
|
||||
if (args.length != 1 || !interop.isNumber(args[0]) || !interop.fitsInInt(args[0])) {
|
||||
throw UnsupportedMessageException.create();
|
||||
}
|
||||
return interop.asInt(args[0]);
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getLayout() == Int8")
|
||||
static Object doInt8(
|
||||
ArrowFixedSizeArrayFactory receiver,
|
||||
Object[] args,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
|
||||
}
|
||||
@GenerateUncached
|
||||
@GenerateInline(false)
|
||||
abstract static class InstantiateNode extends Node {
|
||||
abstract ArrowFixedSizeArrayBuilder executeNew(LogicalLayout logicalLayout, long size);
|
||||
|
||||
@Specialization(guards = "receiver.getLayout() == Int16")
|
||||
static Object doInt16(
|
||||
ArrowFixedSizeArrayFactory receiver,
|
||||
Object[] args,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getLayout() == Int32")
|
||||
static Object doInt32(
|
||||
ArrowFixedSizeArrayFactory receiver,
|
||||
Object[] args,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getLayout() == Int64")
|
||||
static Object doInt64(
|
||||
ArrowFixedSizeArrayFactory receiver,
|
||||
Object[] args,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
return new ArrowFixedSizeArrayBuilder(arraySize(args, iop), receiver.logicalLayout);
|
||||
}
|
||||
|
||||
@CompilerDirectives.TruffleBoundary
|
||||
private static int arraySize(Object[] args, InteropLibrary interop)
|
||||
throws UnsupportedMessageException {
|
||||
if (args.length != 1 || !interop.isNumber(args[0]) || !interop.fitsInInt(args[0])) {
|
||||
throw UnsupportedMessageException.create();
|
||||
@Specialization
|
||||
final ArrowFixedSizeArrayBuilder allocateBuilder(LogicalLayout logicalLayout, long size) {
|
||||
try {
|
||||
return new ArrowFixedSizeArrayBuilder(Math.toIntExact(size), logicalLayout);
|
||||
} catch (ArithmeticException ex) {
|
||||
CompilerDirectives.transferToInterpreter();
|
||||
throw ex;
|
||||
}
|
||||
return interop.asInt(args[0]);
|
||||
}
|
||||
|
||||
@Fallback
|
||||
static Object doOther(ArrowFixedSizeArrayFactory receiver, Object[] args) {
|
||||
throw CompilerDirectives.shouldNotReachHere(unknownLayoutMessage(receiver.getLayout()));
|
||||
}
|
||||
|
||||
@CompilerDirectives.TruffleBoundary
|
||||
private static String unknownLayoutMessage(SizeInBytes layout) {
|
||||
return "unknown layout: " + layout.toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,97 @@
|
||||
package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
import com.oracle.truffle.api.dsl.Bind;
|
||||
import com.oracle.truffle.api.dsl.Cached;
|
||||
import com.oracle.truffle.api.interop.ArityException;
|
||||
import com.oracle.truffle.api.interop.InteropLibrary;
|
||||
import com.oracle.truffle.api.interop.StopIterationException;
|
||||
import com.oracle.truffle.api.interop.TruffleObject;
|
||||
import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.interop.UnsupportedTypeException;
|
||||
import com.oracle.truffle.api.library.CachedLibrary;
|
||||
import com.oracle.truffle.api.library.ExportLibrary;
|
||||
import com.oracle.truffle.api.library.ExportMessage;
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
import com.oracle.truffle.api.profiles.InlinedExactClassProfile;
|
||||
import org.enso.interpreter.arrow.LogicalLayout;
|
||||
|
||||
@ExportLibrary(InteropLibrary.class)
|
||||
public final class ArrowOperationPlus implements TruffleObject {
|
||||
private final LogicalLayout layout;
|
||||
|
||||
public ArrowOperationPlus(LogicalLayout layout) {
|
||||
this.layout = layout;
|
||||
}
|
||||
|
||||
final LogicalLayout layout() {
|
||||
return layout;
|
||||
}
|
||||
|
||||
@ExportMessage
|
||||
boolean isExecutable() {
|
||||
return true;
|
||||
}
|
||||
|
||||
static Object args(Object[] args, int index) throws ArityException {
|
||||
if (args.length != 2) {
|
||||
throw ArityException.create(2, 2, args.length);
|
||||
}
|
||||
return args[index];
|
||||
}
|
||||
|
||||
static Object it(Object[] args, InteropLibrary iop, int index)
|
||||
throws ArityException, UnsupportedMessageException {
|
||||
if (args.length != 2) {
|
||||
throw ArityException.create(2, 2, args.length);
|
||||
}
|
||||
return iop.getIterator(args[index]);
|
||||
}
|
||||
|
||||
ScalarOperationNode createScalarOp(boolean cached) {
|
||||
return cached ? OperationPlus.create() : OperationPlus.getUncached();
|
||||
}
|
||||
|
||||
@ExportMessage(limit = "3")
|
||||
Object execute(
|
||||
Object[] args,
|
||||
@Bind("$node") Node node,
|
||||
@Cached(value = "this.layout()", allowUncached = true) LogicalLayout cachedLayout,
|
||||
@Cached ArrowFixedSizeArrayFactory.InstantiateNode factory,
|
||||
@CachedLibrary("args(args, 0)") InteropLibrary iopArray0,
|
||||
@CachedLibrary("args(args, 1)") InteropLibrary iopArray1,
|
||||
@CachedLibrary("it(args, iopArray0, 0)") InteropLibrary iopIt0,
|
||||
@CachedLibrary("it(args, iopArray1, 1)") InteropLibrary iopIt1,
|
||||
@CachedLibrary(limit = "3") InteropLibrary iopElem,
|
||||
@Cached(value = "this.createScalarOp(true)", uncached = "this.createScalarOp(false)")
|
||||
ScalarOperationNode opNode,
|
||||
@Cached ArrowFixedSizeArrayBuilder.AppendNode append,
|
||||
@Cached ArrowFixedSizeArrayBuilder.BuildNode build,
|
||||
@Cached InlinedExactClassProfile typeOfBuf0,
|
||||
@Cached InlinedExactClassProfile typeOfBuf1)
|
||||
throws ArityException, UnsupportedTypeException, UnsupportedMessageException {
|
||||
var arr0 = args[0];
|
||||
var arr1 = args[1];
|
||||
if (!iopArray0.hasArrayElements(arr0) || !iopArray1.hasArrayElements(arr1)) {
|
||||
throw UnsupportedTypeException.create(args);
|
||||
}
|
||||
var len = iopArray0.getArraySize(arr0);
|
||||
if (len != iopArray1.getArraySize(arr1)) {
|
||||
throw UnsupportedTypeException.create(args, "Arrays must have the same length");
|
||||
}
|
||||
var it0 = iopArray0.getIterator(arr0);
|
||||
var it1 = iopArray1.getIterator(arr1);
|
||||
var builder = factory.allocateBuilder(cachedLayout, len);
|
||||
|
||||
for (long i = 0; i < len; i++) {
|
||||
try {
|
||||
var elem0 = iopIt0.getIteratorNextElement(it0);
|
||||
var elem1 = iopIt1.getIteratorNextElement(it1);
|
||||
var res = opNode.executeOp(elem0, elem1);
|
||||
append.executeAppend(builder, res);
|
||||
} catch (StopIterationException ex) {
|
||||
throw UnsupportedTypeException.create(new Object[] {it0, it1});
|
||||
}
|
||||
}
|
||||
return build.executeBuild(builder);
|
||||
}
|
||||
}
|
@ -1,14 +1,25 @@
|
||||
package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
import com.oracle.truffle.api.CompilerDirectives;
|
||||
import com.oracle.truffle.api.dsl.Bind;
|
||||
import com.oracle.truffle.api.dsl.Cached;
|
||||
import com.oracle.truffle.api.dsl.GenerateInline;
|
||||
import com.oracle.truffle.api.dsl.GenerateUncached;
|
||||
import com.oracle.truffle.api.dsl.NeverDefault;
|
||||
import com.oracle.truffle.api.dsl.Specialization;
|
||||
import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
import com.oracle.truffle.api.profiles.InlinedExactClassProfile;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import org.enso.interpreter.arrow.LogicalLayout;
|
||||
import org.enso.interpreter.arrow.runtime.ByteBufferDirect.DataBufferNode;
|
||||
import org.enso.interpreter.arrow.util.MemoryUtil;
|
||||
|
||||
final class ByteBufferDirect implements AutoCloseable {
|
||||
private final ByteBuffer allocated;
|
||||
private final ByteBuffer dataBuffer;
|
||||
private final ByteBuffer bitmapBuffer;
|
||||
private ByteBuffer bitmapBuffer;
|
||||
|
||||
/**
|
||||
* Creates a fresh buffer with an empty non-null bitmap..
|
||||
@ -22,10 +33,7 @@ final class ByteBufferDirect implements AutoCloseable {
|
||||
|
||||
this.allocated = buffer;
|
||||
this.dataBuffer = buffer.slice(0, padded.getDataBufferSizeInBytes());
|
||||
this.bitmapBuffer = buffer.slice(dataBuffer.capacity(), padded.getValidityBitmapSizeInBytes());
|
||||
for (int i = 0; i < bitmapBuffer.capacity(); i++) {
|
||||
bitmapBuffer.put(i, (byte) 0);
|
||||
}
|
||||
this.bitmapBuffer = null;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -53,8 +61,13 @@ final class ByteBufferDirect implements AutoCloseable {
|
||||
this.dataBuffer = dataBuffer;
|
||||
this.bitmapBuffer = allocated.slice(dataBuffer.capacity(), bitmapSizeInBytes);
|
||||
for (int i = 0; i < bitmapBuffer.capacity(); i++) {
|
||||
bitmapBuffer.put(i, (byte) 255);
|
||||
bitmapBuffer.put(i, (byte) 0xff);
|
||||
}
|
||||
bitmapBuffer.rewind();
|
||||
}
|
||||
|
||||
static ByteBufferDirect forBuffer(ByteBuffer buf) {
|
||||
return new ByteBufferDirect(buf, buf, null);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -93,119 +106,200 @@ final class ByteBufferDirect implements AutoCloseable {
|
||||
return new ByteBufferDirect(allocated, dataBuffer, bitmapBuffer);
|
||||
}
|
||||
|
||||
public void put(byte b) throws UnsupportedMessageException {
|
||||
setValidityBitmap(0, 1);
|
||||
dataBuffer.put(b);
|
||||
@CompilerDirectives.TruffleBoundary
|
||||
ByteBuffer initializeBitmapBuffer() {
|
||||
assert bitmapBuffer == null;
|
||||
bitmapBuffer =
|
||||
allocated.slice(dataBuffer.capacity(), allocated.capacity() - dataBuffer.capacity());
|
||||
for (var i = 0; i < bitmapBuffer.capacity(); i++) {
|
||||
bitmapBuffer.put(i, (byte) 0xff);
|
||||
}
|
||||
return bitmapBuffer;
|
||||
}
|
||||
|
||||
final ByteBuffer getDataBuffer() {
|
||||
return dataBuffer;
|
||||
}
|
||||
|
||||
final ByteBuffer getBitmapBuffer() {
|
||||
return bitmapBuffer;
|
||||
}
|
||||
|
||||
@GenerateInline(false)
|
||||
@GenerateUncached
|
||||
abstract static class DataBufferNode extends Node {
|
||||
static DataBufferNode create() {
|
||||
return ByteBufferDirectFactory.DataBufferNodeGen.create();
|
||||
}
|
||||
|
||||
static DataBufferNode getUncached() {
|
||||
return ByteBufferDirectFactory.DataBufferNodeGen.getUncached();
|
||||
}
|
||||
|
||||
abstract ByteBuffer executeDataBuffer(ByteBufferDirect direct);
|
||||
|
||||
@Specialization
|
||||
static ByteBuffer profiledDataBuffer(
|
||||
ByteBufferDirect direct,
|
||||
@Bind("$node") Node node,
|
||||
@Cached InlinedExactClassProfile bufferClazz) {
|
||||
return bufferClazz.profile(node, direct.dataBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
@GenerateInline(false)
|
||||
@GenerateUncached
|
||||
abstract static class BitmapBufferNode extends Node {
|
||||
static BitmapBufferNode create() {
|
||||
return ByteBufferDirectFactory.BitmapBufferNodeGen.create();
|
||||
}
|
||||
|
||||
static BitmapBufferNode getUncached() {
|
||||
return ByteBufferDirectFactory.BitmapBufferNodeGen.getUncached();
|
||||
}
|
||||
|
||||
abstract ByteBuffer executeBitmapBuffer(ByteBufferDirect direct, boolean forceCreation);
|
||||
|
||||
@Specialization
|
||||
static ByteBuffer profiledBitmapBuffer(
|
||||
ByteBufferDirect direct,
|
||||
boolean forceCreation,
|
||||
@Bind("$node") Node node,
|
||||
@Cached InlinedExactClassProfile bufferClazz) {
|
||||
|
||||
if (direct.bitmapBuffer == null) {
|
||||
if (forceCreation) {
|
||||
direct.bitmapBuffer = direct.initializeBitmapBuffer();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return bufferClazz.profile(node, direct.bitmapBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
static final class PutNode extends Node {
|
||||
private static final PutNode UNCACHED =
|
||||
new PutNode(DataBufferNode.getUncached(), BitmapBufferNode.getUncached());
|
||||
private @Child DataBufferNode dataBuffer;
|
||||
private @Child BitmapBufferNode bitmapBuffer;
|
||||
|
||||
private PutNode(DataBufferNode dbn, BitmapBufferNode bbn) {
|
||||
this.dataBuffer = dbn;
|
||||
this.bitmapBuffer = bbn;
|
||||
}
|
||||
|
||||
@NeverDefault
|
||||
static PutNode create() {
|
||||
return new PutNode(DataBufferNode.create(), BitmapBufferNode.create());
|
||||
}
|
||||
|
||||
@NeverDefault
|
||||
static PutNode getUncached() {
|
||||
return UNCACHED;
|
||||
}
|
||||
|
||||
final void put(ByteBufferDirect direct, byte b) {
|
||||
var db = dataBuffer.executeDataBuffer(direct);
|
||||
addValidityBitmap(direct, db.position(), 1);
|
||||
db.put(b);
|
||||
}
|
||||
|
||||
final void putNull(ByteBufferDirect direct, LogicalLayout unit) {
|
||||
var db = dataBuffer.executeDataBuffer(direct);
|
||||
var index = db.position() / unit.sizeInBytes();
|
||||
|
||||
var bb = bitmapBuffer.executeBitmapBuffer(direct, true);
|
||||
|
||||
var bufferIndex = index >> 3;
|
||||
var slot = bb.get(bufferIndex);
|
||||
var byteIndex = index & BYTE_MASK;
|
||||
var mask = ~(1 << byteIndex);
|
||||
bb.put(bufferIndex, (byte) (slot & mask));
|
||||
|
||||
db.position(db.position() + unit.sizeInBytes());
|
||||
}
|
||||
|
||||
final void putShort(ByteBufferDirect direct, short value) {
|
||||
var db = dataBuffer.executeDataBuffer(direct);
|
||||
addValidityBitmap(direct, db.position(), 2);
|
||||
db.putShort(value);
|
||||
}
|
||||
|
||||
final void putInt(ByteBufferDirect direct, int value) {
|
||||
var db = dataBuffer.executeDataBuffer(direct);
|
||||
addValidityBitmap(direct, db.position(), 4);
|
||||
db.putInt(value);
|
||||
}
|
||||
|
||||
final void putLong(ByteBufferDirect direct, long value) throws UnsupportedMessageException {
|
||||
var db = dataBuffer.executeDataBuffer(direct);
|
||||
addValidityBitmap(direct, db.position(), 8);
|
||||
db.putLong(value);
|
||||
}
|
||||
|
||||
private void addValidityBitmap(ByteBufferDirect direct, int pos, int size) {
|
||||
var bb = bitmapBuffer.executeBitmapBuffer(direct, false);
|
||||
if (bb == null) {
|
||||
return;
|
||||
}
|
||||
var index = pos / size;
|
||||
var bufferIndex = index >> 3;
|
||||
var slot = bb.get(bufferIndex);
|
||||
var byteIndex = index & BYTE_MASK;
|
||||
|
||||
var mask = 1 << byteIndex;
|
||||
var updated = (slot | mask);
|
||||
bb.put(bufferIndex, (byte) (updated));
|
||||
}
|
||||
}
|
||||
|
||||
public byte get(int index) throws UnsupportedMessageException {
|
||||
return dataBuffer.get(index);
|
||||
}
|
||||
|
||||
public void put(int index, byte b) throws UnsupportedMessageException {
|
||||
setValidityBitmap(index, 1);
|
||||
dataBuffer.put(index, b);
|
||||
}
|
||||
|
||||
public void putShort(short value) throws UnsupportedMessageException {
|
||||
setValidityBitmap(0, 2);
|
||||
dataBuffer.putShort(value);
|
||||
}
|
||||
|
||||
public short getShort(int index) throws UnsupportedMessageException {
|
||||
return dataBuffer.getShort(index);
|
||||
}
|
||||
|
||||
public void putShort(int index, short value) throws UnsupportedMessageException {
|
||||
setValidityBitmap(index, 2);
|
||||
dataBuffer.putShort(index, value);
|
||||
}
|
||||
|
||||
public void putInt(int value) throws UnsupportedMessageException {
|
||||
setValidityBitmap(0, 4);
|
||||
dataBuffer.putInt(value);
|
||||
}
|
||||
|
||||
public int getInt(int index) throws UnsupportedMessageException {
|
||||
return dataBuffer.getInt(index);
|
||||
}
|
||||
|
||||
public void putInt(int index, int value) {
|
||||
setValidityBitmap(index, 4);
|
||||
dataBuffer.putInt(index, value);
|
||||
}
|
||||
|
||||
public void putLong(long value) throws UnsupportedMessageException {
|
||||
setValidityBitmap(0, 8);
|
||||
dataBuffer.putLong(value);
|
||||
}
|
||||
|
||||
public long getLong(int index) throws UnsupportedMessageException {
|
||||
return dataBuffer.getLong(index);
|
||||
}
|
||||
|
||||
public void putLong(int index, long value) {
|
||||
setValidityBitmap(index, 8);
|
||||
dataBuffer.putLong(index, value);
|
||||
}
|
||||
|
||||
public void putFloat(float value) throws UnsupportedMessageException {
|
||||
setValidityBitmap(0, 4);
|
||||
dataBuffer.putFloat(value);
|
||||
}
|
||||
|
||||
public float getFloat(int index) throws UnsupportedMessageException {
|
||||
return dataBuffer.getFloat(index);
|
||||
}
|
||||
|
||||
public void putFloat(int index, float value) throws UnsupportedMessageException {
|
||||
setValidityBitmap(index, 4);
|
||||
dataBuffer.putFloat(index, value);
|
||||
}
|
||||
|
||||
public void putDouble(double value) throws UnsupportedMessageException {
|
||||
setValidityBitmap(0, 8);
|
||||
dataBuffer.putDouble(value);
|
||||
}
|
||||
|
||||
public double getDouble(int index) throws UnsupportedMessageException {
|
||||
return dataBuffer.getDouble(index);
|
||||
}
|
||||
|
||||
public void putDouble(int index, double value) throws UnsupportedMessageException {
|
||||
setValidityBitmap(index, 8);
|
||||
dataBuffer.putDouble(index, value);
|
||||
public long getLong(int index, Node node, InlinedExactClassProfile profile)
|
||||
throws UnsupportedMessageException {
|
||||
var buf = profile.profile(node, dataBuffer);
|
||||
return buf.getLong(index);
|
||||
}
|
||||
|
||||
public int capacity() throws UnsupportedMessageException {
|
||||
return dataBuffer.capacity();
|
||||
}
|
||||
|
||||
boolean hasNulls() {
|
||||
return bitmapBuffer != null;
|
||||
}
|
||||
|
||||
public boolean isNull(int index) {
|
||||
if (bitmapBuffer == null) {
|
||||
return false;
|
||||
}
|
||||
return checkForNull(index);
|
||||
}
|
||||
|
||||
private boolean checkForNull(int index) {
|
||||
var bufferIndex = index >> 3;
|
||||
var slot = bitmapBuffer.get(bufferIndex);
|
||||
var byteIndex = index & ~(1 << 3);
|
||||
var byteIndex = index & BYTE_MASK;
|
||||
var mask = 1 << byteIndex;
|
||||
return (slot & mask) == 0;
|
||||
}
|
||||
|
||||
public void setNull(int index) {
|
||||
var bufferIndex = index >> 3;
|
||||
var slot = bitmapBuffer.get(bufferIndex);
|
||||
var byteIndex = index & ~(1 << 3);
|
||||
var mask = ~(1 << byteIndex);
|
||||
bitmapBuffer.put(bufferIndex, (byte) (slot & mask));
|
||||
}
|
||||
|
||||
private void setValidityBitmap(int index0, int unitSize) {
|
||||
var index = index0 / unitSize;
|
||||
var bufferIndex = index >> 3;
|
||||
var slot = bitmapBuffer.get(bufferIndex);
|
||||
var byteIndex = index & ~(1 << 3);
|
||||
var mask = 1 << byteIndex;
|
||||
var updated = (slot | mask);
|
||||
bitmapBuffer.put(bufferIndex, (byte) (updated));
|
||||
}
|
||||
private static final int BYTE_MASK = ~(~(1 << 3) + 1); // 7
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {
|
||||
|
@ -0,0 +1,69 @@
|
||||
package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
import com.oracle.truffle.api.dsl.Cached.Shared;
|
||||
import com.oracle.truffle.api.dsl.GenerateInline;
|
||||
import com.oracle.truffle.api.dsl.GenerateUncached;
|
||||
import com.oracle.truffle.api.dsl.Specialization;
|
||||
import com.oracle.truffle.api.interop.InteropLibrary;
|
||||
import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.library.CachedLibrary;
|
||||
|
||||
@GenerateUncached
|
||||
@GenerateInline(false)
|
||||
abstract class OperationPlus extends ScalarOperationNode {
|
||||
@Override
|
||||
abstract Object executeOp(Object a, Object b) throws UnsupportedMessageException;
|
||||
|
||||
static OperationPlus create() {
|
||||
return OperationPlusNodeGen.create();
|
||||
}
|
||||
|
||||
static OperationPlus getUncached() {
|
||||
return OperationPlusNodeGen.getUncached();
|
||||
}
|
||||
|
||||
@Specialization(rewriteOn = ArithmeticException.class)
|
||||
long doLongs(long a, long b) {
|
||||
return Math.addExact(a, b);
|
||||
}
|
||||
|
||||
@Specialization(replaces = "doLongs")
|
||||
Object doLongsWithOverflowCheck(long a, long b) {
|
||||
long res = a + b;
|
||||
long check1 = a ^ res;
|
||||
long check2 = b ^ res;
|
||||
long checkBoth = check1 & check2;
|
||||
if (checkBoth < 0) {
|
||||
return NullValue.get();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Specialization(
|
||||
guards = {"iop.fitsInLong(a)", "iop.fitsInLong(b)"},
|
||||
rewriteOn = ArithmeticException.class)
|
||||
Object doFitInLong(
|
||||
Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
var la = iop.asLong(a);
|
||||
var lb = iop.asLong(b);
|
||||
return doLongs(la, lb);
|
||||
}
|
||||
|
||||
@Specialization(
|
||||
guards = {"iop.fitsInLong(a)", "iop.fitsInLong(b)"},
|
||||
replaces = "doFitInLong")
|
||||
Object doFitInLongWithOverflowCheck(
|
||||
Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop)
|
||||
throws UnsupportedMessageException {
|
||||
var la = iop.asLong(a);
|
||||
var lb = iop.asLong(b);
|
||||
return doLongsWithOverflowCheck(la, lb);
|
||||
}
|
||||
|
||||
@Specialization(guards = {"iop.isNull(a) || iop.isNull(b)"})
|
||||
NullValue nothing(
|
||||
Object a, Object b, @Shared("iop") @CachedLibrary(limit = "3") InteropLibrary iop) {
|
||||
return NullValue.get();
|
||||
}
|
||||
}
|
@ -19,16 +19,16 @@
|
||||
|
||||
package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
class RoundingUtil {
|
||||
final class RoundingUtil {
|
||||
|
||||
/** The mask for rounding an integer to a multiple of 8. (i.e. clear the lowest 3 bits) */
|
||||
static int ROUND_8_MASK_INT = 0xFFFFFFF8;
|
||||
static final int ROUND_8_MASK_INT = 0xFFFFFFF8;
|
||||
|
||||
/** The mask for rounding a long integer to a multiple of 8. (i.e. clear the lowest 3 bits) */
|
||||
static long ROUND_8_MASK_LONG = 0xFFFFFFFFFFFFFFF8L;
|
||||
static final long ROUND_8_MASK_LONG = 0xFFFFFFFFFFFFFFF8L;
|
||||
|
||||
/** The number of bits to shift for dividing by 8. */
|
||||
static int DIVIDE_BY_8_SHIFT_BITS = 3;
|
||||
static final int DIVIDE_BY_8_SHIFT_BITS = 3;
|
||||
|
||||
private RoundingUtil() {}
|
||||
|
||||
@ -91,13 +91,15 @@ class RoundingUtil {
|
||||
return (int) (dataBufferSize + validityBitmapSize);
|
||||
}
|
||||
|
||||
private long validityBitmapSize;
|
||||
private long dataBufferSize;
|
||||
private final long validityBitmapSize;
|
||||
private final long dataBufferSize;
|
||||
|
||||
private PaddedSize(int valueCount, SizeInBytes unit) {
|
||||
this.valueCount = valueCount;
|
||||
this.unit = unit;
|
||||
computeBufferSize(valueCount, unit);
|
||||
var pair = computeBufferSize(valueCount, unit);
|
||||
this.validityBitmapSize = pair[0];
|
||||
this.dataBufferSize = pair[1];
|
||||
}
|
||||
|
||||
private long defaultRoundedSize(long val) {
|
||||
@ -127,7 +129,7 @@ class RoundingUtil {
|
||||
return defaultRoundedSize(bufferSize);
|
||||
}
|
||||
|
||||
private void computeBufferSize(int valueCount, SizeInBytes unit) {
|
||||
private long[] computeBufferSize(int valueCount, SizeInBytes unit) {
|
||||
var typeWidth = unit.sizeInBytes();
|
||||
long bufferSize = computeCombinedBufferSize(valueCount, typeWidth);
|
||||
assert bufferSize <= Long.MAX_VALUE;
|
||||
@ -149,8 +151,7 @@ class RoundingUtil {
|
||||
--actualCount;
|
||||
} while (true);
|
||||
}
|
||||
this.validityBitmapSize = validityBufferSize;
|
||||
this.dataBufferSize = dataBufferSize;
|
||||
return new long[] {validityBufferSize, dataBufferSize};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,10 @@
|
||||
package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
import com.oracle.truffle.api.dsl.GenerateInline;
|
||||
import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
|
||||
@GenerateInline(false)
|
||||
abstract class ScalarOperationNode extends Node {
|
||||
abstract Object executeOp(Object a, Object b) throws UnsupportedMessageException;
|
||||
}
|
@ -1,7 +1,13 @@
|
||||
package org.enso.interpreter.arrow.runtime;
|
||||
|
||||
import com.oracle.truffle.api.CompilerDirectives;
|
||||
import com.oracle.truffle.api.dsl.*;
|
||||
import com.oracle.truffle.api.dsl.Cached;
|
||||
import com.oracle.truffle.api.dsl.Fallback;
|
||||
import com.oracle.truffle.api.dsl.GenerateInline;
|
||||
import com.oracle.truffle.api.dsl.GenerateUncached;
|
||||
import com.oracle.truffle.api.dsl.ImportStatic;
|
||||
import com.oracle.truffle.api.dsl.NeverDefault;
|
||||
import com.oracle.truffle.api.dsl.Specialization;
|
||||
import com.oracle.truffle.api.interop.InteropLibrary;
|
||||
import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.interop.UnsupportedTypeException;
|
||||
@ -17,58 +23,56 @@ import org.enso.interpreter.arrow.LogicalLayout;
|
||||
@ImportStatic(LogicalLayout.class)
|
||||
@GenerateUncached
|
||||
@GenerateInline(value = false)
|
||||
abstract class WriteToBuilderNode extends Node {
|
||||
|
||||
public abstract void executeWrite(ArrowFixedSizeArrayBuilder receiver, long index, Object value)
|
||||
throws UnsupportedTypeException;
|
||||
abstract class ValueToNumberNode extends Node {
|
||||
/**
|
||||
* Converts {@code value} to a suitable representation to be stored in an appropriate
|
||||
* DirectBuffer.
|
||||
*
|
||||
* @param unit type of layout
|
||||
* @param value a value to convert
|
||||
* @return byte, short, int or long
|
||||
* @throws UnsupportedTypeException if the conversion isn't possible
|
||||
*/
|
||||
abstract Number executeAdjust(LogicalLayout unit, Object value) throws UnsupportedTypeException;
|
||||
|
||||
@NeverDefault
|
||||
static WriteToBuilderNode build() {
|
||||
return WriteToBuilderNodeGen.create();
|
||||
static ValueToNumberNode build() {
|
||||
return ValueToNumberNodeGen.create();
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getUnit() == Date32")
|
||||
void doWriteDay(
|
||||
ArrowFixedSizeArrayBuilder receiver,
|
||||
long index,
|
||||
@NeverDefault
|
||||
static ValueToNumberNode getUncached() {
|
||||
return ValueToNumberNodeGen.getUncached();
|
||||
}
|
||||
|
||||
@Specialization(guards = "unit == Date32")
|
||||
Integer doDay(
|
||||
LogicalLayout unit,
|
||||
Object value,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedTypeException {
|
||||
validAccess(receiver, index);
|
||||
if (iop.isNull(value)) {
|
||||
receiver.getBuffer().setNull((int) index);
|
||||
return;
|
||||
}
|
||||
if (!iop.isDate(value)) {
|
||||
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date");
|
||||
}
|
||||
var at = ArrowFixedArrayDate.typeAdjustedIndex(index, 4);
|
||||
long time;
|
||||
try {
|
||||
time = iop.asDate(value).toEpochDay();
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date");
|
||||
}
|
||||
receiver.getBuffer().putInt(at, Math.toIntExact(time));
|
||||
return Math.toIntExact(time);
|
||||
}
|
||||
|
||||
@Specialization(guards = {"receiver.getUnit() == Date64"})
|
||||
void doWriteMilliseconds(
|
||||
ArrowFixedSizeArrayBuilder receiver,
|
||||
long index,
|
||||
@Specialization(guards = {"unit == Date64"})
|
||||
Long doMilliseconds(
|
||||
LogicalLayout unit,
|
||||
Object value,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedTypeException {
|
||||
validAccess(receiver, index);
|
||||
if (iop.isNull(value)) {
|
||||
receiver.getBuffer().setNull((int) index);
|
||||
return;
|
||||
}
|
||||
if (!iop.isDate(value) || !iop.isTime(value)) {
|
||||
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a date and a time");
|
||||
}
|
||||
|
||||
var at = ArrowFixedArrayDate.typeAdjustedIndex(index, 8);
|
||||
if (iop.isTimeZone(value)) {
|
||||
Instant zoneDateTimeInstant;
|
||||
try {
|
||||
@ -84,7 +88,7 @@ abstract class WriteToBuilderNode extends Node {
|
||||
var secondsPlusNano =
|
||||
zoneDateTimeInstant.getEpochSecond() * ArrowFixedArrayDate.NANO_DIV
|
||||
+ zoneDateTimeInstant.getNano();
|
||||
receiver.getBuffer().putLong(at, secondsPlusNano);
|
||||
return secondsPlusNano;
|
||||
} else {
|
||||
Instant dateTime;
|
||||
try {
|
||||
@ -94,7 +98,7 @@ abstract class WriteToBuilderNode extends Node {
|
||||
}
|
||||
var secondsPlusNano =
|
||||
dateTime.getEpochSecond() * ArrowFixedArrayDate.NANO_DIV + dateTime.getNano();
|
||||
receiver.getBuffer().putLong(at, secondsPlusNano);
|
||||
return secondsPlusNano;
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,109 +113,75 @@ abstract class WriteToBuilderNode extends Node {
|
||||
return date.atTime(time).toInstant(offset);
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getUnit() == Int8")
|
||||
void doWriteByte(
|
||||
ArrowFixedSizeArrayBuilder receiver,
|
||||
long index,
|
||||
@Specialization(guards = "unit == Int8")
|
||||
Byte doByte(
|
||||
LogicalLayout unit,
|
||||
Object value,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedTypeException {
|
||||
validAccess(receiver, index);
|
||||
if (iop.isNull(value)) {
|
||||
receiver.getBuffer().setNull((int) index);
|
||||
return;
|
||||
}
|
||||
if (!iop.fitsInByte(value)) {
|
||||
throw UnsupportedTypeException.create(new Object[] {value}, "value does not fit a byte");
|
||||
}
|
||||
try {
|
||||
receiver.getBuffer().put(typeAdjustedIndex(index, receiver.getUnit()), (iop.asByte(value)));
|
||||
return iop.asByte(value);
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a byte");
|
||||
}
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getUnit() == Int16")
|
||||
void doWriteShort(
|
||||
ArrowFixedSizeArrayBuilder receiver,
|
||||
long index,
|
||||
@Specialization(guards = "unit == Int16")
|
||||
Short doShort(
|
||||
LogicalLayout unit,
|
||||
Object value,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedTypeException {
|
||||
validAccess(receiver, index);
|
||||
if (iop.isNull(value)) {
|
||||
receiver.getBuffer().setNull((int) index);
|
||||
return;
|
||||
}
|
||||
if (!iop.fitsInShort(value)) {
|
||||
throw UnsupportedTypeException.create(
|
||||
new Object[] {value}, "value does not fit a 2 byte short");
|
||||
}
|
||||
try {
|
||||
receiver
|
||||
.getBuffer()
|
||||
.putShort(typeAdjustedIndex(index, receiver.getUnit()), (iop.asShort(value)));
|
||||
return iop.asShort(value);
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a short");
|
||||
}
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getUnit() == Int32")
|
||||
void doWriteInt(
|
||||
ArrowFixedSizeArrayBuilder receiver,
|
||||
long index,
|
||||
@Specialization(guards = "unit == Int32")
|
||||
Integer doInt(
|
||||
LogicalLayout unit,
|
||||
int value,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedTypeException {
|
||||
validAccess(receiver, index);
|
||||
if (iop.isNull(value)) {
|
||||
receiver.getBuffer().setNull((int) index);
|
||||
return;
|
||||
}
|
||||
if (!iop.fitsInInt(value)) {
|
||||
throw UnsupportedTypeException.create(
|
||||
new Object[] {value}, "value does not fit a 4 byte int");
|
||||
}
|
||||
try {
|
||||
receiver.getBuffer().putInt(typeAdjustedIndex(index, receiver.getUnit()), (iop.asInt(value)));
|
||||
return iop.asInt(value);
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw UnsupportedTypeException.create(new Object[] {value}, "value is not an int");
|
||||
}
|
||||
}
|
||||
|
||||
@Specialization(guards = "receiver.getUnit() == Int64")
|
||||
public static void doWriteLong(
|
||||
ArrowFixedSizeArrayBuilder receiver,
|
||||
long index,
|
||||
long value,
|
||||
@Specialization(guards = "unit == Int64")
|
||||
static Long doLong(
|
||||
LogicalLayout unit,
|
||||
Object value,
|
||||
@Cached.Shared("interop") @CachedLibrary(limit = "1") InteropLibrary iop)
|
||||
throws UnsupportedTypeException {
|
||||
validAccess(receiver, index);
|
||||
if (iop.isNull(value)) {
|
||||
receiver.getBuffer().setNull((int) index);
|
||||
return;
|
||||
if (!iop.fitsInLong(value)) {
|
||||
throw UnsupportedTypeException.create(
|
||||
new Object[] {value}, "value does not fit a 8 byte int");
|
||||
}
|
||||
try {
|
||||
return iop.asLong(value);
|
||||
} catch (UnsupportedMessageException e) {
|
||||
throw UnsupportedTypeException.create(new Object[] {value}, "value is not a long");
|
||||
}
|
||||
receiver.getBuffer().putLong(typeAdjustedIndex(index, receiver.getUnit()), value);
|
||||
}
|
||||
|
||||
@Fallback
|
||||
void doWriteOther(ArrowFixedSizeArrayBuilder receiver, long index, Object value)
|
||||
throws UnsupportedTypeException {
|
||||
throw UnsupportedTypeException.create(new Object[] {index, value}, "unknown type of receiver");
|
||||
}
|
||||
|
||||
private static void validAccess(ArrowFixedSizeArrayBuilder receiver, long index)
|
||||
throws UnsupportedTypeException {
|
||||
if (receiver.isSealed()) {
|
||||
throw UnsupportedTypeException.create(
|
||||
new Object[] {receiver}, "receiver is not an unsealed buffer");
|
||||
}
|
||||
if (index >= receiver.getSize() || index < 0) {
|
||||
throw UnsupportedTypeException.create(new Object[] {index}, "index is out of range");
|
||||
}
|
||||
}
|
||||
|
||||
private static int typeAdjustedIndex(long index, SizeInBytes unit) {
|
||||
return ArrowFixedArrayDate.typeAdjustedIndex(index, unit.sizeInBytes());
|
||||
Number doOther(LogicalLayout unit, Object value) throws UnsupportedTypeException {
|
||||
throw UnsupportedTypeException.create(new Object[] {unit, value}, "unknown type");
|
||||
}
|
||||
}
|
@ -0,0 +1,142 @@
|
||||
package org.enso.interpreter.arrow;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import org.graalvm.polyglot.Context;
|
||||
import org.graalvm.polyglot.io.IOAccess;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
public class AddArrowTest {
|
||||
private static Context ctx;
|
||||
|
||||
@BeforeClass
|
||||
public static void initEnsoContext() {
|
||||
ctx =
|
||||
Context.newBuilder()
|
||||
.allowExperimentalOptions(true)
|
||||
.allowIO(IOAccess.ALL)
|
||||
.out(System.out)
|
||||
.err(System.err)
|
||||
.allowAllAccess(true)
|
||||
.build();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void closeEnsoContext() throws Exception {
|
||||
if (ctx != null) {
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addTwoInt8ArrowArrays() {
|
||||
var arrow = ctx.getEngine().getLanguages().get("arrow");
|
||||
assertNotNull("Arrow is available", arrow);
|
||||
var int8Constr = ctx.eval("arrow", "new[Int8]");
|
||||
assertNotNull(int8Constr);
|
||||
|
||||
var arrLength = 10;
|
||||
var builder1 = int8Constr.newInstance(arrLength);
|
||||
var builder2 = int8Constr.newInstance(arrLength);
|
||||
|
||||
for (var i = 0; i < arrLength; i++) {
|
||||
var ni = arrLength - i - 1;
|
||||
var v = i * i;
|
||||
builder1.invokeMember("append", i, (byte) v);
|
||||
builder2.invokeMember("append", ni, (byte) v);
|
||||
}
|
||||
|
||||
var arr1 = builder1.invokeMember("build");
|
||||
assertEquals("Right size of arr1", arrLength, arr1.getArraySize());
|
||||
var arr2 = builder2.invokeMember("build");
|
||||
assertEquals("Right size of arr2", arrLength, arr2.getArraySize());
|
||||
|
||||
var int8Plus = ctx.eval("arrow", "+[Int8]");
|
||||
var resultArr = int8Plus.execute(arr1, arr2);
|
||||
|
||||
assertTrue("Result is an array", resultArr.hasArrayElements());
|
||||
assertEquals("Right size", arrLength, resultArr.getArraySize());
|
||||
|
||||
for (var i = 0; i < arrLength; i++) {
|
||||
var ni = arrLength - i - 1;
|
||||
var v1 = resultArr.getArrayElement(i).asLong();
|
||||
var v2 = resultArr.getArrayElement(ni).asLong();
|
||||
|
||||
assertEquals("Values at " + i + " and " + ni + " are the same", v1, v2);
|
||||
assertTrue("Values are always bigger than zero: " + v1, v1 > 0);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void addTwoInt64ArrowArraysWithNulls() {
|
||||
var arrow = ctx.getEngine().getLanguages().get("arrow");
|
||||
assertNotNull("Arrow is available", arrow);
|
||||
var constr = ctx.eval("arrow", "new[Int64]");
|
||||
assertNotNull(constr);
|
||||
|
||||
var arrLength = 10;
|
||||
var builder1 = constr.newInstance(arrLength);
|
||||
for (int i = 0; i < arrLength; i++) {
|
||||
if (i % 7 < 2) {
|
||||
builder1.invokeMember("append", Long.MAX_VALUE);
|
||||
} else {
|
||||
builder1.invokeMember("append", i);
|
||||
}
|
||||
}
|
||||
|
||||
var builder2 = constr.newInstance(arrLength);
|
||||
for (var i = 0; i < arrLength; i++) {
|
||||
builder2.invokeMember("append", 10 + i);
|
||||
}
|
||||
|
||||
var arr1 = builder1.invokeMember("build");
|
||||
assertEquals("Right size of arr1", arrLength, arr1.getArraySize());
|
||||
var addArr = builder2.invokeMember("build");
|
||||
assertEquals("Right size of arr2", arrLength, addArr.getArraySize());
|
||||
|
||||
var plus = ctx.eval("arrow", "+[Int64]");
|
||||
var res1 = plus.execute(arr1, addArr);
|
||||
|
||||
assertTrue("Result is an array", res1.hasArrayElements());
|
||||
assertEquals("Right size", arrLength, res1.getArraySize());
|
||||
|
||||
assertTrue("is null", res1.getArrayElement(0).isNull());
|
||||
assertTrue("is null", res1.getArrayElement(1).isNull());
|
||||
assertTrue("is null", res1.getArrayElement(7).isNull());
|
||||
assertTrue("is null", res1.getArrayElement(8).isNull());
|
||||
|
||||
var countNulls = 0;
|
||||
for (var i = 0; i < arrLength; i++) {
|
||||
var v = res1.getArrayElement(i);
|
||||
if (v.isNull()) {
|
||||
countNulls++;
|
||||
} else {
|
||||
assertEquals(i * 2 + 10, v.asLong());
|
||||
}
|
||||
}
|
||||
assertEquals("Four nulls", 4, countNulls);
|
||||
|
||||
var res2 = plus.execute(res1, addArr);
|
||||
|
||||
assertTrue("Result is an array", res2.hasArrayElements());
|
||||
assertEquals("Right size", arrLength, res2.getArraySize());
|
||||
|
||||
assertTrue("is null", res2.getArrayElement(0).isNull());
|
||||
assertTrue("is null", res2.getArrayElement(1).isNull());
|
||||
assertTrue("is null", res2.getArrayElement(7).isNull());
|
||||
assertTrue("is null", res2.getArrayElement(8).isNull());
|
||||
|
||||
var countNullsAgain = 0;
|
||||
for (var i = 0; i < arrLength; i++) {
|
||||
var v = res2.getArrayElement(i);
|
||||
if (v.isNull()) {
|
||||
countNullsAgain++;
|
||||
} else {
|
||||
assertEquals(i * 3 + 20, v.asLong());
|
||||
}
|
||||
}
|
||||
assertEquals("Four nulls", 4, countNullsAgain);
|
||||
}
|
||||
}
|
@ -19,6 +19,7 @@ import org.apache.arrow.vector.BaseFixedWidthVector;
|
||||
import org.apache.arrow.vector.BigIntVector;
|
||||
import org.apache.arrow.vector.IntVector;
|
||||
import org.graalvm.polyglot.Context;
|
||||
import org.graalvm.polyglot.PolyglotException;
|
||||
import org.graalvm.polyglot.Value;
|
||||
import org.graalvm.polyglot.io.IOAccess;
|
||||
import org.junit.AfterClass;
|
||||
@ -84,7 +85,7 @@ public class VerifyArrowTest {
|
||||
date32ArrayBuilder.invokeMember("build");
|
||||
assertFalse(date32ArrayBuilder.canInvokeMember("append"));
|
||||
assertThrows(
|
||||
UnsupportedOperationException.class,
|
||||
PolyglotException.class,
|
||||
() -> finalDate32ArrayBuilder.invokeMember("append", startDateTime));
|
||||
assertFalse(date32Array.canInvokeMember("append"));
|
||||
}
|
||||
@ -153,6 +154,33 @@ public class VerifyArrowTest {
|
||||
assertEquals((byte) 5, v.asByte());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void arrowInt64() {
|
||||
var arrow = ctx.getEngine().getLanguages().get("arrow");
|
||||
assertNotNull("Arrow is available", arrow);
|
||||
var constr = ctx.eval("arrow", "new[Int64]");
|
||||
assertNotNull(constr);
|
||||
|
||||
var arrLength = 48;
|
||||
Value builder = constr.newInstance(arrLength);
|
||||
for (var i = 0; i < arrLength; i++) {
|
||||
builder.invokeMember("append", i);
|
||||
}
|
||||
var arr = builder.invokeMember("build");
|
||||
assertEquals(arrLength, arr.getArraySize());
|
||||
for (var i = 0; i < arrLength; i++) {
|
||||
var ith = arr.getArrayElement(i);
|
||||
assertEquals("Checking value at " + i, i, ith.asLong());
|
||||
}
|
||||
|
||||
var plus = ctx.eval("arrow", "+[Int64]");
|
||||
var doubled = plus.execute(arr, arr);
|
||||
for (var i = 0; i < arrLength; i++) {
|
||||
var ith = doubled.getArrayElement(i);
|
||||
assertEquals("Checking double value at " + i, 2 * i, ith.asInt());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void castInt() {
|
||||
var typeLength = LogicalLayout.Int32;
|
||||
|
@ -9,7 +9,6 @@ import org.enso.interpreter.node.ExpressionNode;
|
||||
import org.enso.interpreter.runtime.callable.argument.CallArgument;
|
||||
import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo;
|
||||
import org.enso.interpreter.runtime.callable.function.Function;
|
||||
import org.enso.interpreter.runtime.state.State;
|
||||
|
||||
/**
|
||||
* This node is responsible for organising callable calls so that they are ready to be made.
|
||||
@ -92,10 +91,10 @@ public class ApplicationNode extends ExpressionNode {
|
||||
*/
|
||||
@Override
|
||||
public Object executeGeneric(VirtualFrame frame) {
|
||||
State state = Function.ArgumentsHelper.getState(frame.getArguments());
|
||||
Object[] evaluatedArguments = evaluateArguments(frame);
|
||||
return this.invokeCallableNode.execute(
|
||||
this.callable.executeGeneric(frame), frame, state, evaluatedArguments);
|
||||
var state = Function.ArgumentsHelper.getState(frame.getArguments());
|
||||
var evaluatedArguments = evaluateArguments(frame);
|
||||
var self = this.callable.executeGeneric(frame);
|
||||
return this.invokeCallableNode.execute(self, frame, state, evaluatedArguments);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1,6 +1,7 @@
|
||||
package org.enso.interpreter.node.callable;
|
||||
|
||||
import com.oracle.truffle.api.CompilerDirectives;
|
||||
import com.oracle.truffle.api.dsl.Bind;
|
||||
import com.oracle.truffle.api.dsl.Cached;
|
||||
import com.oracle.truffle.api.dsl.Cached.Shared;
|
||||
import com.oracle.truffle.api.dsl.Fallback;
|
||||
@ -12,6 +13,7 @@ import com.oracle.truffle.api.interop.UnsupportedMessageException;
|
||||
import com.oracle.truffle.api.interop.UnsupportedTypeException;
|
||||
import com.oracle.truffle.api.library.CachedLibrary;
|
||||
import com.oracle.truffle.api.nodes.Node;
|
||||
import com.oracle.truffle.api.profiles.InlinedBranchProfile;
|
||||
import com.oracle.truffle.api.source.SourceSection;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
@ -337,32 +339,37 @@ public abstract class InvokeCallableNode extends BaseNode {
|
||||
"!types.hasSpecialDispatch(self)",
|
||||
"iop.isExecutable(self)",
|
||||
})
|
||||
Object doPolyglot(
|
||||
static Object doPolyglot(
|
||||
Object self,
|
||||
VirtualFrame frame,
|
||||
State state,
|
||||
Object[] arguments,
|
||||
@Bind("$node") Node node,
|
||||
@CachedLibrary(limit = "3") InteropLibrary iop,
|
||||
@Shared("warnings") @CachedLibrary(limit = "3") WarningsLibrary warnings,
|
||||
@CachedLibrary(limit = "3") TypesLibrary types,
|
||||
@Cached ThunkExecutorNode thunkNode) {
|
||||
var errors = EnsoContext.get(this).getBuiltins().error();
|
||||
@Cached ThunkExecutorNode thunkNode,
|
||||
@Cached InlinedBranchProfile errorNeedsToBeReported) {
|
||||
var errors = EnsoContext.get(node).getBuiltins().error();
|
||||
try {
|
||||
for (int i = 0; i < arguments.length; i++) {
|
||||
arguments[i] = thunkNode.executeThunk(frame, arguments[i], state, TailStatus.NOT_TAIL);
|
||||
}
|
||||
return iop.execute(self, arguments);
|
||||
} catch (UnsupportedTypeException ex) {
|
||||
errorNeedsToBeReported.enter(node);
|
||||
var err = errors.makeUnsupportedArgumentsError(ex.getSuppliedValues(), ex.getMessage());
|
||||
throw new PanicException(err, this);
|
||||
throw new PanicException(err, node);
|
||||
} catch (ArityException ex) {
|
||||
errorNeedsToBeReported.enter(node);
|
||||
var err =
|
||||
errors.makeArityError(
|
||||
ex.getExpectedMinArity(), ex.getExpectedMaxArity(), arguments.length);
|
||||
throw new PanicException(err, this);
|
||||
throw new PanicException(err, node);
|
||||
} catch (UnsupportedMessageException ex) {
|
||||
errorNeedsToBeReported.enter(node);
|
||||
var err = errors.makeNotInvokable(self);
|
||||
throw new PanicException(err, this);
|
||||
throw new PanicException(err, node);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,45 +8,118 @@ polyglot java import java.lang.Long as Java_Long
|
||||
|
||||
options = Bench.options . set_warmup (Bench.phase_conf 3 5) . set_measure (Bench.phase_conf 3 5)
|
||||
|
||||
|
||||
create_table : Table
|
||||
create_table num_rows =
|
||||
create_vectors num_rows =
|
||||
x = Vector.new num_rows i->
|
||||
i+1
|
||||
y = Vector.new num_rows i->
|
||||
if i % 10 < 2 then Java_Long.MAX_VALUE else i+1
|
||||
u = Vector.new num_rows i->
|
||||
10 + (i % 100)
|
||||
z = Vector.new num_rows i->
|
||||
if i % 10 < 2 then Nothing else i+1
|
||||
|
||||
t = Table.new [["X", x], ["Y", y], ["U", u]]
|
||||
[x, y, u, z]
|
||||
|
||||
assert condition =
|
||||
if condition.not then Panic.throw "Assertion failed"
|
||||
create_table : Table
|
||||
create_table num_rows =
|
||||
v = create_vectors num_rows
|
||||
x = v.at 0
|
||||
y = v.at 1
|
||||
u = v.at 2
|
||||
z = v.at 3
|
||||
|
||||
assert ((t.at "X" . value_type) == Value_Type.Integer)
|
||||
assert ((t.at "Y" . value_type) == Value_Type.Integer)
|
||||
assert ((t.at "U" . value_type) == Value_Type.Integer)
|
||||
t = Table.new [["X", x], ["Y", y], ["U", u], ["Z", z]]
|
||||
|
||||
Runtime.assert ((t.at "X" . value_type) == Value_Type.Integer)
|
||||
Runtime.assert ((t.at "Y" . value_type) == Value_Type.Integer)
|
||||
Runtime.assert ((t.at "U" . value_type) == Value_Type.Integer)
|
||||
Runtime.assert ((t.at "Z" . value_type) == Value_Type.Integer)
|
||||
t
|
||||
|
||||
create_arrow_columns num_rows =
|
||||
column_to_arrow v:Vector -> Array =
|
||||
builder = int64_new.new v.length
|
||||
v.map e-> builder.append e
|
||||
builder.build
|
||||
|
||||
v = create_vectors num_rows
|
||||
|
||||
x = column_to_arrow (v.at 0)
|
||||
y = column_to_arrow (v.at 1)
|
||||
u = column_to_arrow (v.at 2)
|
||||
z = column_to_arrow (v.at 3)
|
||||
[int64_plus, x, y, u, z]
|
||||
|
||||
foreign arrow int64_new = """
|
||||
new[Int64]
|
||||
|
||||
foreign arrow int64_plus = """
|
||||
+[Int64]
|
||||
|
||||
type Data
|
||||
Value ~table
|
||||
private Value ~table ~arrow
|
||||
|
||||
create num_rows = Data.Value (create_table num_rows)
|
||||
arrow_plus self = self.arrow.at 0
|
||||
arrow_x self = self.arrow.at 1
|
||||
arrow_y self = self.arrow.at 2
|
||||
arrow_u self = self.arrow.at 3
|
||||
arrow_z self = self.arrow.at 4
|
||||
|
||||
create num_rows = Data.Value (create_table num_rows) (create_arrow_columns num_rows)
|
||||
|
||||
collect_benches = Bench.build builder->
|
||||
column_arithmetic_plus_fitting d =
|
||||
(d.table.at "X") + (d.table.at "U")
|
||||
|
||||
column_arithmetic_plus_overflowing d =
|
||||
(d.table.at "Y") + (d.table.at "U")
|
||||
|
||||
column_arithmetic_plus_nothing d =
|
||||
(d.table.at "Z") + (d.table.at "U")
|
||||
|
||||
column_arithmetic_multiply_fitting d =
|
||||
(d.table.at "X") * (d.table.at "U")
|
||||
|
||||
column_arithmetic_multiply_overflowing d =
|
||||
(d.table.at "Y") * (d.table.at "U")
|
||||
|
||||
arrow_arithmetic_plus_fitting d =
|
||||
d.arrow_plus d.arrow_x d.arrow_u
|
||||
|
||||
arrow_arithmetic_plus_overflowing d =
|
||||
d.arrow_plus d.arrow_y d.arrow_u
|
||||
|
||||
arrow_arithmetic_plus_nothing d =
|
||||
d.arrow_plus d.arrow_y d.arrow_u
|
||||
|
||||
num_rows = 1000000
|
||||
data = Data.create num_rows
|
||||
|
||||
Runtime.assert ((column_arithmetic_plus_fitting data . to_vector) == (arrow_arithmetic_plus_fitting data)) "Column and arrow correctness check one"
|
||||
Runtime.assert ((column_arithmetic_plus_overflowing data . to_vector) == (arrow_arithmetic_plus_overflowing data)) "Column and arrow correctness check two"
|
||||
Runtime.assert ((column_arithmetic_plus_nothing data . to_vector) == (arrow_arithmetic_plus_nothing data)) "Column and arrow correctness check three"
|
||||
|
||||
builder.group ("Column_Arithmetic_" + num_rows.to_text) options group_builder->
|
||||
group_builder.specify "Plus_Fitting" <|
|
||||
(data.table.at "X") + (data.table.at "U")
|
||||
column_arithmetic_plus_fitting data
|
||||
group_builder.specify "Plus_Overflowing" <|
|
||||
(data.table.at "Y") + (data.table.at "U")
|
||||
column_arithmetic_plus_overflowing data
|
||||
group_builder.specify "Plus_Nothing" <|
|
||||
column_arithmetic_plus_nothing data
|
||||
|
||||
group_builder.specify "Multiply_Fitting" <|
|
||||
(data.table.at "X") * (data.table.at "U")
|
||||
column_arithmetic_multiply_fitting data
|
||||
group_builder.specify "Multiply_Overflowing" <|
|
||||
(data.table.at "Y") * (data.table.at "U")
|
||||
column_arithmetic_multiply_overflowing data
|
||||
|
||||
builder.group ("Arrow_Arithmetic_" + num_rows.to_text) options group_builder->
|
||||
group_builder.specify "Plus_Fitting" <|
|
||||
arrow_arithmetic_plus_fitting data
|
||||
|
||||
group_builder.specify "Plus_Overflowing" <|
|
||||
arrow_arithmetic_plus_overflowing data
|
||||
|
||||
group_builder.specify "Plus_Nothing" <|
|
||||
arrow_arithmetic_plus_nothing data
|
||||
|
||||
main = collect_benches . run_main
|
||||
|
Loading…
Reference in New Issue
Block a user