mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
This reverts commit c746eb1830
.
This commit is contained in:
parent
015f12bb4c
commit
cbc8ac4946
@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from models.settings import get_supabase_db
|
||||
@ -14,7 +13,7 @@ async def verify_api_key(
|
||||
# Use UTC time to avoid timezone issues
|
||||
current_date = datetime.utcnow().date()
|
||||
supabase_db = get_supabase_db()
|
||||
result = supabase_db.get_active_api_key(UUID(api_key))
|
||||
result = supabase_db.get_active_api_key(api_key)
|
||||
|
||||
if result.data is not None and len(result.data) > 0:
|
||||
api_key_creation_date = datetime.strptime(
|
||||
@ -37,7 +36,7 @@ async def get_user_from_api_key(
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
# Lookup the user_id from the api_keys table
|
||||
user_id_data = supabase_db.get_user_id_by_api_key(UUID(api_key))
|
||||
user_id_data = supabase_db.get_user_id_by_api_key(api_key)
|
||||
|
||||
if not user_id_data.data:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key.")
|
||||
|
@ -1,10 +1,8 @@
|
||||
import os
|
||||
|
||||
if __name__ == "__main__":
|
||||
# import needed here when running main.py to debug backend
|
||||
# you will need to run pip install python-dotenv
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import sentry_sdk
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
@ -12,8 +10,8 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from logger import get_logger
|
||||
from middlewares.cors import add_cors_middleware
|
||||
from routes.chat_routes import chat_router
|
||||
from routes.misc_routes import misc_router
|
||||
from routes.chat_routes import chat_router
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -29,10 +27,12 @@ app = FastAPI()
|
||||
add_cors_middleware(app)
|
||||
|
||||
|
||||
|
||||
app.include_router(chat_router)
|
||||
app.include_router(misc_router)
|
||||
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(_, exc):
|
||||
return JSONResponse(
|
||||
@ -64,5 +64,5 @@ handle_request_validation_error(app)
|
||||
if __name__ == "__main__":
|
||||
# run main.py to debug backend
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=5050)
|
||||
|
||||
|
@ -6,17 +6,18 @@ from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from newspaper import Article
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
class CrawlWebsite(BaseModel):
|
||||
url: str
|
||||
js: bool = False
|
||||
depth: int = int(os.getenv("CRAWL_DEPTH", "1"))
|
||||
depth: int = int(os.getenv("CRAWL_DEPTH","1"))
|
||||
max_pages: int = 100
|
||||
max_time: int = 60
|
||||
|
||||
def _crawl(self, url):
|
||||
try:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return response.text
|
||||
@ -32,7 +33,7 @@ class CrawlWebsite(BaseModel):
|
||||
article.download()
|
||||
article.parse()
|
||||
except Exception as e:
|
||||
print(f"Error downloading or parsing article: {e}")
|
||||
print(f'Error downloading or parsing article: {e}')
|
||||
return None
|
||||
return article.text
|
||||
|
||||
@ -48,13 +49,13 @@ class CrawlWebsite(BaseModel):
|
||||
if not raw_html:
|
||||
return content
|
||||
|
||||
soup = BeautifulSoup(raw_html, "html.parser")
|
||||
links = [a["href"] for a in soup.find_all("a", href=True)]
|
||||
soup = BeautifulSoup(raw_html, 'html.parser')
|
||||
links = [a['href'] for a in soup.find_all('a', href=True)]
|
||||
for link in links:
|
||||
full_url = urljoin(url, link)
|
||||
# Ensure we're staying on the same domain
|
||||
if self.url in full_url:
|
||||
content += self._process_recursive(full_url, depth - 1, visited_urls)
|
||||
content += self._process_recursive(full_url, depth-1, visited_urls)
|
||||
|
||||
return content
|
||||
|
||||
@ -72,8 +73,7 @@ class CrawlWebsite(BaseModel):
|
||||
return temp_file_path, file_name
|
||||
|
||||
def checkGithub(self):
|
||||
return "github.com" in self.url
|
||||
|
||||
return 'github.com' in self.url
|
||||
|
||||
def slugify(text):
|
||||
text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("utf-8")
|
||||
|
@ -43,3 +43,5 @@ class OpenAIBrainPicking(QABaseBrainPicking):
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_key=self.openai_api_key
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
|
||||
|
@ -7,7 +7,6 @@ from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
@ -43,10 +42,9 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
Each have the same prompt template, which is defined in the `prompt_template` property.
|
||||
"""
|
||||
|
||||
supabase_client: Client
|
||||
vector_store: CustomSupabaseVectorStore
|
||||
qa: ConversationalRetrievalChain
|
||||
embeddings: Embeddings
|
||||
supabase_client: Client = None
|
||||
vector_store: CustomSupabaseVectorStore = None
|
||||
qa: ConversationalRetrievalChain = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -55,7 +53,7 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
chat_id: str,
|
||||
streaming: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> "QABaseBrainPicking":
|
||||
super().__init__(
|
||||
model=model,
|
||||
brain_id=brain_id,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from secrets import token_hex
|
||||
from typing import List
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
from asyncpg.exceptions import UniqueViolationError
|
||||
from auth import AuthBearer, get_current_user
|
||||
@ -79,7 +79,7 @@ async def delete_api_key(key_id: str, current_user: User = Depends(get_current_u
|
||||
|
||||
"""
|
||||
supabase_db = get_supabase_db()
|
||||
supabase_db.delete_api_key(UUID(key_id), current_user.id)
|
||||
supabase_db.delete_api_key(key_id, current_user.id)
|
||||
|
||||
return {"message": "API key deleted."}
|
||||
|
||||
|
@ -7,10 +7,10 @@ from venv import logger
|
||||
from auth import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llm.openai import OpenAIBrainPicking
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from models.brain_entity import BrainEntity
|
||||
from llm.openai import OpenAIBrainPicking
|
||||
from models.brains import Brain
|
||||
from models.brain_entity import BrainEntity
|
||||
from models.chat import Chat
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.supabase import SupabaseDB
|
||||
@ -60,7 +60,7 @@ def check_user_limit(
|
||||
):
|
||||
if user.user_openai_api_key is None:
|
||||
date = time.strftime("%Y%m%d")
|
||||
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1000))
|
||||
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1))
|
||||
|
||||
user.increment_user_request_count(date)
|
||||
if int(user.requests_count) >= int(max_requests_number):
|
||||
@ -238,7 +238,7 @@ async def create_stream_question_handler(
|
||||
# Retrieve user's OpenAI API key
|
||||
current_user.user_openai_api_key = request.headers.get("Openai-Api-Key")
|
||||
brain = Brain(id=brain_id)
|
||||
brain_details: BrainEntity | None = None
|
||||
brain_details: BrainEntity = None
|
||||
if not current_user.user_openai_api_key and brain_id:
|
||||
brain_details = get_brain_details(brain_id)
|
||||
if brain_details:
|
||||
@ -268,30 +268,18 @@ async def create_stream_question_handler(
|
||||
if brain_id:
|
||||
gpt_answer_generator = OpenAIBrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
model=(brain_details or chat_question).model
|
||||
if current_user.user_openai_api_key
|
||||
else "gpt-3.5-turbo",
|
||||
max_tokens=(brain_details or chat_question).max_tokens
|
||||
if current_user.user_openai_api_key
|
||||
else 0,
|
||||
temperature=(brain_details or chat_question).temperature
|
||||
if current_user.user_openai_api_key
|
||||
else 256,
|
||||
model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo",
|
||||
max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0,
|
||||
temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256,
|
||||
brain_id=str(brain_id),
|
||||
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
streaming=True,
|
||||
)
|
||||
else:
|
||||
gpt_answer_generator = HeadlessQA(
|
||||
model=chat_question.model
|
||||
if current_user.user_openai_api_key
|
||||
else "gpt-3.5-turbo",
|
||||
temperature=chat_question.temperature
|
||||
if current_user.user_openai_api_key
|
||||
else 256,
|
||||
max_tokens=chat_question.max_tokens
|
||||
if current_user.user_openai_api_key
|
||||
else 0,
|
||||
model=chat_question.model if current_user.user_openai_api_key else "gpt-3.5-turbo",
|
||||
temperature=chat_question.temperature if current_user.user_openai_api_key else 256,
|
||||
max_tokens=chat_question.max_tokens if current_user.user_openai_api_key else 0,
|
||||
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
chat_id=str(chat_id),
|
||||
streaming=True,
|
||||
|
@ -32,7 +32,9 @@ async def crawl_endpoint(
|
||||
brain = Brain(id=brain_id)
|
||||
|
||||
if request.headers.get("Openai-Api-Key"):
|
||||
brain.max_brain_size = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
|
||||
brain.max_brain_size = os.getenv(
|
||||
"MAX_BRAIN_SIZE_WITH_KEY", 209715200
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
file_size = 1000000
|
||||
remaining_free_space = brain.remaining_brain_size
|
||||
|
@ -79,6 +79,13 @@ def test_upload_explore_and_delete_file_txt(client, api_key):
|
||||
headers={"Authorization": "Bearer " + api_key},
|
||||
)
|
||||
|
||||
# Commenting out this test out because it is not working since a moment (investigating).
|
||||
# However, since all PRs were failing, backend tests were starting to get abandoned, which introduced new bugs.
|
||||
|
||||
"""
|
||||
# Assert that the explore response status code is 200 (HTTP OK)
|
||||
assert explore_response.status_code == 200
|
||||
|
||||
# Delete the file
|
||||
delete_response = client.delete(
|
||||
f"/explore/{file_name}",
|
||||
@ -92,6 +99,7 @@ def test_upload_explore_and_delete_file_txt(client, api_key):
|
||||
# Optionally, you can assert on specific fields in the delete response data
|
||||
delete_response_data = delete_response.json()
|
||||
assert "message" in delete_response_data
|
||||
"""
|
||||
|
||||
|
||||
def test_upload_explore_and_delete_file_pdf(client, api_key):
|
||||
@ -187,6 +195,13 @@ def test_upload_explore_and_delete_file_csv(client, api_key):
|
||||
headers={"Authorization": "Bearer " + api_key},
|
||||
)
|
||||
|
||||
# Commenting out this test out because it is not working since a moment (investigating).
|
||||
# However, since all PRs were failing, backend tests were starting to get abandoned, which introduced new bugs.
|
||||
|
||||
"""
|
||||
# Assert that the explore response status code is 200 (HTTP OK)
|
||||
assert explore_response.status_code == 200
|
||||
|
||||
# Delete the file
|
||||
delete_response = client.delete(
|
||||
f"/explore/{file_name}",
|
||||
@ -200,3 +215,4 @@ def test_upload_explore_and_delete_file_csv(client, api_key):
|
||||
# Optionally, you can assert on specific fields in the delete response data
|
||||
delete_response_data = delete_response.json()
|
||||
assert "message" in delete_response_data
|
||||
"""
|
||||
|
@ -1,6 +1,5 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from logger import get_logger
|
||||
@ -49,7 +48,7 @@ def process_batch(batch_ids: List[str]):
|
||||
|
||||
try:
|
||||
if len(batch_ids) == 1:
|
||||
return (supabase_db.get_vectors_by_batch(UUID(batch_ids[0]))).data
|
||||
return (supabase_db.get_vectors_by_batch(batch_ids[0])).data
|
||||
else:
|
||||
return (supabase_db.get_vectors_in_batch(batch_ids)).data
|
||||
except Exception as e:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import SupabaseVectorStore
|
||||
from supabase.client import Client
|
||||
|
||||
@ -14,7 +14,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
client: Client,
|
||||
embedding: Embeddings,
|
||||
embedding: OpenAIEmbeddings,
|
||||
table_name: str,
|
||||
brain_id: str = "none",
|
||||
):
|
||||
@ -29,6 +29,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
|
||||
threshold: float = 0.5,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
|
||||
vectors = self._embedding.embed_documents([query])
|
||||
query_embedding = vectors[0]
|
||||
res = self._client.rpc(
|
||||
|
Loading…
Reference in New Issue
Block a user