Revert "fix: update backend tests (#975)" (#983)

This reverts commit c746eb1830.
This commit is contained in:
Stan Girard 2023-08-19 12:31:15 +02:00 committed by GitHub
parent 015f12bb4c
commit cbc8ac4946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 56 additions and 51 deletions

View File

@ -1,5 +1,4 @@
from datetime import datetime
from uuid import UUID
from fastapi import HTTPException
from models.settings import get_supabase_db
@ -14,7 +13,7 @@ async def verify_api_key(
# Use UTC time to avoid timezone issues
current_date = datetime.utcnow().date()
supabase_db = get_supabase_db()
result = supabase_db.get_active_api_key(UUID(api_key))
result = supabase_db.get_active_api_key(api_key)
if result.data is not None and len(result.data) > 0:
api_key_creation_date = datetime.strptime(
@ -37,7 +36,7 @@ async def get_user_from_api_key(
supabase_db = get_supabase_db()
# Lookup the user_id from the api_keys table
user_id_data = supabase_db.get_user_id_by_api_key(UUID(api_key))
user_id_data = supabase_db.get_user_id_by_api_key(api_key)
if not user_id_data.data:
raise HTTPException(status_code=400, detail="Invalid API key.")

View File

@ -1,10 +1,8 @@
import os
if __name__ == "__main__":
# import needed here when running main.py to debug backend
# you will need to run pip install python-dotenv
from dotenv import load_dotenv
load_dotenv()
import sentry_sdk
from fastapi import FastAPI, HTTPException, Request, status
@ -12,8 +10,8 @@ from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from logger import get_logger
from middlewares.cors import add_cors_middleware
from routes.chat_routes import chat_router
from routes.misc_routes import misc_router
from routes.chat_routes import chat_router
logger = get_logger(__name__)
@ -29,10 +27,12 @@ app = FastAPI()
add_cors_middleware(app)
app.include_router(chat_router)
app.include_router(misc_router)
@app.exception_handler(HTTPException)
async def http_exception_handler(_, exc):
return JSONResponse(
@ -64,5 +64,5 @@ handle_request_validation_error(app)
if __name__ == "__main__":
# run main.py to debug backend
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5050)

View File

@ -6,17 +6,18 @@ from urllib.parse import urljoin
import requests
from pydantic import BaseModel
from newspaper import Article
from bs4 import BeautifulSoup
class CrawlWebsite(BaseModel):
url: str
js: bool = False
depth: int = int(os.getenv("CRAWL_DEPTH", "1"))
depth: int = int(os.getenv("CRAWL_DEPTH","1"))
max_pages: int = 100
max_time: int = 60
def _crawl(self, url):
try:
try:
response = requests.get(url)
if response.status_code == 200:
return response.text
@ -32,7 +33,7 @@ class CrawlWebsite(BaseModel):
article.download()
article.parse()
except Exception as e:
print(f"Error downloading or parsing article: {e}")
print(f'Error downloading or parsing article: {e}')
return None
return article.text
@ -48,13 +49,13 @@ class CrawlWebsite(BaseModel):
if not raw_html:
return content
soup = BeautifulSoup(raw_html, "html.parser")
links = [a["href"] for a in soup.find_all("a", href=True)]
soup = BeautifulSoup(raw_html, 'html.parser')
links = [a['href'] for a in soup.find_all('a', href=True)]
for link in links:
full_url = urljoin(url, link)
# Ensure we're staying on the same domain
if self.url in full_url:
content += self._process_recursive(full_url, depth - 1, visited_urls)
content += self._process_recursive(full_url, depth-1, visited_urls)
return content
@ -72,8 +73,7 @@ class CrawlWebsite(BaseModel):
return temp_file_path, file_name
def checkGithub(self):
return "github.com" in self.url
return 'github.com' in self.url
def slugify(text):
text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("utf-8")

View File

@ -43,3 +43,5 @@ class OpenAIBrainPicking(QABaseBrainPicking):
return OpenAIEmbeddings(
openai_api_key=self.openai_api_key
) # pyright: ignore reportPrivateUsage=none

View File

@ -7,7 +7,6 @@ from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
@ -43,10 +42,9 @@ class QABaseBrainPicking(BaseBrainPicking):
Each have the same prompt template, which is defined in the `prompt_template` property.
"""
supabase_client: Client
vector_store: CustomSupabaseVectorStore
qa: ConversationalRetrievalChain
embeddings: Embeddings
supabase_client: Client = None
vector_store: CustomSupabaseVectorStore = None
qa: ConversationalRetrievalChain = None
def __init__(
self,
@ -55,7 +53,7 @@ class QABaseBrainPicking(BaseBrainPicking):
chat_id: str,
streaming: bool = False,
**kwargs,
):
) -> "QABaseBrainPicking":
super().__init__(
model=model,
brain_id=brain_id,

View File

@ -1,6 +1,6 @@
from secrets import token_hex
from typing import List
from uuid import UUID, uuid4
from uuid import uuid4
from asyncpg.exceptions import UniqueViolationError
from auth import AuthBearer, get_current_user
@ -79,7 +79,7 @@ async def delete_api_key(key_id: str, current_user: User = Depends(get_current_u
"""
supabase_db = get_supabase_db()
supabase_db.delete_api_key(UUID(key_id), current_user.id)
supabase_db.delete_api_key(key_id, current_user.id)
return {"message": "API key deleted."}

View File

@ -7,10 +7,10 @@ from venv import logger
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from llm.openai import OpenAIBrainPicking
from llm.qa_headless import HeadlessQA
from models.brain_entity import BrainEntity
from llm.openai import OpenAIBrainPicking
from models.brains import Brain
from models.brain_entity import BrainEntity
from models.chat import Chat
from models.chats import ChatQuestion
from models.databases.supabase.supabase import SupabaseDB
@ -60,7 +60,7 @@ def check_user_limit(
):
if user.user_openai_api_key is None:
date = time.strftime("%Y%m%d")
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1000))
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1))
user.increment_user_request_count(date)
if int(user.requests_count) >= int(max_requests_number):
@ -238,7 +238,7 @@ async def create_stream_question_handler(
# Retrieve user's OpenAI API key
current_user.user_openai_api_key = request.headers.get("Openai-Api-Key")
brain = Brain(id=brain_id)
brain_details: BrainEntity | None = None
brain_details: BrainEntity = None
if not current_user.user_openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
if brain_details:
@ -268,30 +268,18 @@ async def create_stream_question_handler(
if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=(brain_details or chat_question).model
if current_user.user_openai_api_key
else "gpt-3.5-turbo",
max_tokens=(brain_details or chat_question).max_tokens
if current_user.user_openai_api_key
else 0,
temperature=(brain_details or chat_question).temperature
if current_user.user_openai_api_key
else 256,
model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo",
max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0,
temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256,
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
)
else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model
if current_user.user_openai_api_key
else "gpt-3.5-turbo",
temperature=chat_question.temperature
if current_user.user_openai_api_key
else 256,
max_tokens=chat_question.max_tokens
if current_user.user_openai_api_key
else 0,
model=chat_question.model if current_user.user_openai_api_key else "gpt-3.5-turbo",
temperature=chat_question.temperature if current_user.user_openai_api_key else 256,
max_tokens=chat_question.max_tokens if current_user.user_openai_api_key else 0,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
streaming=True,

View File

@ -32,7 +32,9 @@ async def crawl_endpoint(
brain = Brain(id=brain_id)
if request.headers.get("Openai-Api-Key"):
brain.max_brain_size = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
brain.max_brain_size = os.getenv(
"MAX_BRAIN_SIZE_WITH_KEY", 209715200
) # pyright: ignore reportPrivateUsage=none
file_size = 1000000
remaining_free_space = brain.remaining_brain_size

View File

@ -79,6 +79,13 @@ def test_upload_explore_and_delete_file_txt(client, api_key):
headers={"Authorization": "Bearer " + api_key},
)
# Commenting out this test out because it is not working since a moment (investigating).
# However, since all PRs were failing, backend tests were starting to get abandoned, which introduced new bugs.
"""
# Assert that the explore response status code is 200 (HTTP OK)
assert explore_response.status_code == 200
# Delete the file
delete_response = client.delete(
f"/explore/{file_name}",
@ -92,6 +99,7 @@ def test_upload_explore_and_delete_file_txt(client, api_key):
# Optionally, you can assert on specific fields in the delete response data
delete_response_data = delete_response.json()
assert "message" in delete_response_data
"""
def test_upload_explore_and_delete_file_pdf(client, api_key):
@ -187,6 +195,13 @@ def test_upload_explore_and_delete_file_csv(client, api_key):
headers={"Authorization": "Bearer " + api_key},
)
# Commenting out this test out because it is not working since a moment (investigating).
# However, since all PRs were failing, backend tests were starting to get abandoned, which introduced new bugs.
"""
# Assert that the explore response status code is 200 (HTTP OK)
assert explore_response.status_code == 200
# Delete the file
delete_response = client.delete(
f"/explore/{file_name}",
@ -200,3 +215,4 @@ def test_upload_explore_and_delete_file_csv(client, api_key):
# Optionally, you can assert on specific fields in the delete response data
delete_response_data = delete_response.json()
assert "message" in delete_response_data
"""

View File

@ -1,6 +1,5 @@
from concurrent.futures import ThreadPoolExecutor
from typing import List
from uuid import UUID
from langchain.embeddings.openai import OpenAIEmbeddings
from logger import get_logger
@ -49,7 +48,7 @@ def process_batch(batch_ids: List[str]):
try:
if len(batch_ids) == 1:
return (supabase_db.get_vectors_by_batch(UUID(batch_ids[0]))).data
return (supabase_db.get_vectors_by_batch(batch_ids[0])).data
else:
return (supabase_db.get_vectors_in_batch(batch_ids)).data
except Exception as e:

View File

@ -1,7 +1,7 @@
from typing import Any, List
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import SupabaseVectorStore
from supabase.client import Client
@ -14,7 +14,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
def __init__(
self,
client: Client,
embedding: Embeddings,
embedding: OpenAIEmbeddings,
table_name: str,
brain_id: str = "none",
):
@ -29,6 +29,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
threshold: float = 0.5,
**kwargs: Any
) -> List[Document]:
vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
res = self._client.rpc(