mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-10-26 20:40:51 +03:00
add datacases()
and named_datacases()
(#15265)
* add `datacases()` and `named_datacases()` * correct DataCasesProtocol * add back the tests for testing the test utilities
This commit is contained in:
parent
d1ffa43ea8
commit
217429a126
@ -9,6 +9,10 @@ asyncio_mode = strict
|
|||||||
markers =
|
markers =
|
||||||
benchmark
|
benchmark
|
||||||
data_layer: Mark as a data layer related test.
|
data_layer: Mark as a data layer related test.
|
||||||
|
test_mark_a1: used in testing test utilities
|
||||||
|
test_mark_a2: used in testing test utilities
|
||||||
|
test_mark_b1: used in testing test utilities
|
||||||
|
test_mark_b2: used in testing test utilities
|
||||||
testpaths = tests
|
testpaths = tests
|
||||||
filterwarnings =
|
filterwarnings =
|
||||||
error
|
error
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
|
import functools
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -13,10 +14,10 @@ from statistics import mean
|
|||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from time import thread_time
|
from time import thread_time
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import Any, Callable, Iterator, List, Optional, Type, Union
|
from typing import Any, Callable, Collection, Iterator, List, Optional, Type, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from typing_extensions import final
|
from typing_extensions import Protocol, final
|
||||||
|
|
||||||
from tests.core.data_layer.util import ChiaRoot
|
from tests.core.data_layer.util import ChiaRoot
|
||||||
|
|
||||||
@ -303,3 +304,31 @@ def closing_chia_root_popen(chia_root: ChiaRoot, args: List[str]) -> Iterator[su
|
|||||||
process.wait(timeout=10)
|
process.wait(timeout=10)
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
process.kill()
|
process.kill()
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytest-dev/pytest/blob/7.3.1/src/_pytest/mark/__init__.py#L45
|
||||||
|
Marks = Union[pytest.MarkDecorator, Collection[Union[pytest.MarkDecorator, pytest.Mark]]]
|
||||||
|
|
||||||
|
|
||||||
|
class DataCase(Protocol):
|
||||||
|
marks: Marks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def datacases(*cases: DataCase, _name: str = "case") -> pytest.MarkDecorator:
|
||||||
|
return pytest.mark.parametrize(
|
||||||
|
argnames=_name,
|
||||||
|
argvalues=[pytest.param(case, id=case.id, marks=case.marks) for case in cases],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataCasesDecorator(Protocol):
|
||||||
|
def __call__(self, *cases: DataCase, _name: str = "case") -> pytest.MarkDecorator:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def named_datacases(name: str) -> DataCasesDecorator:
|
||||||
|
return functools.partial(datacases, _name=name)
|
||||||
|
38
tests/util/test_tests_misc.py
Normal file
38
tests/util/test_tests_misc.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.util.misc import Marks, datacases, named_datacases
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCase:
|
||||||
|
id: str
|
||||||
|
marks: Marks
|
||||||
|
|
||||||
|
|
||||||
|
sample_cases = [
|
||||||
|
DataCase(id="id_a", marks=[pytest.mark.test_mark_a1, pytest.mark.test_mark_a2]),
|
||||||
|
DataCase(id="id_b", marks=[pytest.mark.test_mark_b1, pytest.mark.test_mark_b2]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def sample_result(name: str) -> pytest.MarkDecorator:
|
||||||
|
return pytest.mark.parametrize(
|
||||||
|
argnames=name,
|
||||||
|
argvalues=[pytest.param(case, id=case.id, marks=case.marks) for case in sample_cases],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_datacases() -> None:
|
||||||
|
result = datacases(*sample_cases)
|
||||||
|
|
||||||
|
assert result == sample_result(name="case")
|
||||||
|
|
||||||
|
|
||||||
|
def test_named_datacases() -> None:
|
||||||
|
result = named_datacases("Sharrilanda")(*sample_cases)
|
||||||
|
|
||||||
|
assert result == sample_result(name="Sharrilanda")
|
Loading…
Reference in New Issue
Block a user