Added missing support for matrix multiply operator.

This commit is contained in:
Eric Traut 2019-04-01 17:48:27 -07:00
parent 38feba9610
commit 329ad74b2f
2 changed files with 31 additions and 11 deletions

View File

@ -1366,6 +1366,15 @@ export class ExpressionEvaluator {
private _getTypeFromUnaryExpression(node: UnaryExpressionNode, flags: EvaluatorFlags): TypeResult {
let exprType = this._getTypeFromExpression(node.expression, flags).type;
// Map unary operators to magic functions. Note that the bitwise
// invert has two magic functions that are aliases of each other.
const unaryOperatorMap: { [operator: number]: [string, string] } = {
[OperatorType.Add]: ['__pos__', ''],
[OperatorType.Subtract]: ['__neg__', ''],
[OperatorType.Not]: ['__not__', ''],
[OperatorType.BitwiseInvert]: ['__inv__', '__invert__']
};
let type: Type;
if (exprType.isAny()) {
type = exprType;
@ -1426,14 +1435,15 @@ export class ExpressionEvaluator {
this._getTypeFromExpression(node.rightExpression, flags);
});
const arithmeticOperatorMap: { [operator: number]: [string, string] } = {
[OperatorType.Add]: ['__add__', '__radd__'],
[OperatorType.Subtract]: ['__sub__', '__rsub__'],
[OperatorType.Multiply]: ['__mul__', '__rmul__'],
[OperatorType.FloorDivide]: ['__floordiv__', '__rfloordiv__'],
[OperatorType.Divide]: ['__truediv__', '__rtruediv__'],
[OperatorType.Mod]: ['__mod__', '__rmod__'],
[OperatorType.Power]: ['__power__', '__rpower__']
const arithmeticOperatorMap: { [operator: number]: [string, string, boolean] } = {
[OperatorType.Add]: ['__add__', '__radd__', true],
[OperatorType.Subtract]: ['__sub__', '__rsub__', true],
[OperatorType.Multiply]: ['__mul__', '__rmul__', true],
[OperatorType.FloorDivide]: ['__floordiv__', '__rfloordiv__', true],
[OperatorType.Divide]: ['__truediv__', '__rtruediv__', true],
[OperatorType.Mod]: ['__mod__', '__rmod__', true],
[OperatorType.Power]: ['__power__', '__rpower__', true],
[OperatorType.MatrixMultiply]: ['__matmul__', '', false]
};
const bitwiseOperatorMap: { [operator: number]: [string, string] } = {
@ -1478,17 +1488,19 @@ export class ExpressionEvaluator {
return foundMatch;
});
};
const leftClassMatches = getTypeMatch(leftType.getClassType());
const rightClassMatches = getTypeMatch(rightType.getClassType());
const supportsBuiltInTypes = arithmeticOperatorMap[node.operator][2];
if (leftClassMatches[0] && rightClassMatches[0]) {
if (supportsBuiltInTypes && leftClassMatches[0] && rightClassMatches[0]) {
// If they're both int types, the result is an int.
type = new ObjectType(builtInClassTypes[0]!);
} else if (leftClassMatches[1] && rightClassMatches[1]) {
} else if (supportsBuiltInTypes && leftClassMatches[1] && rightClassMatches[1]) {
// If they're both floats or one is a float and one is an int,
// the result is a float.
type = new ObjectType(builtInClassTypes[1]!);
} else if (leftClassMatches[2] && rightClassMatches[2]) {
} else if (supportsBuiltInTypes && leftClassMatches[2] && rightClassMatches[2]) {
// If one is complex and the other is complex, float or int,
// the result is complex.
type = new ObjectType(builtInClassTypes[2]!);

View File

@ -45,4 +45,12 @@ def returnsComplex1() -> complex:
return a + b % (b / a - c // a)
a = 3
b = 4
# This should generate an error because matrix multiply
# isn't supported for int.
c = (a @ b)