feat(chat): use openai function for answer (#354)

* feat(chat): use openai function for answer (backend)

* feat(chat): use openai function for answer (frontend)

* chore: refacto BrainPicking

* feat: update chat creation logic

* feat: simplify chat system logic

* feat: set default method to gpt-3.5-turbo-0613

* feat: use user own openai key

* feat(chat): slightly improve prompts

* feat: add global error interceptor

* feat: remove unused endpoints

* docs: update chat system doc

* chore(linter): add unused import remove config

* feat: improve dx

* feat: improve OpenAiFunctionBasedAnswerGenerator prompt
This commit is contained in:
Mamadou DICKO 2023-06-22 17:50:06 +02:00 committed by GitHub
parent 83fde0aeea
commit 59fe7b089b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 1167 additions and 520 deletions

View File

@ -2,7 +2,8 @@
"python.formatting.provider": "black",
"editor.codeActionsOnSave": {
"source.organizeImports": true,
"source.fixAll": true
"source.fixAll": true,
"source.unusedImports": true
},
"python.linting.enabled": true,
"python.linting.flake8Enabled": true,

View File

@ -0,0 +1,245 @@
from typing import Optional
from typing import Any, Dict, List # For type hinting
from langchain.chat_models import ChatOpenAI
from repository.chat.get_chat_history import get_chat_history
from .utils.format_answer import format_answer
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
from langchain.embeddings.openai import OpenAIEmbeddings
from models.settings import BrainSettings # Importing settings related to the 'brain'
from llm.OpenAiFunctionBasedAnswerGenerator.models.OpenAiAnswer import OpenAiAnswer
from supabase import Client, create_client # For interacting with Supabase database
from vectorstore.supabase import (
CustomSupabaseVectorStore,
) # Custom class for handling vector storage with Supabase
from logger import get_logger
logger = get_logger(__name__)
get_context_function_name = "get_context"
prompt_template = """Your name is Quivr. You are a second brain.
A person will ask you a question and you will provide a helpful answer.
Write the answer in the same language as the question.
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
Your main goal is to answer questions about user uploaded documents. Unless basic questions or greetings, you should always refer to user uploaded documents by fetching them with the {} function.""".format(
get_context_function_name
)
get_history_schema = {
"name": "get_history",
"description": "Get current chat previous messages",
"parameters": {
"type": "object",
"properties": {},
},
}
get_context_schema = {
"name": get_context_function_name,
"description": "A function which returns user uploaded documents and which must be used when you don't now the answer to a question or when the question seems to refer to user uploaded documents",
"parameters": {
"type": "object",
"properties": {},
},
}
class OpenAiFunctionBasedAnswerGenerator:
# Default class attributes
model: str = "gpt-3.5-turbo-0613"
temperature: float = 0.0
max_tokens: int = 256
chat_id: str
supabase_client: Client = None
embeddings: OpenAIEmbeddings = None
settings = BrainSettings()
openai_client: ChatOpenAI = None
user_email: str
def __init__(
self,
model: str,
chat_id: str,
temperature: float,
max_tokens: int,
user_email: str,
user_openai_api_key: str,
) -> "OpenAiFunctionBasedAnswerGenerator":
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.chat_id = chat_id
self.supabase_client = create_client(
self.settings.supabase_url, self.settings.supabase_service_key
)
self.user_email = user_email
if user_openai_api_key is not None:
self.settings.openai_api_key = user_openai_api_key
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key)
self.openai_client = ChatOpenAI(openai_api_key=self.settings.openai_api_key)
def _get_model_response(
self,
messages: list[dict[str, str]] = [],
functions: list[dict[str, Any]] = None,
):
if functions is not None:
model_response = self.openai_client.completion_with_retry(
functions=functions,
messages=messages,
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
else:
model_response = self.openai_client.completion_with_retry(
messages=messages,
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
return model_response
def _get_formatted_history(self) -> List[Dict[str, str]]:
formatted_history = []
history = get_chat_history(self.chat_id)
for chat in history:
formatted_history.append({"role": "user", "content": chat.user_message})
formatted_history.append({"role": "assistant", "content": chat.assistant})
return formatted_history
def _get_formatted_prompt(
self,
question: Optional[str],
useContext: Optional[bool] = False,
useHistory: Optional[bool] = False,
) -> list[dict[str, str]]:
messages = [
{"role": "system", "content": prompt_template},
]
if not useHistory and not useContext:
messages.append(
{"role": "user", "content": question},
)
return messages
if useHistory:
history = self._get_formatted_history()
if len(history):
messages.append(
{
"role": "system",
"content": "Previous messages are already in chat.",
},
)
messages.extend(history)
else:
messages.append(
{
"role": "user",
"content": "This is the first message of the chat. There is no previous one",
}
)
messages.append(
{
"role": "user",
"content": f"Question: {question}\n\nAnswer:",
}
)
if useContext:
chat_context = self._get_context(question)
enhanced_question = f"Here is chat context: {chat_context if len(chat_context) else 'No document found'}"
messages.append({"role": "user", "content": enhanced_question})
messages.append(
{
"role": "user",
"content": f"Question: {question}\n\nAnswer:",
}
)
return messages
def _get_context(self, question: str) -> str:
# retrieve 5 nearest documents
vector_store = CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
user_id=self.user_email,
)
return vector_store.similarity_search(
query=question,
)
def _get_answer_from_question(self, question: str) -> OpenAiAnswer:
functions = [get_history_schema, get_context_schema]
model_response = self._get_model_response(
messages=self._get_formatted_prompt(question=question),
functions=functions,
)
return format_answer(model_response)
def _get_answer_from_question_and_history(self, question: str) -> OpenAiAnswer:
logger.info("Using chat history")
functions = [
get_context_schema,
]
model_response = self._get_model_response(
messages=self._get_formatted_prompt(question=question, useHistory=True),
functions=functions,
)
return format_answer(model_response)
def _get_answer_from_question_and_context(self, question: str) -> OpenAiAnswer:
logger.info("Using documents ")
functions = [
get_history_schema,
]
model_response = self._get_model_response(
messages=self._get_formatted_prompt(question=question, useContext=True),
functions=functions,
)
return format_answer(model_response)
def _get_answer_from_question_and_context_and_history(
self, question: str
) -> OpenAiAnswer:
logger.info("Using context and history")
model_response = self._get_model_response(
messages=self._get_formatted_prompt(
question, useContext=True, useHistory=True
),
)
return format_answer(model_response)
def get_answer(self, question: str) -> str:
response = self._get_answer_from_question(question)
function_name = response.function_call.name if response.function_call else None
if function_name == "get_history":
response = self._get_answer_from_question_and_history(question)
elif function_name == "get_context":
response = self._get_answer_from_question_and_context(question)
if response.function_call:
response = self._get_answer_from_question_and_context_and_history(question)
return response.content or ""

View File

@ -0,0 +1,12 @@
from typing import Optional
from typing import Any, Dict # For type hinting
class FunctionCall:
def __init__(
self,
name: Optional[str] = None,
arguments: Optional[Dict[str, Any]] = None,
):
self.name = name
self.arguments = arguments

View File

@ -0,0 +1,12 @@
from typing import Optional
from .FunctionCall import FunctionCall
class OpenAiAnswer:
def __init__(
self,
content: Optional[str] = None,
function_call: FunctionCall = None,
):
self.content = content
self.function_call = function_call

View File

@ -0,0 +1,17 @@
from llm.OpenAiFunctionBasedAnswerGenerator.models.OpenAiAnswer import OpenAiAnswer
from llm.OpenAiFunctionBasedAnswerGenerator.models.FunctionCall import FunctionCall
from typing import Any, Dict # For type hinting
def format_answer(model_response: Dict[str, Any]) -> OpenAiAnswer:
answer = model_response["choices"][0]["message"]
content = answer["content"]
function_call = None
if answer.get("function_call", None) is not None:
function_call = FunctionCall(
answer["function_call"]["name"],
answer["function_call"]["arguments"],
)
return OpenAiAnswer(content=content, function_call=function_call)

View File

@ -1,32 +1,21 @@
import os # A module to interact with the OS
from typing import Any, Dict, List
from typing import Any, Dict
from models.settings import LLMSettings # For type hinting
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.chat_models import ChatOpenAI, ChatVertexAI
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.llms import GPT4All
from langchain.llms.base import LLM
from langchain.memory import ConversationBufferMemory
from langchain.vectorstores import SupabaseVectorStore
from llm.prompt import LANGUAGE_PROMPT
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
from models.chats import (
ChatMessage,
) # Importing a custom ChatMessage class for handling chat messages
from models.settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
from pydantic import BaseSettings
from supabase import Client # For interacting with Supabase database
from supabase import create_client
from repository.chat.get_chat_history import get_chat_history
from vectorstore.supabase import (
CustomSupabaseVectorStore,
) # Custom class for handling vector storage with Supabase
@ -34,6 +23,7 @@ from logger import get_logger
logger = get_logger(__name__)
class AnswerConversationBufferMemory(ConversationBufferMemory):
"""
This class is a specialized version of ConversationBufferMemory.
@ -48,7 +38,7 @@ class AnswerConversationBufferMemory(ConversationBufferMemory):
)
def get_chat_history(inputs) -> str:
def format_chat_history(inputs) -> str:
"""
Function to concatenate chat history into a single string.
:param inputs: List of tuples containing human and AI messages.
@ -76,18 +66,38 @@ class BrainPicking(BaseModel):
llm: LLM = None
question_generator: LLMChain = None
doc_chain: ConversationalRetrievalChain = None
chat_id: str
max_tokens: int = 256
class Config:
# Allowing arbitrary types for class validation
arbitrary_types_allowed = True
def init(self, model: str, user_id: str) -> "BrainPicking":
def __init__(
self,
model: str,
user_id: str,
chat_id: str,
max_tokens: int,
user_openai_api_key: str,
) -> "BrainPicking":
"""
Initialize the BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
:param model: Language model name to be used.
:param user_id: The user id to be used for CustomSupabaseVectorStore.
:return: BrainPicking instance
"""
super().__init__(
model=model,
user_id=user_id,
chat_id=chat_id,
max_tokens=max_tokens,
user_openai_api_key=user_openai_api_key,
)
# If user provided an API key, update the settings
if user_openai_api_key is not None:
self.settings.openai_api_key = user_openai_api_key
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key)
self.supabase_client = create_client(
self.settings.supabase_url, self.settings.supabase_service_key
@ -98,7 +108,7 @@ class BrainPicking(BaseModel):
table_name="vectors",
user_id=user_id,
)
self.llm = self._determine_llm(
private_model_args={
"model_path": self.llm_config.model_path,
@ -112,7 +122,8 @@ class BrainPicking(BaseModel):
llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT
)
self.doc_chain = load_qa_chain(self.llm, chain_type="stuff")
return self
self.chat_id = chat_id
self.max_tokens = max_tokens
def _determine_llm(
self, private_model_args: dict, private: bool = False, model_name: str = None
@ -128,7 +139,7 @@ class BrainPicking(BaseModel):
model_path = private_model_args["model_path"]
model_n_ctx = private_model_args["n_ctx"]
model_n_batch = private_model_args["n_batch"]
logger.info("Using private model: %s", model_path)
return GPT4All(
@ -142,7 +153,7 @@ class BrainPicking(BaseModel):
return ChatOpenAI(temperature=0, model_name=model_name)
def _get_qa(
self, chat_message: ChatMessage, user_openai_api_key
self,
) -> ConversationalRetrievalChain:
"""
Retrieves a QA chain for the given chat message and API key.
@ -150,42 +161,34 @@ class BrainPicking(BaseModel):
:param user_openai_api_key: The OpenAI API key to be used.
:return: ConversationalRetrievalChain instance
"""
# If user provided an API key, update the settings
if user_openai_api_key is not None and user_openai_api_key != "":
self.settings.openai_api_key = user_openai_api_key
# Initialize and return a ConversationalRetrievalChain
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(),
max_tokens_limit=chat_message.max_tokens,
max_tokens_limit=self.max_tokens,
question_generator=self.question_generator,
combine_docs_chain=self.doc_chain,
get_chat_history=get_chat_history,
get_chat_history=format_chat_history,
)
return qa
def generate_answer(self, chat_message: ChatMessage, user_openai_api_key) -> str:
def generate_answer(self, question: str) -> str:
"""
Generate an answer to a given chat message by interacting with the language model.
:param chat_message: The chat message containing history.
:param user_openai_api_key: The OpenAI API key to be used.
Generate an answer to a given question by interacting with the language model.
:param question: The question
:return: The generated answer.
"""
transformed_history = []
# Get the QA chain
qa = self._get_qa(chat_message, user_openai_api_key)
qa = self._get_qa()
history = get_chat_history(self.chat_id)
# Transform the chat history into a list of tuples
for i in range(0, len(chat_message.history) - 1, 2):
user_message = chat_message.history[i][1]
assistant_message = chat_message.history[i + 1][1]
transformed_history.append((user_message, assistant_message))
# Format the chat history into a list of tuples (human, ai)
transformed_history = [(chat.user_message, chat.assistant) for chat in history]
# Generate the model response using the QA chain
model_response = qa(
{"question": chat_message.question, "chat_history": transformed_history}
)
model_response = qa({"question": question, "chat_history": transformed_history})
answer = model_response["answer"]
return answer

View File

@ -6,4 +6,4 @@ Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

View File

@ -9,4 +9,4 @@ Question: {question}
Helpful Answer:"""
QA_PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
)

View File

@ -6,8 +6,7 @@ def get_logger(logger_name, log_level=logging.INFO):
logger.setLevel(log_level)
logger.propagate = False # Prevent log propagation to avoid double logging
formatter = logging.Formatter(
'%(asctime)s [%(levelname)s] %(name)s: %(message)s')
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)

View File

@ -1,7 +1,7 @@
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
import pypandoc
from fastapi import FastAPI
from logger import get_logger
from middlewares.cors import add_cors_middleware
from routes.api_key_routes import api_key_router
@ -37,3 +37,11 @@ app.include_router(upload_router)
app.include_router(user_router)
app.include_router(api_key_router)
app.include_router(stream_router)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail},
)

31
backend/models/chat.py Normal file
View File

@ -0,0 +1,31 @@
from dataclasses import dataclass
@dataclass
class Chat:
chat_id: str
user_id: str
creation_time: str
chat_name: str
def __init__(self, chat_dict: dict):
self.chat_id = chat_dict.get("chat_id")
self.user_id = chat_dict.get("user_id")
self.creation_time = chat_dict.get("creation_time")
self.chat_name = chat_dict.get("chat_name")
@dataclass
class ChatHistory:
chat_id: str
message_id: str
user_message: str
assistant: str
message_time: str
def __init__(self, chat_dict: dict):
self.chat_id = chat_dict.get("chat_id")
self.message_id = chat_dict.get("message_id")
self.user_message = chat_dict.get("user_message")
self.assistant = chat_dict.get("assistant")
self.message_time = chat_dict.get("message_time")

View File

@ -16,5 +16,8 @@ class ChatMessage(BaseModel):
chat_name: Optional[str] = None
class ChatAttributes(BaseModel):
chat_name: Optional[str] = None
class ChatQuestion(BaseModel):
model: str = "gpt-3.5-turbo-0613"
question: str
temperature: float = 0.0
max_tokens: int = 256

View File

@ -1,4 +1,4 @@
from typing import Annotated, Any, Dict, List, Tuple, Union
from typing import Annotated
from fastapi import Depends
from langchain.embeddings.openai import OpenAIEmbeddings
@ -13,26 +13,33 @@ class BrainSettings(BaseSettings):
supabase_url: str
supabase_service_key: str
class LLMSettings(BaseSettings):
private: bool
model_path: str
model_n_ctx: int
model_n_batch: int
def common_dependencies() -> dict:
settings = BrainSettings()
embeddings = OpenAIEmbeddings(openai_api_key=settings.openai_api_key)
supabase_client: Client = create_client(settings.supabase_url, settings.supabase_service_key)
supabase_client: Client = create_client(
settings.supabase_url, settings.supabase_service_key
)
documents_vector_store = SupabaseVectorStore(
supabase_client, embeddings, table_name="vectors")
supabase_client, embeddings, table_name="vectors"
)
summaries_vector_store = SupabaseVectorStore(
supabase_client, embeddings, table_name="summaries")
supabase_client, embeddings, table_name="summaries"
)
return {
"supabase": supabase_client,
"embeddings": embeddings,
"documents_vector_store": documents_vector_store,
"summaries_vector_store": summaries_vector_store
"summaries_vector_store": summaries_vector_store,
}
CommonsDep = Annotated[dict, Depends(common_dependencies)]
CommonsDep = Annotated[dict, Depends(common_dependencies)]

View File

@ -0,0 +1,32 @@
from logger import get_logger
from models.settings import common_dependencies
from dataclasses import dataclass
from models.chat import Chat
logger = get_logger(__name__)
@dataclass
class CreateChatProperties:
name: str
def __init__(self, name: str):
self.name = name
def create_chat(user_id: str, chat_data: CreateChatProperties) -> Chat:
commons = common_dependencies()
# Chat is created upon the user's first question asked
logger.info(f"New chat entry in chats table for user {user_id}")
# Insert a new row into the chats table
new_chat = {
"user_id": user_id,
"chat_name": chat_data.name,
}
insert_response = commons["supabase"].table("chats").insert(new_chat).execute()
logger.info(f"Insert response {insert_response.data}")
return insert_response.data[0]

View File

@ -0,0 +1,15 @@
from models.chat import Chat
from models.settings import common_dependencies
def get_chat_by_id(chat_id: str) -> Chat:
commons = common_dependencies()
response = (
commons["supabase"]
.from_("chats")
.select("*")
.filter("chat_id", "eq", chat_id)
.execute()
)
return Chat(response.data[0])

View File

@ -0,0 +1,19 @@
from models.chat import ChatHistory
from models.settings import common_dependencies
from typing import List # For type hinting
def get_chat_history(chat_id: str) -> List[ChatHistory]:
commons = common_dependencies()
history: List[ChatHistory] = (
commons["supabase"]
.from_("chat_history")
.select("*")
.filter("chat_id", "eq", chat_id)
.order("message_time", desc=False) # Add the ORDER BY clause
.execute()
).data
if history is None:
return []
else:
return [ChatHistory(message) for message in history]

View File

@ -0,0 +1,15 @@
from models.settings import common_dependencies
from models.chat import Chat
def get_user_chats(user_id: str) -> list[Chat]:
commons = common_dependencies()
response = (
commons["supabase"]
.from_("chats")
.select("chat_id,user_id,creation_time,chat_name")
.filter("user_id", "eq", user_id)
.execute()
)
chats = [Chat(chat_dict) for chat_dict in response.data]
return chats

View File

@ -0,0 +1,45 @@
from logger import get_logger
from models.chat import Chat
from typing import Optional
from dataclasses import dataclass
from models.settings import common_dependencies
logger = get_logger(__name__)
@dataclass
class ChatUpdatableProperties:
chat_name: Optional[str] = None
def __init__(self, chat_name: Optional[str]):
self.chat_name = chat_name
def update_chat(chat_id, chat_data: ChatUpdatableProperties) -> Chat:
commons = common_dependencies()
if not chat_id:
logger.error("No chat_id provided")
return
updates = {}
if chat_data.chat_name is not None:
updates["chat_name"] = chat_data.chat_name
updated_chat = None
if updates:
updated_chat = (
commons["supabase"]
.table("chats")
.update(updates)
.match({"chat_id": chat_id})
.execute()
).data[0]
logger.info(f"Chat {chat_id} updated")
else:
logger.info(f"No updates to apply for chat {chat_id}")
return updated_chat

View File

@ -0,0 +1,24 @@
from models.chat import ChatHistory
from models.settings import common_dependencies
from typing import List # For type hinting
def update_chat_history(
chat_id: str, user_message: str, assistant_answer: str
) -> ChatHistory:
commons = common_dependencies()
response: List[ChatHistory] = (
commons["supabase"]
.table("chat_history")
.insert(
{
"chat_id": str(chat_id),
"user_message": user_message,
"assistant": assistant_answer,
}
)
.execute()
).data
if len(response) == 0:
raise Exception("Error while updating chat history")
return response[0]

View File

@ -1,5 +1,5 @@
pymupdf==1.22.3
langchain==0.0.200
langchain==0.0.207
Markdown==3.4.3
openai==0.27.6
pdf2image==1.16.3

View File

@ -1,5 +1,3 @@
import os
import time
from typing import Optional
from uuid import UUID
@ -7,7 +5,7 @@ from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Request
from logger import get_logger
from models.brains import Brain
from models.settings import CommonsDep, common_dependencies
from models.settings import common_dependencies
from models.users import User
from pydantic import BaseModel
from utils.users import fetch_user_id_from_credentials
@ -50,7 +48,7 @@ async def brain_endpoint(current_user: User = Depends(get_current_user)):
@brain_router.get(
"/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"]
)
async def brain_endpoint(brain_id: UUID):
async def get_brain_endpoint(brain_id: UUID):
"""
Retrieve details of a specific brain by brain ID.
@ -76,7 +74,7 @@ async def brain_endpoint(brain_id: UUID):
@brain_router.delete(
"/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"]
)
async def brain_endpoint(brain_id: UUID):
async def delete_brain_endpoint(brain_id: UUID):
"""
Delete a specific brain by brain ID.
"""
@ -97,7 +95,7 @@ class BrainObject(BaseModel):
# create new brain
@brain_router.post("/brains", dependencies=[Depends(AuthBearer())], tags=["Brain"])
async def brain_endpoint(
async def create_brain_endpoint(
request: Request,
brain: BrainObject,
current_user: User = Depends(get_current_user),
@ -125,7 +123,7 @@ async def brain_endpoint(
@brain_router.put(
"/brains/{brain_id}", dependencies=[Depends(AuthBearer())], tags=["Brain"]
)
async def brain_endpoint(
async def update_brain_endpoint(
request: Request,
brain_id: UUID,
input_brain: Brain,

View File

@ -1,32 +1,37 @@
import os
import time
from uuid import UUID
from models.chat import ChatHistory
from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Request
from llm.brainpicking import BrainPicking
from models.chats import ChatMessage, ChatAttributes
from models.settings import CommonsDep, common_dependencies
from fastapi import HTTPException
from models.chat import Chat
from models.chats import ChatQuestion
from models.settings import common_dependencies
from models.users import User
from utils.chats import (create_chat, get_chat_name_from_first_question,
update_chat)
from utils.users import (create_user, fetch_user_id_from_credentials,
update_user_request_count)
from http.client import HTTPException
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat import update_chat, ChatUpdatableProperties
from repository.chat.create_chat import create_chat, CreateChatProperties
from utils.users import (
fetch_user_id_from_credentials,
update_user_request_count,
)
from repository.chat.get_chat_by_id import get_chat_by_id
from repository.chat.get_user_chats import get_user_chats
from repository.chat.update_chat_history import update_chat_history
from llm.OpenAiFunctionBasedAnswerGenerator.OpenAiFunctionBasedAnswerGenerator import (
OpenAiFunctionBasedAnswerGenerator,
)
chat_router = APIRouter()
def get_user_chats(commons, user_id):
response = (
commons["supabase"]
.from_("chats")
.select("chatId:chat_id, chatName:chat_name")
.filter("user_id", "eq", user_id)
.execute()
)
return response.data
def get_chat_details(commons, chat_id):
response = (
commons["supabase"]
@ -69,32 +74,14 @@ async def get_chats(current_user: User = Depends(get_current_user)):
"""
commons = common_dependencies()
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})
chats = get_user_chats(commons, user_id)
chats = get_user_chats(user_id)
return {"chats": chats}
# get one chat
@chat_router.get("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
async def get_chat_handler(chat_id: UUID):
"""
Retrieve details of a specific chat by chat ID.
- `chat_id`: The ID of the chat to retrieve details for.
- Returns the chat ID and its history.
This endpoint retrieves the details of a specific chat identified by the provided chat ID. It returns the chat ID and its
history, which includes the chat messages exchanged in the chat.
"""
commons = common_dependencies()
chats = get_chat_details(commons, chat_id)
if len(chats) > 0:
return {"chatId": chat_id, "history": chats[0]["history"]}
else:
return {"error": "Chat not found"}
# delete one chat
@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
@chat_router.delete(
"/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"]
)
async def delete_chat(chat_id: UUID):
"""
Delete a specific chat by chat ID.
@ -104,91 +91,118 @@ async def delete_chat(chat_id: UUID):
return {"message": f"{chat_id} has been deleted."}
# helper method for update and create chat
def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=False):
date = time.strftime("%Y%m%d")
user_id = fetch_user_id_from_credentials(commons, {"email": email})
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
user_openai_api_key = request.headers.get("Openai-Api-Key")
userItem = fetch_user_stats(commons, User(email=email), date)
old_request_count = userItem["requests_count"]
history = chat_message.history
history.append(("user", chat_message.question))
chat_name = chat_message.chat_name
if old_request_count == 0:
create_user(commons, email=email, date=date)
else:
update_user_request_count(
commons, email, date, requests_count=old_request_count + 1
)
if user_openai_api_key is None and old_request_count >= float(max_requests_number):
history.append(("assistant", "You have reached your requests limit"))
update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name)
return {"history": history}
brainPicking = BrainPicking().init(chat_message.model, email)
answer = brainPicking.generate_answer(chat_message, user_openai_api_key)
history.append(("assistant", answer))
if is_new_chat:
chat_name = get_chat_name_from_first_question(chat_message)
new_chat = create_chat(commons, user_id, history, chat_name)
chat_id = new_chat.data[0]["chat_id"]
else:
update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name)
return {"history": history, "chatId": chat_id}
# update existing chat
@chat_router.put("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
async def chat_endpoint(
request: Request,
commons: CommonsDep,
chat_id: UUID,
chat_message: ChatMessage,
current_user: User = Depends(get_current_user),
):
"""
Update an existing chat with new chat messages.
"""
return chat_handler(request, commons, chat_id, chat_message, current_user.email)
# update existing chat
@chat_router.put("/chat/{chat_id}/metadata", dependencies=[Depends(AuthBearer())], tags=["Chat"])
async def update_chat_attributes_handler(
commons: CommonsDep,
chat_message: ChatAttributes,
# update existing chat metadata
@chat_router.put(
"/chat/{chat_id}/metadata", dependencies=[Depends(AuthBearer())], tags=["Chat"]
)
async def update_chat_metadata_handler(
chat_data: ChatUpdatableProperties,
chat_id: UUID,
current_user: User = Depends(get_current_user),
):
) -> Chat:
"""
Update chat attributes
"""
commons = common_dependencies()
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})
chat = get_chat_details(commons, chat_id)[0]
if user_id != chat.get('user_id'):
chat = get_chat_by_id(chat_id)
if user_id != chat.user_id:
raise HTTPException(status_code=403, detail="Chat not owned by user")
return update_chat(commons=commons, chat_id=chat_id, chat_name=chat_message.chat_name)
return update_chat(chat_id=chat_id, chat_data=chat_data)
# helper method for update and create chat
def check_user_limit(
email,
):
date = time.strftime("%Y%m%d")
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
commons = common_dependencies()
userItem = fetch_user_stats(commons, User(email=email), date)
old_request_count = userItem["requests_count"]
update_user_request_count(
commons, email, date, requests_count=old_request_count + 1
)
if old_request_count >= float(max_requests_number):
raise HTTPException(
status_code=429,
detail="You have reached the maximum number of requests for today.",
)
# create new chat
@chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
async def create_chat_handler(
request: Request,
commons: CommonsDep,
chat_message: ChatMessage,
chat_data: CreateChatProperties,
current_user: User = Depends(get_current_user),
):
"""
Create a new chat with initial chat messages.
"""
return chat_handler(
request, commons, None, chat_message, current_user.email, is_new_chat=True
)
commons = common_dependencies()
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})
return create_chat(user_id=user_id, chat_data=chat_data)
# add new question to chat
@chat_router.post(
"/chat/{chat_id}/question", dependencies=[Depends(AuthBearer())], tags=["Chat"]
)
async def create_question_handler(
request: Request,
chat_question: ChatQuestion,
chat_id: UUID,
current_user: User = Depends(get_current_user),
) -> ChatHistory:
try:
check_user_limit(current_user.email)
user_openai_api_key = request.headers.get("Openai-Api-Key")
openai_function_compatible_models = [
"gpt-3.5-turbo-0613",
"gpt-4-0613",
]
if chat_question.model in openai_function_compatible_models:
# TODO: RBAC with current_user
gpt_answer_generator = OpenAiFunctionBasedAnswerGenerator(
model=chat_question.model,
chat_id=chat_id,
temperature=chat_question.temperature,
max_tokens=chat_question.max_tokens,
# TODO: use user_id in vectors table instead of email
user_email=current_user.email,
user_openai_api_key=user_openai_api_key,
)
answer = gpt_answer_generator.get_answer(chat_question.question)
else:
brainPicking = BrainPicking(
chat_id=str(chat_id),
model=chat_question.model,
max_tokens=chat_question.max_tokens,
user_id=current_user.email,
user_openai_api_key=user_openai_api_key,
)
answer = brainPicking.generate_answer(chat_question.question)
chat_answer = update_chat_history(
chat_id=chat_id,
user_message=chat_question.question,
assistant_answer=answer,
)
return chat_answer
except HTTPException as e:
raise e
# get chat history
@chat_router.get(
"/chat/{chat_id}/history", dependencies=[Depends(AuthBearer())], tags=["Chat"]
)
async def get_chat_history_handler(
chat_id: UUID,
current_user: User = Depends(get_current_user),
) -> list[ChatHistory]:
# TODO: RBAC with current_user
return get_chat_history(chat_id)

View File

@ -5,49 +5,73 @@ from models.users import User
explore_router = APIRouter()
def get_unique_user_data(commons, user):
"""
Retrieve unique user data vectors.
"""
response = commons['supabase'].table("vectors").select(
"name:metadata->>file_name, size:metadata->>file_size", count="exact").filter("user_id", "eq", user.email).execute()
response = (
commons["supabase"]
.table("vectors")
.select("name:metadata->>file_name, size:metadata->>file_size", count="exact")
.filter("user_id", "eq", user.email)
.execute()
)
documents = response.data # Access the data from the response
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
unique_data = [dict(t) for t in set(tuple(d.items()) for d in documents)]
return unique_data
@explore_router.get("/explore", dependencies=[Depends(AuthBearer())], tags=["Explore"])
async def explore_endpoint( current_user: User = Depends(get_current_user)):
async def explore_endpoint(current_user: User = Depends(get_current_user)):
"""
Retrieve and explore unique user data vectors.
"""
commons = common_dependencies()
unique_data = get_unique_user_data(commons, current_user)
unique_data.sort(key=lambda x: int(x['size']), reverse=True)
unique_data.sort(key=lambda x: int(x["size"]), reverse=True)
return {"documents": unique_data}
@explore_router.delete("/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"])
async def delete_endpoint( file_name: str, credentials: dict = Depends(AuthBearer())):
@explore_router.delete(
"/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"]
)
async def delete_endpoint(file_name: str, credentials: dict = Depends(AuthBearer())):
"""
Delete a specific user file by file name.
"""
commons = common_dependencies()
user = User(email=credentials.get('email', 'none'))
user = User(email=credentials.get("email", "none"))
# Cascade delete the summary from the database first, because it has a foreign key constraint
commons['supabase'].table("summaries").delete().match(
{"metadata->>file_name": file_name}).execute()
commons['supabase'].table("vectors").delete().match(
{"metadata->>file_name": file_name, "user_id": user.email}).execute()
commons["supabase"].table("summaries").delete().match(
{"metadata->>file_name": file_name}
).execute()
commons["supabase"].table("vectors").delete().match(
{"metadata->>file_name": file_name, "user_id": user.email}
).execute()
return {"message": f"{file_name} of user {user.email} has been deleted."}
@explore_router.get("/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"])
async def download_endpoint( file_name: str, current_user: User = Depends(get_current_user)):
@explore_router.get(
"/explore/{file_name}", dependencies=[Depends(AuthBearer())], tags=["Explore"]
)
async def download_endpoint(
file_name: str, current_user: User = Depends(get_current_user)
):
"""
Download a specific user file by file name.
"""
commons = common_dependencies()
response = commons['supabase'].table("vectors").select(
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url", "content").match({"metadata->>file_name": file_name, "user_id": current_user.email}).execute()
response = (
commons["supabase"]
.table("vectors")
.select(
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url",
"content",
)
.match({"metadata->>file_name": file_name, "user_id": current_user.email})
.execute()
)
documents = response.data
return {"documents": documents}

View File

@ -1,48 +1,10 @@
from logger import get_logger
from models.chats import ChatMessage
from models.settings import CommonsDep
from models.settings import common_dependencies
logger = get_logger(__name__)
def create_chat(commons: CommonsDep, user_id, history, chat_name):
# Chat is created upon the user's first question asked
logger.info(f"New chat entry in chats table for user {user_id}")
# Insert a new row into the chats table
new_chat = {
"user_id": user_id,
"history": history, # Empty chat to start
"chat_name": chat_name,
}
insert_response = commons["supabase"].table("chats").insert(new_chat).execute()
logger.info(f"Insert response {insert_response.data}")
return insert_response
def update_chat(commons: CommonsDep, chat_id, history=None, chat_name=None):
if not chat_id:
logger.error("No chat_id provided")
return
updates = {}
if history is not None:
updates["history"] = history
if chat_name is not None:
updates["chat_name"] = chat_name
if updates:
commons["supabase"].table("chats").update(updates).match(
{"chat_id": chat_id}
).execute()
logger.info(f"Chat {chat_id} updated")
else:
logger.info(f"No updates to apply for chat {chat_id}")
def get_chat_name_from_first_question(chat_message: ChatMessage):
# Step 1: Get the summary of the first question
# first_question_summary = summarize_as_title(chat_message.question)

View File

@ -6,32 +6,45 @@ from models.users import User
logger = get_logger(__name__)
def create_user(commons: CommonsDep, email, date):
logger.info(f"New user entry in db document for user {email}")
return(commons['supabase'].table("users").insert(
{"email": email, "date": date, "requests_count": 1}).execute())
return (
commons["supabase"]
.table("users")
.insert({"email": email, "date": date, "requests_count": 1})
.execute()
)
def update_user_request_count(commons: CommonsDep, email, date, requests_count):
logger.info(f"User {email} request count updated to {requests_count}")
commons['supabase'].table("users").update(
{ "requests_count": requests_count}).match({"email": email, "date": date}).execute()
def fetch_user_id_from_credentials(commons: CommonsDep,credentials):
user = User(email=credentials.get('email', 'none'))
commons["supabase"].table("users").update({"requests_count": requests_count}).match(
{"email": email, "date": date}
).execute()
def fetch_user_id_from_credentials(commons: CommonsDep, credentials):
user = User(email=credentials.get("email", "none"))
# Fetch the user's UUID based on their email
response = commons['supabase'].from_('users').select('user_id').filter("email", "eq", user.email).execute()
response = (
commons["supabase"]
.from_("users")
.select("user_id")
.filter("email", "eq", user.email)
.execute()
)
userItem = next(iter(response.data or []), {})
if userItem == {}:
if userItem == {}:
date = time.strftime("%Y%m%d")
create_user_response = create_user(commons, email= user.email, date=date)
user_id = create_user_response.data[0]['user_id']
create_user_response = create_user(commons, email=user.email, date=date)
user_id = create_user_response.data[0]["user_id"]
else:
user_id = userItem['user_id']
else:
user_id = userItem["user_id"]
return user_id

View File

@ -9,32 +9,44 @@ from pydantic import BaseModel
logger = get_logger(__name__)
class Neurons(BaseModel):
class Neurons(BaseModel):
commons: CommonsDep
settings = BrainSettings()
def create_vector(self, user_id, doc, user_openai_api_key=None):
logger.info(f"Creating vector for document")
logger.info(f"Document: {doc}")
if user_openai_api_key:
self.commons['documents_vector_store']._embedding = OpenAIEmbeddings(openai_api_key=user_openai_api_key)
self.commons["documents_vector_store"]._embedding = OpenAIEmbeddings(
openai_api_key=user_openai_api_key
)
try:
sids = self.commons['documents_vector_store'].add_documents([doc])
sids = self.commons["documents_vector_store"].add_documents([doc])
if sids and len(sids) > 0:
self.commons['supabase'].table("vectors").update({"user_id": user_id}).match({"id": sids[0]}).execute()
self.commons["supabase"].table("vectors").update(
{"user_id": user_id}
).match({"id": sids[0]}).execute()
except Exception as e:
logger.error(f"Error creating vector for document {e}")
def create_embedding(self, content):
return self.commons['embeddings'].embed_query(content)
return self.commons["embeddings"].embed_query(content)
def similarity_search(self, query, table='match_summaries', top_k=5, threshold=0.5):
def similarity_search(self, query, table="match_summaries", top_k=5, threshold=0.5):
query_embedding = self.create_embedding(query)
summaries = self.commons['supabase'].rpc(
table, {'query_embedding': query_embedding,
'match_count': top_k, 'match_threshold': threshold}
).execute()
summaries = (
self.commons["supabase"]
.rpc(
table,
{
"query_embedding": query_embedding,
"match_count": top_k,
"match_threshold": threshold,
},
)
.execute()
)
return summaries.data
@ -42,11 +54,10 @@ def create_summary(commons: CommonsDep, document_id, content, metadata):
logger.info(f"Summarizing document {content[:100]}")
summary = llm_summerize(content)
logger.info(f"Summary: {summary}")
metadata['document_id'] = document_id
summary_doc_with_metadata = Document(
page_content=summary, metadata=metadata)
sids = commons['summaries_vector_store'].add_documents(
[summary_doc_with_metadata])
metadata["document_id"] = document_id
summary_doc_with_metadata = Document(page_content=summary, metadata=metadata)
sids = commons["summaries_vector_store"].add_documents([summary_doc_with_metadata])
if sids and len(sids) > 0:
commons['supabase'].table("summaries").update(
{"document_id": document_id}).match({"id": sids[0]}).execute()
commons["supabase"].table("summaries").update(
{"document_id": document_id}
).match({"id": sids[0]}).execute()

View File

@ -7,18 +7,26 @@ from supabase import Client
class CustomSupabaseVectorStore(SupabaseVectorStore):
'''A custom vector store that uses the match_vectors table instead of the vectors table.'''
"""A custom vector store that uses the match_vectors table instead of the vectors table."""
user_id: str
def __init__(self, client: Client, embedding: OpenAIEmbeddings, table_name: str, user_id: str = "none"):
def __init__(
self,
client: Client,
embedding: OpenAIEmbeddings,
table_name: str,
user_id: str = "none",
):
super().__init__(client, embedding, table_name)
self.user_id = user_id
def similarity_search(
self,
query: str,
table: str = "match_vectors",
k: int = 6,
threshold: float = 0.5,
self,
query: str,
table: str = "match_vectors",
k: int = 6,
threshold: float = 0.5,
**kwargs: Any
) -> List[Document]:
vectors = self._embedding.embed_documents([query])
@ -46,4 +54,4 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
documents = [doc for doc, _ in match_result]
return documents
return documents

View File

@ -0,0 +1,51 @@
---
sidebar_position: 2
---
# Chat system
**URL**: https://api.quivr.app/chat
**Swagger**: https://api.quivr.app/docs
## Overview
Users can create multiple chat sessions, each with its own set of chat messages. The application provides endpoints to perform various operations such as retrieving all chats for the current user, deleting specific chats, updating chat attributes, creating new chats with initial messages, adding new questions to existing chats, and retrieving the chat history. These features enable users to communicate and interact with their data in a conversational manner.
1. **Retrieve all chats for the current user:**
- HTTP method: GET
- Endpoint: `/chat`
- Description: This endpoint retrieves all the chats associated with the current authenticated user. It returns a list of chat objects containing the chat ID and chat name for each chat.
2. **Delete a specific chat by chat ID:**
- HTTP method: DELETE
- Endpoint: `/chat/{chat_id}`
- Description: This endpoint allows deleting a specific chat identified by its chat ID.
3. **Update chat attributes:**
- HTTP method: PUT
- Endpoint: `/chat/{chat_id}/metadata`
- Description: This endpoint is used to update the attributes of a chat, such as the chat name.
4. **Create a new chat with initial chat messages:**
- HTTP method: POST
- Endpoint: `/chat`
- Description: This endpoint creates a new chat with initial chat messages. It expects the chat name in the request payload.
5. **Add a new question to a chat:**
- HTTP method: POST
- Endpoint: `/chat/{chat_id}/question`
- Description: This endpoint allows adding a new question to a chat. It generates an answer for the question using different models based on the provided model type.
Models like gpt-4-0613 and gpt-3.5-turbo-0613 use a custom OpenAI function-based answer generator.
![Function based answer generator](../../../static/img/answer_schema.png)
6. **Get the chat history:**
- HTTP method: GET
- Endpoint: `/chat/{chat_id}/history`
- Description: This endpoint retrieves the chat history for a specific chat identified by its chat ID.

BIN
docs/static/img/answer_schema.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

View File

@ -0,0 +1,48 @@
"use client";
import { createContext, useContext, useState } from "react";
import { ChatHistory } from "../types";
type ChatContextProps = {
history: ChatHistory[];
setHistory: (history: ChatHistory[]) => void;
addToHistory: (message: ChatHistory) => void;
};
export const ChatContext = createContext<ChatContextProps | undefined>(
undefined
);
export const ChatProvider = ({
children,
}: {
children: JSX.Element | JSX.Element[];
}): JSX.Element => {
const [history, setHistory] = useState<ChatHistory[]>([]);
const addToHistory = (message: ChatHistory) => {
setHistory((prevHistory) => [...prevHistory, message]);
};
return (
<ChatContext.Provider
value={{
history,
setHistory,
addToHistory,
}}
>
{children}
</ChatContext.Provider>
);
};
export const useChatContext = (): ChatContextProps => {
const context = useContext(ChatContext);
if (context === undefined) {
throw new Error("useChatContext must be used inside ChatProvider");
}
return context;
};

View File

@ -0,0 +1,89 @@
import { AxiosError } from "axios";
import { useParams } from "next/navigation";
import { useEffect, useState } from "react";
import { useBrainConfig } from "@/lib/context/BrainConfigProvider/hooks/useBrainConfig";
import { useToast } from "@/lib/hooks";
import { useChatService } from "./useChatService";
import { useChatContext } from "../context/ChatContext";
import { ChatQuestion } from "../types";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useChat = () => {
const params = useParams();
const [chatId, setChatId] = useState<string | undefined>(
params?.chatId as string | undefined
);
const [generatingAnswer, setGeneratingAnswer] = useState(false);
const {
config: { maxTokens, model, temperature },
} = useBrainConfig();
const { history, setHistory, addToHistory } = useChatContext();
const { publish } = useToast();
const {
createChat,
getChatHistory,
addQuestion: addQuestionToChat,
} = useChatService();
useEffect(() => {
const fetchHistory = async () => {
const chatHistory = await getChatHistory(chatId);
setHistory(chatHistory);
};
void fetchHistory();
}, [chatId]);
const generateNewChatIdFromName = async (
chatName: string
): Promise<string> => {
const rep = await createChat({ name: chatName });
setChatId(rep.data.chat_id);
return rep.data.chat_id;
};
const addQuestion = async (question: string, callback?: () => void) => {
const chatQuestion: ChatQuestion = {
model,
question,
temperature,
max_tokens: maxTokens,
};
try {
setGeneratingAnswer(true);
const currentChatId =
chatId ??
// if chatId is undefined, we need to create a new chat on fly
(await generateNewChatIdFromName(
question.split(" ").slice(0, 3).join(" ")
));
const answer = await addQuestionToChat(currentChatId, chatQuestion);
addToHistory(answer);
callback?.();
} catch (error) {
console.error({ error });
if ((error as AxiosError).response?.status === 429) {
publish({
variant: "danger",
text: "You have reached the limit of requests, please try again later",
});
return;
}
publish({
variant: "danger",
text: "Error occurred while getting answer",
});
} finally {
setGeneratingAnswer(false);
}
};
return { history, addQuestion, generatingAnswer };
};

View File

@ -0,0 +1,40 @@
import { useAxios } from "@/lib/hooks";
import { ChatEntity, ChatHistory, ChatQuestion } from "../types";
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useChatService = () => {
const { axiosInstance } = useAxios();
const createChat = async ({ name }: { name: string }) => {
return axiosInstance.post<ChatEntity>(`/chat`, { name });
};
const getChatHistory = async (chatId: string | undefined) => {
if (chatId === undefined) {
return [];
}
const rep = (
await axiosInstance.get<ChatHistory[]>(`/chat/${chatId}/history`)
).data;
return rep;
};
const addQuestion = async (
chatId: string,
chatQuestion: ChatQuestion
): Promise<ChatHistory> => {
return (
await axiosInstance.post<ChatHistory>(
`/chat/${chatId}/question`,
chatQuestion
)
).data;
};
return {
createChat,
getChatHistory,
addQuestion,
};
};

View File

@ -1,31 +1,12 @@
/* eslint-disable */
"use client";
import { UUID } from "crypto";
import { useEffect } from "react";
import PageHeading from "@/lib/components/ui/PageHeading";
import useChatsContext from "@/lib/context/ChatsProvider/hooks/useChatsContext";
import { ChatInput, ChatMessages } from "../components";
import { ChatProvider } from "./context/ChatContext";
interface ChatPageProps {
params: {
chatId: UUID;
};
}
export default function ChatPage({ params }: ChatPageProps) {
const chatId: UUID | undefined = params.chatId;
const { fetchChat, resetChat } = useChatsContext();
useEffect(() => {
if (!chatId) {
resetChat();
}
fetchChat(chatId);
}, []);
export default function ChatPage() {
return (
<main className="flex flex-col w-full pt-10">
<section className="flex flex-col flex-1 items-center w-full h-full min-h-screen">
@ -33,12 +14,14 @@ export default function ChatPage({ params }: ChatPageProps) {
title="Chat with your brain"
subtitle="Talk to a language model about your uploaded data"
/>
<div className="relative h-full w-full flex flex-col flex-1 items-center">
<div className="h-full flex-1 w-full flex flex-col items-center">
<ChatMessages />
<ChatProvider>
<div className="relative h-full w-full flex flex-col flex-1 items-center">
<div className="h-full flex-1 w-full flex flex-col items-center">
<ChatMessages />
</div>
<ChatInput />
</div>
<ChatInput />
</div>
</ChatProvider>
</section>
</main>
);

View File

@ -0,0 +1,22 @@
import { UUID } from "crypto";
export type ChatQuestion = {
model: string;
question?: string;
temperature: number;
max_tokens: number;
};
export type ChatHistory = {
chat_id: string;
message_id: string;
user_message: string;
assistant: string;
message_time: string;
};
export type ChatEntity = {
chat_id: UUID;
user_id: string;
creation_time: string;
chat_name: string;
};

View File

@ -5,8 +5,14 @@ import { MdMic, MdMicOff } from "react-icons/md";
import Button from "@/lib/components/ui/Button";
import { useSpeech } from "@/lib/context/ChatsProvider/hooks/useSpeech";
export const MicButton = (): JSX.Element => {
const { isListening, speechSupported, startListening } = useSpeech();
type MicButtonProps = {
setMessage: (newValue: string | ((prevValue: string) => string)) => void;
};
export const MicButton = ({ setMessage }: MicButtonProps): JSX.Element => {
const { isListening, speechSupported, startListening } = useSpeech({
setMessage,
});
return (
<Button

View File

@ -1,35 +1,41 @@
/* eslint-disable */
"use client";
import Button from "@/lib/components/ui/Button";
import useChatsContext from "@/lib/context/ChatsProvider/hooks/useChatsContext";
import { useChat } from "@/app/chat/[chatId]/hooks/useChat";
import { useState } from "react";
import { ConfigButton } from "./ConfigButton";
import { MicButton } from "./MicButton";
export const ChatInput = (): JSX.Element => {
const { isSendingMessage, sendMessage, setMessage, message, chat } =
useChatsContext();
const [message, setMessage] = useState<string>(""); // for optimistic updates
const { addQuestion, generatingAnswer } = useChat();
const submitQuestion = () => {
addQuestion(message, () => setMessage(""));
};
return (
<form
onSubmit={(e) => {
e.preventDefault();
if (!isSendingMessage) {
sendMessage(chat?.chatId);
if (!generatingAnswer) {
submitQuestion();
}
}}
className="sticky bottom-0 p-5 bg-white dark:bg-black rounded-t-md border border-black/10 dark:border-white/25 border-b-0 w-full max-w-3xl flex items-center justify-center gap-2 z-20"
>
<textarea
autoFocus
value={message[1]}
onChange={(e) => setMessage((msg) => [msg[0], e.target.value])}
value={message}
required
onChange={(e) => setMessage(e.target.value)}
onKeyDown={(e) => {
if (message.length === 0) return;
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault(); // Prevents the newline from being entered in the textarea
if (!isSendingMessage) {
sendMessage(chat?.chatId);
} // Call the submit function here
if (!generatingAnswer) {
submitQuestion();
}
}
}}
className="w-full p-2 border border-gray-300 dark:border-gray-500 outline-none rounded dark:bg-gray-800"
@ -38,12 +44,12 @@ export const ChatInput = (): JSX.Element => {
<Button
className="px-3 py-2 sm:px-4 sm:py-2"
type="submit"
isLoading={isSendingMessage}
isLoading={generatingAnswer}
>
{isSendingMessage ? "Thinking..." : "Chat"}
{generatingAnswer ? "Thinking..." : "Chat"}
</Button>
<div className="flex items-center">
<MicButton />
<MicButton setMessage={setMessage} />
<ConfigButton />
</div>
</form>

View File

@ -1,52 +1,52 @@
/* eslint-disable */
"use client";
import { useEffect, useRef } from "react";
import { useCallback, useEffect, useRef } from "react";
import Card from "@/lib/components/ui/Card";
import useChatsContext from "@/lib/context/ChatsProvider/hooks/useChatsContext";
import { useChat } from "../../[chatId]/hooks/useChat";
import { ChatMessage } from "./ChatMessage";
export const ChatMessages = (): JSX.Element => {
const lastChatRef = useRef<HTMLDivElement | null>(null);
const { history } = useChat();
const { chat } = useChatsContext();
const scrollToBottom = useCallback(() => {
if (lastChatRef.current) {
lastChatRef.current.scrollIntoView({
behavior: "smooth",
block: "start",
});
}
}, []);
useEffect(() => {
if (!chat || !lastChatRef.current) {
return;
}
// if (chat.history.length > 2) {
lastChatRef.current.scrollIntoView({
behavior: "smooth",
block: "end",
});
// }
}, [chat, lastChatRef]);
if (!chat) {
return <></>;
}
scrollToBottom();
}, [history, scrollToBottom]);
return (
<Card className="p-5 max-w-3xl w-full flex flex-col h-full mb-8">
<div className="flex-1">
{chat.history.length === 0 ? (
{history.length === 0 ? (
<div className="text-center opacity-50">
Ask a question, or describe a task.
</div>
) : (
chat.history.map(([speaker, text], idx) => {
return (
history.map(({ assistant, message_id, user_message }, idx) => (
<>
<ChatMessage
ref={idx === chat.history.length - 1 ? lastChatRef : null}
key={idx}
speaker={speaker}
text={text}
key={message_id}
speaker={"user"}
text={user_message}
/>
);
})
<ChatMessage
key={message_id}
speaker={"assistant"}
text={assistant}
/>
</>
))
)}
<div ref={lastChatRef} />
</div>
</Card>
);

View File

@ -5,15 +5,14 @@ import { useState } from "react";
import { FiEdit, FiSave, FiTrash2 } from "react-icons/fi";
import { MdChatBubbleOutline } from "react-icons/md";
import { ChatEntity } from "@/app/chat/[chatId]/types";
import { useAxios, useToast } from "@/lib/hooks";
import { Chat, ChatResponse } from "@/lib/types/Chat";
import { cn } from "@/lib/utils";
import { ChatName } from "./components/ChatName";
interface ChatsListItemProps {
chat: Chat;
chat: ChatEntity;
deleteChat: (id: UUID) => void;
}
@ -21,32 +20,31 @@ export const ChatsListItem = ({
chat,
deleteChat,
}: ChatsListItemProps): JSX.Element => {
console.log({ chat });
const pathname = usePathname()?.split("/").at(-1);
const selected = chat.chatId === pathname;
const [chatName, setChatName] = useState(chat.chatName);
const selected = chat.chat_id === pathname;
const [chatName, setChatName] = useState(chat.chat_name);
const { axiosInstance } = useAxios();
const {publish} = useToast()
const { publish } = useToast();
const [editingName, setEditingName] = useState(false);
const updateChatName = async () => {
if(chatName !== chat.chatName) {
await axiosInstance.put<ChatResponse>(`/chat/${chat.chatId}/metadata`, {
chat_name:chatName,
});
publish({text:'Chat name updated',variant:'success'})
const updateChatName = async () => {
if (chatName !== chat.chat_name) {
await axiosInstance.put<ChatEntity>(`/chat/${chat.chat_id}/metadata`, {
chat_name: chatName,
});
publish({ text: "Chat name updated", variant: "success" });
}
}
};
const handleEditNameClick = () => {
if(editingName){
setEditingName(false) ;
void updateChatName()
if (editingName) {
setEditingName(false);
void updateChatName();
} else {
setEditingName(true);
}
else {
setEditingName(true)
}
}
};
return (
<div
@ -59,35 +57,32 @@ export const ChatsListItem = ({
>
<Link
className="flex flex-col flex-1 min-w-0 p-4"
href={`/chat/${chat.chatId}`}
key={chat.chatId}
href={`/chat/${chat.chat_id}`}
key={chat.chat_id}
>
<div className="flex items-center gap-2">
<MdChatBubbleOutline className="text-xl" />
<ChatName setName={setChatName} editing={editingName} name={chatName} />
<ChatName
setName={setChatName}
editing={editingName}
name={chatName}
/>
</div>
<div className="grid-cols-2 text-xs opacity-50 whitespace-nowrap">
{chat.chatId}
{chat.chat_id}
</div>
</Link>
<div className="opacity-0 group-hover:opacity-100 flex items-center justify-center hover:text-red-700 bg-gradient-to-l from-white dark:from-black to-transparent z-10 transition-opacity">
<button
className="p-0"
type="button"
onClick={handleEditNameClick
}
>
{editingName ? <FiSave/> : <FiEdit />}
<button className="p-0" type="button" onClick={handleEditNameClick}>
{editingName ? <FiSave /> : <FiEdit />}
</button>
<button
className="p-5"
type="button"
onClick={() => deleteChat(chat.chatId)}
onClick={() => deleteChat(chat.chat_id)}
>
<FiTrash2 />
</button>
</div>
{/* Fade to white */}

View File

@ -1,18 +1,23 @@
interface ChatNameProps {
name: string;
editing?: boolean;
setName: (name:string) => void;
setName: (name: string) => void;
}
export const ChatName = ({setName,name,editing=false}:ChatNameProps):JSX.Element => {
if(editing) {
return <input onChange={(event) => setName(event.target.value)} autoFocus value={name} />
}
export const ChatName = ({
setName,
name,
editing = false,
}: ChatNameProps): JSX.Element => {
if (editing) {
return (
<p>{name}</p>
)
}
<input
onChange={(event) => setName(event.target.value)}
autoFocus
value={name}
/>
);
}
return <p>{name}</p>;
};

View File

@ -47,7 +47,7 @@ export const ChatsList = (): JSX.Element => {
<div className="flex flex-col gap-0">
{allChats.map((chat) => (
<ChatsListItem
key={chat.chatId}
key={chat.chat_id}
chat={chat}
deleteChat={deleteChat}
/>

View File

@ -3,22 +3,16 @@ import { UUID } from "crypto";
import { useRouter } from "next/navigation";
import { useEffect, useState } from "react";
import { useBrainConfig } from "@/lib/context/BrainConfigProvider/hooks/useBrainConfig";
import { useAxios } from "@/lib/hooks";
import { useToast } from "@/lib/hooks/useToast";
import { Chat, ChatMessage } from "../../../types/Chat";
import { ChatEntity } from "@/app/chat/[chatId]/types";
export default function useChats() {
const [allChats, setAllChats] = useState<Chat[]>([]);
const [chat, setChat] = useState<Chat | null>(null);
const [isSendingMessage, setIsSendingMessage] = useState(false);
const [message, setMessage] = useState<ChatMessage>(["", ""]); // for optimistic updates
const [allChats, setAllChats] = useState<ChatEntity[]>([]);
const { axiosInstance } = useAxios();
const {
config: { maxTokens, model, temperature },
} = useBrainConfig();
const router = useRouter();
const { publish } = useToast();
@ -26,7 +20,7 @@ export default function useChats() {
try {
console.log("Fetching all chats");
const response = await axiosInstance.get<{
chats: Chat[];
chats: ChatEntity[];
}>(`/chat`);
setAllChats(response.data.chats);
console.log("Fetched all chats");
@ -39,102 +33,10 @@ export default function useChats() {
}
};
const fetchChat = async (chatId?: UUID) => {
if (!chatId) {
return;
}
try {
console.log(`Fetching chat ${chatId}`);
const response = await axiosInstance.get<Chat>(`/chat/${chatId}`);
console.log(response.data);
setChat(response.data);
} catch (error) {
console.error(error);
publish({
variant: "danger",
text: `Error occured while fetching ${chatId}`,
});
}
};
type ChatResponse = Omit<Chat, "chatId"> & { chatId: UUID | undefined };
const createChat = async ({
options,
}: {
options: Record<string, string | unknown>;
}) => {
await fetchAllChats();
return axiosInstance.post<ChatResponse>(`/chat`, options);
};
const updateChat = ({
options,
}: {
options: Record<string, string | unknown>;
}) => {
return axiosInstance.put<ChatResponse>(`/chat/${options.chat_id}`, options);
};
const sendMessage = async (chatId?: UUID, msg?: ChatMessage) => {
setIsSendingMessage(true);
if (msg) {
setMessage(msg);
}
const options: Record<string, unknown> = {
chat_id: chatId,
model,
question: msg ? msg[1] : message[1],
history: chat ? chat.history : [],
temperature,
max_tokens: maxTokens,
use_summarization: false,
};
const response = await (chatId !== undefined
? updateChat({ options })
: createChat({ options }));
// response.data.chatId can be undefined when the max number of requests has reached
if (!response.data.chatId) {
publish({
text: "You have reached max number of requests.",
variant: "danger",
});
setMessage(["", ""]);
setIsSendingMessage(false);
return;
}
const newChat = {
chatId: response.data.chatId,
history: response.data.history,
chatName: response.data.chatName,
};
if (!chatId) {
// Creating a new chat
console.log("---- Creating a new chat ----");
setAllChats((chats) => {
console.log({ chats });
return [...chats, newChat];
});
setChat(newChat);
router.push(`/chat/${response.data.chatId}`);
}
setChat(newChat);
setMessage(["", ""]);
setIsSendingMessage(false);
};
const deleteChat = async (chatId: UUID) => {
try {
await axiosInstance.delete(`/chat/${chatId}`);
setAllChats((chats) => chats.filter((chat) => chat.chatId !== chatId));
setAllChats((chats) => chats.filter((chat) => chat.chat_id !== chatId));
// TODO: Change route only when the current chat is being deleted
router.push("/chat");
publish({
@ -147,26 +49,12 @@ export default function useChats() {
}
};
const resetChat = async () => {
setChat(null);
};
useEffect(() => {
fetchAllChats();
}, []);
return {
allChats,
chat,
isSendingMessage,
message,
setMessage,
resetChat,
fetchAllChats,
fetchChat,
deleteChat,
sendMessage,
};
}

View File

@ -3,14 +3,14 @@ import { useEffect, useState } from "react";
import { isSpeechRecognitionSupported } from "@/lib/helpers/isSpeechRecognitionSupported";
import useChatsContext from "./useChatsContext";
type useSpeechProps = {
setMessage: (newValue: string | ((prevValue: string) => string)) => void;
};
export const useSpeech = () => {
export const useSpeech = ({ setMessage }: useSpeechProps) => {
const [isListening, setIsListening] = useState(false);
const [speechSupported, setSpeechSupported] = useState(false);
const { setMessage } = useChatsContext();
useEffect(() => {
if (isSpeechRecognitionSupported()) {
setSpeechSupported(true);
@ -39,7 +39,7 @@ export const useSpeech = () => {
mic.onresult = (event: SpeechRecognitionEvent) => {
const interimTranscript =
event.results[event.results.length - 1][0].transcript;
setMessage((prevMessage) => ["user", prevMessage + interimTranscript]);
setMessage((prevMessage) => prevMessage + interimTranscript);
};
if (isListening) {

View File

@ -1,14 +0,0 @@
import { UUID } from "crypto";
export interface Chat {
chatId: UUID;
chatName: string;
history: ChatHistory;
}
export type ChatMessage = [string, string];
export type ChatHistory = ChatMessage[];
export type ChatResponse = Omit<Chat, "chatId"> & {
chatId: UUID | undefined;
};

View File

@ -15,6 +15,16 @@ CREATE TABLE IF NOT EXISTS chats(
chat_name TEXT
);
-- Create chat_history table
CREATE TABLE IF NOT EXISTS chat_history (
message_id UUID DEFAULT uuid_generate_v4(),
chat_id UUID REFERENCES chats(chat_id),
user_message TEXT,
assistant TEXT,
message_time TIMESTAMP DEFAULT current_timestamp,
PRIMARY KEY (chat_id, message_id)
);
-- Create vector extension
CREATE EXTENSION IF NOT EXISTS vector;