feat: add chat with models (#2933)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):

---------

Co-authored-by: AmineDiro <aminedirhoussi1@gmail.com>
This commit is contained in:
Stan Girard 2024-08-06 14:51:27 +02:00 committed by GitHub
parent 1c63608b4c
commit fccd197511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 624 additions and 32 deletions

View File

@ -18,7 +18,7 @@
},
"json.sortOnSave.enable": true,
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit",

View File

@ -32,4 +32,7 @@ test-type:
fi
front:
cd frontend && yarn build && yarn start
cd frontend && yarn build && yarn start
test:
cd backend/core && ./scripts/run_tests.sh

View File

@ -16,6 +16,7 @@ from quivr_api.modules.brain.controller import brain_router
from quivr_api.modules.chat.controller import chat_router
from quivr_api.modules.knowledge.controller import knowledge_router
from quivr_api.modules.misc.controller import misc_router
from quivr_api.modules.models.controller.model_routes import model_router
from quivr_api.modules.onboarding.controller import onboarding_router
from quivr_api.modules.prompt.controller import prompt_router
from quivr_api.modules.sync.controller import sync_router
@ -78,13 +79,13 @@ app.include_router(sync_router)
app.include_router(onboarding_router)
app.include_router(misc_router)
app.include_router(analytics_router)
app.include_router(upload_router)
app.include_router(user_router)
app.include_router(api_key_router)
app.include_router(subscription_router)
app.include_router(prompt_router)
app.include_router(knowledge_router)
app.include_router(model_router)
PROFILING = os.getenv("PROFILING", "false").lower() == "true"

View File

