mirror of
https://github.com/StanGirard/quivr.git
synced 2024-08-16 08:30:28 +03:00
feat(chat): use openai function for answer (#354)
* feat(chat): use openai function for answer (backend) * feat(chat): use openai function for answer (frontend) * chore: refacto BrainPicking * feat: update chat creation logic * feat: simplify chat system logic * feat: set default method to gpt-3.5-turbo-0613 * feat: use user own openai key * feat(chat): slightly improve prompts * feat: add global error interceptor * feat: remove unused endpoints * docs: update chat system doc * chore(linter): add unused import remove config * feat: improve dx * feat: improve OpenAiFunctionBasedAnswerGenerator prompt
This commit is contained in:
parent
83fde0aeea
commit
59fe7b089b
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -2,7 +2,8 @@
|
||||
"python.formatting.provider": "black",
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": true,
|
||||
"source.fixAll": true
|
||||
"source.fixAll": true,
|
||||
"source.unusedImports": true
|
||||
},
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.flake8Enabled": true,
|
||||
|
@ -0,0 +1,245 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from typing import Any, Dict, List # For type hinting
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from .utils.format_answer import format_answer
|
||||
|
||||
|
||||
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
from models.settings import BrainSettings # Importing settings related to the 'brain'
|
||||
from llm.OpenAiFunctionBasedAnswerGenerator.models.OpenAiAnswer import OpenAiAnswer
|
||||
|
||||
from supabase import Client, create_client # For interacting with Supabase database
|
||||
from vectorstore.supabase import (
|
||||
CustomSupabaseVectorStore,
|
||||
) # Custom class for handling vector storage with Supabase
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
get_context_function_name = "get_context"
|
||||
|
||||
prompt_template = """Your name is Quivr. You are a second brain.
|
||||
A person will ask you a question and you will provide a helpful answer.
|
||||
Write the answer in the same language as the question.
|
||||
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
||||
Your main goal is to answer questions about user uploaded documents. Unless basic questions or greetings, you should always refer to user uploaded documents by fetching them with the {} function.""".format(
|
||||
get_context_function_name
|
||||
)
|
||||
|
||||
get_history_schema = {
|
||||
"name": "get_history",
|
||||
"description": "Get current chat previous messages",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
get_context_schema = {
|
||||
"name": get_context_function_name,
|
||||
"description": "A function which returns user uploaded documents and which must be used when you don't now the answer to a question or when the question seems to refer to user uploaded documents",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAiFunctionBasedAnswerGenerator:
|
||||
# Default class attributes
|
||||
model: str = "gpt-3.5-turbo-0613"
|
||||
temperature: float = 0.0
|
||||
max_tokens: int = 256
|
||||
chat_id: str
|
||||
supabase_client: Client = None
|
||||
embeddings: OpenAIEmbeddings = None
|
||||
settings = BrainSettings()
|
||||
openai_client: ChatOpenAI = None
|
||||
user_email: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
chat_id: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_email: str,
|
||||
user_openai_api_key: str,
|
||||
) -> "OpenAiFunctionBasedAnswerGenerator":
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.chat_id = chat_id
|
||||
self.supabase_client = create_client(
|
||||
self.settings.supabase_url, self.settings.supabase_service_key
|
||||
)
|
||||
self.user_email = user_email
|
||||
if user_openai_api_key is not None:
|
||||
self.settings.openai_api_key = user_openai_api_key
|
||||
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key)
|
||||
self.openai_client = ChatOpenAI(openai_api_key=self.settings.openai_api_key)
|
||||
|
||||
def _get_model_response(
|
||||
self,
|
||||
messages: list[dict[str, str]] = [],
|
||||
functions: list[dict[str, Any]] = None,
|
||||
):
|
||||
if functions is not None:
|
||||
model_response = self.openai_client.completion_with_retry(
|
||||
functions=functions,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
else:
|
||||
model_response = self.openai_client.completion_with_retry(
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
def _get_formatted_history(self) -> List[Dict[str, str]]:
|
||||
formatted_history = []
|
||||
history = get_chat_history(self.chat_id)
|
||||
for chat in history:
|
||||
formatted_history.append({"role": "user", "content": chat.user_message})
|
||||
formatted_history.append({"role": "assistant", "content": chat.assistant})
|
||||
return formatted_history
|
||||
|
||||
def _get_formatted_prompt(
|
||||
self,
|
||||
question: Optional[str],
|
||||
useContext: Optional[bool] = False,
|
||||
useHistory: Optional[bool] = False,
|
||||
) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{"role": "system", "content": prompt_template},
|
||||
]
|
||||
|
||||
if not useHistory and not useContext:
|
||||
messages.append(
|
||||
{"role": "user", "content": question},
|
||||
)
|
||||
return messages
|
||||
|
||||
if useHistory:
|
||||
history = self._get_formatted_history()
|
||||
if len(history):
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Previous messages are already in chat.",
|
||||
},
|
||||
)
|
||||
messages.extend(history)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is the first message of the chat. There is no previous one",
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nAnswer:",
|
||||
}
|
||||
)
|
||||
if useContext:
|
||||
chat_context = self._get_context(question)
|
||||
enhanced_question = f"Here is chat context: {chat_context if len(chat_context) else 'No document found'}"
|
||||
messages.append({"role": "user", "content": enhanced_question})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nAnswer:",
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
def _get_context(self, question: str) -> str:
|
||||
# retrieve 5 nearest documents
|
||||
vector_store = CustomSupabaseVectorStore(
|
||||
self.supabase_client,
|
||||
self.embeddings,
|
||||
table_name="vectors",
|
||||
user_id=self.user_email,
|
||||
)
|
||||
|
||||
return vector_store.similarity_search(
|
||||
query=question,
|
||||
)
|
||||
|
||||
def _get_answer_from_question(self, question: str) -> OpenAiAnswer:
|
||||
functions = [get_history_schema, get_context_schema]
|
||||
|
||||
model_response = self._get_model_response(
|
||||
messages=self._get_formatted_prompt(question=question),
|
||||
functions=functions,
|
||||
)
|
||||
|
||||
return format_answer(model_response)
|
||||
|
||||
def _get_answer_from_question_and_history(self, question: str) -> OpenAiAnswer:
|
||||
logger.info("Using chat history")
|
||||
|
||||
functions = [
|
||||
get_context_schema,
|
||||
]
|
||||
|
||||
model_response = self._get_model_response(
|
||||
messages=self._get_formatted_prompt(question=question, useHistory=True),
|
||||
functions=functions,
|
||||
)
|
||||
|
||||
return format_answer(model_response)
|
||||
|
||||
def _get_answer_from_question_and_context(self, question: str) -> OpenAiAnswer:
|
||||
logger.info("Using documents ")
|
||||
|
||||
functions = [
|
||||
get_history_schema,
|
||||
]
|
||||
model_response = self._get_model_response(
|
||||
messages=self._get_formatted_prompt(question=question, useContext=True),
|
||||
functions=functions,
|
||||
)
|
||||
|
||||
return format_answer(model_response)
|
||||
|
||||
def _get_answer_from_question_and_context_and_history(
|
||||
self, question: str
|
||||
) -> OpenAiAnswer:
|
||||
logger.info("Using context and history")
|
||||
model_response = self._get_model_response(
|
||||
messages=self._get_formatted_prompt(
|
||||
question, useContext=True, useHistory=True
|
||||
),
|
||||
)
|
||||
return format_answer(model_response)
|
||||
|
||||
def get_answer(self, question: str) -> str:
|
||||
response = self._get_answer_from_question(question)
|
||||
function_name = response.function_call.name if response.function_call else None
|
||||
|
||||
if function_name == "get_history":
|
||||
response = self._get_answer_from_question_and_history(question)
|
||||
elif function_name == "get_context":
|
||||
response = self._get_answer_from_question_and_context(question)
|
||||
|
||||
if response.function_call:
|
||||
response = self._get_answer_from_question_and_context_and_history(question)
|
||||
|
||||
return response.content or ""
|
@ -0,0 +1,12 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Dict # For type hinting
|
||||
|
||||
|
||||
class FunctionCall:
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
arguments: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.arguments = arguments
|
@ -0,0 +1,12 @@
|
||||
from typing import Optional
|
||||
from .FunctionCall import FunctionCall
|
||||
|
||||
|
||||
class OpenAiAnswer:
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
function_call: FunctionCall = None,
|
||||
):
|
||||
self.content = content
|
||||
self.function_call = function_call
|
@ -0,0 +1,17 @@
|
||||
from llm.OpenAiFunctionBasedAnswerGenerator.models.OpenAiAnswer import OpenAiAnswer
|
||||
from llm.OpenAiFunctionBasedAnswerGenerator.models.FunctionCall import FunctionCall
|
||||
from typing import Any, Dict # For type hinting
|
||||
|
||||
|
||||
def format_answer(model_response: Dict[str, Any]) -> OpenAiAnswer:
|
||||
answer = model_response["choices"][0]["message"]
|
||||
content = answer["content"]
|
||||
function_call = None
|
||||
|
||||
if answer.get("function_call", None) is not None:
|
||||
function_call = FunctionCall(
|
||||
answer["function_call"]["name"],
|
||||
answer["function_call"]["arguments"],
|
||||
)
|
||||
|
||||
return OpenAiAnswer(content=content, function_call=function_call)
|
@ -1,32 +1,21 @@
|
||||
import os # A module to interact with the OS
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
from models.settings import LLMSettings # For type hinting
|
||||
|
||||
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
|
||||
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
|
||||
from langchain.chat_models import ChatOpenAI, ChatVertexAI
|
||||
from langchain.chat_models.anthropic import ChatAnthropic
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.vectorstores import SupabaseVectorStore
|
||||
from llm.prompt import LANGUAGE_PROMPT
|
||||
|
||||
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
from models.chats import (
|
||||
ChatMessage,
|
||||
) # Importing a custom ChatMessage class for handling chat messages
|
||||
from models.settings import BrainSettings # Importing settings related to the 'brain'
|
||||
from pydantic import BaseModel # For data validation and settings management
|
||||
from pydantic import BaseSettings
|
||||
from supabase import Client # For interacting with Supabase database
|
||||
from supabase import create_client
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from vectorstore.supabase import (
|
||||
CustomSupabaseVectorStore,
|
||||
) # Custom class for handling vector storage with Supabase
|
||||
@ -34,6 +23,7 @@ from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnswerConversationBufferMemory(ConversationBufferMemory):
|
||||
"""
|
||||
This class is a specialized version of ConversationBufferMemory.
|
||||
@ -48,7 +38,7 @@ class AnswerConversationBufferMemory(ConversationBufferMemory):
|
||||
)
|
||||
|
||||
|
||||
def get_chat_history(inputs) -> str:
|
||||
def format_chat_history(inputs) -> str:
|
||||
"""
|
||||
Function to concatenate chat history into a single string.
|
||||
:param inputs: List of tuples containing human and AI messages.
|
||||
@ -76,18 +66,38 @@ class BrainPicking(BaseModel):
|
||||
llm: LLM = None
|
||||
question_generator: LLMChain = None
|
||||
doc_chain: ConversationalRetrievalChain = None
|
||||
chat_id: str
|
||||
max_tokens: int = 256
|
||||
|
||||
class Config:
|
||||
# Allowing arbitrary types for class validation
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def init(self, model: str, user_id: str) -> "BrainPicking":
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
max_tokens: int,
|
||||
user_openai_api_key: str,
|
||||
) -> "BrainPicking":
|
||||
"""
|
||||
Initialize the BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
|
||||
:param model: Language model name to be used.
|
||||
:param user_id: The user id to be used for CustomSupabaseVectorStore.
|
||||
:return: BrainPicking instance
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
max_tokens=max_tokens,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
# If user provided an API key, update the settings
|
||||
if user_openai_api_key is not None:
|
||||
self.settings.openai_api_key = user_openai_api_key
|
||||
|
||||
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key)
|
||||
self.supabase_client = create_client(
|
||||
self.settings.supabase_url, self.settings.supabase_service_key
|
||||
@ -98,7 +108,7 @@ class BrainPicking(BaseModel):
|
||||
table_name="vectors",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
self.llm = self._determine_llm(
|
||||
private_model_args={
|
||||
"model_path": self.llm_config.model_path,
|
||||
@ -112,7 +122,8 @@ class BrainPicking(BaseModel):
|
||||
llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT
|
||||
)
|
||||
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
|
||||
return self
|
||||
self.chat_id = chat_id
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _determine_llm(
|
||||
self, private_model_args: dict, private: bool = False, model_name: str = None
|
||||
@ -128,7 +139,7 @@ class BrainPicking(BaseModel):
|
||||
model_path = private_model_args["model_path"]
|
||||
model_n_ctx = private_model_args["n_ctx"]
|
||||
model_n_batch = private_model_args["n_batch"]
|
||||
|
||||
|
||||
logger.info("Using private model: %s", model_path)
|
||||
|
||||
return GPT4All(
|
||||
@ -142,7 +153,7 @@ class BrainPicking(BaseModel):
|
||||
return ChatOpenAI(temperature=0, model_name=model_name)
|
||||
|
||||
def _get_qa(
|
||||
self, chat_message: ChatMessage, user_openai_api_key
|
||||
self,
|
||||
) -> ConversationalRetrievalChain:
|
||||
"""
|
||||
Retrieves a QA chain for the given chat message and API key.
|
||||
@ -150,42 +161,34 @@ class BrainPicking(BaseModel):
|
||||
:param user_openai_api_key: The OpenAI API key to be used.
|
||||
:return: ConversationalRetrievalChain instance
|
||||
"""
|
||||
# If user provided an API key, update the settings
|
||||
if user_openai_api_key is not None and user_openai_api_key != "":
|
||||
self.settings.openai_api_key = user_openai_api_key
|
||||
|
||||
# Initialize and return a ConversationalRetrievalChain
|
||||
qa = ConversationalRetrievalChain(
|
||||
retriever=self.vector_store.as_retriever(),
|
||||
max_tokens_limit=chat_message.max_tokens,
|
||||
max_tokens_limit=self.max_tokens,
|
||||
question_generator=self.question_generator,
|
||||
combine_docs_chain=self.doc_chain,
|
||||
get_chat_history=get_chat_history,
|
||||
get_chat_history=format_chat_history,
|
||||
)
|
||||
return qa
|
||||
|
||||
def generate_answer(self, chat_message: ChatMessage, user_openai_api_key) -> str:
|
||||
def generate_answer(self, question: str) -> str:
|
||||
"""
|
||||
Generate an answer to a given chat message by interacting with the language model.
|
||||
:param chat_message: The chat message containing history.
|
||||
:param user_openai_api_key: The OpenAI API key to be used.
|
||||
Generate an answer to a given question by interacting with the language model.
|
||||
:param question: The question
|
||||
:return: The generated answer.
|
||||
"""
|
||||
transformed_history = []
|
||||
|
||||
# Get the QA chain
|
||||
qa = self._get_qa(chat_message, user_openai_api_key)
|
||||
qa = self._get_qa()
|
||||
history = get_chat_history(self.chat_id)
|
||||
|
||||
# Transform the chat history into a list of tuples
|
||||
for i in range(0, len(chat_message.history) - 1, 2):
|
||||
user_message = chat_message.history[i][1]
|
||||
assistant_message = chat_message.history[i + 1][1]
|
||||
transformed_history.append((user_message, assistant_message))
|
||||
# Format the chat history into a list of tuples (human, ai)
|
||||
transformed_history = [(chat.user_message, chat.assistant) for chat in history]
|
||||
|
||||
# Generate the model response using the QA chain
|
||||
model_response = qa(
|
||||
{"question": chat_message.question, "chat_history": transformed_history}
|
||||
)
|
||||
model_response = qa({"question": question, "chat_history": transformed_history})
|
||||
answer = model_response["answer"]
|
||||
|
||||
return answer
|
||||
|
@ -6,4 +6,4 @@ Chat History:
|
||||
{chat_history}
|
||||
Follow Up Input: {question}
|
||||
Standalone question:"""
|
||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
||||
|
@ -9,4 +9,4 @@ Question: {question}
|
||||
Helpful Answer:"""
|
||||
QA_PROMPT = PromptTemplate(
|
||||
template=prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
)
|
||||
|
@ -6,8 +6,7 @@ def get_logger(logger_name, log_level=logging.INFO):
|
||||
logger.setLevel(log_level)
|
||||
logger.propagate = False # Prevent log propagation to avoid double logging
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s [%(levelname)s] %(name)s: %(message)s')
|
||||
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import pypandoc
|
||||
from fastapi import FastAPI
|
||||
from logger import get_logger
|
||||
from middlewares.cors import add_cors_middleware
|
||||
from routes.api_key_routes import api_key_router
|
||||
@ -37,3 +37,11 @@ app.include_router(upload_router)
|
||||
app.include_router(user_router)
|
||||
app.include_router(api_key_router)
|
||||
app.include_router(stream_router)
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": exc.detail},
|
||||
)
|
||||
|
31
backend/models/chat.py
Normal file
31
backend/models/chat.py
Normal file
@ -0,0 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chat:
|
||||
chat_id: str
|
||||
user_id: str
|
||||
creation_time: str
|
||||
chat_name: str
|
||||
|
||||
def __init__(self, chat_dict: dict):
|
||||
self.chat_id = chat_dict.get("chat_id")
|
||||
self.user_id = chat_dict.get("user_id")
|
||||
self.creation_time = chat_dict.get("creation_time")
|
||||
self.chat_name = chat_dict.get("chat_name")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatHistory:
|
||||
chat_id: str
|
||||
message_id: str
|
||||
user_message: str
|
||||
assistant: str
|
||||
message_time: str
|
||||
|
||||
def __init__(self, chat_dict: dict):
|
||||
self.chat_id = chat_dict.get("chat_id")
|
||||
self.message_id = chat_dict.get("message_id")
|
||||
self.user_message = chat_dict.get("user_message")
|
||||
self.assistant = chat_dict.get("assistant")
|
||||
self.message_time = chat_dict.get("message_time")
|
@ -16,5 +16,8 @@ class ChatMessage(BaseModel):
|
||||
chat_name: Optional[str] = None
|
||||
|
||||
|
||||
class ChatAttributes(BaseModel):
|
||||
chat_name: Optional[str] = None
|
||||
class ChatQuestion(BaseModel):
|
||||
model: str = "gpt-3.5-turbo-0613"
|
||||
question: str
|
||||
temperature: float = 0.0
|
||||
max_tokens: int = 256
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Any, Dict, List, Tuple, Union
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
@ -13,26 +13,33 @@ class BrainSettings(BaseSettings):
|
||||
supabase_url: str
|
||||
supabase_service_key: str
|
||||
|
||||
|
||||
class LLMSettings(BaseSettings):
|
||||
private: bool
|
||||
model_path: str
|
||||
model_n_ctx: int
|
||||
model_n_batch: int
|
||||
|
||||
|
||||
def common_dependencies() -> dict:
|
||||
settings = BrainSettings()
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=settings.openai_api_key)
|
||||
supabase_client: Client = create_client(settings.supabase_url, settings.supabase_service_key)
|
||||
supabase_client: Client = create_client(
|
||||
settings.supabase_url, settings.supabase_service_key
|
||||
)
|
||||
documents_vector_store = SupabaseVectorStore(
|
||||
supabase_client, embeddings, table_name="vectors")
|
||||
supabase_client, embeddings, table_name="vectors"
|
||||
)
|
||||
summaries_vector_store = SupabaseVectorStore(
|
||||
supabase_client, embeddings, table_name="summaries")
|
||||
|
||||
supabase_client, embeddings, table_name="summaries"
|
||||
)
|
||||
|
||||
return {
|
||||
"supabase": supabase_client,
|
||||
"embeddings": embeddings,
|
||||
"documents_vector_store": documents_vector_store,
|
||||
"summaries_vector_store": summaries_vector_store
|
||||
"summaries_vector_store": summaries_vector_store,
|
||||
}
|
||||
|
||||
CommonsDep = Annotated[dict, Depends(common_dependencies)]
|
||||
|
||||
CommonsDep = Annotated[dict, Depends(common_dependencies)]
|
||||
|
32
backend/repository/chat/create_chat.py
Normal file
32
backend/repository/chat/create_chat.py
Normal file
@ -0,0 +1,32 @@
|
||||
from logger import get_logger
|
||||
from models.settings import common_dependencies
|
||||
from dataclasses import dataclass
|
||||
from models.chat import Chat
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateChatProperties:
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
def create_chat(user_id: str, chat_data: CreateChatProperties) -> Chat:
|
||||
commons = common_dependencies()
|
||||
|
||||
# Chat is created upon the user's first question asked
|
||||
logger.info(f"New chat entry in chats table for user {user_id}")
|
||||
|
||||
# Insert a new row into the chats table
|
||||
new_chat = {
|
||||
"user_id": user_id,
|
||||
"chat_name": chat_data.name,
|
||||
}
|
||||
insert_response = commons["supabase"].table("chats").insert(new_chat).execute()
|
||||
logger.info(f"Insert response {insert_response.data}")
|
||||
|
||||
return insert_response.data[0]
|
15
backend/repository/chat/get_chat_by_id.py
Normal file
15
backend/repository/chat/get_chat_by_id.py
Normal file
@ -0,0 +1,15 @@
|
||||
from models.chat import Chat
|
||||
from models.settings import common_dependencies
|
||||
|
||||
|
||||
def get_chat_by_id(chat_id: str) -> Chat:
|
||||
commons = common_dependencies()
|
||||
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.from_("chats")
|
||||
.select("*")
|
||||
.filter("chat_id", "eq", chat_id)
|
||||
.execute()
|
||||
)
|
||||
return Chat(response.data[0])
|
19
backend/repository/chat/get_chat_history.py
Normal file
19
backend/repository/chat/get_chat_history.py
Normal file
@ -0,0 +1,19 @@
|
||||
from models.chat import ChatHistory
|
||||
from models.settings import common_dependencies
|
||||
from typing import List # For type hinting
|
||||
|
||||
|
||||
def get_chat_history(chat_id: str) -> List[ChatHistory]:
|
||||
commons = common_dependencies()
|
||||
history: List[ChatHistory] = (
|
||||
commons["supabase"]
|
||||
.from_("chat_history")
|
||||
.select("*")
|
||||
.filter("chat_id", "eq", chat_id)
|
||||
.order("message_time", desc=False) # Add the ORDER BY clause
|
||||
.execute()
|
||||
).data
|
||||
if history is None:
|
||||
return []
|
||||
else:
|
||||
return [ChatHistory(message) for message in history]
|
15
backend/repository/chat/get_user_chats.py
Normal file
15
backend/repository/chat/get_user_chats.py
Normal file
@ -0,0 +1,15 @@
|
||||
from models.settings import common_dependencies
|
||||
from models.chat import Chat
|
||||
|
||||
|
||||
def get_user_chats(user_id: str) -> list[Chat]:
|
||||
commons = common_dependencies()
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.from_("chats")
|
||||
.select("chat_id,user_id,creation_time,chat_name")
|
||||
.filter("user_id", "eq", user_id)
|
||||
.execute()
|
||||
)
|
||||
chats = [Chat(chat_dict) for chat_dict in response.data]
|
||||
return chats
|
45
backend/repository/chat/update_chat.py
Normal file
45
backend/repository/chat/update_chat.py
Normal file
@ -0,0 +1,45 @@
|
||||
from logger import get_logger
|
||||
from models.chat import Chat
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from models.settings import common_dependencies
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatUpdatableProperties:
|
||||
chat_name: Optional[str] = None
|
||||
|
||||
def __init__(self, chat_name: Optional[str]):
|
||||
self.chat_name = chat_name
|
||||
|
||||
|
||||
def update_chat(chat_id, chat_data: ChatUpdatableProperties) -> Chat:
|
||||
commons = common_dependencies()
|
||||
|
||||
if not chat_id:
|
||||
logger.error("No chat_id provided")
|
||||
return
|
||||
|
||||
updates = {}
|
||||
|
||||
if chat_data.chat_name is not None:
|
||||
updates["chat_name"] = chat_data.chat_name
|
||||
|
||||
updated_chat = None
|
||||
|
||||
if updates:
|
||||
updated_chat = (
|
||||
commons["supabase"]
|
||||
.table("chats")
|
||||
.update(updates)
|
||||
.match({"chat_id": chat_id})
|
||||
.execute()
|
||||
).data[0]
|
||||
logger.info(f"Chat {chat_id} updated")
|
||||
else:
|
||||
logger.info(f"No updates to apply for chat {chat_id}")
|
||||
return updated_chat
|
24
backend/repository/chat/update_chat_history.py
Normal file
24
backend/repository/chat/update_chat_history.py
Normal file
@ -0,0 +1,24 @@
|
||||
from models.chat import ChatHistory
|
||||
from models.settings import common_dependencies
|
||||
from typing import List # For type hinting
|
||||
|
||||
|
||||
def update_chat_history(
|
||||
chat_id: str, user_message: str, assistant_answer: str
|
||||
) -> ChatHistory:
|
||||
commons = common_dependencies()
|
||||
response: List[ChatHistory] = (
|
||||
commons["supabase"]
|
||||
.table("chat_history")
|
||||
.insert(
|
||||
{
|
||||
"chat_id": str(chat_id),
|
||||
"user_message": user_message,
|
||||
"assistant": assistant_answer,
|
||||
}
|
||||
)
|
||||
.execute()
|
||||
).data
|
||||
if len(response) == 0:
|
||||
raise Exception("Error while updating chat history")
|
||||
return response[0]
|
@ -1,5 +1,5 @@
|
||||
pymupdf==1.22.3
|
||||
langchain==0.0.200
|
||||
langchain==0.0.207
|
||||
Markdown==3.4.3
|
||||
openai==0.27.6
|
||||
pdf2image==1.16.3
|
||||
|
@ -1,5 +1,3 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
@ -7,7 +5,7 @@ from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from logger import get_logger
|
||||
from models.brains import Brain
|
||||
from models.settings import CommonsDep, common_dependencies
|
||||
from models.settings import common_dependencies
|
||||
from models.users import User
|
||||
from pydantic import BaseModel
|
||||
from utils.users import fetch_user_id_from_credentials
|
||||
@ -50,7 +48,7 @@ async def brain_endpoint(current_user: User = Depends(get_current_user)):
|
||||
@brain_router.get(
|
||||
"/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"]
|
||||
)
|
||||
async def brain_endpoint(brain_id: UUID):
|
||||
async def get_brain_endpoint(brain_id: UUID):
|
||||
"""
|
||||
Retrieve details of a specific brain by brain ID.
|
||||
|
||||
@ -76,7 +74,7 @@ async def brain_endpoint(brain_id: UUID):
|
||||
@brain_router.delete(
|
||||
"/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"]
|
||||
)
|
||||
async def brain_endpoint(brain_id: UUID):
|
||||
async def delete_brain_endpoint(brain_id: UUID):
|
||||
"""
|
||||
Delete a specific brain by brain ID.
|
||||
"""
|
||||
@ -97,7 +95,7 @@ class BrainObject(BaseModel):
|
||||
|
||||
# create new brain
|
||||
@brain_router.post("/brains", dependencies=[Depends(AuthBearer())], tags=["Brain"])
|
||||
async def brain_endpoint(
|
||||
async def create_brain_endpoint(
|
||||
request: Request,
|
||||
brain: BrainObject,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@ -125,7 +123,7 @@ async def brain_endpoint(
|
||||
@brain_router.put(
|
||||
"/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"]
|
||||
)
|
||||
async def brain_endpoint(
|
||||
async def update_brain_endpoint(
|
||||
request: Request,
|
||||
brain_id: UUID,
|
||||
input_brain: Brain,
|
||||
|
@ -1,32 +1,37 @@
|
||||
import os
|
||||
import time
|
||||
from uuid import UUID
|
||||
from models.chat import ChatHistory
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from llm.brainpicking import BrainPicking
|
||||
from models.chats import ChatMessage, ChatAttributes
|
||||
from models.settings import CommonsDep, common_dependencies
|
||||
from fastapi import HTTPException
|
||||
from models.chat import Chat
|
||||
from models.chats import ChatQuestion
|
||||
from models.settings import common_dependencies
|
||||
from models.users import User
|
||||
from utils.chats import (create_chat, get_chat_name_from_first_question,
|
||||
update_chat)
|
||||
from utils.users import (create_user, fetch_user_id_from_credentials,
|
||||
update_user_request_count)
|
||||
from http.client import HTTPException
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from repository.chat.update_chat import update_chat, ChatUpdatableProperties
|
||||
from repository.chat.create_chat import create_chat, CreateChatProperties
|
||||
from utils.users import (
|
||||
fetch_user_id_from_credentials,
|
||||
update_user_request_count,
|
||||
)
|
||||
|
||||
|
||||
from repository.chat.get_chat_by_id import get_chat_by_id
|
||||
|
||||
from repository.chat.get_user_chats import get_user_chats
|
||||
from repository.chat.update_chat_history import update_chat_history
|
||||
|
||||
|
||||
from llm.OpenAiFunctionBasedAnswerGenerator.OpenAiFunctionBasedAnswerGenerator import (
|
||||
OpenAiFunctionBasedAnswerGenerator,
|
||||
)
|
||||
|
||||
chat_router = APIRouter()
|
||||
|
||||
|
||||
def get_user_chats(commons, user_id):
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.from_("chats")
|
||||
.select("chatId:chat_id, chatName:chat_name")
|
||||
.filter("user_id", "eq", user_id)
|
||||
.execute()
|
||||
)
|
||||
return response.data
|
||||
|
||||
|
||||
def get_chat_details(commons, chat_id):
|
||||
response = (
|
||||
commons["supabase"]
|
||||
@ -69,32 +74,14 @@ async def get_chats(current_user: User = Depends(get_current_user)):
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})
|
||||
chats = get_user_chats(commons, user_id)
|
||||
chats = get_user_chats(user_id)
|
||||
return {"chats": chats}
|
||||
|
||||
|
||||
# get one chat
|
||||
@chat_router.get("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def get_chat_handler(chat_id: UUID):
|
||||
"""
|
||||
Retrieve details of a specific chat by chat ID.
|
||||
|
||||
- `chat_id`: The ID of the chat to retrieve details for.
|
||||
- Returns the chat ID and its history.
|
||||
|
||||
This endpoint retrieves the details of a specific chat identified by the provided chat ID. It returns the chat ID and its
|
||||
history, which includes the chat messages exchanged in the chat.
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
chats = get_chat_details(commons, chat_id)
|
||||
if len(chats) > 0:
|
||||
return {"chatId": chat_id, "history": chats[0]["history"]}
|
||||
else:
|
||||
return {"error": "Chat not found"}
|
||||
|
||||
|
||||
# delete one chat
|
||||
@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
@chat_router.delete(
|
||||
"/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"]
|
||||
)
|
||||
async def delete_chat(chat_id: UUID):
|
||||
"""
|
||||
Delete a specific chat by chat ID.
|
||||
@ -104,91 +91,118 @@ async def delete_chat(chat_id: UUID):
|
||||
return {"message": f"{chat_id} has been deleted."}
|
||||
|
||||
|
||||
# helper method for update and create chat
|
||||
def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=False):
|
||||
date = time.strftime("%Y%m%d")
|
||||
user_id = fetch_user_id_from_credentials(commons, {"email": email})
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
user_openai_api_key = request.headers.get("Openai-Api-Key")
|
||||
|
||||
userItem = fetch_user_stats(commons, User(email=email), date)
|
||||
old_request_count = userItem["requests_count"]
|
||||
|
||||
history = chat_message.history
|
||||
history.append(("user", chat_message.question))
|
||||
|
||||
chat_name = chat_message.chat_name
|
||||
|
||||
if old_request_count == 0:
|
||||
create_user(commons, email=email, date=date)
|
||||
else:
|
||||
update_user_request_count(
|
||||
commons, email, date, requests_count=old_request_count + 1
|
||||
)
|
||||
if user_openai_api_key is None and old_request_count >= float(max_requests_number):
|
||||
history.append(("assistant", "You have reached your requests limit"))
|
||||
update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name)
|
||||
return {"history": history}
|
||||
|
||||
brainPicking = BrainPicking().init(chat_message.model, email)
|
||||
answer = brainPicking.generate_answer(chat_message, user_openai_api_key)
|
||||
history.append(("assistant", answer))
|
||||
|
||||
if is_new_chat:
|
||||
chat_name = get_chat_name_from_first_question(chat_message)
|
||||
new_chat = create_chat(commons, user_id, history, chat_name)
|
||||
chat_id = new_chat.data[0]["chat_id"]
|
||||
else:
|
||||
update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name)
|
||||
|
||||
return {"history": history, "chatId": chat_id}
|
||||
|
||||
|
||||
# update existing chat
|
||||
@chat_router.put("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def chat_endpoint(
|
||||
request: Request,
|
||||
commons: CommonsDep,
|
||||
chat_id: UUID,
|
||||
chat_message: ChatMessage,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update an existing chat with new chat messages.
|
||||
"""
|
||||
return chat_handler(request, commons, chat_id, chat_message, current_user.email)
|
||||
|
||||
|
||||
# update existing chat
|
||||
@chat_router.put("/chat/{chat_id}/metadata", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def update_chat_attributes_handler(
|
||||
commons: CommonsDep,
|
||||
chat_message: ChatAttributes,
|
||||
# update existing chat metadata
|
||||
@chat_router.put(
|
||||
"/chat/{chat_id}/metadata", dependencies=[Depends(AuthBearer())], tags=["Chat"]
|
||||
)
|
||||
async def update_chat_metadata_handler(
|
||||
chat_data: ChatUpdatableProperties,
|
||||
chat_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
) -> Chat:
|
||||
"""
|
||||
Update chat attributes
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
|
||||
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})
|
||||
chat = get_chat_details(commons, chat_id)[0]
|
||||
if user_id != chat.get('user_id'):
|
||||
chat = get_chat_by_id(chat_id)
|
||||
if user_id != chat.user_id:
|
||||
raise HTTPException(status_code=403, detail="Chat not owned by user")
|
||||
return update_chat(commons=commons, chat_id=chat_id, chat_name=chat_message.chat_name)
|
||||
return update_chat(chat_id=chat_id, chat_data=chat_data)
|
||||
|
||||
|
||||
# helper method for update and create chat
|
||||
def check_user_limit(
|
||||
email,
|
||||
):
|
||||
date = time.strftime("%Y%m%d")
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
commons = common_dependencies()
|
||||
userItem = fetch_user_stats(commons, User(email=email), date)
|
||||
old_request_count = userItem["requests_count"]
|
||||
|
||||
update_user_request_count(
|
||||
commons, email, date, requests_count=old_request_count + 1
|
||||
)
|
||||
if old_request_count >= float(max_requests_number):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="You have reached the maximum number of requests for today.",
|
||||
)
|
||||
|
||||
|
||||
# create new chat
|
||||
@chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def create_chat_handler(
|
||||
request: Request,
|
||||
commons: CommonsDep,
|
||||
chat_message: ChatMessage,
|
||||
chat_data: CreateChatProperties,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new chat with initial chat messages.
|
||||
"""
|
||||
return chat_handler(
|
||||
request, commons, None, chat_message, current_user.email, is_new_chat=True
|
||||
)
|
||||
|
||||
commons = common_dependencies()
|
||||
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})
|
||||
return create_chat(user_id=user_id, chat_data=chat_data)
|
||||
|
||||
|
||||
# add new question to chat
|
||||
@chat_router.post(
|
||||
"/chat/{chat_id}/question", dependencies=[Depends(AuthBearer())], tags=["Chat"]
|
||||
)
|
||||
async def create_question_handler(
|
||||
request: Request,
|
||||
chat_question: ChatQuestion,
|
||||
chat_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> ChatHistory:
|
||||
try:
|
||||
check_user_limit(current_user.email)
|
||||
user_openai_api_key = request.headers.get("Openai-Api-Key")
|
||||
openai_function_compatible_models = [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-4-0613",
|
||||
]
|
||||
if chat_question.model in openai_function_compatible_models:
|
||||
# TODO: RBAC with current_user
|
||||
gpt_answer_generator = OpenAiFunctionBasedAnswerGenerator(
|
||||
model=chat_question.model,
|
||||
chat_id=chat_id,
|
||||
temperature=chat_question.temperature,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
# TODO: use user_id in vectors table instead of email
|
||||
user_email=current_user.email,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
answer = gpt_answer_generator.get_answer(chat_question.question)
|
||||
else:
|
||||
brainPicking = BrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
model=chat_question.model,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
user_id=current_user.email,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
answer = brainPicking.generate_answer(chat_question.question)
|
||||
|
||||
chat_answer = update_chat_history(
|
||||
chat_id=chat_id,
|
||||
user_message=chat_question.question,
|
||||
assistant_answer=answer,
|
||||
)
|
||||
return chat_answer
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
|
||||
|
||||
# get chat history
|
||||
@chat_router.get(
|
||||
"/chat/{chat_id}/history", dependencies=[Depends(AuthBearer())], tags=["Chat"]
|
||||
)
|
||||
async def get_chat_history_handler(
|
||||
chat_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> list[ChatHistory]:
|
||||
# TODO: RBAC with current_user
|
||||
return get_chat_history(chat_id)
|
||||
|
@ -5,49 +5,73 @@ from models.users import User
|
||||
|
||||
explore_router = APIRouter()
|
||||
|
||||
|
||||
def get_unique_user_data(commons, user):
|
||||
"""
|
||||
Retrieve unique user data vectors.
|
||||
"""
|
||||
response = commons['supabase'].table("vectors").select(
|
||||
"name:metadata->>file_name, size:metadata->>file_size", count="exact").filter("user_id", "eq", user.email).execute()
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.table("vectors")
|
||||
.select("name:metadata->>file_name, size:metadata->>file_size", count="exact")
|
||||
.filter("user_id", "eq", user.email)
|
||||
.execute()
|
||||
)
|
||||
documents = response.data # Access the data from the response
|
||||
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
|
||||
unique_data = [dict(t) for t in set(tuple(d.items()) for d in documents)]
|
||||
return unique_data
|
||||
|
||||
|
||||
@explore_router.get("/explore", dependencies=[Depends(AuthBearer())], tags=["Explore"])
|
||||
async def explore_endpoint( current_user: User = Depends(get_current_user)):
|
||||
async def explore_endpoint(current_user: User = Depends(get_current_user)):
|
||||
"""
|
||||
Retrieve and explore unique user data vectors.
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
unique_data = get_unique_user_data(commons, current_user)
|
||||
unique_data.sort(key=lambda x: int(x['size']), reverse=True)
|
||||
unique_data.sort(key=lambda x: int(x["size"]), reverse=True)
|
||||
return {"documents": unique_data}
|
||||
|
||||
|
||||
@explore_router.delete("/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"])
|
||||
async def delete_endpoint( file_name: str, credentials: dict = Depends(AuthBearer())):
|
||||
@explore_router.delete(
|
||||
"/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"]
|
||||
)
|
||||
async def delete_endpoint(file_name: str, credentials: dict = Depends(AuthBearer())):
|
||||
"""
|
||||
Delete a specific user file by file name.
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
user = User(email=credentials.get("email", "none"))
|
||||
# Cascade delete the summary from the database first, because it has a foreign key constraint
|
||||
commons['supabase'].table("summaries").delete().match(
|
||||
{"metadata->>file_name": file_name}).execute()
|
||||
commons['supabase'].table("vectors").delete().match(
|
||||
{"metadata->>file_name": file_name, "user_id": user.email}).execute()
|
||||
commons["supabase"].table("summaries").delete().match(
|
||||
{"metadata->>file_name": file_name}
|
||||
).execute()
|
||||
commons["supabase"].table("vectors").delete().match(
|
||||
{"metadata->>file_name": file_name, "user_id": user.email}
|
||||
).execute()
|
||||
return {"message": f"{file_name} of user {user.email} has been deleted."}
|
||||
|
||||
@explore_router.get("/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"])
|
||||
async def download_endpoint( file_name: str, current_user: User = Depends(get_current_user)):
|
||||
|
||||
@explore_router.get(
|
||||
"/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"]
|
||||
)
|
||||
async def download_endpoint(
|
||||
file_name: str, current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Download a specific user file by file name.
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
response = commons['supabase'].table("vectors").select(
|
||||
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url", "content").match({"metadata->>file_name": file_name, "user_id": current_user.email}).execute()
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.table("vectors")
|
||||
.select(
|
||||
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url",
|
||||
"content",
|
||||
)
|
||||
.match({"metadata->>file_name": file_name, "user_id": current_user.email})
|
||||
.execute()
|
||||
)
|
||||
documents = response.data
|
||||
return {"documents": documents}
|
||||
|
@ -1,48 +1,10 @@
|
||||
from logger import get_logger
|
||||
from models.chats import ChatMessage
|
||||
from models.settings import CommonsDep
|
||||
from models.settings import common_dependencies
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_chat(commons: CommonsDep, user_id, history, chat_name):
|
||||
# Chat is created upon the user's first question asked
|
||||
logger.info(f"New chat entry in chats table for user {user_id}")
|
||||
|
||||
# Insert a new row into the chats table
|
||||
new_chat = {
|
||||
"user_id": user_id,
|
||||
"history": history, # Empty chat to start
|
||||
"chat_name": chat_name,
|
||||
}
|
||||
insert_response = commons["supabase"].table("chats").insert(new_chat).execute()
|
||||
logger.info(f"Insert response {insert_response.data}")
|
||||
|
||||
return insert_response
|
||||
|
||||
|
||||
def update_chat(commons: CommonsDep, chat_id, history=None, chat_name=None):
|
||||
if not chat_id:
|
||||
logger.error("No chat_id provided")
|
||||
return
|
||||
|
||||
updates = {}
|
||||
|
||||
if history is not None:
|
||||
updates["history"] = history
|
||||
|
||||
if chat_name is not None:
|
||||
updates["chat_name"] = chat_name
|
||||
|
||||
if updates:
|
||||
commons["supabase"].table("chats").update(updates).match(
|
||||
{"chat_id": chat_id}
|
||||
).execute()
|
||||
logger.info(f"Chat {chat_id} updated")
|
||||
else:
|
||||
logger.info(f"No updates to apply for chat {chat_id}")
|
||||
|
||||
|
||||
def get_chat_name_from_first_question(chat_message: ChatMessage):
|
||||
# Step 1: Get the summary of the first question
|
||||
# first_question_summary = summarize_as_title(chat_message.question)
|
||||
|
@ -6,32 +6,45 @@ from models.users import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_user(commons: CommonsDep, email, date):
|
||||
logger.info(f"New user entry in db document for user {email}")
|
||||
|
||||
return(commons['supabase'].table("users").insert(
|
||||
{"email": email, "date": date, "requests_count": 1}).execute())
|
||||
return (
|
||||
commons["supabase"]
|
||||
.table("users")
|
||||
.insert({"email": email, "date": date, "requests_count": 1})
|
||||
.execute()
|
||||
)
|
||||
|
||||
|
||||
def update_user_request_count(commons: CommonsDep, email, date, requests_count):
|
||||
logger.info(f"User {email} request count updated to {requests_count}")
|
||||
commons['supabase'].table("users").update(
|
||||
{ "requests_count": requests_count}).match({"email": email, "date": date}).execute()
|
||||
|
||||
def fetch_user_id_from_credentials(commons: CommonsDep,credentials):
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
commons["supabase"].table("users").update({"requests_count": requests_count}).match(
|
||||
{"email": email, "date": date}
|
||||
).execute()
|
||||
|
||||
|
||||
def fetch_user_id_from_credentials(commons: CommonsDep, credentials):
|
||||
user = User(email=credentials.get("email", "none"))
|
||||
|
||||
# Fetch the user's UUID based on their email
|
||||
response = commons['supabase'].from_('users').select('user_id').filter("email", "eq", user.email).execute()
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.from_("users")
|
||||
.select("user_id")
|
||||
.filter("email", "eq", user.email)
|
||||
.execute()
|
||||
)
|
||||
|
||||
userItem = next(iter(response.data or []), {})
|
||||
|
||||
if userItem == {}:
|
||||
if userItem == {}:
|
||||
date = time.strftime("%Y%m%d")
|
||||
create_user_response = create_user(commons, email= user.email, date=date)
|
||||
user_id = create_user_response.data[0]['user_id']
|
||||
create_user_response = create_user(commons, email=user.email, date=date)
|
||||
user_id = create_user_response.data[0]["user_id"]
|
||||
|
||||
else:
|
||||
user_id = userItem['user_id']
|
||||
else:
|
||||
user_id = userItem["user_id"]
|
||||
|
||||
return user_id
|
||||
|
||||
|
@ -9,32 +9,44 @@ from pydantic import BaseModel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class Neurons(BaseModel):
|
||||
|
||||
class Neurons(BaseModel):
|
||||
commons: CommonsDep
|
||||
settings = BrainSettings()
|
||||
|
||||
|
||||
def create_vector(self, user_id, doc, user_openai_api_key=None):
|
||||
logger.info(f"Creating vector for document")
|
||||
logger.info(f"Document: {doc}")
|
||||
if user_openai_api_key:
|
||||
self.commons['documents_vector_store']._embedding = OpenAIEmbeddings(openai_api_key=user_openai_api_key)
|
||||
self.commons["documents_vector_store"]._embedding = OpenAIEmbeddings(
|
||||
openai_api_key=user_openai_api_key
|
||||
)
|
||||
try:
|
||||
sids = self.commons['documents_vector_store'].add_documents([doc])
|
||||
sids = self.commons["documents_vector_store"].add_documents([doc])
|
||||
if sids and len(sids) > 0:
|
||||
self.commons['supabase'].table("vectors").update({"user_id": user_id}).match({"id": sids[0]}).execute()
|
||||
self.commons["supabase"].table("vectors").update(
|
||||
{"user_id": user_id}
|
||||
).match({"id": sids[0]}).execute()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating vector for document {e}")
|
||||
|
||||
def create_embedding(self, content):
|
||||
return self.commons['embeddings'].embed_query(content)
|
||||
return self.commons["embeddings"].embed_query(content)
|
||||
|
||||
def similarity_search(self, query, table='match_summaries', top_k=5, threshold=0.5):
|
||||
def similarity_search(self, query, table="match_summaries", top_k=5, threshold=0.5):
|
||||
query_embedding = self.create_embedding(query)
|
||||
summaries = self.commons['supabase'].rpc(
|
||||
table, {'query_embedding': query_embedding,
|
||||
'match_count': top_k, 'match_threshold': threshold}
|
||||
).execute()
|
||||
summaries = (
|
||||
self.commons["supabase"]
|
||||
.rpc(
|
||||
table,
|
||||
{
|
||||
"query_embedding": query_embedding,
|
||||
"match_count": top_k,
|
||||
"match_threshold": threshold,
|
||||
},
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
return summaries.data
|
||||
|
||||
|
||||
@ -42,11 +54,10 @@ def create_summary(commons: CommonsDep, document_id, content, metadata):
|
||||
logger.info(f"Summarizing document {content[:100]}")
|
||||
summary = llm_summerize(content)
|
||||
logger.info(f"Summary: {summary}")
|
||||
metadata['document_id'] = document_id
|
||||
summary_doc_with_metadata = Document(
|
||||
page_content=summary, metadata=metadata)
|
||||
sids = commons['summaries_vector_store'].add_documents(
|
||||
[summary_doc_with_metadata])
|
||||
metadata["document_id"] = document_id
|
||||
summary_doc_with_metadata = Document(page_content=summary, metadata=metadata)
|
||||
sids = commons["summaries_vector_store"].add_documents([summary_doc_with_metadata])
|
||||
if sids and len(sids) > 0:
|
||||
commons['supabase'].table("summaries").update(
|
||||
{"document_id": document_id}).match({"id": sids[0]}).execute()
|
||||
commons["supabase"].table("summaries").update(
|
||||
{"document_id": document_id}
|
||||
).match({"id": sids[0]}).execute()
|
||||
|
@ -7,18 +7,26 @@ from supabase import Client
|
||||
|
||||
|
||||
class CustomSupabaseVectorStore(SupabaseVectorStore):
|
||||
'''A custom vector store that uses the match_vectors table instead of the vectors table.'''
|
||||
"""A custom vector store that uses the match_vectors table instead of the vectors table."""
|
||||
|
||||
user_id: str
|
||||
def __init__(self, client: Client, embedding: OpenAIEmbeddings, table_name: str, user_id: str = "none"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
embedding: OpenAIEmbeddings,
|
||||
table_name: str,
|
||||
user_id: str = "none",
|
||||
):
|
||||
super().__init__(client, embedding, table_name)
|
||||
self.user_id = user_id
|
||||
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
table: str = "match_vectors",
|
||||
k: int = 6,
|
||||
threshold: float = 0.5,
|
||||
self,
|
||||
query: str,
|
||||
table: str = "match_vectors",
|
||||
k: int = 6,
|
||||
threshold: float = 0.5,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
vectors = self._embedding.embed_documents([query])
|
||||
@ -46,4 +54,4 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
|
||||
|
||||
documents = [doc for doc, _ in match_result]
|
||||
|
||||
return documents
|
||||
return documents
|
||||
|
51
docs/docs/backend/api/chat.md
Normal file
51
docs/docs/backend/api/chat.md
Normal file
@ -0,0 +1,51 @@
|
||||
---
|
||||
sidebar_position: 2
|
||||
---
|
||||
|
||||
# Chat system
|
||||
|
||||
**URL**: https://api.quivr.app/chat
|
||||
|
||||
**Swagger**: https://api.quivr.app/docs
|
||||
|
||||
## Overview
|
||||
|
||||
Users can create multiple chat sessions, each with its own set of chat messages. The application provides endpoints to perform various operations such as retrieving all chats for the current user, deleting specific chats, updating chat attributes, creating new chats with initial messages, adding new questions to existing chats, and retrieving the chat history. These features enable users to communicate and interact with their data in a conversational manner.
|
||||
|
||||
1. **Retrieve all chats for the current user:**
|
||||
|
||||
- HTTP method: GET
|
||||
- Endpoint: `/chat`
|
||||
- Description: This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects containing the chat ID and chat name for each chat.
|
||||
|
||||
2. **Delete a specific chat by chat ID:**
|
||||
|
||||
- HTTP method: DELETE
|
||||
- Endpoint: `/chat/{chat_id}`
|
||||
- Description: This endpoint allows deleting a specific chat identified by its chat ID.
|
||||
|
||||
3. **Update chat attributes:**
|
||||
|
||||
- HTTP method: PUT
|
||||
- Endpoint: `/chat/{chat_id}/metadata`
|
||||
- Description: This endpoint is used to update the attributes of a chat, such as the chat name.
|
||||
|
||||
4. **Create a new chat with initial chat messages:**
|
||||
|
||||
- HTTP method: POST
|
||||
- Endpoint: `/chat`
|
||||
- Description: This endpoint creates a new chat with initial chat messages. It expects the chat name in the request payload.
|
||||
|
||||
5. **Add a new question to a chat:**
|
||||
|
||||
- HTTP method: POST
|
||||
- Endpoint: `/chat/{chat_id}/question`
|
||||
- Description: This endpoint allows adding a new question to a chat. It generates an answer for the question using different models based on the provided model type.
|
||||
|
||||
Models like gpt-4-0613 and gpt-3.5-turbo-0613 use a custom OpenAI function-based answer generator.
|
||||
![Function based answer generator](../../../static/img/answer_schema.png)
|
||||
|
||||
6. **Get the chat history:**
|
||||
- HTTP method: GET
|
||||
- Endpoint: `/chat/{chat_id}/history`
|
||||
- Description: This endpoint retrieves the chat history for a specific chat identified by its chat ID.
|
BIN
docs/static/img/answer_schema.png
vendored
Normal file
BIN
docs/static/img/answer_schema.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.0 MiB |
48
frontend/app/chat/[chatId]/context/ChatContext.tsx
Normal file
48
frontend/app/chat/[chatId]/context/ChatContext.tsx
Normal file
@ -0,0 +1,48 @@
|
||||
"use client";
|
||||
|
||||
import { createContext, useContext, useState } from "react";
|
||||
|
||||
import { ChatHistory } from "../types";
|
||||
|
||||
type ChatContextProps = {
|
||||
history: ChatHistory[];
|
||||
setHistory: (history: ChatHistory[]) => void;
|
||||
addToHistory: (message: ChatHistory) => void;
|
||||
};
|
||||
|
||||
export const ChatContext = createContext<ChatContextProps | undefined>(
|
||||
undefined
|
||||
);
|
||||
|
||||
export const ChatProvider = ({
|
||||
children,
|
||||
}: {
|
||||
children: JSX.Element | JSX.Element[];
|
||||
}): JSX.Element => {
|
||||
const [history, setHistory] = useState<ChatHistory[]>([]);
|
||||
const addToHistory = (message: ChatHistory) => {
|
||||
setHistory((prevHistory) => [...prevHistory, message]);
|
||||
};
|
||||
|
||||
return (
|
||||
<ChatContext.Provider
|
||||
value={{
|
||||
history,
|
||||
setHistory,
|
||||
addToHistory,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</ChatContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const useChatContext = (): ChatContextProps => {
|
||||
const context = useContext(ChatContext);
|
||||
|
||||
if (context === undefined) {
|
||||
throw new Error("useChatContext must be used inside ChatProvider");
|
||||
}
|
||||
|
||||
return context;
|
||||
};
|
89
frontend/app/chat/[chatId]/hooks/useChat.ts
Normal file
89
frontend/app/chat/[chatId]/hooks/useChat.ts
Normal file
@ -0,0 +1,89 @@
|
||||
import { AxiosError } from "axios";
|
||||
import { useParams } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
import { useBrainConfig } from "@/lib/context/BrainConfigProvider/hooks/useBrainConfig";
|
||||
import { useToast } from "@/lib/hooks";
|
||||
|
||||
import { useChatService } from "./useChatService";
|
||||
import { useChatContext } from "../context/ChatContext";
|
||||
import { ChatQuestion } from "../types";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
||||
export const useChat = () => {
|
||||
const params = useParams();
|
||||
const [chatId, setChatId] = useState<string | undefined>(
|
||||
params?.chatId as string | undefined
|
||||
);
|
||||
const [generatingAnswer, setGeneratingAnswer] = useState(false);
|
||||
const {
|
||||
config: { maxTokens, model, temperature },
|
||||
} = useBrainConfig();
|
||||
const { history, setHistory, addToHistory } = useChatContext();
|
||||
const { publish } = useToast();
|
||||
|
||||
const {
|
||||
createChat,
|
||||
getChatHistory,
|
||||
addQuestion: addQuestionToChat,
|
||||
} = useChatService();
|
||||
|
||||
useEffect(() => {
|
||||
const fetchHistory = async () => {
|
||||
const chatHistory = await getChatHistory(chatId);
|
||||
setHistory(chatHistory);
|
||||
};
|
||||
void fetchHistory();
|
||||
}, [chatId]);
|
||||
|
||||
const generateNewChatIdFromName = async (
|
||||
chatName: string
|
||||
): Promise<string> => {
|
||||
const rep = await createChat({ name: chatName });
|
||||
setChatId(rep.data.chat_id);
|
||||
|
||||
return rep.data.chat_id;
|
||||
};
|
||||
|
||||
const addQuestion = async (question: string, callback?: () => void) => {
|
||||
const chatQuestion: ChatQuestion = {
|
||||
model,
|
||||
question,
|
||||
temperature,
|
||||
max_tokens: maxTokens,
|
||||
};
|
||||
|
||||
try {
|
||||
setGeneratingAnswer(true);
|
||||
const currentChatId =
|
||||
chatId ??
|
||||
// if chatId is undefined, we need to create a new chat on fly
|
||||
(await generateNewChatIdFromName(
|
||||
question.split(" ").slice(0, 3).join(" ")
|
||||
));
|
||||
const answer = await addQuestionToChat(currentChatId, chatQuestion);
|
||||
addToHistory(answer);
|
||||
callback?.();
|
||||
} catch (error) {
|
||||
console.error({ error });
|
||||
|
||||
if ((error as AxiosError).response?.status === 429) {
|
||||
publish({
|
||||
variant: "danger",
|
||||
text: "You have reached the limit of requests, please try again later",
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
publish({
|
||||
variant: "danger",
|
||||
text: "Error occurred while getting answer",
|
||||
});
|
||||
} finally {
|
||||
setGeneratingAnswer(false);
|
||||
}
|
||||
};
|
||||
|
||||
return { history, addQuestion, generatingAnswer };
|
||||
};
|
40
frontend/app/chat/[chatId]/hooks/useChatService.ts
Normal file
40
frontend/app/chat/[chatId]/hooks/useChatService.ts
Normal file
@ -0,0 +1,40 @@
|
||||
import { useAxios } from "@/lib/hooks";
|
||||
|
||||
import { ChatEntity, ChatHistory, ChatQuestion } from "../types";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
||||
export const useChatService = () => {
|
||||
const { axiosInstance } = useAxios();
|
||||
|
||||
const createChat = async ({ name }: { name: string }) => {
|
||||
return axiosInstance.post<ChatEntity>(`/chat`, { name });
|
||||
};
|
||||
|
||||
const getChatHistory = async (chatId: string | undefined) => {
|
||||
if (chatId === undefined) {
|
||||
return [];
|
||||
}
|
||||
const rep = (
|
||||
await axiosInstance.get<ChatHistory[]>(`/chat/${chatId}/history`)
|
||||
).data;
|
||||
|
||||
return rep;
|
||||
};
|
||||
const addQuestion = async (
|
||||
chatId: string,
|
||||
chatQuestion: ChatQuestion
|
||||
): Promise<ChatHistory> => {
|
||||
return (
|
||||
await axiosInstance.post<ChatHistory>(
|
||||
`/chat/${chatId}/question`,
|
||||
chatQuestion
|
||||
)
|
||||
).data;
|
||||
};
|
||||
|
||||
return {
|
||||
createChat,
|
||||
getChatHistory,
|
||||
addQuestion,
|
||||
};
|
||||
};
|
@ -1,31 +1,12 @@
|
||||
/* eslint-disable */
|
||||
"use client";
|
||||
import { UUID } from "crypto";
|
||||
import { useEffect } from "react";
|
||||
|
||||
import PageHeading from "@/lib/components/ui/PageHeading";
|
||||
import useChatsContext from "@/lib/context/ChatsProvider/hooks/useChatsContext";
|
||||
|
||||
import { ChatInput, ChatMessages } from "../components";
|
||||
import { ChatProvider } from "./context/ChatContext";
|
||||
|
||||
interface ChatPageProps {
|
||||
params: {
|
||||
chatId: UUID;
|
||||
};
|
||||
}
|
||||
|
||||
export default function ChatPage({ params }: ChatPageProps) {
|
||||
const chatId: UUID | undefined = params.chatId;
|
||||
|
||||
const { fetchChat, resetChat } = useChatsContext();
|
||||
|
||||
useEffect(() => {
|
||||
if (!chatId) {
|
||||
resetChat();
|
||||
}
|
||||
fetchChat(chatId);
|
||||
}, []);
|
||||
|
||||
export default function ChatPage() {
|
||||
return (
|
||||
<main className="flex flex-col w-full pt-10">
|
||||
<section className="flex flex-col flex-1 items-center w-full h-full min-h-screen">
|
||||
@ -33,12 +14,14 @@ export default function ChatPage({ params }: ChatPageProps) {
|
||||
title="Chat with your brain"
|
||||
subtitle="Talk to a language model about your uploaded data"
|
||||
/>
|
||||
<div className="relative h-full w-full flex flex-col flex-1 items-center">
|
||||
<div className="h-full flex-1 w-full flex flex-col items-center">
|
||||
<ChatMessages />
|
||||
<ChatProvider>
|
||||
<div className="relative h-full w-full flex flex-col flex-1 items-center">
|
||||
<div className="h-full flex-1 w-full flex flex-col items-center">
|
||||
<ChatMessages />
|
||||
</div>
|
||||
<ChatInput />
|
||||
</div>
|
||||
<ChatInput />
|
||||
</div>
|
||||
</ChatProvider>
|
||||
</section>
|
||||
</main>
|
||||
);
|
||||
|
22
frontend/app/chat/[chatId]/types/index.ts
Normal file
22
frontend/app/chat/[chatId]/types/index.ts
Normal file
@ -0,0 +1,22 @@
|
||||
import { UUID } from "crypto";
|
||||
|
||||
export type ChatQuestion = {
|
||||
model: string;
|
||||
question?: string;
|
||||
temperature: number;
|
||||
max_tokens: number;
|
||||
};
|
||||
export type ChatHistory = {
|
||||
chat_id: string;
|
||||
message_id: string;
|
||||
user_message: string;
|
||||
assistant: string;
|
||||
message_time: string;
|
||||
};
|
||||
|
||||
export type ChatEntity = {
|
||||
chat_id: UUID;
|
||||
user_id: string;
|
||||
creation_time: string;
|
||||
chat_name: string;
|
||||
};
|
@ -5,8 +5,14 @@ import { MdMic, MdMicOff } from "react-icons/md";
|
||||
import Button from "@/lib/components/ui/Button";
|
||||
import { useSpeech } from "@/lib/context/ChatsProvider/hooks/useSpeech";
|
||||
|
||||
export const MicButton = (): JSX.Element => {
|
||||
const { isListening, speechSupported, startListening } = useSpeech();
|
||||
type MicButtonProps = {
|
||||
setMessage: (newValue: string | ((prevValue: string) => string)) => void;
|
||||
};
|
||||
|
||||
export const MicButton = ({ setMessage }: MicButtonProps): JSX.Element => {
|
||||
const { isListening, speechSupported, startListening } = useSpeech({
|
||||
setMessage,
|
||||
});
|
||||
|
||||
return (
|
||||
<Button
|
||||
|
@ -1,35 +1,41 @@
|
||||
/* eslint-disable */
|
||||
"use client";
|
||||
import Button from "@/lib/components/ui/Button";
|
||||
import useChatsContext from "@/lib/context/ChatsProvider/hooks/useChatsContext";
|
||||
|
||||
import { useChat } from "@/app/chat/[chatId]/hooks/useChat";
|
||||
import { useState } from "react";
|
||||
import { ConfigButton } from "./ConfigButton";
|
||||
import { MicButton } from "./MicButton";
|
||||
|
||||
export const ChatInput = (): JSX.Element => {
|
||||
const { isSendingMessage, sendMessage, setMessage, message, chat } =
|
||||
useChatsContext();
|
||||
const [message, setMessage] = useState<string>(""); // for optimistic updates
|
||||
const { addQuestion, generatingAnswer } = useChat();
|
||||
|
||||
const submitQuestion = () => {
|
||||
addQuestion(message, () => setMessage(""));
|
||||
};
|
||||
return (
|
||||
<form
|
||||
onSubmit={(e) => {
|
||||
e.preventDefault();
|
||||
if (!isSendingMessage) {
|
||||
sendMessage(chat?.chatId);
|
||||
if (!generatingAnswer) {
|
||||
submitQuestion();
|
||||
}
|
||||
}}
|
||||
className="sticky bottom-0 p-5 bg-white dark:bg-black rounded-t-md border border-black/10 dark:border-white/25 border-b-0 w-full max-w-3xl flex items-center justify-center gap-2 z-20"
|
||||
>
|
||||
<textarea
|
||||
autoFocus
|
||||
value={message[1]}
|
||||
onChange={(e) => setMessage((msg) => [msg[0], e.target.value])}
|
||||
value={message}
|
||||
required
|
||||
onChange={(e) => setMessage(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (message.length === 0) return;
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault(); // Prevents the newline from being entered in the textarea
|
||||
if (!isSendingMessage) {
|
||||
sendMessage(chat?.chatId);
|
||||
} // Call the submit function here
|
||||
if (!generatingAnswer) {
|
||||
submitQuestion();
|
||||
}
|
||||
}
|
||||
}}
|
||||
className="w-full p-2 border border-gray-300 dark:border-gray-500 outline-none rounded dark:bg-gray-800"
|
||||
@ -38,12 +44,12 @@ export const ChatInput = (): JSX.Element => {
|
||||
<Button
|
||||
className="px-3 py-2 sm:px-4 sm:py-2"
|
||||
type="submit"
|
||||
isLoading={isSendingMessage}
|
||||
isLoading={generatingAnswer}
|
||||
>
|
||||
{isSendingMessage ? "Thinking..." : "Chat"}
|
||||
{generatingAnswer ? "Thinking..." : "Chat"}
|
||||
</Button>
|
||||
<div className="flex items-center">
|
||||
<MicButton />
|
||||
<MicButton setMessage={setMessage} />
|
||||
<ConfigButton />
|
||||
</div>
|
||||
</form>
|
||||
|
@ -1,52 +1,52 @@
|
||||
/* eslint-disable */
|
||||
"use client";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
|
||||
import Card from "@/lib/components/ui/Card";
|
||||
import useChatsContext from "@/lib/context/ChatsProvider/hooks/useChatsContext";
|
||||
import { useChat } from "../../[chatId]/hooks/useChat";
|
||||
import { ChatMessage } from "./ChatMessage";
|
||||
|
||||
export const ChatMessages = (): JSX.Element => {
|
||||
const lastChatRef = useRef<HTMLDivElement | null>(null);
|
||||
const { history } = useChat();
|
||||
|
||||
const { chat } = useChatsContext();
|
||||
const scrollToBottom = useCallback(() => {
|
||||
if (lastChatRef.current) {
|
||||
lastChatRef.current.scrollIntoView({
|
||||
behavior: "smooth",
|
||||
block: "start",
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!chat || !lastChatRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if (chat.history.length > 2) {
|
||||
lastChatRef.current.scrollIntoView({
|
||||
behavior: "smooth",
|
||||
block: "end",
|
||||
});
|
||||
// }
|
||||
}, [chat, lastChatRef]);
|
||||
|
||||
if (!chat) {
|
||||
return <></>;
|
||||
}
|
||||
scrollToBottom();
|
||||
}, [history, scrollToBottom]);
|
||||
|
||||
return (
|
||||
<Card className="p-5 max-w-3xl w-full flex flex-col h-full mb-8">
|
||||
<div className="flex-1">
|
||||
{chat.history.length === 0 ? (
|
||||
{history.length === 0 ? (
|
||||
<div className="text-center opacity-50">
|
||||
Ask a question, or describe a task.
|
||||
</div>
|
||||
) : (
|
||||
chat.history.map(([speaker, text], idx) => {
|
||||
return (
|
||||
history.map(({ assistant, message_id, user_message }, idx) => (
|
||||
<>
|
||||
<ChatMessage
|
||||
ref={idx === chat.history.length - 1 ? lastChatRef : null}
|
||||
key={idx}
|
||||
speaker={speaker}
|
||||
text={text}
|
||||
key={message_id}
|
||||
speaker={"user"}
|
||||
text={user_message}
|
||||
/>
|
||||
);
|
||||
})
|
||||
<ChatMessage
|
||||
key={message_id}
|
||||
speaker={"assistant"}
|
||||
text={assistant}
|
||||
/>
|
||||
</>
|
||||
))
|
||||
)}
|
||||
<div ref={lastChatRef} />
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
|
@ -5,15 +5,14 @@ import { useState } from "react";
|
||||
import { FiEdit, FiSave, FiTrash2 } from "react-icons/fi";
|
||||
import { MdChatBubbleOutline } from "react-icons/md";
|
||||
|
||||
|
||||
import { ChatEntity } from "@/app/chat/[chatId]/types";
|
||||
import { useAxios, useToast } from "@/lib/hooks";
|
||||
import { Chat, ChatResponse } from "@/lib/types/Chat";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { ChatName } from "./components/ChatName";
|
||||
|
||||
interface ChatsListItemProps {
|
||||
chat: Chat;
|
||||
chat: ChatEntity;
|
||||
deleteChat: (id: UUID) => void;
|
||||
}
|
||||
|
||||
@ -21,32 +20,31 @@ export const ChatsListItem = ({
|
||||
chat,
|
||||
deleteChat,
|
||||
}: ChatsListItemProps): JSX.Element => {
|
||||
console.log({ chat });
|
||||
const pathname = usePathname()?.split("/").at(-1);
|
||||
const selected = chat.chatId === pathname;
|
||||
const [chatName, setChatName] = useState(chat.chatName);
|
||||
const selected = chat.chat_id === pathname;
|
||||
const [chatName, setChatName] = useState(chat.chat_name);
|
||||
const { axiosInstance } = useAxios();
|
||||
const {publish} = useToast()
|
||||
const { publish } = useToast();
|
||||
const [editingName, setEditingName] = useState(false);
|
||||
|
||||
const updateChatName = async () => {
|
||||
if(chatName !== chat.chatName) {
|
||||
await axiosInstance.put<ChatResponse>(`/chat/${chat.chatId}/metadata`, {
|
||||
chat_name:chatName,
|
||||
|
||||
});
|
||||
publish({text:'Chat name updated',variant:'success'})
|
||||
const updateChatName = async () => {
|
||||
if (chatName !== chat.chat_name) {
|
||||
await axiosInstance.put<ChatEntity>(`/chat/${chat.chat_id}/metadata`, {
|
||||
chat_name: chatName,
|
||||
});
|
||||
publish({ text: "Chat name updated", variant: "success" });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleEditNameClick = () => {
|
||||
if(editingName){
|
||||
setEditingName(false) ;
|
||||
void updateChatName()
|
||||
if (editingName) {
|
||||
setEditingName(false);
|
||||
void updateChatName();
|
||||
} else {
|
||||
setEditingName(true);
|
||||
}
|
||||
else {
|
||||
setEditingName(true)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
@ -59,35 +57,32 @@ export const ChatsListItem = ({
|
||||
>
|
||||
<Link
|
||||
className="flex flex-col flex-1 min-w-0 p-4"
|
||||
href={`/chat/${chat.chatId}`}
|
||||
key={chat.chatId}
|
||||
href={`/chat/${chat.chat_id}`}
|
||||
key={chat.chat_id}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
|
||||
<MdChatBubbleOutline className="text-xl" />
|
||||
<ChatName setName={setChatName} editing={editingName} name={chatName} />
|
||||
<ChatName
|
||||
setName={setChatName}
|
||||
editing={editingName}
|
||||
name={chatName}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid-cols-2 text-xs opacity-50 whitespace-nowrap">
|
||||
{chat.chatId}
|
||||
{chat.chat_id}
|
||||
</div>
|
||||
</Link>
|
||||
<div className="opacity-0 group-hover:opacity-100 flex items-center justify-center hover:text-red-700 bg-gradient-to-l from-white dark:from-black to-transparent z-10 transition-opacity">
|
||||
<button
|
||||
className="p-0"
|
||||
type="button"
|
||||
onClick={handleEditNameClick
|
||||
}
|
||||
>
|
||||
{editingName ? <FiSave/> : <FiEdit />}
|
||||
<button className="p-0" type="button" onClick={handleEditNameClick}>
|
||||
{editingName ? <FiSave /> : <FiEdit />}
|
||||
</button>
|
||||
<button
|
||||
className="p-5"
|
||||
type="button"
|
||||
onClick={() => deleteChat(chat.chatId)}
|
||||
onClick={() => deleteChat(chat.chat_id)}
|
||||
>
|
||||
<FiTrash2 />
|
||||
</button>
|
||||
|
||||
</div>
|
||||
|
||||
{/* Fade to white */}
|
||||
|
@ -1,18 +1,23 @@
|
||||
|
||||
|
||||
interface ChatNameProps {
|
||||
name: string;
|
||||
editing?: boolean;
|
||||
setName: (name:string) => void;
|
||||
setName: (name: string) => void;
|
||||
}
|
||||
|
||||
export const ChatName = ({setName,name,editing=false}:ChatNameProps):JSX.Element => {
|
||||
|
||||
if(editing) {
|
||||
return <input onChange={(event) => setName(event.target.value)} autoFocus value={name} />
|
||||
}
|
||||
|
||||
export const ChatName = ({
|
||||
setName,
|
||||
name,
|
||||
editing = false,
|
||||
}: ChatNameProps): JSX.Element => {
|
||||
if (editing) {
|
||||
return (
|
||||
<p>{name}</p>
|
||||
)
|
||||
}
|
||||
<input
|
||||
onChange={(event) => setName(event.target.value)}
|
||||
autoFocus
|
||||
value={name}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <p>{name}</p>;
|
||||
};
|
||||
|
@ -47,7 +47,7 @@ export const ChatsList = (): JSX.Element => {
|
||||
<div className="flex flex-col gap-0">
|
||||
{allChats.map((chat) => (
|
||||
<ChatsListItem
|
||||
key={chat.chatId}
|
||||
key={chat.chat_id}
|
||||
chat={chat}
|
||||
deleteChat={deleteChat}
|
||||
/>
|
||||
|
@ -3,22 +3,16 @@ import { UUID } from "crypto";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
import { useBrainConfig } from "@/lib/context/BrainConfigProvider/hooks/useBrainConfig";
|
||||
import { useAxios } from "@/lib/hooks";
|
||||
import { useToast } from "@/lib/hooks/useToast";
|
||||
|
||||
import { Chat, ChatMessage } from "../../../types/Chat";
|
||||
import { ChatEntity } from "@/app/chat/[chatId]/types";
|
||||
|
||||
export default function useChats() {
|
||||
const [allChats, setAllChats] = useState<Chat[]>([]);
|
||||
const [chat, setChat] = useState<Chat | null>(null);
|
||||
const [isSendingMessage, setIsSendingMessage] = useState(false);
|
||||
const [message, setMessage] = useState<ChatMessage>(["", ""]); // for optimistic updates
|
||||
const [allChats, setAllChats] = useState<ChatEntity[]>([]);
|
||||
|
||||
const { axiosInstance } = useAxios();
|
||||
const {
|
||||
config: { maxTokens, model, temperature },
|
||||
} = useBrainConfig();
|
||||
|
||||
const router = useRouter();
|
||||
const { publish } = useToast();
|
||||
|
||||
@ -26,7 +20,7 @@ export default function useChats() {
|
||||
try {
|
||||
console.log("Fetching all chats");
|
||||
const response = await axiosInstance.get<{
|
||||
chats: Chat[];
|
||||
chats: ChatEntity[];
|
||||
}>(`/chat`);
|
||||
setAllChats(response.data.chats);
|
||||
console.log("Fetched all chats");
|
||||
@ -39,102 +33,10 @@ export default function useChats() {
|
||||
}
|
||||
};
|
||||
|
||||
const fetchChat = async (chatId?: UUID) => {
|
||||
if (!chatId) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
console.log(`Fetching chat ${chatId}`);
|
||||
const response = await axiosInstance.get<Chat>(`/chat/${chatId}`);
|
||||
console.log(response.data);
|
||||
|
||||
setChat(response.data);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
publish({
|
||||
variant: "danger",
|
||||
text: `Error occured while fetching ${chatId}`,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
type ChatResponse = Omit<Chat, "chatId"> & { chatId: UUID | undefined };
|
||||
|
||||
const createChat = async ({
|
||||
options,
|
||||
}: {
|
||||
options: Record<string, string | unknown>;
|
||||
}) => {
|
||||
await fetchAllChats();
|
||||
|
||||
return axiosInstance.post<ChatResponse>(`/chat`, options);
|
||||
};
|
||||
|
||||
const updateChat = ({
|
||||
options,
|
||||
}: {
|
||||
options: Record<string, string | unknown>;
|
||||
}) => {
|
||||
return axiosInstance.put<ChatResponse>(`/chat/${options.chat_id}`, options);
|
||||
};
|
||||
|
||||
const sendMessage = async (chatId?: UUID, msg?: ChatMessage) => {
|
||||
setIsSendingMessage(true);
|
||||
|
||||
if (msg) {
|
||||
setMessage(msg);
|
||||
}
|
||||
const options: Record<string, unknown> = {
|
||||
chat_id: chatId,
|
||||
model,
|
||||
question: msg ? msg[1] : message[1],
|
||||
history: chat ? chat.history : [],
|
||||
temperature,
|
||||
max_tokens: maxTokens,
|
||||
use_summarization: false,
|
||||
};
|
||||
|
||||
const response = await (chatId !== undefined
|
||||
? updateChat({ options })
|
||||
: createChat({ options }));
|
||||
|
||||
// response.data.chatId can be undefined when the max number of requests has reached
|
||||
if (!response.data.chatId) {
|
||||
publish({
|
||||
text: "You have reached max number of requests.",
|
||||
variant: "danger",
|
||||
});
|
||||
setMessage(["", ""]);
|
||||
setIsSendingMessage(false);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const newChat = {
|
||||
chatId: response.data.chatId,
|
||||
history: response.data.history,
|
||||
chatName: response.data.chatName,
|
||||
};
|
||||
if (!chatId) {
|
||||
// Creating a new chat
|
||||
console.log("---- Creating a new chat ----");
|
||||
setAllChats((chats) => {
|
||||
console.log({ chats });
|
||||
|
||||
return [...chats, newChat];
|
||||
});
|
||||
setChat(newChat);
|
||||
router.push(`/chat/${response.data.chatId}`);
|
||||
}
|
||||
setChat(newChat);
|
||||
setMessage(["", ""]);
|
||||
setIsSendingMessage(false);
|
||||
};
|
||||
|
||||
const deleteChat = async (chatId: UUID) => {
|
||||
try {
|
||||
await axiosInstance.delete(`/chat/${chatId}`);
|
||||
setAllChats((chats) => chats.filter((chat) => chat.chatId !== chatId));
|
||||
setAllChats((chats) => chats.filter((chat) => chat.chat_id !== chatId));
|
||||
// TODO: Change route only when the current chat is being deleted
|
||||
router.push("/chat");
|
||||
publish({
|
||||
@ -147,26 +49,12 @@ export default function useChats() {
|
||||
}
|
||||
};
|
||||
|
||||
const resetChat = async () => {
|
||||
setChat(null);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchAllChats();
|
||||
}, []);
|
||||
|
||||
return {
|
||||
allChats,
|
||||
chat,
|
||||
isSendingMessage,
|
||||
message,
|
||||
setMessage,
|
||||
resetChat,
|
||||
|
||||
fetchAllChats,
|
||||
fetchChat,
|
||||
|
||||
deleteChat,
|
||||
sendMessage,
|
||||
};
|
||||
}
|
||||
|
@ -3,14 +3,14 @@ import { useEffect, useState } from "react";
|
||||
|
||||
import { isSpeechRecognitionSupported } from "@/lib/helpers/isSpeechRecognitionSupported";
|
||||
|
||||
import useChatsContext from "./useChatsContext";
|
||||
type useSpeechProps = {
|
||||
setMessage: (newValue: string | ((prevValue: string) => string)) => void;
|
||||
};
|
||||
|
||||
export const useSpeech = () => {
|
||||
export const useSpeech = ({ setMessage }: useSpeechProps) => {
|
||||
const [isListening, setIsListening] = useState(false);
|
||||
const [speechSupported, setSpeechSupported] = useState(false);
|
||||
|
||||
const { setMessage } = useChatsContext();
|
||||
|
||||
useEffect(() => {
|
||||
if (isSpeechRecognitionSupported()) {
|
||||
setSpeechSupported(true);
|
||||
@ -39,7 +39,7 @@ export const useSpeech = () => {
|
||||
mic.onresult = (event: SpeechRecognitionEvent) => {
|
||||
const interimTranscript =
|
||||
event.results[event.results.length - 1][0].transcript;
|
||||
setMessage((prevMessage) => ["user", prevMessage + interimTranscript]);
|
||||
setMessage((prevMessage) => prevMessage + interimTranscript);
|
||||
};
|
||||
|
||||
if (isListening) {
|
||||
|
@ -1,14 +0,0 @@
|
||||
import { UUID } from "crypto";
|
||||
|
||||
export interface Chat {
|
||||
chatId: UUID;
|
||||
chatName: string;
|
||||
history: ChatHistory;
|
||||
}
|
||||
export type ChatMessage = [string, string];
|
||||
|
||||
export type ChatHistory = ChatMessage[];
|
||||
|
||||
export type ChatResponse = Omit<Chat, "chatId"> & {
|
||||
chatId: UUID | undefined;
|
||||
};
|
@ -15,6 +15,16 @@ CREATE TABLE IF NOT EXISTS chats(
|
||||
chat_name TEXT
|
||||
);
|
||||
|
||||
-- Create chat_history table
|
||||
CREATE TABLE IF NOT EXISTS chat_history (
|
||||
message_id UUID DEFAULT uuid_generate_v4(),
|
||||
chat_id UUID REFERENCES chats(chat_id),
|
||||
user_message TEXT,
|
||||
assistant TEXT,
|
||||
message_time TIMESTAMP DEFAULT current_timestamp,
|
||||
PRIMARY KEY (chat_id, message_id)
|
||||
);
|
||||
|
||||
-- Create vector extension
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user