Enhanced the "aliased conditional" type narrowing capability to accommodate multiple assignments of the variable used within the aliased conditional as long as the variable isn't reassigned between the the aliased conditional assignment and the conditional check that uses the aliased value.

This commit is contained in:
Eric Traut 2021-12-31 12:52:20 -07:00
parent 20a47fea6f
commit 00067d22d8
2 changed files with 98 additions and 15 deletions

View File

@ -15,6 +15,7 @@ import {
isExpressionNode,
NameNode,
ParameterCategory,
ParseNode,
ParseNodeType,
} from '../parser/parseNodes';
import { KeywordType, OperatorType } from '../parser/tokenizerTypes';
@ -469,16 +470,45 @@ export function getTypeNarrowingCallback(
testExpression !== reference
) {
// Make sure the reference expression is a constant parameter or variable.
// If it is modified somewhere within the scope, it's not safe to apply
// this form of type narrowing.
if (getDeclForLocalConst(evaluator, reference) !== undefined) {
const testExprDecl = getDeclForLocalConst(evaluator, testExpression);
// If the reference expression is modified within the scope multiple times,
// we need to validate that it is not modified between the test expression
// evaluation and the conditional check.
const testExprDecl = getDeclsForLocalVar(evaluator, testExpression, testExpression);
if (testExprDecl && testExprDecl.length === 1 && testExprDecl[0].type === DeclarationType.Variable) {
const referenceDecls = getDeclsForLocalVar(evaluator, reference, testExpression);
if (testExprDecl && testExprDecl.type === DeclarationType.Variable) {
const initNode = testExprDecl.inferredTypeSource;
if (referenceDecls) {
let modifyingDecls: Declaration[] = [];
if (initNode && initNode !== testExpression && isExpressionNode(initNode)) {
return getTypeNarrowingCallback(evaluator, reference, initNode, isPositiveTest);
if (referenceDecls.length > 1) {
// If there is more than one assignment to the reference variable within
// the local scope, make sure that none of these assignments are done
// after the test expression but before the condition check.
//
// This is OK:
// val = None
// is_none = val is None
// if is_none: ...
//
// This is not OK:
// val = None
// is_none = val is None
// val = 1
// if is_none: ...
modifyingDecls = referenceDecls.filter((decl) => {
return (
evaluator.isNodeReachable(testExpression, decl.node) &&
evaluator.isNodeReachable(decl.node, testExprDecl[0].node)
);
});
}
if (modifyingDecls.length === 0) {
const initNode = testExprDecl[0].inferredTypeSource;
if (initNode && initNode !== testExpression && isExpressionNode(initNode)) {
return getTypeNarrowingCallback(evaluator, reference, initNode, isPositiveTest);
}
}
}
}
@ -497,9 +527,12 @@ export function getTypeNarrowingCallback(
}
// Determines whether the symbol is a local variable or parameter within
// the current scope _and_ is a constant (assigned only once). If so, it
// returns the declaration for the symbol.
function getDeclForLocalConst(evaluator: TypeEvaluator, name: NameNode): Declaration | undefined {
// the current scope.
function getDeclsForLocalVar(
evaluator: TypeEvaluator,
name: NameNode,
reachableFrom: ParseNode
): Declaration[] | undefined {
const scope = getScopeForNode(name);
if (scope?.type !== ScopeType.Function && scope?.type !== ScopeType.Module) {
return undefined;
@ -511,16 +544,33 @@ function getDeclForLocalConst(evaluator: TypeEvaluator, name: NameNode): Declara
}
const decls = symbol.getDeclarations();
if (decls.length !== 1) {
if (
decls.length === 0 ||
decls.some((decl) => decl.type !== DeclarationType.Variable && decl.type !== DeclarationType.Parameter)
) {
return undefined;
}
const primaryDecl = decls[0];
if (primaryDecl.type !== DeclarationType.Variable && primaryDecl.type !== DeclarationType.Parameter) {
// If there are any assignments within different scopes (e.g. via a "global" or
// "nonlocal" reference), don't consider it a local variable.
let prevDeclScope: ParseNode | undefined;
if (
decls.some((decl) => {
const nodeToConsider = decl.type === DeclarationType.Parameter ? decl.node.name! : decl.node;
const declScopeNode = ParseTreeUtils.getExecutionScopeNode(nodeToConsider);
if (prevDeclScope && declScopeNode !== prevDeclScope) {
return true;
}
prevDeclScope = declScopeNode;
return false;
})
) {
return undefined;
}
return primaryDecl;
const reachableDecls = decls.filter((decl) => evaluator.isNodeReachable(reachableFrom, decl.node));
return reachableDecls.length > 0 ? reachableDecls : undefined;
}
// Handle type narrowing for expressions of the form "a[I] is None" and "a[I] is not None" where

View File

@ -1,5 +1,6 @@
# This sample tests the case where a local (constant) variable that
# is assigned a narrowing expression can be used in a type guard condition.
# These are sometimes referred to as "aliased conditional expressions".
from typing import Literal, Optional, Union
@ -83,3 +84,35 @@ def func6(x: Union[A, B]) -> None:
if random.random() < 0.5:
x = B()
def get_string() -> str:
...
def get_optional_string() -> Optional[str]:
...
def func7(val: Optional[str] = None):
val = get_optional_string()
val_is_none = val is None
if val_is_none:
val = get_string()
t1: Literal["str"] = reveal_type(val)
def func8(val: Optional[str] = None):
val = get_optional_string()
val_is_none = val is None
val = get_optional_string()
if val_is_none:
val = get_string()
t1: Literal["str | None"] = reveal_type(val)