diff --git a/docs/type-concepts.md b/docs/type-concepts.md index 193c4c954..f707dd608 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -178,7 +178,7 @@ In addition to assignment-based type narrowing, Pyright supports the following t * `x[I] == V` and `x[I] != V` (where I and V are literal expressions and x is a known-length tuple that is distinguished by the index indicated by I) * `x[I] is None` and `x[I] is not None` (where I is a literal expression and x is a known-length tuple that is distinguished by the index indicated by I) * `len(x) == L` and `len(x) != L` (where x is tuple and L is a literal integer) -* `x in y` (where y is instance of list, set, frozenset, deque, or tuple) +* `x in y` or `x not in y` (where y is instance of list, set, frozenset, deque, or tuple) * `S in D` and `S not in D` (where S is a string literal and D is a TypedDict) * `isinstance(x, T)` (where T is a type or a tuple of types) * `issubclass(x, T)` (where T is a type or a tuple of types) diff --git a/packages/pyright-internal/src/analyzer/binder.ts b/packages/pyright-internal/src/analyzer/binder.ts index 5d3576cd3..b7877ce65 100644 --- a/packages/pyright-internal/src/analyzer/binder.ts +++ b/packages/pyright-internal/src/analyzer/binder.ts @@ -2814,8 +2814,8 @@ export class Binder extends ParseTreeWalker { } } - // Look for "X in Y". - if (expression.operator === OperatorType.In) { + // Look for "X in Y" or "X not in Y". + if (expression.operator === OperatorType.In || expression.operator === OperatorType.NotIn) { return this._isNarrowingExpression( expression.leftExpression, expressionList, diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 47a62d605..f84ee33f1 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -389,17 +389,20 @@ export function getTypeNarrowingCallback( } } - if (testExpression.operator === OperatorType.In) { - // Look for "x in y" where y is one of several built-in types. - if (isPositiveTest && ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) { - const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type; - return (type: Type) => { - return narrowTypeForContains(evaluator, type, rightType); - }; - } - } - if (testExpression.operator === OperatorType.In || testExpression.operator === OperatorType.NotIn) { + // Look for "x in y" or "x not in y" where y is one of several built-in types. + if (ParseTreeUtils.isMatchingExpression(reference, testExpression.leftExpression)) { + const rightType = evaluator.getTypeOfExpression(testExpression.rightExpression).type; + const adjIsPositiveTest = + testExpression.operator === OperatorType.In ? isPositiveTest : !isPositiveTest; + + if (adjIsPositiveTest) { + return (type: Type) => { + return narrowTypeForContains(evaluator, type, rightType); + }; + } + } + if (ParseTreeUtils.isMatchingExpression(reference, testExpression.rightExpression)) { // Look for in y where y is a union that contains // one or more TypedDicts. @@ -1342,16 +1345,18 @@ function narrowTypeForContains(evaluator: TypeEvaluator, referenceType: Type, co const elementTypeWithoutLiteral = stripLiteralValue(elementType); const narrowedType = mapSubtypes(referenceType, (referenceSubtype) => { - if (isAnyOrUnknown(referenceSubtype)) { + const concreteReferenceType = evaluator.makeTopLevelTypeVarsConcrete(referenceSubtype); + + if (isAnyOrUnknown(concreteReferenceType)) { canNarrow = false; return referenceSubtype; } - if (evaluator.assignType(elementType, referenceSubtype)) { + if (evaluator.assignType(elementType, concreteReferenceType)) { return referenceSubtype; } - if (evaluator.assignType(elementTypeWithoutLiteral, referenceSubtype)) { + if (evaluator.assignType(elementTypeWithoutLiteral, concreteReferenceType)) { return mapSubtypes(elementType, (elementSubtype) => { if (isClassInstance(elementSubtype) && isSameWithoutLiteralValue(referenceSubtype, elementSubtype)) { return elementSubtype; diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py index 4df0e43f7..875bfad36 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py @@ -69,3 +69,17 @@ def func2(a: Literal[1, 2, 3]): reveal_type(a, expected_text="Literal[1, 2]") else: reveal_type(a, expected_text="Literal[1, 2, 3]") + + +def func3(val: str | None, container: frozenset[str]): + if val in container: + reveal_type(val, expected_text="str") + else: + reveal_type(val, expected_text="str | None") + + +def func4(val: str | None, container: list[str]): + if val not in container: + reveal_type(val, expected_text="str | None") + else: + reveal_type(val, expected_text="str")