Extended dataclass_transform mechanism to support implicit init argument values for field descriptors.

This commit is contained in:
Eric Traut 2021-09-03 19:13:07 -07:00
parent ad0e8acf14
commit e68dff84cb
4 changed files with 193 additions and 10 deletions

View File

@ -2329,6 +2329,51 @@ export function createTypeEvaluator(
if (value === false) {
includeInInit = false;
}
} else {
// See if the field constructor has an `init` parameter with
// a default value.
let callTarget: FunctionType | undefined;
if (isFunction(callType)) {
callTarget = callType;
} else if (isOverloadedFunction(callType)) {
callTarget = getBestOverloadForArguments(
statement.rightExpression,
callType,
statement.rightExpression.arguments
);
} else if (isInstantiableClass(callType)) {
const initCall = getBoundMethod(callType, '__init__');
if (initCall) {
if (isFunction(initCall)) {
callTarget = initCall;
} else if (isOverloadedFunction(initCall)) {
callTarget = getBestOverloadForArguments(
statement.rightExpression,
initCall,
statement.rightExpression.arguments
);
}
}
}
if (callTarget) {
const initParam = callTarget.details.parameters.find((p) => p.name === 'init');
if (
initParam &&
initParam.defaultValueExpression &&
initParam.hasDeclaredType
) {
if (
isClass(initParam.type) &&
ClassType.isBuiltIn(initParam.type, 'bool') &&
isLiteralType(initParam.type)
) {
if (initParam.type.literalValue === false) {
includeInInit = false;
}
}
}
}
}
const kwOnlyArg = statement.rightExpression.arguments.find(
@ -6981,6 +7026,40 @@ export function createTypeEvaluator(
return { argumentErrors: false, returnType: combineTypes(returnTypes), isTypeIncomplete };
}
function getBestOverloadForArguments(
errorNode: ExpressionNode,
type: OverloadedFunctionType,
argList: FunctionArgument[]
): FunctionType | undefined {
let firstMatch: FunctionType | undefined;
type.overloads.forEach((overload) => {
if (!firstMatch) {
useSpeculativeMode(errorNode, () => {
if (FunctionType.isOverloaded(overload)) {
const matchResults = matchFunctionArgumentsToParameters(errorNode, argList, overload);
if (!matchResults.argumentErrors) {
const callResult = validateFunctionArgumentTypes(
errorNode,
matchResults,
overload,
new TypeVarMap(getTypeVarScopeId(overload)),
/* skipUnknownArgCheck */ true,
/* expectedType */ undefined
);
if (callResult && !callResult.argumentErrors) {
firstMatch = overload;
}
}
}
});
}
});
return firstMatch;
}
function validateOverloadedFunctionArguments(
errorNode: ExpressionNode,
argList: FunctionArgument[],
@ -24419,7 +24498,7 @@ export function createTypeEvaluator(
}
// Specializes the specified function for the specified class,
// optionally stripping the first first paramter (the "self" or "cls")
// optionally stripping the first first parameter (the "self" or "cls")
// off of the specialized function in the process. The baseType
// is the type used to reference the member, and the memberClass
// is the class that provided the member (could be an ancestor of

View File

@ -0,0 +1,87 @@
# This sample tests the case where a field descriptor has an implicit
# "init" parameter type based on an overload.
from typing import (
Any,
Callable,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
T = TypeVar("T")
class ModelField:
def __init__(
self,
*,
default: Optional[Any] = ...,
init: Optional[bool] = True,
**kwargs: Any,
) -> None:
...
@overload
def field(
*,
default: Optional[str] = None,
resolver: Callable[[], Any],
init: Literal[False] = False,
) -> Any:
...
@overload
def field(
*,
default: Optional[str] = None,
resolver: None = None,
init: Literal[True] = True,
) -> Any:
...
def field(
*,
default: Optional[str] = None,
resolver: Optional[Callable[[], Any]] = None,
init: bool = True,
) -> Any:
...
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[T], T]:
# If used within a stub file, the following implementation can be
# replaced with "...".
return lambda a: a
@__dataclass_transform__(kw_only_default=True, field_descriptors=(field,))
def create_model(*, init: bool = True) -> Callable[[Type[T]], Type[T]]:
...
@create_model()
class CustomerModel:
id: int = field(resolver=lambda: 0)
name: str = field(default="Voldemort")
CustomerModel()
CustomerModel(name="hi")
# This should generate an error because "id" is not
# supposed to be part of the init function.
CustomerModel(id=1, name="hi")

View File

@ -824,6 +824,12 @@ test('DataclassTransform2', () => {
TestUtils.validateResults(analysisResults, 4);
});
test('DataclassTransform3', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['dataclassTransform3.py']);
TestUtils.validateResults(analysisResults, 1);
});
test('Async1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['async1.py']);

View File

@ -273,7 +273,10 @@ can be positional and may use other names.
`init` is an optional bool parameter that indicates whether the field should
be included in the synthesized `__init__` method. If unspecified, it defaults
to True.
to True. Field descriptor functions can use overloads that implicitly specify
the value of `init` using a literal bool value type (Literal[False] or
Literal[True]).
`default` is an optional parameter that provides the default value for the
field.
@ -287,19 +290,27 @@ provided a value when the class is instantiated.
the field. This alternative name is used in the synthesized `__init__` method.
This example demonstrates
```python
# Library code (within type stub or inline):
class ModelField:
def __init__(
self,
@overload
def model_field(
*,
default: Optional[Any] = ...,
init: Optional[bool] = True,
**kwargs: Any
) -> None: ...
resolver: Callable[[], Any],
init: Literal[False] = False,
) -> Any: ...
@typing.dataclass_transform(kw_only_default=True, field_descriptors=(ModelField, ))
@overload
def model_field(
*,
default: Optional[Any] = ...,
resolver: None = None,
init: bool = True,
) -> Any: ...
@typing.dataclass_transform(kw_only_default=True, field_descriptors=(model_field, ))
def create_model(
*,
init: bool = True
@ -309,7 +320,7 @@ def create_model(
# Code that imports this library:
@create_model(init=False)
class CustomerModel:
id: int = ModelField(default=0)
id: int = ModelField(resolver=lambda : 0)
name: str
```