mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-07-07 11:06:28 +03:00
Feat: chat name edit (#343)
* feat(chat): add name update * chore(linting): add flake8 * feat: add chat name edit
This commit is contained in:
parent
8ed8a2c9ef
commit
e1a740472f
4
.flake8
Normal file
4
.flake8
Normal file
|
@ -0,0 +1,4 @@
|
|||
[flake8]
|
||||
; Minimal configuration for Flake8 to work with Black.
|
||||
max-line-length = 100
|
||||
ignore = E101,E111,E112,E221,E222,E501,E711,E712,W503,W504,F401
|
15
.vscode/settings.json
vendored
15
.vscode/settings.json
vendored
|
@ -1,10 +1,15 @@
|
|||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.autopep8"
|
||||
},
|
||||
"python.formatting.provider": "black",
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": true
|
||||
"source.organizeImports": true,
|
||||
"source.fixAll":true
|
||||
},
|
||||
"python.linting.enabled": true
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"editor.formatOnSave": true,
|
||||
"[typescript]": {
|
||||
"editor.defaultFormatter": "esbenp.prettier-vscode",
|
||||
"editor.formatOnSave": true
|
||||
},
|
||||
"editor.formatOnSaveMode": "modifications"
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
@ -6,13 +5,22 @@ from models.settings import CommonsDep
|
|||
from pydantic import DateError
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str, commons: CommonsDep):
|
||||
async def verify_api_key(api_key: str, commons: CommonsDep):
|
||||
try:
|
||||
# Use UTC time to avoid timezone issues
|
||||
current_date = datetime.utcnow().date()
|
||||
result = commons['supabase'].table('api_keys').select('api_key', 'creation_time').filter('api_key', 'eq', api_key).filter('is_active', 'eq', True).execute()
|
||||
result = (
|
||||
commons["supabase"]
|
||||
.table("api_keys")
|
||||
.select("api_key", "creation_time")
|
||||
.filter("api_key", "eq", api_key)
|
||||
.filter("is_active", "eq", True)
|
||||
.execute()
|
||||
)
|
||||
if result.data is not None and len(result.data) > 0:
|
||||
api_key_creation_date = datetime.strptime(result.data[0]['creation_time'], "%Y-%m-%dT%H:%M:%S").date()
|
||||
api_key_creation_date = datetime.strptime(
|
||||
result.data[0]["creation_time"], "%Y-%m-%dT%H:%M:%S"
|
||||
).date()
|
||||
|
||||
# Check if the API key was created today: Todo remove this check and use deleted_time instead.
|
||||
if api_key_creation_date == current_date:
|
||||
|
@ -20,17 +28,34 @@ async def verify_api_key(api_key: str, commons: CommonsDep):
|
|||
return False
|
||||
except DateError:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
async def get_user_from_api_key(api_key: str, commons: CommonsDep):
|
||||
# Lookup the user_id from the api_keys table
|
||||
user_id_data = commons['supabase'].table('api_keys').select('user_id').filter('api_key', 'eq', api_key).execute()
|
||||
|
||||
if not user_id_data.data:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key.")
|
||||
|
||||
user_id = user_id_data.data[0]['user_id']
|
||||
user_id_data = (
|
||||
commons["supabase"]
|
||||
.table("api_keys")
|
||||
.select("user_id")
|
||||
.filter("api_key", "eq", api_key)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Lookup the email from the users table. Todo: remove and use user_id for credentials
|
||||
user_email_data = commons['supabase'].table('users').select('email').filter('user_id', 'eq', user_id).execute()
|
||||
if not user_id_data.data:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key.")
|
||||
|
||||
return {'email': user_email_data.data[0]['email']} if user_email_data.data else {'email': None}
|
||||
user_id = user_id_data.data[0]["user_id"]
|
||||
|
||||
# Lookup the email from the users table. Todo: remove and use user_id for credentials
|
||||
user_email_data = (
|
||||
commons["supabase"]
|
||||
.table("users")
|
||||
.select("email")
|
||||
.filter("user_id", "eq", user_id)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return (
|
||||
{"email": user_email_data.data[0]["email"]}
|
||||
if user_email_data.data
|
||||
else {"email": None}
|
||||
)
|
||||
|
|
|
@ -14,7 +14,9 @@ class AuthBearer(HTTPBearer):
|
|||
super().__init__(auto_error=auto_error)
|
||||
|
||||
async def __call__(self, request: Request, commons: CommonsDep):
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = await super().__call__(request)
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = await super().__call__(
|
||||
request
|
||||
)
|
||||
self.check_scheme(credentials)
|
||||
token = credentials.credentials
|
||||
return await self.authenticate(token, commons)
|
||||
|
@ -33,10 +35,13 @@ class AuthBearer(HTTPBearer):
|
|||
elif await verify_api_key(token, commons):
|
||||
return await get_user_from_api_key(token, commons)
|
||||
else:
|
||||
raise HTTPException(status_code=402, detail="Invalid token or expired token.")
|
||||
raise HTTPException(
|
||||
status_code=402, detail="Invalid token or expired token."
|
||||
)
|
||||
|
||||
def get_test_user(self):
|
||||
return {'email': 'test@example.com'} # replace with test user information
|
||||
return {"email": "test@example.com"} # replace with test user information
|
||||
|
||||
|
||||
def get_current_user(credentials: dict = Depends(AuthBearer())) -> User:
|
||||
return User(email=credentials.get('email', 'none'))
|
||||
return User(email=credentials.get("email", "none"))
|
||||
|
|
|
@ -8,6 +8,7 @@ from jose.exceptions import JWTError
|
|||
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
|
@ -18,19 +19,24 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_access_token(token: str):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_aud": False})
|
||||
payload = jwt.decode(
|
||||
token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_aud": False}
|
||||
)
|
||||
return payload
|
||||
except JWTError as e:
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def verify_token(token: str):
|
||||
payload = decode_access_token(token)
|
||||
return payload is not None
|
||||
|
||||
|
||||
def get_user_email_from_token(token: str):
|
||||
payload = decode_access_token(token)
|
||||
if payload:
|
||||
return payload.get("email")
|
||||
return "none"
|
||||
return "none"
|
||||
|
|
|
@ -9,11 +9,11 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class CrawlWebsite(BaseModel):
|
||||
url : str
|
||||
js : bool = False
|
||||
depth : int = 1
|
||||
max_pages : int = 100
|
||||
max_time : int = 60
|
||||
url: str
|
||||
js: bool = False
|
||||
depth: int = 1
|
||||
max_pages: int = 100
|
||||
max_time: int = 60
|
||||
|
||||
def _crawl(self, url):
|
||||
response = requests.get(url)
|
||||
|
@ -24,18 +24,19 @@ class CrawlWebsite(BaseModel):
|
|||
|
||||
def process(self):
|
||||
content = self._crawl(self.url)
|
||||
|
||||
## Create a file
|
||||
|
||||
# Create a file
|
||||
file_name = slugify(self.url) + ".html"
|
||||
temp_file_path = os.path.join(tempfile.gettempdir(), file_name)
|
||||
with open(temp_file_path, 'w') as temp_file:
|
||||
with open(temp_file_path, "w") as temp_file:
|
||||
temp_file.write(content)
|
||||
## Process the file
|
||||
|
||||
# Process the file
|
||||
|
||||
if content:
|
||||
return temp_file_path, file_name
|
||||
else:
|
||||
return None
|
||||
|
||||
def checkGithub(self):
|
||||
if "github.com" in self.url:
|
||||
return True
|
||||
|
@ -43,9 +44,8 @@ class CrawlWebsite(BaseModel):
|
|||
return False
|
||||
|
||||
|
||||
|
||||
def slugify(text):
|
||||
text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8')
|
||||
text = re.sub(r'[^\w\s-]', '', text).strip().lower()
|
||||
text = re.sub(r'[-\s]+', '-', text)
|
||||
return text
|
||||
text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("utf-8")
|
||||
text = re.sub(r"[^\w\s-]", "", text).strip().lower()
|
||||
text = re.sub(r"[-\s]+", "-", text)
|
||||
return text
|
||||
|
|
|
@ -8,11 +8,12 @@ logger = get_logger(__name__)
|
|||
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
openai.api_key = openai_api_key
|
||||
summary_llm = guidance.llms.OpenAI('gpt-3.5-turbo-0613', caching=False)
|
||||
summary_llm = guidance.llms.OpenAI("gpt-3.5-turbo-0613", caching=False)
|
||||
|
||||
|
||||
def llm_summerize(document):
|
||||
summary = guidance("""
|
||||
summary = guidance(
|
||||
"""
|
||||
{{#system~}}
|
||||
You are a world best summarizer. \n
|
||||
Condense the text, capturing essential points and core ideas. Include relevant \
|
||||
|
@ -28,21 +29,23 @@ Summarize the following text:
|
|||
{{#assistant~}}
|
||||
{{gen 'summarization' temperature=0.2 max_tokens=100}}
|
||||
{{/assistant~}}
|
||||
""", llm=summary_llm)
|
||||
""",
|
||||
llm=summary_llm,
|
||||
)
|
||||
|
||||
summary = summary(document=document)
|
||||
logger.info('Summarization: %s', summary)
|
||||
return summary['summarization']
|
||||
logger.info("Summarization: %s", summary)
|
||||
return summary["summarization"]
|
||||
|
||||
|
||||
def llm_evaluate_summaries(question, summaries, model):
|
||||
if not model.startswith('gpt'):
|
||||
logger.info(
|
||||
f'Model {model} not supported. Using gpt-3.5-turbo instead.')
|
||||
model = 'gpt-3.5-turbo-0613'
|
||||
logger.info(f'Evaluating summaries with {model}')
|
||||
if not model.startswith("gpt"):
|
||||
logger.info(f"Model {model} not supported. Using gpt-3.5-turbo instead.")
|
||||
model = "gpt-3.5-turbo-0613"
|
||||
logger.info(f"Evaluating summaries with {model}")
|
||||
evaluation_llm = guidance.llms.OpenAI(model, caching=False)
|
||||
evaluation = guidance("""
|
||||
evaluation = guidance(
|
||||
"""
|
||||
{{#system~}}
|
||||
You are a world best evaluator. You evaluate the relevance of summaries based \
|
||||
on user input question. Return evaluation in following csv format, csv headers \
|
||||
|
@ -73,23 +76,30 @@ Summary
|
|||
{{#assistant~}}
|
||||
{{gen 'evaluation' temperature=0.2 stop='<|im_end|>'}}
|
||||
{{/assistant~}}
|
||||
""", llm=evaluation_llm)
|
||||
""",
|
||||
llm=evaluation_llm,
|
||||
)
|
||||
result = evaluation(question=question, summaries=summaries)
|
||||
evaluations = {}
|
||||
for evaluation in result['evaluation'].split('\n'):
|
||||
if evaluation == '' or not evaluation[0].isdigit():
|
||||
for evaluation in result["evaluation"].split("\n"):
|
||||
if evaluation == "" or not evaluation[0].isdigit():
|
||||
continue
|
||||
logger.info('Evaluation Row: %s', evaluation)
|
||||
summary_id, document_id, score, *reason = evaluation.split(',')
|
||||
logger.info("Evaluation Row: %s", evaluation)
|
||||
summary_id, document_id, score, *reason = evaluation.split(",")
|
||||
if not score.isdigit():
|
||||
continue
|
||||
score = int(score)
|
||||
if score < 3 or score > 5:
|
||||
continue
|
||||
evaluations[summary_id] = {
|
||||
'evaluation': score,
|
||||
'reason': ','.join(reason),
|
||||
'summary_id': summary_id,
|
||||
'document_id': document_id,
|
||||
"evaluation": score,
|
||||
"reason": ",".join(reason),
|
||||
"summary_id": summary_id,
|
||||
"document_id": document_id,
|
||||
}
|
||||
return [e for e in sorted(evaluations.values(), key=lambda x: x['evaluation'], reverse=True)]
|
||||
return [
|
||||
e
|
||||
for e in sorted(
|
||||
evaluations.values(), key=lambda x: x["evaluation"], reverse=True
|
||||
)
|
||||
]
|
||||
|
|
|
@ -11,12 +11,13 @@ class Brain(BaseModel):
|
|||
model: str = "gpt-3.5-turbo-0613"
|
||||
temperature: float = 0.0
|
||||
max_tokens: int = 256
|
||||
|
||||
class BrainToUpdate(BaseModel):
|
||||
|
||||
|
||||
class BrainToUpdate(BaseModel):
|
||||
brain_id: UUID
|
||||
brain_name: Optional[str] = "New Brain"
|
||||
status: Optional[str] = "public"
|
||||
model: Optional[str] = "gpt-3.5-turbo-0613"
|
||||
temperature: Optional[float] = 0.0
|
||||
max_tokens: Optional[int] = 256
|
||||
file_sha1: Optional[str] = ''
|
||||
file_sha1: Optional[str] = ""
|
||||
|
|
|
@ -12,4 +12,9 @@ class ChatMessage(BaseModel):
|
|||
temperature: float = 0.0
|
||||
max_tokens: int = 256
|
||||
use_summarization: bool = False
|
||||
chat_id: Optional[UUID] = None,
|
||||
chat_id: Optional[UUID] = None
|
||||
chat_name: Optional[str] = None
|
||||
|
||||
|
||||
class ChatAttributes(BaseModel):
|
||||
chat_name: Optional[str] = None
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class User (BaseModel):
|
||||
class User(BaseModel):
|
||||
email: str
|
||||
|
|
|
@ -13,7 +13,15 @@ from utils.file import compute_sha1_from_content, compute_sha1_from_file
|
|||
from utils.vectors import Neurons, create_summary
|
||||
|
||||
|
||||
async def process_file(commons: CommonsDep, file: UploadFile, loader_class, file_suffix, enable_summarization, user, user_openai_api_key):
|
||||
async def process_file(
|
||||
commons: CommonsDep,
|
||||
file: UploadFile,
|
||||
loader_class,
|
||||
file_suffix,
|
||||
enable_summarization,
|
||||
user,
|
||||
user_openai_api_key,
|
||||
):
|
||||
documents = []
|
||||
file_name = file.filename
|
||||
file_size = file.file._file.tell() # Getting the size of the file
|
||||
|
@ -36,7 +44,8 @@ async def process_file(commons: CommonsDep, file: UploadFile, loader_class, file
|
|||
chunk_overlap = 0
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
documents = text_splitter.split_documents(documents)
|
||||
|
||||
|
@ -48,17 +57,19 @@ async def process_file(commons: CommonsDep, file: UploadFile, loader_class, file
|
|||
"chunk_size": chunk_size,
|
||||
"chunk_overlap": chunk_overlap,
|
||||
"date": dateshort,
|
||||
"summarization": "true" if enable_summarization else "false"
|
||||
"summarization": "true" if enable_summarization else "false",
|
||||
}
|
||||
doc_with_metadata = Document(
|
||||
page_content=doc.page_content, metadata=metadata)
|
||||
neurons = Neurons(commons=commons)
|
||||
neurons.create_vector(user.email, doc_with_metadata, user_openai_api_key)
|
||||
# add_usage(stats_db, "embedding", "audio", metadata={"file_name": file_meta_name,"file_type": ".txt", "chunk_size": chunk_size, "chunk_overlap": chunk_overlap})
|
||||
# add_usage(stats_db, "embedding", "audio", metadata={"file_name": file_meta_name,"file_type": ".txt", "chunk_size": chunk_size, "chunk_overlap": chunk_overlap})
|
||||
|
||||
# Remove the enable_summarization and ids
|
||||
# Remove the enable_summarization and ids
|
||||
if enable_summarization and ids and len(ids) > 0:
|
||||
create_summary(commons, document_id=ids[0], content = doc.page_content, metadata = metadata)
|
||||
create_summary(
|
||||
commons, document_id=ids[0], content=doc.page_content, metadata=metadata
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
|
@ -66,13 +77,24 @@ async def file_already_exists(supabase, file, user):
|
|||
# TODO: user brain id instead of user
|
||||
file_content = await file.read()
|
||||
file_sha1 = compute_sha1_from_content(file_content)
|
||||
response = supabase.table("vectors").select("id").filter("metadata->>file_sha1", "eq", file_sha1) \
|
||||
.filter("user_id", "eq", user.email).execute()
|
||||
response = (
|
||||
supabase.table("vectors")
|
||||
.select("id")
|
||||
.filter("metadata->>file_sha1", "eq", file_sha1)
|
||||
.filter("user_id", "eq", user.email)
|
||||
.execute()
|
||||
)
|
||||
return len(response.data) > 0
|
||||
|
||||
|
||||
async def file_already_exists_from_content(supabase, file_content, user):
|
||||
# TODO: user brain id instead of user
|
||||
# TODO: user brain id instead of user
|
||||
file_sha1 = compute_sha1_from_content(file_content)
|
||||
response = supabase.table("vectors").select("id").filter("metadata->>file_sha1", "eq", file_sha1) \
|
||||
.filter("user_id", "eq", user.email).execute()
|
||||
return len(response.data) > 0
|
||||
response = (
|
||||
supabase.table("vectors")
|
||||
.select("id")
|
||||
.filter("metadata->>file_sha1", "eq", file_sha1)
|
||||
.filter("user_id", "eq", user.email)
|
||||
.execute()
|
||||
)
|
||||
return len(response.data) > 0
|
||||
|
|
|
@ -5,5 +5,19 @@ from models.settings import CommonsDep
|
|||
from .common import process_file
|
||||
|
||||
|
||||
def process_csv(commons: CommonsDep, file: UploadFile, enable_summarization, user, user_openai_api_key):
|
||||
return process_file(commons, file, CSVLoader, ".csv", enable_summarization, user, user_openai_api_key)
|
||||
def process_csv(
|
||||
commons: CommonsDep,
|
||||
file: UploadFile,
|
||||
enable_summarization,
|
||||
user,
|
||||
user_openai_api_key,
|
||||
):
|
||||
return process_file(
|
||||
commons,
|
||||
file,
|
||||
CSVLoader,
|
||||
".csv",
|
||||
enable_summarization,
|
||||
user,
|
||||
user_openai_api_key,
|
||||
)
|
||||
|
|
|
@ -20,3 +20,5 @@ python-jose==3.3.0
|
|||
google_cloud_aiplatform==1.25.0
|
||||
transformers==4.30.1
|
||||
asyncpg==0.27.0
|
||||
flake8==6.0.0
|
||||
flake8-black==0.3.6
|
||||
|
|
|
@ -1,36 +1,60 @@
|
|||
import os
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from llm.brainpicking import BrainPicking
|
||||
from models.chats import ChatMessage
|
||||
from models.chats import ChatMessage, ChatAttributes
|
||||
from models.settings import CommonsDep, common_dependencies
|
||||
from models.users import User
|
||||
from utils.chats import (create_chat, get_chat_name_from_first_question,
|
||||
update_chat)
|
||||
from utils.users import (create_user, fetch_user_id_from_credentials,
|
||||
update_user_request_count)
|
||||
from http.client import HTTPException
|
||||
|
||||
chat_router = APIRouter()
|
||||
|
||||
|
||||
def get_user_chats(commons, user_id):
|
||||
response = commons['supabase'].from_('chats').select('chatId:chat_id, chatName:chat_name').filter("user_id", "eq", user_id).execute()
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.from_("chats")
|
||||
.select("chatId:chat_id, chatName:chat_name")
|
||||
.filter("user_id", "eq", user_id)
|
||||
.execute()
|
||||
)
|
||||
return response.data
|
||||
|
||||
|
||||
def get_chat_details(commons, chat_id):
|
||||
response = commons['supabase'].from_('chats').select('*').filter("chat_id", "eq", chat_id).execute()
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.from_("chats")
|
||||
.select("*")
|
||||
.filter("chat_id", "eq", chat_id)
|
||||
.execute()
|
||||
)
|
||||
return response.data
|
||||
|
||||
|
||||
def delete_chat_from_db(commons, chat_id):
|
||||
commons['supabase'].table("chats").delete().match({"chat_id": chat_id}).execute()
|
||||
commons["supabase"].table("chats").delete().match({"chat_id": chat_id}).execute()
|
||||
|
||||
|
||||
def fetch_user_stats(commons, user, date):
|
||||
response = commons['supabase'].from_('users').select('*').filter("email", "eq", user.email).filter("date", "eq", date).execute()
|
||||
response = (
|
||||
commons["supabase"]
|
||||
.from_("users")
|
||||
.select("*")
|
||||
.filter("email", "eq", user.email)
|
||||
.filter("date", "eq", date)
|
||||
.execute()
|
||||
)
|
||||
userItem = next(iter(response.data or []), {"requests_count": 0})
|
||||
return userItem
|
||||
|
||||
|
||||
# get all chats
|
||||
@chat_router.get("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def get_chats(current_user: User = Depends(get_current_user)):
|
||||
|
@ -48,9 +72,10 @@ async def get_chats(current_user: User = Depends(get_current_user)):
|
|||
chats = get_user_chats(commons, user_id)
|
||||
return {"chats": chats}
|
||||
|
||||
|
||||
# get one chat
|
||||
@chat_router.get("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def get_chats( chat_id: UUID):
|
||||
async def get_chat_handler(chat_id: UUID):
|
||||
"""
|
||||
Retrieve details of a specific chat by chat ID.
|
||||
|
||||
|
@ -63,13 +88,14 @@ async def get_chats( chat_id: UUID):
|
|||
commons = common_dependencies()
|
||||
chats = get_chat_details(commons, chat_id)
|
||||
if len(chats) > 0:
|
||||
return {"chatId": chat_id, "history": chats[0]['history']}
|
||||
return {"chatId": chat_id, "history": chats[0]["history"]}
|
||||
else:
|
||||
return {"error": "Chat not found"}
|
||||
|
||||
|
||||
# delete one chat
|
||||
@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def delete_chat( chat_id: UUID):
|
||||
async def delete_chat(chat_id: UUID):
|
||||
"""
|
||||
Delete a specific chat by chat ID.
|
||||
"""
|
||||
|
@ -77,29 +103,32 @@ async def delete_chat( chat_id: UUID):
|
|||
delete_chat_from_db(commons, chat_id)
|
||||
return {"message": f"{chat_id} has been deleted."}
|
||||
|
||||
|
||||
# helper method for update and create chat
|
||||
def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=False):
|
||||
date = time.strftime("%Y%m%d")
|
||||
user_id = fetch_user_id_from_credentials(commons, {"email": email})
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
user_openai_api_key = request.headers.get('Openai-Api-Key')
|
||||
user_openai_api_key = request.headers.get("Openai-Api-Key")
|
||||
|
||||
userItem = fetch_user_stats(commons, User(email=email), date)
|
||||
old_request_count = userItem['requests_count']
|
||||
old_request_count = userItem["requests_count"]
|
||||
|
||||
history = chat_message.history
|
||||
history.append(("user", chat_message.question))
|
||||
|
||||
|
||||
if old_request_count == 0:
|
||||
create_user(commons, email= email, date=date)
|
||||
else:
|
||||
update_user_request_count(commons,email, date, requests_count=old_request_count + 1)
|
||||
if user_openai_api_key is None and old_request_count >= float(max_requests_number):
|
||||
history.append(('assistant', "You have reached your requests limit"))
|
||||
update_chat(commons, chat_id=chat_id, history=history)
|
||||
return {"history": history}
|
||||
chat_name = chat_message.chat_name
|
||||
|
||||
if old_request_count == 0:
|
||||
create_user(commons, email=email, date=date)
|
||||
else:
|
||||
update_user_request_count(
|
||||
commons, email, date, requests_count=old_request_count + 1
|
||||
)
|
||||
if user_openai_api_key is None and old_request_count >= float(max_requests_number):
|
||||
history.append(("assistant", "You have reached your requests limit"))
|
||||
update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name)
|
||||
return {"history": history}
|
||||
|
||||
brainPicking = BrainPicking().init(chat_message.model, email)
|
||||
answer = brainPicking.generate_answer(chat_message, user_openai_api_key)
|
||||
|
@ -108,24 +137,58 @@ def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=Fal
|
|||
if is_new_chat:
|
||||
chat_name = get_chat_name_from_first_question(chat_message)
|
||||
new_chat = create_chat(commons, user_id, history, chat_name)
|
||||
chat_id = new_chat.data[0]['chat_id']
|
||||
chat_id = new_chat.data[0]["chat_id"]
|
||||
else:
|
||||
update_chat(commons, chat_id=chat_id, history=history)
|
||||
update_chat(commons, chat_id=chat_id, history=history, chat_name=chat_name)
|
||||
|
||||
return {"history": history, "chatId": chat_id}
|
||||
|
||||
|
||||
# update existing chat
|
||||
@chat_router.put("/chat/{chat_id}", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def chat_endpoint(request: Request, commons: CommonsDep, chat_id: UUID, chat_message: ChatMessage, current_user: User = Depends(get_current_user)):
|
||||
async def chat_endpoint(
|
||||
request: Request,
|
||||
commons: CommonsDep,
|
||||
chat_id: UUID,
|
||||
chat_message: ChatMessage,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update an existing chat with new chat messages.
|
||||
"""
|
||||
return chat_handler(request, commons, chat_id, chat_message, current_user.email)
|
||||
|
||||
|
||||
# update existing chat
|
||||
@chat_router.put("/chat/{chat_id}/metadata", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def update_chat_attributes_handler(
|
||||
commons: CommonsDep,
|
||||
chat_message: ChatAttributes,
|
||||
chat_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Update chat attributes
|
||||
"""
|
||||
|
||||
user_id = fetch_user_id_from_credentials(commons, {"email": current_user.email})
|
||||
chat = get_chat_details(commons, chat_id)[0]
|
||||
if user_id != chat.get('user_id'):
|
||||
raise HTTPException(status_code=403, detail="Chat not owned by user")
|
||||
return update_chat(commons=commons, chat_id=chat_id, chat_name=chat_message.chat_name)
|
||||
|
||||
|
||||
# create new chat
|
||||
@chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
|
||||
async def chat_endpoint(request: Request, commons: CommonsDep, chat_message: ChatMessage, current_user: User = Depends(get_current_user)):
|
||||
async def create_chat_handler(
|
||||
request: Request,
|
||||
commons: CommonsDep,
|
||||
chat_message: ChatMessage,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new chat with initial chat messages.
|
||||
"""
|
||||
return chat_handler(request, commons, None, chat_message, current_user.email, is_new_chat=True)
|
||||
return chat_handler(
|
||||
request, commons, None, chat_message, current_user.email, is_new_chat=True
|
||||
)
|
||||
|
|
|
@ -10,10 +10,12 @@ user_router = APIRouter()
|
|||
|
||||
MAX_BRAIN_SIZE_WITH_OWN_KEY = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
|
||||
|
||||
|
||||
def get_unique_documents(vectors):
|
||||
# Convert each dictionary to a tuple of items, then to a set to remove duplicates, and then back to a dictionary
|
||||
return [dict(t) for t in set(tuple(d.items()) for d in vectors)]
|
||||
|
||||
|
||||
def get_user_vectors(commons, email):
|
||||
# Access the supabase table and get the vectors
|
||||
user_vectors_response = commons['supabase'].table("vectors").select(
|
||||
|
@ -22,11 +24,13 @@ def get_user_vectors(commons, email):
|
|||
.execute()
|
||||
return user_vectors_response.data
|
||||
|
||||
|
||||
def get_user_request_stats(commons, email):
|
||||
requests_stats = commons['supabase'].from_('users').select(
|
||||
'*').filter("email", "eq", email).execute()
|
||||
return requests_stats.data
|
||||
|
||||
|
||||
@user_router.get("/user", dependencies=[Depends(AuthBearer())], tags=["User"])
|
||||
async def get_user_endpoint(request: Request, current_user: User = Depends(get_current_user)):
|
||||
"""
|
||||
|
|
|
@ -8,31 +8,45 @@ logger = get_logger(__name__)
|
|||
def create_chat(commons: CommonsDep, user_id, history, chat_name):
|
||||
# Chat is created upon the user's first question asked
|
||||
logger.info(f"New chat entry in chats table for user {user_id}")
|
||||
|
||||
|
||||
# Insert a new row into the chats table
|
||||
new_chat = {
|
||||
"user_id": user_id,
|
||||
"history": history, # Empty chat to start
|
||||
"chat_name": chat_name
|
||||
"history": history, # Empty chat to start
|
||||
"chat_name": chat_name,
|
||||
}
|
||||
insert_response = commons['supabase'].table('chats').insert(new_chat).execute()
|
||||
insert_response = commons["supabase"].table("chats").insert(new_chat).execute()
|
||||
logger.info(f"Insert response {insert_response.data}")
|
||||
|
||||
return(insert_response)
|
||||
return insert_response
|
||||
|
||||
def update_chat(commons: CommonsDep, chat_id, history):
|
||||
|
||||
def update_chat(commons: CommonsDep, chat_id, history=None, chat_name=None):
|
||||
if not chat_id:
|
||||
logger.error("No chat_id provided")
|
||||
return
|
||||
commons['supabase'].table("chats").update(
|
||||
{ "history": history}).match({"chat_id": chat_id}).execute()
|
||||
logger.info(f"Chat {chat_id} updated")
|
||||
|
||||
|
||||
def get_chat_name_from_first_question( chat_message: ChatMessage):
|
||||
# Step 1: Get the summary of the first question
|
||||
updates = {}
|
||||
|
||||
if history is not None:
|
||||
updates["history"] = history
|
||||
|
||||
if chat_name is not None:
|
||||
updates["chat_name"] = chat_name
|
||||
|
||||
if updates:
|
||||
commons["supabase"].table("chats").update(updates).match(
|
||||
{"chat_id": chat_id}
|
||||
).execute()
|
||||
logger.info(f"Chat {chat_id} updated")
|
||||
else:
|
||||
logger.info(f"No updates to apply for chat {chat_id}")
|
||||
|
||||
|
||||
def get_chat_name_from_first_question(chat_message: ChatMessage):
|
||||
# Step 1: Get the summary of the first question
|
||||
# first_question_summary = summarize_as_title(chat_message.question)
|
||||
# Step 2: Process this summary to create a chat name by selecting the first three words
|
||||
chat_name = ' '.join(chat_message.question.split()[:3])
|
||||
chat_name = " ".join(chat_message.question.split()[:3])
|
||||
|
||||
return chat_name
|
||||
|
|
|
@ -20,12 +20,11 @@ export default function ChatPage({ params }: ChatPageProps) {
|
|||
const { fetchChat, resetChat } = useChatsContext();
|
||||
|
||||
useEffect(() => {
|
||||
// if (chatId)
|
||||
if (!chatId) {
|
||||
resetChat();
|
||||
}
|
||||
fetchChat(chatId);
|
||||
}, [fetchChat, chatId]);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<main className="flex flex-col w-full pt-10">
|
||||
|
|
|
@ -2,24 +2,52 @@
|
|||
import { UUID } from "crypto";
|
||||
import Link from "next/link";
|
||||
import { usePathname } from "next/navigation";
|
||||
import { FiTrash2 } from "react-icons/fi";
|
||||
import { FiEdit, FiSave, FiTrash2 } from "react-icons/fi";
|
||||
import { MdChatBubbleOutline } from "react-icons/md";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { Chat } from "../../../../lib/types/Chat";
|
||||
import { useToast } from "@/lib/hooks/useToast";
|
||||
import { useAxios } from "@/lib/useAxios";
|
||||
import { useState } from "react";
|
||||
import { Chat, ChatResponse } from "../../../../../../lib/types/Chat";
|
||||
import { ChatName } from "./components/ChatName";
|
||||
|
||||
interface ChatsListItemProps {
|
||||
chat: Chat;
|
||||
deleteChat: (id: UUID) => void;
|
||||
}
|
||||
|
||||
const ChatsListItem = ({
|
||||
export const ChatsListItem = ({
|
||||
chat,
|
||||
deleteChat,
|
||||
}: ChatsListItemProps): JSX.Element => {
|
||||
const pathname = usePathname()?.split("/").at(-1);
|
||||
const selected = chat.chatId === pathname;
|
||||
const [chatName, setChatName] = useState(chat.chatName);
|
||||
const { axiosInstance } = useAxios();
|
||||
const {publish} = useToast()
|
||||
const [editingName, setEditingName] = useState(false);
|
||||
|
||||
const updateChatName = async () => {
|
||||
if(chatName !== chat.chatName) {
|
||||
await axiosInstance.put<ChatResponse>(`/chat/${chat.chatId}/metadata`, {
|
||||
chat_name:chatName,
|
||||
|
||||
});
|
||||
publish({text:'Chat name updated',variant:'success'})
|
||||
}
|
||||
}
|
||||
|
||||
const handleEditNameClick = () => {
|
||||
if(editingName){
|
||||
setEditingName(false) ;
|
||||
void updateChatName()
|
||||
}
|
||||
else {
|
||||
setEditingName(true)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
|
@ -36,15 +64,23 @@ const ChatsListItem = ({
|
|||
key={chat.chatId}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
|
||||
<MdChatBubbleOutline className="text-xl" />
|
||||
|
||||
<p className="min-w-0 flex-1 whitespace-nowrap">{chat.chatName}</p>
|
||||
<ChatName setName={setChatName} editing={editingName} name={chatName} />
|
||||
</div>
|
||||
<div className="grid-cols-2 text-xs opacity-50 whitespace-nowrap">
|
||||
{chat.chatId}
|
||||
</div>
|
||||
</Link>
|
||||
<div className="opacity-0 group-hover:opacity-100 flex items-center justify-center hover:text-red-700 bg-gradient-to-l from-white dark:from-black to-transparent z-10 transition-opacity">
|
||||
<button
|
||||
className="p-0"
|
||||
type="button"
|
||||
onClick={handleEditNameClick
|
||||
}
|
||||
>
|
||||
{editingName ? <FiSave/> : <FiEdit />}
|
||||
</button>
|
||||
<button
|
||||
className="p-5"
|
||||
type="button"
|
||||
|
@ -52,6 +88,7 @@ const ChatsListItem = ({
|
|||
>
|
||||
<FiTrash2 />
|
||||
</button>
|
||||
|
||||
</div>
|
||||
|
||||
{/* Fade to white */}
|
||||
|
@ -63,4 +100,4 @@ const ChatsListItem = ({
|
|||
);
|
||||
};
|
||||
|
||||
export default ChatsListItem;
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
|
||||
|
||||
interface ChatNameProps {
|
||||
name: string;
|
||||
editing?: boolean;
|
||||
setName: (name:string) => void;
|
||||
}
|
||||
|
||||
export const ChatName = ({setName,name,editing=false}:ChatNameProps):JSX.Element => {
|
||||
|
||||
if(editing) {
|
||||
return <input onChange={(event) => setName(event.target.value)} autoFocus value={name} />
|
||||
}
|
||||
|
||||
return (
|
||||
<p>{name}</p>
|
||||
)
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
export * from "./ChatsListItem";
|
|
@ -6,8 +6,8 @@ import { MotionConfig, motion } from "framer-motion";
|
|||
import { useState } from "react";
|
||||
import { MdChevronRight } from "react-icons/md";
|
||||
|
||||
import ChatsListItem from "./ChatsListItem";
|
||||
import { NewChatButton } from "./NewChatButton";
|
||||
import { ChatsListItem } from "./components/ChatsListItem/";
|
||||
|
||||
export const ChatsList = (): JSX.Element => {
|
||||
const { allChats, deleteChat } = useChatsContext();
|
||||
|
|
|
@ -24,7 +24,7 @@ export const Crawler = (): JSX.Element => {
|
|||
/>
|
||||
</div>
|
||||
<div className="flex flex-col items-center justify-center gap-5">
|
||||
<Button isLoading={isCrawling} onClick={crawlWebsite}>
|
||||
<Button isLoading={isCrawling} onClick={() => void crawlWebsite()}>
|
||||
Crawl
|
||||
</Button>
|
||||
</div>
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/* eslint-disable */
|
||||
import { UUID } from "crypto";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
import { useBrainConfig } from "@/lib/context/BrainConfigProvider/hooks/useBrainConfig";
|
||||
import { useToast } from "@/lib/hooks/useToast";
|
||||
|
@ -22,7 +22,7 @@ export default function useChats() {
|
|||
const router = useRouter();
|
||||
const { publish } = useToast();
|
||||
|
||||
const fetchAllChats = useCallback(async () => {
|
||||
const fetchAllChats = async () => {
|
||||
try {
|
||||
console.log("Fetching all chats");
|
||||
const response = await axiosInstance.get<{
|
||||
|
@ -37,29 +37,26 @@ export default function useChats() {
|
|||
text: "Error occured while fetching your chats",
|
||||
});
|
||||
}
|
||||
}, [axiosInstance, publish]);
|
||||
};
|
||||
|
||||
const fetchChat = useCallback(
|
||||
async (chatId?: UUID) => {
|
||||
if (!chatId) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
console.log(`Fetching chat ${chatId}`);
|
||||
const response = await axiosInstance.get<Chat>(`/chat/${chatId}`);
|
||||
console.log(response.data);
|
||||
const fetchChat = async (chatId?: UUID) => {
|
||||
if (!chatId) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
console.log(`Fetching chat ${chatId}`);
|
||||
const response = await axiosInstance.get<Chat>(`/chat/${chatId}`);
|
||||
console.log(response.data);
|
||||
|
||||
setChat(response.data);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
publish({
|
||||
variant: "danger",
|
||||
text: `Error occured while fetching ${chatId}`,
|
||||
});
|
||||
}
|
||||
},
|
||||
[axiosInstance, publish]
|
||||
);
|
||||
setChat(response.data);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
publish({
|
||||
variant: "danger",
|
||||
text: `Error occured while fetching ${chatId}`,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
type ChatResponse = Omit<Chat, "chatId"> & { chatId: UUID | undefined };
|
||||
|
||||
|
@ -156,7 +153,7 @@ export default function useChats() {
|
|||
|
||||
useEffect(() => {
|
||||
fetchAllChats();
|
||||
}, [fetchAllChats]);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
allChats,
|
||||
|
|
|
@ -8,3 +8,7 @@ export interface Chat {
|
|||
export type ChatMessage = [string, string];
|
||||
|
||||
export type ChatHistory = ChatMessage[];
|
||||
|
||||
export type ChatResponse = Omit<Chat, "chatId"> & {
|
||||
chatId: UUID | undefined;
|
||||
};
|
||||
|
|
|
@ -30,6 +30,9 @@
|
|||
"name": "next"
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"@/*": ["./*"],
|
||||
}
|
||||
},
|
||||
"ts-node": {
|
||||
"compilerOptions": {
|
||||
|
|
Loading…
Reference in New Issue
Block a user