mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
feat(citations): system added (#2498)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
19365c4bb5
commit
b7ff2e77af
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import AsyncIterable, List, Optional
|
from typing import AsyncIterable, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ from modules.user.service.user_usage import UserUsage
|
|||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__, log_level=logging.INFO)
|
||||||
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||||
|
|
||||||
|
|
||||||
@ -43,7 +44,11 @@ def is_valid_uuid(uuid_to_test, version=4):
|
|||||||
return str(uuid_obj) == uuid_to_test
|
return str(uuid_obj) == uuid_to_test
|
||||||
|
|
||||||
|
|
||||||
def generate_source(source_documents, brain_id):
|
def generate_source(source_documents, brain_id, citations: List[int] = None):
|
||||||
|
"""
|
||||||
|
Generate the sources list for the answer
|
||||||
|
It takes in a list of sources documents and citations that points to the docs index that was used in the answer
|
||||||
|
"""
|
||||||
# Initialize an empty list for sources
|
# Initialize an empty list for sources
|
||||||
sources_list: List[Sources] = []
|
sources_list: List[Sources] = []
|
||||||
|
|
||||||
@ -51,26 +56,26 @@ def generate_source(source_documents, brain_id):
|
|||||||
generated_urls = {}
|
generated_urls = {}
|
||||||
|
|
||||||
# remove duplicate sources with same name and create a list of unique sources
|
# remove duplicate sources with same name and create a list of unique sources
|
||||||
source_documents = list(
|
sources_url_cache = {}
|
||||||
{v.metadata["file_name"]: v for v in source_documents}.values()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get source documents from the result, default to an empty list if not found
|
# Get source documents from the result, default to an empty list if not found
|
||||||
|
|
||||||
# If source documents exist
|
# If source documents exist
|
||||||
if source_documents:
|
if source_documents:
|
||||||
logger.info(f"Source documents found: {source_documents}")
|
logger.info(f"Citations {citations}")
|
||||||
# Iterate over each document
|
# Iterate over each document
|
||||||
for doc in source_documents:
|
for doc, index in zip(source_documents, range(len(source_documents))):
|
||||||
logger.info("Document: %s", doc)
|
logger.info(f"Processing source document {doc.metadata['file_name']}")
|
||||||
|
if citations is not None:
|
||||||
|
if index not in citations:
|
||||||
|
logger.info(f"Skipping source document {doc.metadata['file_name']}")
|
||||||
|
continue
|
||||||
# Check if 'url' is in the document metadata
|
# Check if 'url' is in the document metadata
|
||||||
logger.info(f"Metadata 1: {doc.metadata}")
|
|
||||||
is_url = (
|
is_url = (
|
||||||
"original_file_name" in doc.metadata
|
"original_file_name" in doc.metadata
|
||||||
and doc.metadata["original_file_name"] is not None
|
and doc.metadata["original_file_name"] is not None
|
||||||
and doc.metadata["original_file_name"].startswith("http")
|
and doc.metadata["original_file_name"].startswith("http")
|
||||||
)
|
)
|
||||||
logger.info(f"Is URL: {is_url}")
|
|
||||||
|
|
||||||
# Determine the name based on whether it's a URL or a file
|
# Determine the name based on whether it's a URL or a file
|
||||||
name = (
|
name = (
|
||||||
@ -91,11 +96,15 @@ def generate_source(source_documents, brain_id):
|
|||||||
if file_path in generated_urls:
|
if file_path in generated_urls:
|
||||||
source_url = generated_urls[file_path]
|
source_url = generated_urls[file_path]
|
||||||
else:
|
else:
|
||||||
generated_url = generate_file_signed_url(file_path)
|
# Generate the URL
|
||||||
if generated_url is not None:
|
if file_path in sources_url_cache:
|
||||||
source_url = generated_url.get("signedURL", "")
|
source_url = sources_url_cache[file_path]
|
||||||
else:
|
else:
|
||||||
source_url = ""
|
generated_url = generate_file_signed_url(file_path)
|
||||||
|
if generated_url is not None:
|
||||||
|
source_url = generated_url.get("signedURL", "")
|
||||||
|
else:
|
||||||
|
source_url = ""
|
||||||
# Store the generated URL
|
# Store the generated URL
|
||||||
generated_urls[file_path] = source_url
|
generated_urls[file_path] = source_url
|
||||||
|
|
||||||
@ -106,6 +115,7 @@ def generate_source(source_documents, brain_id):
|
|||||||
type=type_,
|
type=type_,
|
||||||
source_url=source_url,
|
source_url=source_url,
|
||||||
original_file_name=name,
|
original_file_name=name,
|
||||||
|
citation=doc.page_content,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -219,10 +229,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
|||||||
|
|
||||||
def calculate_pricing(self):
|
def calculate_pricing(self):
|
||||||
|
|
||||||
logger.info("Calculating pricing")
|
|
||||||
logger.info(f"Model: {self.model}")
|
|
||||||
logger.info(f"User settings: {self.user_settings}")
|
|
||||||
logger.info(f"Models settings: {self.models_settings}")
|
|
||||||
model_to_use = find_model_and_generate_metadata(
|
model_to_use = find_model_and_generate_metadata(
|
||||||
self.chat_id,
|
self.chat_id,
|
||||||
self.brain.model,
|
self.brain.model,
|
||||||
@ -248,6 +254,8 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
|||||||
self.initialize_streamed_chat_history(chat_id, question)
|
self.initialize_streamed_chat_history(chat_id, question)
|
||||||
)
|
)
|
||||||
metadata = self.metadata or {}
|
metadata = self.metadata or {}
|
||||||
|
citations = None
|
||||||
|
answer = ""
|
||||||
model_response = conversational_qa_chain.invoke(
|
model_response = conversational_qa_chain.invoke(
|
||||||
{
|
{
|
||||||
"question": question.question,
|
"question": question.question,
|
||||||
@ -258,57 +266,23 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.model_compatible_with_function_calling(model=self.model):
|
||||||
|
if model_response["answer"].tool_calls:
|
||||||
|
citations = model_response["answer"].tool_calls[-1]["args"]["citations"]
|
||||||
|
if citations:
|
||||||
|
citations = citations
|
||||||
|
answer = model_response["answer"].tool_calls[-1]["args"]["answer"]
|
||||||
|
metadata["citations"] = citations
|
||||||
|
else:
|
||||||
|
answer = model_response["answer"].content
|
||||||
sources = model_response["docs"] or []
|
sources = model_response["docs"] or []
|
||||||
if len(sources) > 0:
|
if len(sources) > 0:
|
||||||
sources_list = generate_source(sources, self.brain_id)
|
sources_list = generate_source(sources, self.brain_id, citations=citations)
|
||||||
metadata["sources"] = sources_list
|
serialized_sources_list = [source.dict() for source in sources_list]
|
||||||
|
metadata["sources"] = serialized_sources_list
|
||||||
|
|
||||||
answer = model_response["answer"].content
|
return self.save_non_streaming_answer(
|
||||||
|
chat_id=chat_id, question=question, answer=answer, metadata=metadata
|
||||||
if save_answer:
|
|
||||||
# save the answer to the database or not -> add a variable
|
|
||||||
new_chat = chat_service.update_chat_history(
|
|
||||||
CreateChatHistory(
|
|
||||||
**{
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"user_message": question.question,
|
|
||||||
"assistant": answer,
|
|
||||||
"brain_id": self.brain.brain_id,
|
|
||||||
"prompt_id": self.prompt_to_use_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return GetChatHistoryOutput(
|
|
||||||
**{
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"user_message": question.question,
|
|
||||||
"assistant": answer,
|
|
||||||
"message_time": new_chat.message_time,
|
|
||||||
"prompt_title": (
|
|
||||||
self.prompt_to_use.title if self.prompt_to_use else None
|
|
||||||
),
|
|
||||||
"brain_name": self.brain.name if self.brain else None,
|
|
||||||
"message_id": new_chat.message_id,
|
|
||||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
|
||||||
"metadata": metadata,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return GetChatHistoryOutput(
|
|
||||||
**{
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"user_message": question.question,
|
|
||||||
"assistant": answer,
|
|
||||||
"message_time": None,
|
|
||||||
"prompt_title": (
|
|
||||||
self.prompt_to_use.title if self.prompt_to_use else None
|
|
||||||
),
|
|
||||||
"brain_name": None,
|
|
||||||
"message_id": None,
|
|
||||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
|
||||||
"metadata": metadata,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
@ -318,9 +292,10 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
|||||||
transformed_history, streamed_chat_history = (
|
transformed_history, streamed_chat_history = (
|
||||||
self.initialize_streamed_chat_history(chat_id, question)
|
self.initialize_streamed_chat_history(chat_id, question)
|
||||||
)
|
)
|
||||||
response_tokens = []
|
response_tokens = ""
|
||||||
sources = []
|
sources = []
|
||||||
|
citations = []
|
||||||
|
first = True
|
||||||
async for chunk in conversational_qa_chain.astream(
|
async for chunk in conversational_qa_chain.astream(
|
||||||
{
|
{
|
||||||
"question": question.question,
|
"question": question.question,
|
||||||
@ -330,18 +305,47 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
if chunk.get("answer"):
|
if not streamed_chat_history.metadata:
|
||||||
logger.info(f"Chunk: {chunk}")
|
streamed_chat_history.metadata = {}
|
||||||
response_tokens.append(chunk["answer"].content)
|
if self.model_compatible_with_function_calling(model=self.model):
|
||||||
streamed_chat_history.assistant = chunk["answer"].content
|
if chunk.get("answer"):
|
||||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
if first:
|
||||||
|
gathered = chunk["answer"]
|
||||||
|
first = False
|
||||||
|
else:
|
||||||
|
gathered = gathered + chunk["answer"]
|
||||||
|
if (
|
||||||
|
gathered.tool_calls
|
||||||
|
and gathered.tool_calls[-1].get("args")
|
||||||
|
and "answer" in gathered.tool_calls[-1]["args"]
|
||||||
|
):
|
||||||
|
# Only send the difference between answer and response_tokens which was the previous answer
|
||||||
|
answer = gathered.tool_calls[-1]["args"]["answer"]
|
||||||
|
difference = answer[len(response_tokens) :]
|
||||||
|
streamed_chat_history.assistant = difference
|
||||||
|
response_tokens = answer
|
||||||
|
|
||||||
|
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||||
|
if (
|
||||||
|
gathered.tool_calls
|
||||||
|
and gathered.tool_calls[-1].get("args")
|
||||||
|
and "citations" in gathered.tool_calls[-1]["args"]
|
||||||
|
):
|
||||||
|
citations = gathered.tool_calls[-1]["args"]["citations"]
|
||||||
|
else:
|
||||||
|
if chunk.get("answer"):
|
||||||
|
response_tokens += chunk["answer"].content
|
||||||
|
streamed_chat_history.assistant = chunk["answer"].content
|
||||||
|
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||||
|
|
||||||
if chunk.get("docs"):
|
if chunk.get("docs"):
|
||||||
sources = chunk["docs"]
|
sources = chunk["docs"]
|
||||||
|
|
||||||
sources_list = generate_source(sources, self.brain_id)
|
sources_list = generate_source(sources, self.brain_id, citations)
|
||||||
if not streamed_chat_history.metadata:
|
|
||||||
streamed_chat_history.metadata = {}
|
streamed_chat_history.metadata["citations"] = citations
|
||||||
# Serialize the sources list
|
|
||||||
|
# Serialize the sources list
|
||||||
serialized_sources_list = [source.dict() for source in sources_list]
|
serialized_sources_list = [source.dict() for source in sources_list]
|
||||||
streamed_chat_history.metadata["sources"] = serialized_sources_list
|
streamed_chat_history.metadata["sources"] = serialized_sources_list
|
||||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||||
@ -398,7 +402,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error updating message by ID: %s", e)
|
logger.error("Error updating message by ID: %s", e)
|
||||||
|
|
||||||
def save_non_streaming_answer(self, chat_id, question, answer):
|
def save_non_streaming_answer(self, chat_id, question, answer, metadata):
|
||||||
new_chat = chat_service.update_chat_history(
|
new_chat = chat_service.update_chat_history(
|
||||||
CreateChatHistory(
|
CreateChatHistory(
|
||||||
**{
|
**{
|
||||||
@ -407,6 +411,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
|||||||
"assistant": answer,
|
"assistant": answer,
|
||||||
"brain_id": self.brain.brain_id,
|
"brain_id": self.brain.brain_id,
|
||||||
"prompt_id": self.prompt_to_use_id,
|
"prompt_id": self.prompt_to_use_id,
|
||||||
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -423,5 +428,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
|||||||
"brain_name": self.brain.name if self.brain else None,
|
"brain_name": self.brain.name if self.brain else None,
|
||||||
"message_id": new_chat.message_id,
|
"message_id": new_chat.message_id,
|
||||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||||
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -39,3 +39,20 @@ class QAInterface(ABC):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"generate_stream is an abstract method and must be implemented"
|
"generate_stream is an abstract method and must be implemented"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def model_compatible_with_function_calling(self, model: str):
|
||||||
|
if model in [
|
||||||
|
"gpt-4-turbo",
|
||||||
|
"gpt-4-turbo-2024-04-09",
|
||||||
|
"gpt-4-turbo-preview",
|
||||||
|
"gpt-4-0125-preview",
|
||||||
|
"gpt-4-1106-preview",
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-0125",
|
||||||
|
"gpt-3.5-turbo-1106",
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.chains import ConversationalRetrievalChain
|
from langchain.chains import ConversationalRetrievalChain
|
||||||
@ -14,8 +15,10 @@ from langchain_cohere import CohereRerank
|
|||||||
from langchain_community.chat_models import ChatLiteLLM
|
from langchain_community.chat_models import ChatLiteLLM
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
|
||||||
|
from langchain_core.pydantic_v1 import Field as FieldV1
|
||||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
from logger import get_logger
|
from logger import get_logger
|
||||||
from models import BrainSettings # Importing settings related to the 'brain'
|
from models import BrainSettings # Importing settings related to the 'brain'
|
||||||
from modules.brain.service.brain_service import BrainService
|
from modules.brain.service.brain_service import BrainService
|
||||||
@ -26,7 +29,20 @@ from pydantic_settings import BaseSettings
|
|||||||
from supabase.client import Client, create_client
|
from supabase.client import Client, create_client
|
||||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__, log_level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class cited_answer(BaseModelV1):
|
||||||
|
"""Answer the user question based only on the given sources, and cite the sources used."""
|
||||||
|
|
||||||
|
answer: str = FieldV1(
|
||||||
|
...,
|
||||||
|
description="The answer to the user question, which is based only on the given sources.",
|
||||||
|
)
|
||||||
|
citations: List[int] = FieldV1(
|
||||||
|
...,
|
||||||
|
description="The integer IDs of the SPECIFIC sources which justify the answer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# First step is to create the Rephrasing Prompt
|
# First step is to create the Rephrasing Prompt
|
||||||
@ -66,7 +82,9 @@ ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
|||||||
|
|
||||||
# How we format documents
|
# How we format documents
|
||||||
|
|
||||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
|
||||||
|
template="Source: {index} \n {page_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_valid_uuid(uuid_to_test, version=4):
|
def is_valid_uuid(uuid_to_test, version=4):
|
||||||
@ -116,6 +134,23 @@ class QuivrRAG(BaseModel):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def model_compatible_with_function_calling(self):
|
||||||
|
if self.model in [
|
||||||
|
"gpt-4-turbo",
|
||||||
|
"gpt-4-turbo-2024-04-09",
|
||||||
|
"gpt-4-turbo-preview",
|
||||||
|
"gpt-4-0125-preview",
|
||||||
|
"gpt-4-1106-preview",
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-0125",
|
||||||
|
"gpt-3.5-turbo-1106",
|
||||||
|
"gpt-3.5-turbo-0613",
|
||||||
|
]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
supabase_client: Optional[Client] = None
|
supabase_client: Optional[Client] = None
|
||||||
vector_store: Optional[CustomSupabaseVectorStore] = None
|
vector_store: Optional[CustomSupabaseVectorStore] = None
|
||||||
qa: Optional[ConversationalRetrievalChain] = None
|
qa: Optional[ConversationalRetrievalChain] = None
|
||||||
@ -197,6 +232,9 @@ class QuivrRAG(BaseModel):
|
|||||||
def _combine_documents(
|
def _combine_documents(
|
||||||
self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
|
self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
|
||||||
):
|
):
|
||||||
|
# for each docs, add an index in the metadata to be able to cite the sources
|
||||||
|
for doc, index in zip(docs, range(len(docs))):
|
||||||
|
doc.metadata["index"] = index
|
||||||
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
||||||
return document_separator.join(doc_strings)
|
return document_separator.join(doc_strings)
|
||||||
|
|
||||||
@ -287,14 +325,27 @@ class QuivrRAG(BaseModel):
|
|||||||
"question": itemgetter("question"),
|
"question": itemgetter("question"),
|
||||||
"custom_instructions": itemgetter("custom_instructions"),
|
"custom_instructions": itemgetter("custom_instructions"),
|
||||||
}
|
}
|
||||||
|
llm = ChatLiteLLM(
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
model=self.model,
|
||||||
|
temperature=self.temperature,
|
||||||
|
api_base=api_base,
|
||||||
|
)
|
||||||
|
if self.model_compatible_with_function_calling():
|
||||||
|
|
||||||
|
# And finally, we do the part that returns the answers
|
||||||
|
llm_function = ChatOpenAI(
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
model=self.model,
|
||||||
|
temperature=self.temperature,
|
||||||
|
)
|
||||||
|
llm = llm_function.bind_tools(
|
||||||
|
[cited_answer],
|
||||||
|
tool_choice="cited_answer",
|
||||||
|
)
|
||||||
|
|
||||||
# And finally, we do the part that returns the answers
|
|
||||||
answer = {
|
answer = {
|
||||||
"answer": final_inputs
|
"answer": final_inputs | ANSWER_PROMPT | llm,
|
||||||
| ANSWER_PROMPT
|
|
||||||
| ChatLiteLLM(
|
|
||||||
max_tokens=self.max_tokens, model=self.model, api_base=api_base
|
|
||||||
),
|
|
||||||
"docs": itemgetter("docs"),
|
"docs": itemgetter("docs"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||||
from modules.notification.entity.notification import Notification
|
from modules.notification.entity.notification import Notification
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
@ -33,6 +33,7 @@ class Sources(BaseModel):
|
|||||||
source_url: str
|
source_url: str
|
||||||
type: str
|
type: str
|
||||||
original_file_name: str
|
original_file_name: str
|
||||||
|
citation: str
|
||||||
|
|
||||||
|
|
||||||
class ChatItemType(Enum):
|
class ChatItemType(Enum):
|
||||||
|
@ -11,6 +11,7 @@ class CreateChatHistory(BaseModel):
|
|||||||
assistant: str
|
assistant: str
|
||||||
prompt_id: Optional[UUID] = None
|
prompt_id: Optional[UUID] = None
|
||||||
brain_id: Optional[UUID] = None
|
brain_id: Optional[UUID] = None
|
||||||
|
metadata: Optional[dict] = {}
|
||||||
|
|
||||||
|
|
||||||
class QuestionAndAnswer(BaseModel):
|
class QuestionAndAnswer(BaseModel):
|
||||||
|
@ -74,6 +74,7 @@ class Chats(ChatsInterface):
|
|||||||
"brain_id": (
|
"brain_id": (
|
||||||
str(chat_history.brain_id) if chat_history.brain_id else None
|
str(chat_history.brain_id) if chat_history.brain_id else None
|
||||||
),
|
),
|
||||||
|
"metadata": chat_history.metadata if chat_history.metadata else {},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
.execute()
|
.execute()
|
||||||
@ -104,7 +105,9 @@ class Chats(ChatsInterface):
|
|||||||
def delete_chat_history(self, chat_id):
|
def delete_chat_history(self, chat_id):
|
||||||
self.db.table("chat_history").delete().match({"chat_id": chat_id}).execute()
|
self.db.table("chat_history").delete().match({"chat_id": chat_id}).execute()
|
||||||
|
|
||||||
def update_chat_message(self, chat_id, message_id, chat_message_properties: ChatMessageProperties ):
|
def update_chat_message(
|
||||||
|
self, chat_id, message_id, chat_message_properties: ChatMessageProperties
|
||||||
|
):
|
||||||
response = (
|
response = (
|
||||||
self.db.table("chat_history")
|
self.db.table("chat_history")
|
||||||
.update(chat_message_properties)
|
.update(chat_message_properties)
|
||||||
|
Loading…
Reference in New Issue
Block a user