fix: asyncpg pooling config fix (#2795)

# Description

closes #2782.

Changes `sqlalchemy` connection pooling config : 

- **pool_pre_ping=True** : pessimistic disconnect handling. 

> - It is critical to note that the pre-ping approach does not
accommodate **for connections dropped in the middle of transactions or
other SQL operations** ! But this should only happen if we lose the
database either due to network of DB server restart.

- **pool_size=10**, with no db side pooling for now, if 6 uvicorn
process workers are spawned (or 6 instances of the backed) we have a
pool of 60 processes connecting to the database.
- **pool_recycle=1800** :  Recycles the pool every 30min

Added additional session config : 
- expire_on_commit=False, When True, all instances will be fully expired
after each commit, so that all attribute/object access subsequent to a
completed transaction will load from the most recent database state.
- autoflush=False, query operations will issue a Session.flush() call to
this Session before proceeding. We are calling `session.commit` (which
flushes) on each repository method so ne need to reflush on subsequent
access
This commit is contained in:
AmineDiro 2024-07-04 14:28:02 +02:00 committed by GitHub
parent 913217f682
commit 757bceeb95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 33 deletions

View File

@ -3,6 +3,7 @@ from typing import List
from uuid import UUID
from fastapi import HTTPException
from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import Brain
from quivr_api.modules.brain.service.brain_service import BrainService
@ -52,7 +53,7 @@ class ChatService(BaseService[ChatRepository]):
return inserted_chat
def get_follow_up_question(
self, brain_id: UUID = None, question: str = None
self, brain_id: UUID | None = None, question: str = None
) -> [str]:
follow_up = [
"Summarize the conversation",

View File

@ -6,41 +6,20 @@ from uuid import uuid4
import pytest
import pytest_asyncio
import sqlalchemy
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
from quivr_api.modules.chat.dto.inputs import QuestionAndAnswer
from quivr_api.modules.chat.entity.chat import Chat, ChatHistory
from quivr_api.modules.chat.repository.chats import ChatRepository
from quivr_api.modules.chat.service.chat_service import ChatService
from quivr_api.modules.user.entity.user_identity import User
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import create_engine, select
from sqlmodel.ext.asyncio.session import AsyncSession
pg_database_url = "postgres:postgres@localhost:54322/postgres"
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
@pytest.fixture(scope="session", autouse=True)
def db_setup():
# setup
sync_engine = create_engine(
"postgresql://" + pg_database_url,
echo=True if os.getenv("ORM_DEBUG") else False,
)
# TODO(@amine) : for now don't drop anything
yield sync_engine
# teardown
# NOTE: For now we rely on Supabase migrations for defining schemas
# SQLModel.metadata.create_all(sync_engine, checkfirst=True)
# SQLModel.metadata.drop_all(sync_engine)
@pytest_asyncio.fixture(scope="session")
async def async_engine():
engine = create_async_engine(
"postgresql+asyncpg://" + pg_database_url,
echo=True if os.getenv("ORM_DEBUG") else False,
)
yield engine
TestData = Tuple[Brain, User, List[Chat], List[ChatHistory]]
@pytest.fixture(scope="session")
@ -50,6 +29,20 @@ def event_loop(request: pytest.FixtureRequest):
loop.close()
@pytest_asyncio.fixture(scope="session")
async def async_engine():
engine = create_async_engine(
"postgresql+asyncpg://" + pg_database_base_url,
echo=True if os.getenv("ORM_DEBUG") else False,
future=True,
pool_pre_ping=True,
pool_size=10,
pool_recycle=0.1,
)
yield engine
@pytest_asyncio.fixture()
async def session(async_engine):
async with async_engine.connect() as conn:
@ -69,7 +62,14 @@ async def session(async_engine):
yield async_session
TestData = Tuple[Brain, User, List[Chat], List[ChatHistory]]
@pytest.mark.asyncio
async def test_pool_reconnect(session: AsyncSession):
# time.sleep(10)
response = await asyncio.gather(
*[session.exec(sqlalchemy.text("SELECT 1;")) for _ in range(100)]
)
result = [r.fetchall() for r in response]
assert list(result[0]) == [(1,)]
@pytest_asyncio.fixture()
@ -81,7 +81,11 @@ async def test_data(
await session.exec(select(User).where(User.email == "admin@quivr.app"))
).one()
# Brain data
brain_1 = Brain(name="test_brain", description="this is a test brain")
brain_1 = Brain(
name="test_brain",
description="this is a test brain",
brain_type=BrainType.integration,
)
# Chat data
chat_1 = Chat(chat_name="chat1", user=user_1)
chat_2 = Chat(chat_name="chat2", user=user_1)
@ -125,6 +129,15 @@ async def test_get_user_chats(session: AsyncSession, test_data: TestData):
assert len(query_chats) == len(chats)
@pytest.mark.asyncio
async def test_get_chat_history_close(session: AsyncSession, test_data: TestData):
brain_1, _, chats, chat_history = test_data
assert chats[0].chat_id
assert len(chat_history) > 0
assert chat_history[-1].message_time
assert chat_history[0].message_time
@pytest.mark.asyncio
async def test_get_chat_history(session: AsyncSession, test_data: TestData):
brain_1, _, chats, chat_history = test_data
@ -159,7 +172,7 @@ async def test_add_qa(session: AsyncSession, test_data: TestData):
assert resp_chat.assistant == qa.answer
## CHAT SERVICE
# CHAT SERVICE
@pytest.mark.asyncio

View File

@ -2,10 +2,11 @@ import os
from typing import AsyncGenerator, Callable, Generic, Type, TypeVar
from fastapi import Depends
from quivr_api.models.settings import settings
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.models.settings import settings
class BaseRepository:
def __init__(self, session: AsyncSession):
@ -34,11 +35,17 @@ async_engine = create_async_engine(
settings.pg_database_async_url,
echo=True if os.getenv("ORM_DEBUG") else False,
future=True,
# NOTE: pessimistic bound on
pool_pre_ping=True,
pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6
pool_recycle=1800,
)
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(async_engine) as session:
async with AsyncSession(
async_engine, expire_on_commit=False, autoflush=False
) as session:
yield session