@ -19,6 +19,7 @@ from quivr_api.modules.chat.dto.inputs import (
)
from quivr_api.modules.chat.entity.chat import Chat
from quivr_api.modules.chat.service.chat_service import ChatService
from quivr_api.modules.chat_llm_service.chat_llm_service import ChatLLMService
from quivr_api.modules.dependencies import get_service
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.prompt.service.prompt_service import PromptService
@ -166,16 +167,25 @@ async def create_question_handler(
# TODO: check logic into middleware
validate_authorization(user_id=current_user.id, brain_id=brain_id)
try:
rag_service = RAGService(
current_user,
brain_id,
chat_id,
brain_service,
prompt_service,
chat_service,
knowledge_service,
)
chat_answer = await rag_service.generate_answer(chat_question.question)
service = None
if brain_id:
service = RAGService(
current_user,
brain_id,
chat_id,
brain_service,
prompt_service,
chat_service,
knowledge_service,
)
else:
service = ChatLLMService(
current_user,
chat_question.model,
chat_id,
chat_service,
)
chat_answer = await service.generate_answer(chat_question.question)
maybe_send_telemetry("question_asked", {"streaming": False}, request)
return chat_answer
@ -214,19 +224,28 @@ async def create_stream_question_handler(
)
try:
rag_service = RAGService(
current_user,
brain_id,
chat_id,
brain_service,
prompt_service,
chat_service,
knowledge_service,
)
service = None
if brain_id:
service = RAGService(
current_user,
brain_id,
chat_id,
brain_service,
prompt_service,
chat_service,
knowledge_service,
)
else:
service = ChatLLMService(
current_user,
chat_question.model,
chat_id,
chat_service,
)
maybe_send_telemetry("question_asked", {"streaming": True}, request)
return StreamingResponse(
rag_service.generate_answer_stream(chat_question.question),
service.generate_answer_stream(chat_question.question),
media_type="text/event-stream",
)

View File

@ -53,7 +53,7 @@ class ChatService(BaseService[ChatRepository]):
return inserted_chat
def get_follow_up_question(
self, brain_id: UUID | None = None, question: str = None
self, brain_id: UUID | None = None, question: str | None = None
) -> [str]:
follow_up = [
"Summarize the conversation",
@ -87,9 +87,15 @@ class ChatService(BaseService[ChatRepository]):
enriched_history: List[GetChatHistoryOutput] = []
if len(history) == 0:
return enriched_history
brain: Brain = await history[0].awaitable_attrs.brain
prompt: Prompt = await brain.awaitable_attrs.prompt
for message in history:
brain: Brain | None = (
await message.awaitable_attrs.brain if message.brain_id else None
)
prompt: Prompt | None = None
if brain:
prompt = (
await brain.awaitable_attrs.prompt if message.prompt_id else None
)
enriched_history.append(
# TODO : WHY bother with having ids here ??
GetChatHistoryOutput(

View File

@ -0,0 +1,3 @@
from .chat_llm_service import ChatLLMService
__all__ = ["ChatLLMService"]

View File

@ -0,0 +1,214 @@
import datetime
from uuid import UUID, uuid4
from quivr_core.chat import ChatHistory as ChatHistoryCore
from quivr_core.chat_llm import ChatLLM
from quivr_core.config import LLMEndpointConfig
from quivr_core.llm.llm_endpoint import LLMEndpoint
from quivr_core.models import ParsedRAGResponse, RAGResponseMetadata
from quivr_api.logger import get_logger
from quivr_api.models.settings import settings
from quivr_api.modules.brain.service.utils.format_chat_history import (
format_chat_history,
)
from quivr_api.modules.chat.controller.chat.utils import (
compute_cost,
find_model_and_generate_metadata,
update_user_usage,
)
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
from quivr_api.modules.chat.service.chat_service import ChatService
from quivr_api.modules.user.entity.user_identity import UserIdentity
from quivr_api.modules.user.service.user_usage import UserUsage
logger = get_logger(__name__)
class ChatLLMService:
def __init__(
self,
current_user: UserIdentity,
model_name: str,
chat_id: UUID,
chat_service: ChatService,
):
# Services
self.chat_service = chat_service
# Base models
self.current_user = current_user
self.chat_id = chat_id
# check at init time
self.model_to_use = self.check_and_update_user_usage(
self.current_user, model_name
)
def _build_chat_history(
self,
history: list[GetChatHistoryOutput],
) -> ChatHistoryCore:
transformed_history = format_chat_history(history)
chat_history = ChatHistoryCore(brain_id=None, chat_id=self.chat_id)
[chat_history.append(m) for m in transformed_history]
return chat_history
def build_llm(self) -> ChatLLM:
ollama_url = (
settings.ollama_api_base_url
if settings.ollama_api_base_url
and self.model_to_use.name.startswith("ollama")
else None
)
chat_llm = ChatLLM(
llm=LLMEndpoint.from_config(
LLMEndpointConfig(
model=self.model_to_use.name,
llm_base_url=ollama_url,
llm_api_key="abc-123" if ollama_url else None,
temperature=(LLMEndpointConfig.model_fields["temperature"].default),
max_input=self.model_to_use.max_input,
max_tokens=self.model_to_use.max_output,
),
)
)
return chat_llm
def check_and_update_user_usage(self, user: UserIdentity, model_name: str):
"""Check user limits and raises if user reached his limits:
1. Raise if one of the conditions :
- User doesn't have access to brains
- Model of brain is not is user_settings.models
- Latest sum_30d(user_daily_user) < user_settings.max_monthly_usage
- Check sum(user_settings.daily_user_count)+ model_price < user_settings.monthly_chat_credits
2. Updates user usage
"""
# TODO(@aminediro) : THIS is bug prone, should retrieve it from DB here
user_usage = UserUsage(id=user.id, email=user.email)
user_settings = user_usage.get_user_settings()
all_models = user_usage.get_models()
# TODO(@aminediro): refactor this function
model_to_use = find_model_and_generate_metadata(
model_name,
user_settings,
all_models,
)
cost = compute_cost(model_to_use, all_models)
# Raises HTTP if user usage exceeds limits
update_user_usage(user_usage, user_settings, cost) # noqa: F821
return model_to_use
def save_answer(self, question: str, answer: ParsedRAGResponse):
logger.info(
f"Saving answer for chat {self.chat_id} with model {self.model_to_use.name}"
)
logger.info(answer)
return self.chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": self.chat_id,
"user_message": question,
"assistant": answer.answer,
"brain_id": None,
"prompt_id": None,
"metadata": answer.metadata.model_dump() if answer.metadata else {},
}
)
)
async def generate_answer(
self,
question: str,
):
logger.info(
f"Creating question for chat {self.chat_id} with model {self.model_to_use.name} "
)
chat_llm = self.build_llm()
history = await self.chat_service.get_chat_history(self.chat_id)
# Format the history, sanitize the input
chat_history = self._build_chat_history(history)
parsed_response = chat_llm.answer(question, chat_history)
# Save the answer to db
new_chat_entry = self.save_answer(question, parsed_response)
# Format output to be correct
return GetChatHistoryOutput(
**{
"chat_id": self.chat_id,
"user_message": question,
"assistant": parsed_response.answer,
"message_time": new_chat_entry.message_time,
"prompt_title": None,
"brain_name": None,
"message_id": new_chat_entry.message_id,
"brain_id": None,
"metadata": (
parsed_response.metadata.model_dump()
if parsed_response.metadata
else {}
),
}
)
async def generate_answer_stream(
self,
question: str,
):
logger.info(
f"Creating question for chat {self.chat_id} with model {self.model_to_use.name} "
)
# Build the rag config
chat_llm = self.build_llm()
# Get chat history
history = await self.chat_service.get_chat_history(self.chat_id)
# Format the history, sanitize the input
chat_history = self._build_chat_history(history)
full_answer = ""
message_metadata = {
"chat_id": self.chat_id,
"message_id": uuid4(), # do we need it ?,
"user_message": question, # TODO: define result
"message_time": datetime.datetime.now(), # TODO: define result
"prompt_title": None,
"brain_name": None,
"brain_id": None,
}
async for response in chat_llm.answer_astream(question, chat_history):
# Format output to be correct servicedf;j
if not response.last_chunk:
streamed_chat_history = GetChatHistoryOutput(
assistant=response.answer,
metadata=response.metadata.model_dump(),
**message_metadata,
)
full_answer += response.answer
yield f"data: {streamed_chat_history.model_dump_json()}"
if response.last_chunk and full_answer == "":
full_answer += response.answer
# For last chunk parse the sources, and the full answer
streamed_chat_history = GetChatHistoryOutput(
assistant=full_answer,
metadata=response.metadata.model_dump(),
**message_metadata,
)
logger.info("Last chunk before saving")
self.save_answer(
question,
ParsedRAGResponse(
answer=full_answer,
metadata=RAGResponseMetadata(**streamed_chat_history.metadata),
),
)
yield f"data: {streamed_chat_history.model_dump_json()}"

View File

@ -0,0 +1,3 @@
from .model_routes import model_router
__all__ = ["model_router"]

View File

@ -0,0 +1,30 @@
from typing import Annotated, List
from fastapi import APIRouter, Depends
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.dependencies import get_service
from quivr_api.modules.models.entity.model import Model
from quivr_api.modules.models.service.model_service import ModelService
from quivr_api.modules.user.entity.user_identity import UserIdentity
logger = get_logger(__name__)
model_router = APIRouter()
ModelServiceDep = Annotated[ModelService, Depends(get_service(ModelService))]
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
# get all chats
@model_router.get(
"/models",
response_model=List[Model],
dependencies=[Depends(AuthBearer())],
tags=["Models"],
)
async def get_models(current_user: UserIdentityDep, model_service: ModelServiceDep):
"""
Retrieve all models for the current user.
"""
models = await model_service.get_models()
return models

View File

@ -0,0 +1,13 @@
from sqlmodel import Field, SQLModel
class Model(SQLModel, table=True):
__tablename__ = "models"
name: str = Field(primary_key=True)
price: int = Field(default=1)
max_input: int = Field(default=2000)
max_output: int = Field(default=1000)
class Config:
arbitrary_types_allowed = True

View File

@ -0,0 +1,19 @@
from typing import Sequence
from quivr_api.models.settings import get_supabase_client
from quivr_api.modules.dependencies import BaseRepository
from quivr_api.modules.models.entity.model import Model
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
class ModelRepository(BaseRepository):
def __init__(self, session: AsyncSession):
super().__init__(session)
# TODO: for now use it instead of session
self.db = get_supabase_client()
async def get_models(self) -> Sequence[Model]:
query = select(Model)
response = await self.session.exec(query)
return response.all()

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from quivr_api.modules.models.entity.model import Model
class ModelsInterface(ABC):
@abstractmethod
def get_models(self) -> list[Model]:
"""
Get all models
"""
pass

View File

@ -0,0 +1,21 @@
from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import BaseService
from quivr_api.modules.models.entity.model import Model
from quivr_api.modules.models.repository.model import ModelRepository
logger = get_logger(__name__)
class ModelService(BaseService[ModelRepository]):
repository_cls = ModelRepository
def __init__(self, repository: ModelRepository):
self.repository = repository
async def get_models(self) -> list[Model]:
logger.info("Getting models")
models = await self.repository.get_models()
logger.info(f"Insert response {models}")
return models

View File

@ -0,0 +1,70 @@
import pytest
import pytest_asyncio
from quivr_api.modules.models.entity.model import Model
@pytest_asyncio.fixture()
async def sample_models():
return [
Model(name="gpt-3.5-turbo", price=1, max_input=4000, max_output=2000),
Model(name="gpt-4", price=5, max_input=8000, max_output=4000),
]
@pytest.mark.asyncio
async def test_model_creation():
model = Model(name="test-model", price=2, max_input=1000, max_output=500)
assert model.name == "test-model"
assert model.price == 2
assert model.max_input == 1000
assert model.max_output == 500
@pytest.mark.asyncio
async def test_model_attributes(sample_models):
model = sample_models[0]
assert hasattr(model, "name")
assert hasattr(model, "price")
assert hasattr(model, "max_input")
assert hasattr(model, "max_output")
@pytest.mark.asyncio
async def test_model_validation():
# Test valid model creation
valid_model = Model(name="valid-model", price=3, max_input=5000, max_output=2500)
assert valid_model.name == "valid-model"
assert valid_model.price == 3
assert valid_model.max_input == 5000
assert valid_model.max_output == 2500
@pytest.mark.asyncio
async def test_model_default_values():
default_model = Model(name="default-model")
assert default_model.name == "default-model"
assert default_model.price == 1
assert default_model.max_input == 2000
assert default_model.max_output == 1000
@pytest.mark.asyncio
async def test_model_comparison():
model1 = Model(name="model1", price=2, max_input=3000, max_output=1500)
model2 = Model(name="model2", price=3, max_input=4000, max_output=2000)
model3 = Model(name="model1", price=2, max_input=3000, max_output=1500)
assert model1 != model2
assert model1 == model3
@pytest.mark.asyncio
async def test_model_dict_representation():
model = Model(name="test-model", price=2, max_input=3000, max_output=1500)
expected_dict = {
"name": "test-model",
"price": 2,
"max_input": 3000,
"max_output": 1500,
}
assert model.dict() == expected_dict

View File

@ -0,0 +1,12 @@
from quivr_core import ChatLLM
from quivr_core.config import LLMEndpointConfig
from quivr_core.llm import LLMEndpoint
if __name__ == "__main__":
llm_endpoint = LLMEndpoint.from_config(LLMEndpointConfig(model="gpt-4o-mini"))
chat_llm = ChatLLM(
llm=llm_endpoint,
)
print(chat_llm.llm_endpoint.info())
response = chat_llm.answer("Hello,what is your model?")
print(response)

View File

@ -1,9 +1,10 @@
from importlib.metadata import entry_points
from .brain import Brain
from .chat_llm import ChatLLM
from .processor.registry import register_processor, registry
__all__ = ["Brain", "registry", "register_processor"]
__all__ = ["Brain", "ChatLLM", "registry", "register_processor"]
def register_entries():

View File

@ -3,12 +3,11 @@ from typing import Any, Generator, Tuple
from uuid import UUID, uuid4
from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.models import ChatMessage
class ChatHistory:
def __init__(self, chat_id: UUID, brain_id: UUID) -> None:
def __init__(self, chat_id: UUID, brain_id: UUID | None) -> None:
self.id = chat_id
self.brain_id = brain_id
# TODO(@aminediro): maybe use a deque() instead ?

View File

@ -0,0 +1,146 @@
import logging
from operator import itemgetter
from typing import AsyncGenerator
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from quivr_core.chat import ChatHistory
from quivr_core.llm import LLMEndpoint
from quivr_core.models import (
ParsedRAGChunkResponse,
ParsedRAGResponse,
RAGResponseMetadata,
)
from quivr_core.utils import get_chunk_metadata, parse_chunk_response, parse_response
logger = logging.getLogger("quivr_core")
class ChatLLM:
def __init__(self, *, llm: LLMEndpoint):
self.llm_endpoint = llm
def filter_history(
self,
chat_history: ChatHistory | None,
):
"""
Filter out the chat history to only include the messages that are relevant to the current question
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
"""
total_tokens = 0
total_pairs = 0
filtered_chat_history: list[AIMessage | HumanMessage] = []
if chat_history is None:
return filtered_chat_history
for human_message, ai_message in chat_history.iter_pairs():
# TODO: replace with tiktoken
message_tokens = (len(human_message.content) + len(ai_message.content)) // 4
if (
total_tokens + message_tokens > self.llm_endpoint._config.max_input
or total_pairs >= 10
):
break
filtered_chat_history.append(human_message)
filtered_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
return filtered_chat_history[::-1]
def build_chain(self):
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(
lambda x: self.filter_history(x["chat_history"]),
),
question=lambda x: x["question"],
)
logger.info(f"loaded_memory: {loaded_memory}")
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are Quivr. You are an assistant.",
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
final_inputs = {
"question": itemgetter("question"),
"chat_history": itemgetter("chat_history"),
}
llm = self.llm_endpoint._llm
answer = {"answer": final_inputs | prompt | llm, "docs": lambda _: []}
return loaded_memory | answer
def answer(
self, question: str, history: ChatHistory | None = None
) -> ParsedRAGResponse:
chain = self.build_chain()
raw_llm_response = chain.invoke({"question": question, "chat_history": history})
response = parse_response(raw_llm_response, self.llm_endpoint._config.model)
return response
async def answer_astream(
self, question: str, history: ChatHistory | None = None
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
chain = self.build_chain()
rolling_message = AIMessageChunk(content="")
prev_answer = ""
chunk_id = 0
async for chunk in chain.astream(
{"question": question, "chat_history": history}
):
if "answer" in chunk:
rolling_message, answer_str = parse_chunk_response(
rolling_message,
chunk,
self.llm_endpoint.supports_func_calling(),
)
if len(answer_str) > 0:
if self.llm_endpoint.supports_func_calling():
diff_answer = answer_str[len(prev_answer) :]
if len(diff_answer) > 0:
parsed_chunk = ParsedRAGChunkResponse(
answer=diff_answer,
metadata=RAGResponseMetadata(),
)
prev_answer += diff_answer
logger.debug(
f"answer_astream func_calling=True question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
)
yield parsed_chunk
else:
parsed_chunk = ParsedRAGChunkResponse(
answer=answer_str,
metadata=RAGResponseMetadata(),
)
logger.debug(
f"answer_astream func_calling=False question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
)
yield parsed_chunk
chunk_id += 1
# Last chunk provides metadata
last_chunk = ParsedRAGChunkResponse(
answer=rolling_message.content,
metadata=get_chunk_metadata(rolling_message),
last_chunk=True,
)
last_chunk.metadata.model_name = self.llm_endpoint._config.model
logger.debug(
f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}"
)
yield last_chunk

