diff --git a/backend/api/quivr_api/modules/chat/service/chat_service.py b/backend/api/quivr_api/modules/chat/service/chat_service.py index 6318caf49..9aa763093 100644 --- a/backend/api/quivr_api/modules/chat/service/chat_service.py +++ b/backend/api/quivr_api/modules/chat/service/chat_service.py @@ -3,6 +3,7 @@ from typing import List from uuid import UUID from fastapi import HTTPException + from quivr_api.logger import get_logger from quivr_api.modules.brain.entity.brain_entity import Brain from quivr_api.modules.brain.service.brain_service import BrainService @@ -52,7 +53,7 @@ class ChatService(BaseService[ChatRepository]): return inserted_chat def get_follow_up_question( - self, brain_id: UUID = None, question: str = None + self, brain_id: UUID | None = None, question: str = None ) -> [str]: follow_up = [ "Summarize the conversation", diff --git a/backend/api/quivr_api/modules/chat/tests/conftest.py b/backend/api/quivr_api/modules/chat/tests/conftest.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/api/quivr_api/modules/chat/tests/test_chats.py b/backend/api/quivr_api/modules/chat/tests/test_chats.py index 625b8a245..2ea7ddf55 100644 --- a/backend/api/quivr_api/modules/chat/tests/test_chats.py +++ b/backend/api/quivr_api/modules/chat/tests/test_chats.py @@ -6,41 +6,20 @@ from uuid import uuid4 import pytest import pytest_asyncio import sqlalchemy +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType from quivr_api.modules.chat.dto.inputs import QuestionAndAnswer from quivr_api.modules.chat.entity.chat import Chat, ChatHistory from quivr_api.modules.chat.repository.chats import ChatRepository from quivr_api.modules.chat.service.chat_service import ChatService from quivr_api.modules.user.entity.user_identity import User -from sqlalchemy.ext.asyncio import create_async_engine -from sqlmodel import create_engine, select -from sqlmodel.ext.asyncio.session import AsyncSession -pg_database_url = "postgres:postgres@localhost:54322/postgres" +pg_database_base_url = "postgres:postgres@localhost:54322/postgres" - -@pytest.fixture(scope="session", autouse=True) -def db_setup(): - # setup - sync_engine = create_engine( - "postgresql://" + pg_database_url, - echo=True if os.getenv("ORM_DEBUG") else False, - ) - # TODO(@amine) : for now don't drop anything - yield sync_engine - # teardown - # NOTE: For now we rely on Supabase migrations for defining schemas - # SQLModel.metadata.create_all(sync_engine, checkfirst=True) - # SQLModel.metadata.drop_all(sync_engine) - - -@pytest_asyncio.fixture(scope="session") -async def async_engine(): - engine = create_async_engine( - "postgresql+asyncpg://" + pg_database_url, - echo=True if os.getenv("ORM_DEBUG") else False, - ) - yield engine +TestData = Tuple[Brain, User, List[Chat], List[ChatHistory]] @pytest.fixture(scope="session") @@ -50,6 +29,20 @@ def event_loop(request: pytest.FixtureRequest): loop.close() +@pytest_asyncio.fixture(scope="session") +async def async_engine(): + engine = create_async_engine( + "postgresql+asyncpg://" + pg_database_base_url, + echo=True if os.getenv("ORM_DEBUG") else False, + future=True, + pool_pre_ping=True, + pool_size=10, + pool_recycle=0.1, + ) + + yield engine + + @pytest_asyncio.fixture() async def session(async_engine): async with async_engine.connect() as conn: @@ -69,7 +62,14 @@ async def session(async_engine): yield async_session -TestData = Tuple[Brain, User, List[Chat], List[ChatHistory]] +@pytest.mark.asyncio +async def test_pool_reconnect(session: AsyncSession): + # time.sleep(10) + response = await asyncio.gather( + *[session.exec(sqlalchemy.text("SELECT 1;")) for _ in range(100)] + ) + result = [r.fetchall() for r in response] + assert list(result[0]) == [(1,)] @pytest_asyncio.fixture() @@ -81,7 +81,11 @@ async def test_data( await session.exec(select(User).where(User.email == "admin@quivr.app")) ).one() # Brain data - brain_1 = Brain(name="test_brain", description="this is a test brain") + brain_1 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) # Chat data chat_1 = Chat(chat_name="chat1", user=user_1) chat_2 = Chat(chat_name="chat2", user=user_1) @@ -125,6 +129,15 @@ async def test_get_user_chats(session: AsyncSession, test_data: TestData): assert len(query_chats) == len(chats) +@pytest.mark.asyncio +async def test_get_chat_history_close(session: AsyncSession, test_data: TestData): + brain_1, _, chats, chat_history = test_data + assert chats[0].chat_id + assert len(chat_history) > 0 + assert chat_history[-1].message_time + assert chat_history[0].message_time + + @pytest.mark.asyncio async def test_get_chat_history(session: AsyncSession, test_data: TestData): brain_1, _, chats, chat_history = test_data @@ -159,7 +172,7 @@ async def test_add_qa(session: AsyncSession, test_data: TestData): assert resp_chat.assistant == qa.answer -## CHAT SERVICE +# CHAT SERVICE @pytest.mark.asyncio diff --git a/backend/api/quivr_api/modules/dependencies.py b/backend/api/quivr_api/modules/dependencies.py index 0bd6c49d0..94a6ce767 100644 --- a/backend/api/quivr_api/modules/dependencies.py +++ b/backend/api/quivr_api/modules/dependencies.py @@ -2,10 +2,11 @@ import os from typing import AsyncGenerator, Callable, Generic, Type, TypeVar from fastapi import Depends -from quivr_api.models.settings import settings from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession +from quivr_api.models.settings import settings + class BaseRepository: def __init__(self, session: AsyncSession): @@ -34,11 +35,17 @@ async_engine = create_async_engine( settings.pg_database_async_url, echo=True if os.getenv("ORM_DEBUG") else False, future=True, + # NOTE: pessimistic bound on + pool_pre_ping=True, + pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6 + pool_recycle=1800, ) async def get_async_session() -> AsyncGenerator[AsyncSession, None]: - async with AsyncSession(async_engine) as session: + async with AsyncSession( + async_engine, expire_on_commit=False, autoflush=False + ) as session: yield session