only open SQLite log file once per db wrapper (#13754)

* only open SQLite log file once per db wrapper

* actually remember the log file

* actually remember to close the connections...
This commit is contained in:
Kyle Altendorf 2022-11-01 16:43:19 -04:00 committed by GitHub
parent 73acfc7409
commit 7897c80ed1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 58 deletions

View File

@ -69,7 +69,7 @@ from chia.util.bech32m import encode_puzzle_hash
from chia.util.check_fork_next_block import check_fork_next_block
from chia.util.condition_tools import pkm_pairs
from chia.util.config import PEER_DB_PATH_KEY_DEPRECATED, process_config_start_method
from chia.util.db_wrapper import DBWrapper2, create_connection
from chia.util.db_wrapper import DBWrapper2, manage_connection
from chia.util.errors import ConsensusError, Err, ValidationError
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.util.path import path_from_root
@ -323,7 +323,7 @@ class FullNode:
self._respond_transaction_semaphore = asyncio.Semaphore(200)
# create the store (db) and full node instance
# TODO: is this standardized and thus able to be handled by DBWrapper2?
async with create_connection(self.db_path) as db_connection:
async with manage_connection(self.db_path) as db_connection:
db_version = await lookup_db_version(db_connection)
self.log.info(f"using blockchain database {self.db_path}, which is version {db_version}")

View File

@ -4,11 +4,9 @@ import asyncio
import contextlib
import functools
import sqlite3
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Generator, Iterable, Optional, Type, Union
from typing import Any, AsyncIterator, Dict, Iterable, Optional, TextIO, Type, Union
import aiosqlite
from typing_extensions import final
@ -28,55 +26,46 @@ async def execute_fetchone(
return None
@dataclass
class create_connection:
"""Create an object that can both be `await`ed and `async with`ed to get a
connection.
"""
async def _create_connection(
database: Union[str, Path],
uri: bool = False,
log_file: Optional[TextIO] = None,
name: Optional[str] = None,
) -> aiosqlite.Connection:
connection = await aiosqlite.connect(database=database, uri=uri)
# def create_connection( (for searchability
database: Union[str, Path]
uri: bool = False
log_path: Optional[Path] = None
name: Optional[str] = None
_connection: Optional[aiosqlite.Connection] = field(init=False, default=None)
if log_file is not None:
await connection.set_trace_callback(functools.partial(sql_trace_callback, file=log_file, name=name))
async def _create(self) -> aiosqlite.Connection:
self._connection = await aiosqlite.connect(database=self.database, uri=self.uri)
if self.log_path is not None:
await self._connection.set_trace_callback(
functools.partial(sql_trace_callback, path=self.log_path, name=self.name)
)
return self._connection
def __await__(self) -> Generator[Any, None, Any]:
return self._create().__await__()
async def __aenter__(self) -> aiosqlite.Connection:
self._connection = await self._create()
return self._connection
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if self._connection is None:
raise RuntimeError("exiting while self._connection is None")
await self._connection.close()
return connection
def sql_trace_callback(req: str, path: Path, name: Optional[str] = None) -> None:
@contextlib.asynccontextmanager
async def manage_connection(
database: Union[str, Path],
uri: bool = False,
log_path: Optional[Path] = None,
name: Optional[str] = None,
) -> AsyncIterator[aiosqlite.Connection]:
if log_path is not None:
with log_path.open("a", encoding="utf-8") as file:
connection = await _create_connection(database=database, uri=uri, log_file=file, name=name)
else:
connection = await _create_connection(database=database, uri=uri, name=name)
try:
yield connection
finally:
await connection.close()
def sql_trace_callback(req: str, file: TextIO, name: Optional[str] = None) -> None:
timestamp = datetime.now().strftime("%H:%M:%S.%f")
with path.open(mode="a") as log:
if name is not None:
line = f"{timestamp} {name} {req}\n"
else:
line = f"{timestamp} {req}\n"
log.write(line)
if name is not None:
line = f"{timestamp} {name} {req}\n"
else:
line = f"{timestamp} {req}\n"
file.write(line)
@final
@ -89,6 +78,7 @@ class DBWrapper2:
_in_use: Dict[asyncio.Task, aiosqlite.Connection]
_current_writer: Optional[asyncio.Task]
_savepoint_name: int
_log_file: Optional[TextIO]
async def add_connection(self, c: aiosqlite.Connection) -> None:
# this guarantees that reader connections can only be used for reading
@ -97,7 +87,12 @@ class DBWrapper2:
self._read_connections.put_nowait(c)
self._num_read_connections += 1
def __init__(self, connection: aiosqlite.Connection, db_version: int = 1) -> None:
def __init__(
self,
connection: aiosqlite.Connection,
db_version: int = 1,
log_file: Optional[TextIO] = None,
) -> None:
self._read_connections = asyncio.Queue()
self._write_connection = connection
self._lock = asyncio.Lock()
@ -106,6 +101,7 @@ class DBWrapper2:
self._in_use = {}
self._current_writer = None
self._savepoint_name = 0
self._log_file = log_file
@classmethod
async def create(
@ -121,7 +117,11 @@ class DBWrapper2:
foreign_keys: bool = False,
row_factory: Optional[Type[aiosqlite.Row]] = None,
) -> DBWrapper2:
write_connection = await create_connection(database=database, uri=uri, log_path=log_path, name="writer")
if log_path is None:
log_file = None
else:
log_file = log_path.open("a", encoding="utf-8")
write_connection = await _create_connection(database=database, uri=uri, log_file=log_file, name="writer")
await (await write_connection.execute(f"pragma journal_mode={journal_mode}")).close()
if synchronous is not None:
await (await write_connection.execute(f"pragma synchronous={synchronous}")).close()
@ -130,13 +130,13 @@ class DBWrapper2:
write_connection.row_factory = row_factory
self = cls(connection=write_connection, db_version=db_version)
self = cls(connection=write_connection, db_version=db_version, log_file=log_file)
for index in range(reader_count):
read_connection = await create_connection(
read_connection = await _create_connection(
database=database,
uri=uri,
log_path=log_path,
log_file=log_file,
name=f"reader-{index}",
)
read_connection.row_factory = row_factory
@ -145,10 +145,14 @@ class DBWrapper2:
return self
async def close(self) -> None:
while self._num_read_connections > 0:
await (await self._read_connections.get()).close()
self._num_read_connections -= 1
await self._write_connection.close()
try:
while self._num_read_connections > 0:
await (await self._read_connections.get()).close()
self._num_read_connections -= 1
await self._write_connection.close()
finally:
if self._log_file is not None:
self._log_file.close()
def _next_savepoint(self) -> str:
name = f"s{self._savepoint_name}"