mirror of
https://github.com/StanGirard/quivr.git
synced 2024-10-04 00:33:03 +03:00
fix: asyncpg pooling config fix (#2795)
# Description closes #2782. Changes `sqlalchemy` connection pooling config : - **pool_pre_ping=True** : pessimistic disconnect handling. > - It is critical to note that the pre-ping approach does not accommodate **for connections dropped in the middle of transactions or other SQL operations** ! But this should only happen if we lose the database either due to network of DB server restart. - **pool_size=10**, with no db side pooling for now, if 6 uvicorn process workers are spawned (or 6 instances of the backed) we have a pool of 60 processes connecting to the database. - **pool_recycle=1800** : Recycles the pool every 30min Added additional session config : - expire_on_commit=False, When True, all instances will be fully expired after each commit, so that all attribute/object access subsequent to a completed transaction will load from the most recent database state. - autoflush=False, query operations will issue a Session.flush() call to this Session before proceeding. We are calling `session.commit` (which flushes) on each repository method so ne need to reflush on subsequent access
This commit is contained in:
parent
913217f682
commit
757bceeb95
@ -3,6 +3,7 @@ from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.entity.brain_entity import Brain
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
@ -52,7 +53,7 @@ class ChatService(BaseService[ChatRepository]):
|
||||
return inserted_chat
|
||||
|
||||
def get_follow_up_question(
|
||||
self, brain_id: UUID = None, question: str = None
|
||||
self, brain_id: UUID | None = None, question: str = None
|
||||
) -> [str]:
|
||||
follow_up = [
|
||||
"Summarize the conversation",
|
||||
|
@ -6,41 +6,20 @@ from uuid import uuid4
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import sqlalchemy
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
|
||||
from quivr_api.modules.chat.dto.inputs import QuestionAndAnswer
|
||||
from quivr_api.modules.chat.entity.chat import Chat, ChatHistory
|
||||
from quivr_api.modules.chat.repository.chats import ChatRepository
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.user.entity.user_identity import User
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import create_engine, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
pg_database_url = "postgres:postgres@localhost:54322/postgres"
|
||||
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def db_setup():
|
||||
# setup
|
||||
sync_engine = create_engine(
|
||||
"postgresql://" + pg_database_url,
|
||||
echo=True if os.getenv("ORM_DEBUG") else False,
|
||||
)
|
||||
# TODO(@amine) : for now don't drop anything
|
||||
yield sync_engine
|
||||
# teardown
|
||||
# NOTE: For now we rely on Supabase migrations for defining schemas
|
||||
# SQLModel.metadata.create_all(sync_engine, checkfirst=True)
|
||||
# SQLModel.metadata.drop_all(sync_engine)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://" + pg_database_url,
|
||||
echo=True if os.getenv("ORM_DEBUG") else False,
|
||||
)
|
||||
yield engine
|
||||
TestData = Tuple[Brain, User, List[Chat], List[ChatHistory]]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -50,6 +29,20 @@ def event_loop(request: pytest.FixtureRequest):
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://" + pg_database_base_url,
|
||||
echo=True if os.getenv("ORM_DEBUG") else False,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10,
|
||||
pool_recycle=0.1,
|
||||
)
|
||||
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def session(async_engine):
|
||||
async with async_engine.connect() as conn:
|
||||
@ -69,7 +62,14 @@ async def session(async_engine):
|
||||
yield async_session
|
||||
|
||||
|
||||
TestData = Tuple[Brain, User, List[Chat], List[ChatHistory]]
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_reconnect(session: AsyncSession):
|
||||
# time.sleep(10)
|
||||
response = await asyncio.gather(
|
||||
*[session.exec(sqlalchemy.text("SELECT 1;")) for _ in range(100)]
|
||||
)
|
||||
result = [r.fetchall() for r in response]
|
||||
assert list(result[0]) == [(1,)]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
@ -81,7 +81,11 @@ async def test_data(
|
||||
await session.exec(select(User).where(User.email == "admin@quivr.app"))
|
||||
).one()
|
||||
# Brain data
|
||||
brain_1 = Brain(name="test_brain", description="this is a test brain")
|
||||
brain_1 = Brain(
|
||||
name="test_brain",
|
||||
description="this is a test brain",
|
||||
brain_type=BrainType.integration,
|
||||
)
|
||||
# Chat data
|
||||
chat_1 = Chat(chat_name="chat1", user=user_1)
|
||||
chat_2 = Chat(chat_name="chat2", user=user_1)
|
||||
@ -125,6 +129,15 @@ async def test_get_user_chats(session: AsyncSession, test_data: TestData):
|
||||
assert len(query_chats) == len(chats)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chat_history_close(session: AsyncSession, test_data: TestData):
|
||||
brain_1, _, chats, chat_history = test_data
|
||||
assert chats[0].chat_id
|
||||
assert len(chat_history) > 0
|
||||
assert chat_history[-1].message_time
|
||||
assert chat_history[0].message_time
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chat_history(session: AsyncSession, test_data: TestData):
|
||||
brain_1, _, chats, chat_history = test_data
|
||||
@ -159,7 +172,7 @@ async def test_add_qa(session: AsyncSession, test_data: TestData):
|
||||
assert resp_chat.assistant == qa.answer
|
||||
|
||||
|
||||
## CHAT SERVICE
|
||||
# CHAT SERVICE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -2,10 +2,11 @@ import os
|
||||
from typing import AsyncGenerator, Callable, Generic, Type, TypeVar
|
||||
|
||||
from fastapi import Depends
|
||||
from quivr_api.models.settings import settings
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.models.settings import settings
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
@ -34,11 +35,17 @@ async_engine = create_async_engine(
|
||||
settings.pg_database_async_url,
|
||||
echo=True if os.getenv("ORM_DEBUG") else False,
|
||||
future=True,
|
||||
# NOTE: pessimistic bound on
|
||||
pool_pre_ping=True,
|
||||
pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6
|
||||
pool_recycle=1800,
|
||||
)
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSession(async_engine) as session:
|
||||
async with AsyncSession(
|
||||
async_engine, expire_on_commit=False, autoflush=False
|
||||
) as session:
|
||||
yield session
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user