Added support for x[K] is V and x[K] is not V type narrowing forms. This addresses https://github.com/microsoft/pyright/issues/4453.

This commit is contained in:
Eric Traut 2023-01-13 16:53:18 -08:00
parent 8544b19923
commit a2784020f2
3 changed files with 47 additions and 2 deletions

View File

@ -176,7 +176,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t
* `x.y is None` and `x.y is not None` (where x is a type that is distinguished by a field with a None)
* `x.y is E` and `x.y is not E` (where E is a literal enum or bool and x is a type that is distinguished by a field with a literal type)
* `x.y == L` and `x.y != L` (where L is a literal expression and x is a type that is distinguished by a field or property with a literal type)
* `x[K] == V` and `x[K] != V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type)
* `x[K] == V`, `x[K] != V`, `x[K] is V`, and `x[K] is not V` (where K and V are literal expressions and x is a type that is distinguished by a TypedDict field with a literal type)
* `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I)
* `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I)
* `len(x) == L` and `len(x) != L` (where x is tuple and L is a literal integer)

View File

@ -246,6 +246,40 @@ export function getTypeNarrowingCallback(
};
}
}
// Look for X[<literal>] is <literal> or X[<literal>] is not <literal>
if (
testExpression.leftExpression.nodeType === ParseNodeType.Index &&
testExpression.leftExpression.items.length === 1 &&
!testExpression.leftExpression.trailingComma &&
testExpression.leftExpression.items[0].argumentCategory === ArgumentCategory.Simple &&
ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression.baseExpression)
) {
const indexTypeResult = evaluator.getTypeOfExpression(
testExpression.leftExpression.items[0].valueExpression
);
const indexType = indexTypeResult.type;
if (isClassInstance(indexType) && isLiteralType(indexType)) {
if (ClassType.isBuiltIn(indexType, 'str')) {
const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type;
if (isClassInstance(rightType) && rightType.literalValue !== undefined) {
return (type: Type) => {
return {
type: narrowTypeForDiscriminatedDictEntryComparison(
evaluator,
type,
indexType,
rightType,
adjIsPositiveTest
),
isIncomplete: !!indexTypeResult.isIncomplete,
};
};
}
}
}
}
}
if (equalsOrNotEqualsOperator) {

View File

@ -23,7 +23,7 @@ class OtherEvent(TypedDict):
Event = Union[NewJobEvent, CancelJobEvent, OtherEvent]
def process_event(event: Event) -> None:
def process_event1(event: Event) -> None:
if event["tag"] == "new-job":
reveal_type(event, expected_text="NewJobEvent")
event["job_name"]
@ -33,3 +33,14 @@ def process_event(event: Event) -> None:
else:
reveal_type(event, expected_text="OtherEvent")
event["message"]
def process_event2(event: Event) -> None:
if event["tag"] is "new-job":
reveal_type(event, expected_text="NewJobEvent")
event["job_name"]
elif event["tag"] is 2:
reveal_type(event, expected_text="CancelJobEvent")
event["job_id"]
else:
reveal_type(event, expected_text="OtherEvent")
event["message"]