quivr/backend/utils.py
2023-05-21 23:39:55 -07:00

89 lines
2.7 KiB
Python

import hashlib
import os
from typing import Annotated, List, Tuple
from fastapi import Depends
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import SupabaseVectorStore
from pydantic import BaseModel
from supabase import create_client, Client
from langchain.schema import Document
from llm.summarization import llm_summerize
from logger import get_logger
logger = get_logger(__name__)
openai_api_key = os.environ.get("OPENAI_API_KEY")
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
supabase_client: Client = create_client(supabase_url, supabase_key)
documents_vector_store = SupabaseVectorStore(
supabase_client, embeddings, table_name="documents")
summaries_vector_store = SupabaseVectorStore(
supabase_client, embeddings, table_name="summaries")
def compute_sha1_from_file(file_path):
with open(file_path, "rb") as file:
bytes = file.read()
readable_hash = compute_sha1_from_content(bytes)
return readable_hash
def compute_sha1_from_content(content):
readable_hash = hashlib.sha1(content).hexdigest()
return readable_hash
def common_dependencies():
return {
"supabase": supabase_client,
"embeddings": embeddings,
"documents_vector_store": documents_vector_store,
"summaries_vector_store": summaries_vector_store
}
CommonsDep = Annotated[dict, Depends(common_dependencies)]
class ChatMessage(BaseModel):
model: str = "gpt-3.5-turbo"
question: str
# A list of tuples where each tuple is (speaker, text)
history: List[Tuple[str, str]]
temperature: float = 0.0
max_tokens: int = 256
use_summarization: bool = False
def create_summary(document_id, content, metadata):
logger.info(f"Summarizing document {content[:100]}")
summary = llm_summerize(content)
logger.info(f"Summary: {summary}")
metadata['document_id'] = document_id
summary_doc_with_metadata = Document(
page_content=summary, metadata=metadata)
sids = summaries_vector_store.add_documents(
[summary_doc_with_metadata])
if sids and len(sids) > 0:
supabase_client.table("summaries").update(
{"document_id": document_id}).match({"id": sids[0]}).execute()
def create_embedding(content):
return embeddings.embed_query(content)
def similarity_search(query, table='match_summaries', top_k=5, threshold=0.5):
query_embedding = create_embedding(query)
summaries = supabase_client.rpc(
table, {'query_embedding': query_embedding,
'match_count': top_k, 'match_threshold': threshold}
).execute()
return summaries.data