From e1a740472f4fb1b0ee3957d284727fd645dad315 Mon Sep 17 00:00:00 2001 From: Mamadou DICKO <63923024+mamadoudicko@users.noreply.github.com> Date: Tue, 20 Jun 2023 09:54:23 +0200 Subject: [PATCH] Feat: chat name edit (#343) * feat(chat): add name update * chore(linting): add flake8 * feat: add chat name edit --- .flake8 | 4 + .vscode/settings.json | 15 ++- backend/auth/api_key_handler.py | 53 +++++--- backend/auth/auth_bearer.py | 13 +- backend/auth/jwt_token_handler.py | 14 ++- backend/crawl/crawler.py | 30 ++--- backend/llm/summarization.py | 52 ++++---- backend/models/brains.py | 7 +- backend/models/chats.py | 7 +- backend/models/users.py | 2 +- backend/parsers/common.py | 46 +++++-- backend/parsers/csv.py | 18 ++- backend/requirements.txt | 2 + backend/routes/chat_routes.py | 113 ++++++++++++++---- backend/routes/user_routes.py | 4 + backend/utils/chats.py | 40 +++++-- frontend/app/chat/[chatId]/page.tsx | 3 +- .../ChatsListItem}/ChatsListItem.tsx | 49 +++++++- .../ChatsListItem/components/ChatName.tsx | 18 +++ .../components/ChatsListItem/index.ts | 1 + .../app/chat/components/ChatsList/index.tsx | 2 +- .../app/upload/components/Crawler/index.tsx | 2 +- .../context/ChatsProvider/hooks/useChats.ts | 45 ++++--- frontend/lib/types/Chat.ts | 4 + frontend/tsconfig.eslint.json | 3 + 25 files changed, 393 insertions(+), 154 deletions(-) create mode 100644 .flake8 rename frontend/app/chat/components/ChatsList/{ => components/ChatsListItem}/ChatsListItem.tsx (58%) create mode 100644 frontend/app/chat/components/ChatsList/components/ChatsListItem/components/ChatName.tsx create mode 100644 frontend/app/chat/components/ChatsList/components/ChatsListItem/index.ts diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..292006cb1 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +; Minimal configuration for Flake8 to work with Black. +max-line-length = 100 +ignore = E101,E111,E112,E221,E222,E501,E711,E712,W503,W504,F401 diff --git a/.vscode/settings.json b/.vscode/settings.json index 403534473..08fbd3b94 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,10 +1,15 @@ { - "[python]": { - "editor.defaultFormatter": "ms-python.autopep8" - }, "python.formatting.provider": "black", "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": true, + "source.fixAll":true }, - "python.linting.enabled": true + "python.linting.enabled": true, + "python.linting.flake8Enabled": true, + "editor.formatOnSave": true, + "[typescript]": { + "editor.defaultFormatter": "esbenp.prettier-vscode", + "editor.formatOnSave": true + }, + "editor.formatOnSaveMode": "modifications" } diff --git a/backend/auth/api_key_handler.py b/backend/auth/api_key_handler.py index 7fd96c1b6..e0ecfa800 100644 --- a/backend/auth/api_key_handler.py +++ b/backend/auth/api_key_handler.py @@ -1,4 +1,3 @@ - from datetime import datetime from fastapi import HTTPException @@ -6,13 +5,22 @@ from models.settings import CommonsDep from pydantic import DateError -async def verify_api_key(api_key: str, commons: CommonsDep): +async def verify_api_key(api_key: str, commons: CommonsDep): try: # Use UTC time to avoid timezone issues current_date = datetime.utcnow().date() - result = commons['supabase'].table('api_keys').select('api_key', 'creation_time').filter('api_key', 'eq', api_key).filter('is_active', 'eq', True).execute() + result = ( + commons["supabase"] + .table("api_keys") + .select("api_key", "creation_time") + .filter("api_key", "eq", api_key) + .filter("is_active", "eq", True) + .execute() + ) if result.data is not None and len(result.data) > 0: - api_key_creation_date = datetime.strptime(result.data[0]['creation_time'], "%Y-%m-%dT%H:%M:%S").date() + api_key_creation_date = datetime.strptime( + result.data[0]["creation_time"], "%Y-%m-%dT%H:%M:%S" + ).date() # Check if the API key was created today: Todo remove this check and use deleted_time instead. if api_key_creation_date == current_date: @@ -20,17 +28,34 @@ async def verify_api_key(api_key: str, commons: CommonsDep): return False except DateError: return False - + + async def get_user_from_api_key(api_key: str, commons: CommonsDep): # Lookup the user_id from the api_keys table - user_id_data = commons['supabase'].table('api_keys').select('user_id').filter('api_key', 'eq', api_key).execute() - - if not user_id_data.data: - raise HTTPException(status_code=400, detail="Invalid API key.") - - user_id = user_id_data.data[0]['user_id'] + user_id_data = ( + commons["supabase"] + .table("api_keys") + .select("user_id") + .filter("api_key", "eq", api_key) + .execute() + ) - # Lookup the email from the users table. Todo: remove and use user_id for credentials - user_email_data = commons['supabase'].table('users').select('email').filter('user_id', 'eq', user_id).execute() + if not user_id_data.data: + raise HTTPException(status_code=400, detail="Invalid API key.") - return {'email': user_email_data.data[0]['email']} if user_email_data.data else {'email': None} + user_id = user_id_data.data[0]["user_id"] + + # Lookup the email from the users table. Todo: remove and use user_id for credentials + user_email_data = ( + commons["supabase"] + .table("users") + .select("email") + .filter("user_id", "eq", user_id) + .execute() + ) + + return ( + {"email": user_email_data.data[0]["email"]} + if user_email_data.data + else {"email": None} + ) diff --git a/backend/auth/auth_bearer.py b/backend/auth/auth_bearer.py index 7d9c9c270..1a1760050 100644 --- a/backend/auth/auth_bearer.py +++ b/backend/auth/auth_bearer.py @@ -14,7 +14,9 @@ class AuthBearer(HTTPBearer): super().__init__(auto_error=auto_error) async def __call__(self, request: Request, commons: CommonsDep): - credentials: Optional[HTTPAuthorizationCredentials] = await super().__call__(request) + credentials: Optional[HTTPAuthorizationCredentials] = await super().__call__( + request + ) self.check_scheme(credentials) token = credentials.credentials return await self.authenticate(token, commons) @@ -33,10 +35,13 @@ class AuthBearer(HTTPBearer): elif await verify_api_key(token, commons): return await get_user_from_api_key(token, commons) else: - raise HTTPException(status_code=402, detail="Invalid token or expired token.") + raise HTTPException( + status_code=402, detail="Invalid token or expired token." + ) def get_test_user(self): - return {'email': 'test@example.com'} # replace with test user information + return {"email": "test@example.com"} # replace with test user information + def get_current_user(credentials: dict = Depends(AuthBearer())) -> User: - return User(email=credentials.get('email', 'none')) + return User(email=credentials.get("email", "none")) diff --git a/backend/auth/jwt_token_handler.py b/backend/auth/jwt_token_handler.py index 57fc64341..6975c7893 100644 --- a/backend/auth/jwt_token_handler.py +++ b/backend/auth/jwt_token_handler.py @@ -8,6 +8,7 @@ from jose.exceptions import JWTError SECRET_KEY = os.environ.get("JWT_SECRET_KEY") ALGORITHM = "HS256" + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: @@ -18,19 +19,24 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + def decode_access_token(token: str): try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_aud": False}) + payload = jwt.decode( + token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_aud": False} + ) return payload - except JWTError as e: + except JWTError: return None - + + def verify_token(token: str): payload = decode_access_token(token) return payload is not None + def get_user_email_from_token(token: str): payload = decode_access_token(token) if payload: return payload.get("email") - return "none" \ No newline at end of file + return "none" diff --git a/backend/crawl/crawler.py b/backend/crawl/crawler.py index 1c31f60d5..0c2b24edc 100644 --- a/backend/crawl/crawler.py +++ b/backend/crawl/crawler.py @@ -9,11 +9,11 @@ from pydantic import BaseModel class CrawlWebsite(BaseModel): - url : str - js : bool = False - depth : int = 1 - max_pages : int = 100 - max_time : int = 60 + url: str + js: bool = False + depth: int = 1 + max_pages: int = 100 + max_time: int = 60 def _crawl(self, url): response = requests.get(url) @@ -24,18 +24,19 @@ class CrawlWebsite(BaseModel): def process(self): content = self._crawl(self.url) - - ## Create a file + + # Create a file file_name = slugify(self.url) + ".html" temp_file_path = os.path.join(tempfile.gettempdir(), file_name) - with open(temp_file_path, 'w') as temp_file: + with open(temp_file_path, "w") as temp_file: temp_file.write(content) - ## Process the file - + # Process the file + if content: return temp_file_path, file_name else: return None + def checkGithub(self): if "github.com" in self.url: return True @@ -43,9 +44,8 @@ class CrawlWebsite(BaseModel): return False - def slugify(text): - text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8') - text = re.sub(r'[^\w\s-]', '', text).strip().lower() - text = re.sub(r'[-\s]+', '-', text) - return text \ No newline at end of file + text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("utf-8") + text = re.sub(r"[^\w\s-]", "", text).strip().lower() + text = re.sub(r"[-\s]+", "-", text) + return text diff --git a/backend/llm/summarization.py b/backend/llm/summarization.py index a14efecfd..1df98f5df 100644 --- a/backend/llm/summarization.py +++ b/backend/llm/summarization.py @@ -8,11 +8,12 @@ logger = get_logger(__name__) openai_api_key = os.environ.get("OPENAI_API_KEY") openai.api_key = openai_api_key -summary_llm = guidance.llms.OpenAI('gpt-3.5-turbo-0613', caching=False) +summary_llm = guidance.llms.OpenAI("gpt-3.5-turbo-0613", caching=False) def llm_summerize(document): - summary = guidance(""" + summary = guidance( + """ {{#system~}} You are a world best summarizer. \n Condense the text, capturing essential points and core ideas. Include relevant \ @@ -28,21 +29,23 @@ Summarize the following text: {{#assistant~}} {{gen 'summarization' temperature=0.2 max_tokens=100}} {{/assistant~}} -""", llm=summary_llm) +""", + llm=summary_llm, + ) summary = summary(document=document) - logger.info('Summarization: %s', summary) - return summary['summarization'] + logger.info("Summarization: %s", summary) + return summary["summarization"] def llm_evaluate_summaries(question, summaries, model): - if not model.startswith('gpt'): - logger.info( - f'Model {model} not supported. Using gpt-3.5-turbo instead.') - model = 'gpt-3.5-turbo-0613' - logger.info(f'Evaluating summaries with {model}') + if not model.startswith("gpt"): + logger.info(f"Model {model} not supported. Using gpt-3.5-turbo instead.") + model = "gpt-3.5-turbo-0613" + logger.info(f"Evaluating summaries with {model}") evaluation_llm = guidance.llms.OpenAI(model, caching=False) - evaluation = guidance(""" + evaluation = guidance( + """ {{#system~}} You are a world best evaluator. You evaluate the relevance of summaries based \ on user input question. Return evaluation in following csv format, csv headers \ @@ -73,23 +76,30 @@ Summary {{#assistant~}} {{gen 'evaluation' temperature=0.2 stop='<|im_end|>'}} {{/assistant~}} -""", llm=evaluation_llm) +""", + llm=evaluation_llm, + ) result = evaluation(question=question, summaries=summaries) evaluations = {} - for evaluation in result['evaluation'].split('\n'): - if evaluation == '' or not evaluation[0].isdigit(): + for evaluation in result["evaluation"].split("\n"): + if evaluation == "" or not evaluation[0].isdigit(): continue - logger.info('Evaluation Row: %s', evaluation) - summary_id, document_id, score, *reason = evaluation.split(',') + logger.info("Evaluation Row: %s", evaluation) + summary_id, document_id, score, *reason = evaluation.split(",") if not score.isdigit(): continue score = int(score) if score < 3 or score > 5: continue evaluations[summary_id] = { - 'evaluation': score, - 'reason': ','.join(reason), - 'summary_id': summary_id, - 'document_id': document_id, + "evaluation": score, + "reason": ",".join(reason), + "summary_id": summary_id, + "document_id": document_id, } - return [e for e in sorted(evaluations.values(), key=lambda x: x['evaluation'], reverse=True)] + return [ + e + for e in sorted( + evaluations.values(), key=lambda x: x["evaluation"], reverse=True + ) + ] diff --git a/backend/models/brains.py b/backend/models/brains.py index 83046bf46..6b73669ee 100644 --- a/backend/models/brains.py +++ b/backend/models/brains.py @@ -11,12 +11,13 @@ class Brain(BaseModel): model: str = "gpt-3.5-turbo-0613" temperature: float = 0.0 max_tokens: int = 256 - -class BrainToUpdate(BaseModel): + + +class BrainToUpdate(BaseModel): brain_id: UUID brain_name: Optional[str] = "New Brain" status: Optional[str] = "public" model: Optional[str] = "gpt-3.5-turbo-0613" temperature: Optional[float] = 0.0 max_tokens: Optional[int] = 256 - file_sha1: Optional[str] = '' \ No newline at end of file + file_sha1: Optional[str] = "" diff --git a/backend/models/chats.py b/backend/models/chats.py index 0e6ae629e..502afc841 100644 --- a/backend/models/chats.py +++ b/backend/models/chats.py @@ -12,4 +12,9 @@ class ChatMessage(BaseModel): temperature: float = 0.0 max_tokens: int = 256 use_summarization: bool = False - chat_id: Optional[UUID] = None, + chat_id: Optional[UUID] = None + chat_name: Optional[str] = None + + +class ChatAttributes(BaseModel): + chat_name: Optional[str] = None diff --git a/backend/models/users.py b/backend/models/users.py index 6b95efdca..d84f1d4f6 100644 --- a/backend/models/users.py +++ b/backend/models/users.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -class User (BaseModel): +class User(BaseModel): email: str diff --git a/backend/parsers/common.py b/backend/parsers/common.py index 3ccc436b7..268471e08 100644 --- a/backend/parsers/common.py +++ b/backend/parsers/common.py @@ -13,7 +13,15 @@ from utils.file import compute_sha1_from_content, compute_sha1_from_file from utils.vectors import Neurons, create_summary -async def process_file(commons: CommonsDep, file: UploadFile, loader_class, file_suffix, enable_summarization, user, user_openai_api_key): +async def process_file( + commons: CommonsDep, + file: UploadFile, + loader_class, + file_suffix, + enable_summarization, + user, + user_openai_api_key, +): documents = [] file_name = file.filename file_size = file.file._file.tell() # Getting the size of the file @@ -36,7 +44,8 @@ async def process_file(commons: CommonsDep, file: UploadFile, loader_class, file chunk_overlap = 0 text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( - chunk_size=chunk_size, chunk_overlap=chunk_overlap) + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) documents = text_splitter.split_documents(documents) @@ -48,17 +57,19 @@ async def process_file(commons: CommonsDep, file: UploadFile, loader_class, file "chunk_size": chunk_size, "chunk_overlap": chunk_overlap, "date": dateshort, - "summarization": "true" if enable_summarization else "false" + "summarization": "true" if enable_summarization else "false", } doc_with_metadata = Document( page_content=doc.page_content, metadata=metadata) neurons = Neurons(commons=commons) neurons.create_vector(user.email, doc_with_metadata, user_openai_api_key) - # add_usage(stats_db, "embedding", "audio", metadata={"file_name": file_meta_name,"file_type": ".txt", "chunk_size": chunk_size, "chunk_overlap": chunk_overlap}) + # add_usage(stats_db, "embedding", "audio", metadata={"file_name": file_meta_name,"file_type": ".txt", "chunk_size": chunk_size, "chunk_overlap": chunk_overlap}) - # Remove the enable_summarization and ids + # Remove the enable_summarization and ids if enable_summarization and ids and len(ids) > 0: - create_summary(commons, document_id=ids[0], content = doc.page_content, metadata = metadata) + create_summary( + commons, document_id=ids[0], content=doc.page_content, metadata=metadata + ) return @@ -66,13 +77,24 @@ async def file_already_exists(supabase, file, user): # TODO: user brain id instead of user file_content = await file.read() file_sha1 = compute_sha1_from_content(file_content) - response = supabase.table("vectors").select("id").filter("metadata->>file_sha1", "eq", file_sha1) \ - .filter("user_id", "eq", user.email).execute() + response = ( + supabase.table("vectors") + .select("id") + .filter("metadata->>file_sha1", "eq", file_sha1) + .filter("user_id", "eq", user.email) + .execute() + ) return len(response.data) > 0 + async def file_already_exists_from_content(supabase, file_content, user): - # TODO: user brain id instead of user + # TODO: user brain id instead of user file_sha1 = compute_sha1_from_content(file_content) - response = supabase.table("vectors").select("id").filter("metadata->>file_sha1", "eq", file_sha1) \ - .filter("user_id", "eq", user.email).execute() - return len(response.data) > 0 \ No newline at end of file + response = ( + supabase.table("vectors") + .select("id") + .filter("metadata->>file_sha1", "eq", file_sha1) + .filter("user_id", "eq", user.email) + .execute() + ) + return len(response.data) > 0 diff --git a/backend/parsers/csv.py b/backend/parsers/csv.py index 5fa2c59ea..5a8807b60 100644 --- a/backend/parsers/csv.py +++ b/backend/parsers/csv.py @@ -5,5 +5,19 @@ from models.settings import CommonsDep from .common import process_file -def process_csv(commons: CommonsDep, file: UploadFile, enable_summarization, user, user_openai_api_key): - return process_file(commons, file, CSVLoader, ".csv", enable_summarization, user, user_openai_api_key) +def process_csv( + commons: CommonsDep, + file: UploadFile, + enable_summarization, + user, + user_openai_api_key, +): + return process_file( + commons, + file, + CSVLoader, + ".csv", + enable_summarization, + user, + user_openai_api_key, + ) diff --git a/backend/requirements.txt b/backend/requirements.txt index dc2211cef..5b8ed2623 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -20,3 +20,5 @@ python-jose==3.3.0 google_cloud_aiplatform==1.25.0 transformers==4.30.1 asyncpg==0.27.0 +flake8==6.0.0 +flake8-black==0.3.6 diff --git a/backend/routes/chat_routes.py b/backend/routes/chat_routes.py index ca1ff388b..19a8f7865 100644 --- a/backend/routes/chat_routes.py +++ b/backend/routes/chat_routes.py @@ -1,36 +1,60 @@ import os import time from uuid import UUID - from auth.auth_bearer import AuthBearer, get_current_user from fastapi import APIRouter, Depends, Request from llm.brainpicking import BrainPicking -from models.chats import ChatMessage +from models.chats import ChatMessage, ChatAttributes from models.settings import CommonsDep, common_dependencies from models.users import User from utils.chats import (create_chat, get_chat_name_from_first_question, update_chat) from utils.users import (create_user, fetch_user_id_from_credentials, update_user_request_count) +from http.client import HTTPException chat_router = APIRouter() + def get_user_chats(commons, user_id): - response = commons['supabase'].from_('chats').select('chatId:chat_id, chatName:chat_name').filter("user_id", "eq", user_id).execute() + response = ( + commons["supabase"] + .from_("chats") + .select("chatId:chat_id, chatName:chat_name") + .filter("user_id", "eq", user_id) + .execute() + ) return response.data + def get_chat_details(commons, chat_id): - response = commons['supabase'].from_('chats').select('*').filter("chat_id", "eq", chat_id).execute() + response = ( + commons["supabase"] + .from_("chats") + .select("*") + .filter("chat_id", "eq", chat_id) + .execute() + ) return response.data + def delete_chat_from_db(commons, chat_id): - commons['supabase'].table("chats").delete().match({"chat_id": chat_id}).execute() + commons["supabase"].table("chats").delete().match({"chat_id": chat_id}).execute() + def fetch_user_stats(commons, user, date): - response = commons['supabase'].from_('users').select('*').filter("email", "eq", user.email).filter("date", "eq", date).execute() + response = ( + commons["supabase"] + .from_("users") + .select("*") + .filter("email", "eq", user.email) + .filter("date", "eq", date) + .execute() + ) userItem = next(iter(response.data or []), {"requests_count": 0}) return userItem + # get all chats @chat_router.get("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"]) async def get_chats(current_user: User = Depends(get_current_user)): @@ -48,9 +72,10 @@ async def get_chats(current_user: User = Depends(get_current_user)): chats = get_user_chats(commons, user_id) return {"chats": chats} + # get one chat @chat_router.get("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"]) -async def get_chats( chat_id: UUID): +async def get_chat_handler(chat_id: UUID): """ Retrieve details of a specific chat by chat ID. @@ -63,13 +88,14 @@ async def get_chats( chat_id: UUID): commons = common_dependencies() chats = get_chat_details(commons, chat_id) if len(chats) > 0: - return {"chatId": chat_id, "history": chats[0]['history']} + return {"chatId": chat_id, "history": chats[0]["history"]} else: return {"error": "Chat not found"} + # delete one chat @chat_router.delete("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"]) -async def delete_chat( chat_id: UUID): +async def delete_chat(chat_id: UUID): """ Delete a specific chat by chat ID. """ @@ -77,29 +103,32 @@ async def delete_chat( chat_id: UUID): delete_chat_from_db(commons, chat_id) return {"message": f"{chat_id} has been deleted."} + # helper method for update and create chat def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=False): date = time.strftime("%Y%m%d") user_id = fetch_user_id_from_credentials(commons, {"email": email}) max_requests_number = os.getenv("MAX_REQUESTS_NUMBER") - user_openai_api_key = request.headers.get('Openai-Api-Key') + user_openai_api_key = request.headers.get("Openai-Api-Key") userItem = fetch_user_stats(commons, User(email=email), date) - old_request_count = userItem['requests_count'] + old_request_count = userItem["requests_count"] history = chat_message.history history.append(("user", chat_message.question)) - - if old_request_count == 0: - create_user(commons, email= email, date=date) - else: - update_user_request_count(commons,email, date, requests_count=old_request_count + 1) - if user_openai_api_key is None and old_request_count >= float(max_requests_number): - history.append(('assistant', "You have reached your requests limit")) - update_chat(commons, chat_id=chat_id, history=history) - return {"history": history} + chat_name = chat_message.chat_name + if old_request_count == 0: + create_user(commons, email=email, date=date) + else: + update_user_request_count( + commons, email, date, requests_count=old_request_count + 1 + ) + if user_openai_api_key is None and old_request_count >= float(max_requests_number): + history.append(("assistant", "You have reached your requests limit")) + update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name) + return {"history": history} brainPicking = BrainPicking().init(chat_message.model, email) answer = brainPicking.generate_answer(chat_message, user_openai_api_key) @@ -108,24 +137,58 @@ def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=Fal if is_new_chat: chat_name = get_chat_name_from_first_question(chat_message) new_chat = create_chat(commons, user_id, history, chat_name) - chat_id = new_chat.data[0]['chat_id'] + chat_id = new_chat.data[0]["chat_id"] else: - update_chat(commons, chat_id=chat_id, history=history) + update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name) return {"history": history, "chatId": chat_id} + # update existing chat @chat_router.put("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"]) -async def chat_endpoint(request: Request, commons: CommonsDep, chat_id: UUID, chat_message: ChatMessage, current_user: User = Depends(get_current_user)): +async def chat_endpoint( + request: Request, + commons: CommonsDep, + chat_id: UUID, + chat_message: ChatMessage, + current_user: User = Depends(get_current_user), +): """ Update an existing chat with new chat messages. """ return chat_handler(request, commons, chat_id, chat_message, current_user.email) + +# update existing chat +@chat_router.put("/chat/{chat_id}/metadata", dependencies=[Depends(AuthBearer())], tags=["Chat"]) +async def update_chat_attributes_handler( + commons: CommonsDep, + chat_message: ChatAttributes, + chat_id: UUID, + current_user: User = Depends(get_current_user), +): + """ + Update chat attributes + """ + + user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email}) + chat = get_chat_details(commons, chat_id)[0] + if user_id != chat.get('user_id'): + raise HTTPException(status_code=403, detail="Chat not owned by user") + return update_chat(commons=commons, chat_id=chat_id, chat_name=chat_message.chat_name) + + # create new chat @chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"]) -async def chat_endpoint(request: Request, commons: CommonsDep, chat_message: ChatMessage, current_user: User = Depends(get_current_user)): +async def create_chat_handler( + request: Request, + commons: CommonsDep, + chat_message: ChatMessage, + current_user: User = Depends(get_current_user), +): """ Create a new chat with initial chat messages. """ - return chat_handler(request, commons, None, chat_message, current_user.email, is_new_chat=True) + return chat_handler( + request, commons, None, chat_message, current_user.email, is_new_chat=True + ) diff --git a/backend/routes/user_routes.py b/backend/routes/user_routes.py index a18334be2..3c36f7c68 100644 --- a/backend/routes/user_routes.py +++ b/backend/routes/user_routes.py @@ -10,10 +10,12 @@ user_router = APIRouter() MAX_BRAIN_SIZE_WITH_OWN_KEY = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200)) + def get_unique_documents(vectors): # Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary return [dict(t) for t in set(tuple(d.items()) for d in vectors)] + def get_user_vectors(commons, email): # Access the supabase table and get the vectors user_vectors_response = commons['supabase'].table("vectors").select( @@ -22,11 +24,13 @@ def get_user_vectors(commons, email): .execute() return user_vectors_response.data + def get_user_request_stats(commons, email): requests_stats = commons['supabase'].from_('users').select( '*').filter("email", "eq", email).execute() return requests_stats.data + @user_router.get("/user", dependencies=[Depends(AuthBearer())], tags=["User"]) async def get_user_endpoint(request: Request, current_user: User = Depends(get_current_user)): """ diff --git a/backend/utils/chats.py b/backend/utils/chats.py index ec0d390b8..c6d44bd3c 100644 --- a/backend/utils/chats.py +++ b/backend/utils/chats.py @@ -8,31 +8,45 @@ logger = get_logger(__name__) def create_chat(commons: CommonsDep, user_id, history, chat_name): # Chat is created upon the user's first question asked logger.info(f"New chat entry in chats table for user {user_id}") - + # Insert a new row into the chats table new_chat = { "user_id": user_id, - "history": history, # Empty chat to start - "chat_name": chat_name + "history": history, # Empty chat to start + "chat_name": chat_name, } - insert_response = commons['supabase'].table('chats').insert(new_chat).execute() + insert_response = commons["supabase"].table("chats").insert(new_chat).execute() logger.info(f"Insert response {insert_response.data}") - return(insert_response) + return insert_response -def update_chat(commons: CommonsDep, chat_id, history): + +def update_chat(commons: CommonsDep, chat_id, history=None, chat_name=None): if not chat_id: logger.error("No chat_id provided") return - commons['supabase'].table("chats").update( - { "history": history}).match({"chat_id": chat_id}).execute() - logger.info(f"Chat {chat_id} updated") - -def get_chat_name_from_first_question( chat_message: ChatMessage): - # Step 1: Get the summary of the first question + updates = {} + + if history is not None: + updates["history"] = history + + if chat_name is not None: + updates["chat_name"] = chat_name + + if updates: + commons["supabase"].table("chats").update(updates).match( + {"chat_id": chat_id} + ).execute() + logger.info(f"Chat {chat_id} updated") + else: + logger.info(f"No updates to apply for chat {chat_id}") + + +def get_chat_name_from_first_question(chat_message: ChatMessage): + # Step 1: Get the summary of the first question # first_question_summary = summarize_as_title(chat_message.question) # Step 2: Process this summary to create a chat name by selecting the first three words - chat_name = ' '.join(chat_message.question.split()[:3]) + chat_name = " ".join(chat_message.question.split()[:3]) return chat_name diff --git a/frontend/app/chat/[chatId]/page.tsx b/frontend/app/chat/[chatId]/page.tsx index 3344d1caf..3980fd944 100644 --- a/frontend/app/chat/[chatId]/page.tsx +++ b/frontend/app/chat/[chatId]/page.tsx @@ -20,12 +20,11 @@ export default function ChatPage({ params }: ChatPageProps) { const { fetchChat, resetChat } = useChatsContext(); useEffect(() => { - // if (chatId) if (!chatId) { resetChat(); } fetchChat(chatId); - }, [fetchChat, chatId]); + }, []); return (
diff --git a/frontend/app/chat/components/ChatsList/ChatsListItem.tsx b/frontend/app/chat/components/ChatsList/components/ChatsListItem/ChatsListItem.tsx similarity index 58% rename from frontend/app/chat/components/ChatsList/ChatsListItem.tsx rename to frontend/app/chat/components/ChatsList/components/ChatsListItem/ChatsListItem.tsx index aa61a9d50..8810e2b24 100644 --- a/frontend/app/chat/components/ChatsList/ChatsListItem.tsx +++ b/frontend/app/chat/components/ChatsList/components/ChatsListItem/ChatsListItem.tsx @@ -2,24 +2,52 @@ import { UUID } from "crypto"; import Link from "next/link"; import { usePathname } from "next/navigation"; -import { FiTrash2 } from "react-icons/fi"; +import { FiEdit, FiSave, FiTrash2 } from "react-icons/fi"; import { MdChatBubbleOutline } from "react-icons/md"; import { cn } from "@/lib/utils"; -import { Chat } from "../../../../lib/types/Chat"; +import { useToast } from "@/lib/hooks/useToast"; +import { useAxios } from "@/lib/useAxios"; +import { useState } from "react"; +import { Chat, ChatResponse } from "../../../../../../lib/types/Chat"; +import { ChatName } from "./components/ChatName"; interface ChatsListItemProps { chat: Chat; deleteChat: (id: UUID) => void; } -const ChatsListItem = ({ +export const ChatsListItem = ({ chat, deleteChat, }: ChatsListItemProps): JSX.Element => { const pathname = usePathname()?.split("/").at(-1); const selected = chat.chatId === pathname; + const [chatName, setChatName] = useState(chat.chatName); + const { axiosInstance } = useAxios(); + const {publish} = useToast() + const [editingName, setEditingName] = useState(false); + + const updateChatName = async () => { + if(chatName !== chat.chatName) { + await axiosInstance.put(`/chat/${chat.chatId}/metadata`, { + chat_name:chatName, + + }); + publish({text:'Chat name updated',variant:'success'}) + } + } + + const handleEditNameClick = () => { + if(editingName){ + setEditingName(false) ; + void updateChatName() + } + else { + setEditingName(true) + } + } return (
+ - -

{chat.chatName}

+
{chat.chatId}
+ +
{/* Fade to white */} @@ -63,4 +100,4 @@ const ChatsListItem = ({ ); }; -export default ChatsListItem; + diff --git a/frontend/app/chat/components/ChatsList/components/ChatsListItem/components/ChatName.tsx b/frontend/app/chat/components/ChatsList/components/ChatsListItem/components/ChatName.tsx new file mode 100644 index 000000000..9852b8583 --- /dev/null +++ b/frontend/app/chat/components/ChatsList/components/ChatsListItem/components/ChatName.tsx @@ -0,0 +1,18 @@ + + +interface ChatNameProps { + name: string; + editing?: boolean; +setName: (name:string) => void; +} + +export const ChatName = ({setName,name,editing=false}:ChatNameProps):JSX.Element => { + + if(editing) { + return setName(event.target.value)} autoFocus value={name} /> + } + + return ( +

{name}

+ ) +} \ No newline at end of file diff --git a/frontend/app/chat/components/ChatsList/components/ChatsListItem/index.ts b/frontend/app/chat/components/ChatsList/components/ChatsListItem/index.ts new file mode 100644 index 000000000..4f061febe --- /dev/null +++ b/frontend/app/chat/components/ChatsList/components/ChatsListItem/index.ts @@ -0,0 +1 @@ +export * from "./ChatsListItem"; diff --git a/frontend/app/chat/components/ChatsList/index.tsx b/frontend/app/chat/components/ChatsList/index.tsx index 489f60130..0b5c78398 100644 --- a/frontend/app/chat/components/ChatsList/index.tsx +++ b/frontend/app/chat/components/ChatsList/index.tsx @@ -6,8 +6,8 @@ import { MotionConfig, motion } from "framer-motion"; import { useState } from "react"; import { MdChevronRight } from "react-icons/md"; -import ChatsListItem from "./ChatsListItem"; import { NewChatButton } from "./NewChatButton"; +import { ChatsListItem } from "./components/ChatsListItem/"; export const ChatsList = (): JSX.Element => { const { allChats, deleteChat } = useChatsContext(); diff --git a/frontend/app/upload/components/Crawler/index.tsx b/frontend/app/upload/components/Crawler/index.tsx index 71599a946..73c7c391c 100644 --- a/frontend/app/upload/components/Crawler/index.tsx +++ b/frontend/app/upload/components/Crawler/index.tsx @@ -24,7 +24,7 @@ export const Crawler = (): JSX.Element => { />
-
diff --git a/frontend/lib/context/ChatsProvider/hooks/useChats.ts b/frontend/lib/context/ChatsProvider/hooks/useChats.ts index dc5a42372..21e07ab2e 100644 --- a/frontend/lib/context/ChatsProvider/hooks/useChats.ts +++ b/frontend/lib/context/ChatsProvider/hooks/useChats.ts @@ -1,7 +1,7 @@ /* eslint-disable */ import { UUID } from "crypto"; import { useRouter } from "next/navigation"; -import { useCallback, useEffect, useState } from "react"; +import { useEffect, useState } from "react"; import { useBrainConfig } from "@/lib/context/BrainConfigProvider/hooks/useBrainConfig"; import { useToast } from "@/lib/hooks/useToast"; @@ -22,7 +22,7 @@ export default function useChats() { const router = useRouter(); const { publish } = useToast(); - const fetchAllChats = useCallback(async () => { + const fetchAllChats = async () => { try { console.log("Fetching all chats"); const response = await axiosInstance.get<{ @@ -37,29 +37,26 @@ export default function useChats() { text: "Error occured while fetching your chats", }); } - }, [axiosInstance, publish]); + }; - const fetchChat = useCallback( - async (chatId?: UUID) => { - if (!chatId) { - return; - } - try { - console.log(`Fetching chat ${chatId}`); - const response = await axiosInstance.get(`/chat/${chatId}`); - console.log(response.data); + const fetchChat = async (chatId?: UUID) => { + if (!chatId) { + return; + } + try { + console.log(`Fetching chat ${chatId}`); + const response = await axiosInstance.get(`/chat/${chatId}`); + console.log(response.data); - setChat(response.data); - } catch (error) { - console.error(error); - publish({ - variant: "danger", - text: `Error occured while fetching ${chatId}`, - }); - } - }, - [axiosInstance, publish] - ); + setChat(response.data); + } catch (error) { + console.error(error); + publish({ + variant: "danger", + text: `Error occured while fetching ${chatId}`, + }); + } + }; type ChatResponse = Omit & { chatId: UUID | undefined }; @@ -156,7 +153,7 @@ export default function useChats() { useEffect(() => { fetchAllChats(); - }, [fetchAllChats]); + }, []); return { allChats, diff --git a/frontend/lib/types/Chat.ts b/frontend/lib/types/Chat.ts index 3a71646bd..6c79157d6 100644 --- a/frontend/lib/types/Chat.ts +++ b/frontend/lib/types/Chat.ts @@ -8,3 +8,7 @@ export interface Chat { export type ChatMessage = [string, string]; export type ChatHistory = ChatMessage[]; + + export type ChatResponse = Omit & { + chatId: UUID | undefined; + }; diff --git a/frontend/tsconfig.eslint.json b/frontend/tsconfig.eslint.json index 8a268bfe0..99400ba6e 100644 --- a/frontend/tsconfig.eslint.json +++ b/frontend/tsconfig.eslint.json @@ -30,6 +30,9 @@ "name": "next" } ], + "paths": { + "@/*": ["./*"], + } }, "ts-node": { "compilerOptions": {