diff --git a/docs/type-concepts.md b/docs/type-concepts.md index 3b9b4a03a..84f62c0e2 100644 --- a/docs/type-concepts.md +++ b/docs/type-concepts.md @@ -187,6 +187,55 @@ def func2(val: Optional[int]): In the example of `func1`, the type was narrowed in both the positive and negative cases. In the example of `func2`, the type was narrowed only the positive case because the type of `val` might be either `int` (specifically, a value of 0) or `None` in the negative case. +### Narrowing Based on a Local Variable + +Pyright also supports a type guard expression `c`, where `c` is an identifier that refers to a local variable that is assigned one of the above supported type guard expression forms. For example, `c = a is not None` can be used to narrow the expression `a`. This pattern is supported only in cases where `c` is a local variable within a module or function scope and is assigned a value only once. It is also limited to cases where expression `a` is a simple identifier (as opposed to a member access expression or subscript expression), is local to the function or module scope, and is assigned only once within the scope. Unary `not` operators are allowed for expression `a`, but binary `and` and `or` are not. + +```python +def func1(x: str | None): + is_str = x is not None + + if is_str: + reveal_type(x) # str + else: + reveal_type(x) # None +``` + +```python +def func2(val: str | bytes): + is_str = not isinstance(val, bytes) + + if not is_str: + reveal_type(val) # bytes + else: + reveal_type(val) # str +``` + +```python +def func3(x: List[str | None]) -> str: + is_str = x[0] is not None + + if is_str: + # This technique doesn't work for subscript expressions, + # so x[0] is not narrowed in this case. + reveal_type(x[0]) # str | None +``` + +```python +def func4(x: str | None): + is_str = x is not None + + if is_str: + # This technique doesn't work in cases where the target + # expression is assigned elsewhere. Here `x` is assigned + # elsewhere in the function, so its type is not narrowed + # in this case. + reveal_type(x) # str | None + + x = "" +``` + + ### Narrowing for Implied Else When an “if” or “elif” clause is used without a corresponding “else”, Pyright will generally assume that the code can “fall through” without executing the “if” or “elif” block. However, there are cases where the analyzer can determine that a fall-through is not possible because the “if” or “elif” is guaranteed to be executed based on type analysis. diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index 7c3e0f72c..84e82405b 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -9,10 +9,20 @@ * negative ("else") narrowing cases. */ -import { ArgumentCategory, ExpressionNode, ParameterCategory, ParseNodeType } from '../parser/parseNodes'; +import { + ArgumentCategory, + ExpressionNode, + isExpressionNode, + NameNode, + ParameterCategory, + ParseNodeType, +} from '../parser/parseNodes'; import { KeywordType, OperatorType } from '../parser/tokenizerTypes'; import { getFileInfo } from './analyzerNodeInfo'; +import { Declaration, DeclarationType } from './declaration'; import * as ParseTreeUtils from './parseTreeUtils'; +import { ScopeType } from './scope'; +import { getScopeForNode } from './scopeUtils'; import { Symbol, SymbolFlags } from './symbol'; import { getTypedDictMembersForClass } from './typedDicts'; import { EvaluatorFlags, TypeEvaluator } from './typeEvaluatorTypes'; @@ -435,9 +445,64 @@ export function getTypeNarrowingCallback( }; } + // Is this a reference to a local variable that was assigned a value + // that can inform type narrowing of the reference expression? + if (testExpression.nodeType === ParseNodeType.Name && reference.nodeType === ParseNodeType.Name) { + // Make sure the reference expression is a constant parameter or variable. + // If it is modified somewhere within the scope, it's not safe to apply + // this form of type narrowing. + if (getDeclForLocalConst(evaluator, reference) !== undefined) { + const testExprDecl = getDeclForLocalConst(evaluator, testExpression); + + if (testExprDecl && testExprDecl.type === DeclarationType.Variable) { + const initNode = testExprDecl.inferredTypeSource; + + if (initNode && initNode !== testExpression && isExpressionNode(initNode)) { + return getTypeNarrowingCallback(evaluator, reference, initNode, isPositiveTest); + } + } + } + } + + // We normally won't find a "not" operator here because they are stripped out + // by the binder when it creates condition flow nodes, but we can find this + // in the case of local variables type narrowing. + if (testExpression.nodeType === ParseNodeType.UnaryOperation) { + if (testExpression.operator === OperatorType.Not) { + return getTypeNarrowingCallback(evaluator, reference, testExpression.expression, !isPositiveTest); + } + } + return undefined; } +// Determines whether the symbol is a local variable or parameter within +// the current scope _and_ is a constant (assigned only once). If so, it +// returns the declaration for the symbol. +function getDeclForLocalConst(evaluator: TypeEvaluator, name: NameNode): Declaration | undefined { + const scope = getScopeForNode(name); + if (scope?.type !== ScopeType.Function && scope?.type !== ScopeType.Module) { + return undefined; + } + + const symbol = scope.lookUpSymbol(name.value); + if (!symbol) { + return undefined; + } + + const decls = symbol.getDeclarations(); + if (decls.length !== 1) { + return undefined; + } + + const primaryDecl = decls[0]; + if (primaryDecl.type !== DeclarationType.Variable && primaryDecl.type !== DeclarationType.Parameter) { + return undefined; + } + + return primaryDecl; +} + // Handle type narrowing for expressions of the form "a[I] is None" and "a[I] is not None" where // I is an integer and a is a union of Tuples with known lengths and entry types. function narrowTupleTypeForIsNone(evaluator: TypeEvaluator, type: Type, isPositiveTest: boolean, indexValue: number) { diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingLocalConst1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingLocalConst1.py new file mode 100644 index 000000000..d3bb61705 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingLocalConst1.py @@ -0,0 +1,85 @@ +# This sample tests the case where a local (constant) variable that +# is assigned a narrowing expression can be used in a type guard condition. + + +from typing import Literal, Optional, Union +import random + + +class A: + a: int + + +class B: + b: int + + +def func1(x: Union[A, B]) -> None: + is_a = not not isinstance(x, A) + + if not is_a: + t1: Literal["B"] = reveal_type(x) + else: + t2: Literal["A"] = reveal_type(x) + + +def func2(x: Union[A, B]) -> None: + is_a = isinstance(x, A) + + if random.random() < 0.5: + x = B() + + if is_a: + t1: Literal["B | A"] = reveal_type(x) + else: + t2: Literal["B | A"] = reveal_type(x) + + +def func3(x: Optional[int]): + is_number = x != None + + if is_number: + t1: Literal["int"] = reveal_type(x) + else: + t2: Literal["None"] = reveal_type(x) + + +def func4() -> Optional[A]: + return A() if random.random() < 0.5 else None + + +maybe_a1 = func4() +is_a1 = maybe_a1 + +if is_a1: + t1: Literal["A"] = reveal_type(maybe_a1) +else: + t2: Literal["None"] = reveal_type(maybe_a1) + +maybe_a2 = func4() + + +def func5(): + global maybe_a2 + maybe_a2 = False + + +is_a2 = maybe_a2 + +if is_a2: + t3: Literal["A | None"] = reveal_type(maybe_a2) +else: + t4: Literal["A | None"] = reveal_type(maybe_a2) + + +def func6(x: Union[A, B]) -> None: + is_a = isinstance(x, A) + + for y in range(1): + if is_a: + t1: Literal["A | B"] = reveal_type(x) + else: + t2: Literal["A | B"] = reveal_type(x) + + if random.random() < 0.5: + x = B() diff --git a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts index ec9dbaeb0..9bad20b96 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator1.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator1.test.ts @@ -409,6 +409,12 @@ test('TypeNarrowingFalsy1', () => { TestUtils.validateResults(analysisResults, 0); }); +test('TypeNarrowingLocalConst1', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typeNarrowingLocalConst1.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('ReturnTypes1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['returnTypes1.py']);