View File

@ -1,6 +1,5 @@
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic.v1 import SecretStr
from quivr_core.brain.info import LLMInfo
from quivr_core.config import LLMEndpointConfig
from quivr_core.utils import model_supports_function_calling

View File

@ -31,7 +31,7 @@ class cited_answer(BaseModelV1):
class ChatMessage(BaseModelV1):
chat_id: UUID
message_id: UUID
brain_id: UUID
brain_id: UUID | None
msg: AIMessage | HumanMessage
message_time: datetime
metadata: dict[str, Any]
@ -59,6 +59,7 @@ class RAGResponseMetadata(BaseModel):
citations: list[int] | None = None
followup_questions: list[str] | None = None
sources: list[Any] | None = None
model_name: str | None = None
class ParsedRAGResponse(BaseModel):

View File

@ -4,6 +4,7 @@ from typing import Any, List, Tuple, no_type_check
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import format_document
from quivr_core.models import (
ParsedRAGResponse,
QuivrKnowledge,
@ -31,6 +32,7 @@ def model_supports_function_calling(model_name: str):
"gpt-3.5-turbo",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-mini",
]
return model_name in models_supporting_function_calls
@ -113,7 +115,8 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe
answer = raw_response["answer"].content
sources = raw_response["docs"] or []
metadata = {"sources": sources}
metadata = {"sources": sources, "model_name":model_name}
metadata["model_name"] = model_name
if model_supports_function_calling(model_name):
if raw_response["answer"].tool_calls:

View File

@ -0,0 +1,17 @@
import pytest
from quivr_core import ChatLLM
from quivr_core.chat_llm import ChatLLM
@pytest.mark.base
def test_chat_llm(fake_llm):
chat_llm = ChatLLM(
llm=fake_llm,
)
answer = chat_llm.answer("Hello, how are you?")
assert len(answer.answer) > 0
assert answer.metadata is not None
assert answer.metadata.citations is None
assert answer.metadata.followup_questions is None
assert answer.metadata.sources == []