feat(refacto): removed duplicate lines

This commit is contained in:
Stan Girard 2023-05-12 23:22:21 +02:00
parent 921d7e2502
commit 4174f0693a
6 changed files with 56 additions and 116 deletions

View File

@ -1,29 +1,34 @@
import streamlit as st
from loaders.audio import process_audio
from loaders.audio import process_audio
from loaders.txt import process_txt
from loaders.csv import process_csv
from loaders.markdown import process_markdown
from utils import compute_sha1_from_content
def file_uploader(supabase, openai_key, vector_store):
files = st.file_uploader("Upload a file", accept_multiple_files=True, type=["txt", "csv", "md", "m4a", "mp3", "webm", "mp4", "mpga", "wav", "mpeg"])
file_processors = {
".txt": process_txt,
".csv": process_csv,
".md": process_markdown,
".m4a": process_audio,
".mp3": process_audio,
".webm": process_audio,
".mp4": process_audio,
".mpga": process_audio,
".wav": process_audio,
".mpeg": process_audio,
}
files = st.file_uploader("Upload a file", accept_multiple_files=True, type=list(file_processors.keys()))
if st.button("Add to Database"):
if files is not None:
for file in files:
for file in files:
if file_already_exists(supabase, file):
st.write(f"😎 {file.name} is already in the database.")
else:
if file.name.endswith(".txt"):
process_txt(vector_store, file)
st.write(f"{file.name} ")
elif file.name.endswith((".m4a", ".mp3", ".webm", ".mp4", ".mpga", ".wav", ".mpeg")):
process_audio(openai_key,vector_store, file)
st.write(f"{file.name} ")
elif file.name.endswith(".csv"):
process_csv(vector_store, file)
st.write(f"{file.name} ")
elif file.name.endswith(".md"):
process_markdown(vector_store, file)
else:
file_extension = os.path.splitext(file.name)[-1]
if file_extension in file_processors:
file_processors[file_extension](vector_store, file)
st.write(f"{file.name} ")
else:
st.write(f"{file.name} is not a valid file type.")
@ -31,7 +36,4 @@ def file_uploader(supabase, openai_key, vector_store):
def file_already_exists(supabase, file):
file_sha1 = compute_sha1_from_content(file.getvalue())
response = supabase.table("documents").select("id").eq("metadata->>file_sha1", file_sha1).execute()
if len(response.data) > 0:
return True
else:
return False
return len(response.data) > 0

28
loaders/common.py Normal file
View File

@ -0,0 +1,28 @@
import tempfile
from utils import compute_sha1_from_file
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
def process_file(vector_store, file, loader_class, file_suffix):
documents = []
file_sha = ""
with tempfile.NamedTemporaryFile(delete=True, suffix=file_suffix) as tmp_file:
tmp_file.write(file.getvalue())
tmp_file.flush()
loader = loader_class(tmp_file.name)
documents = loader.load()
file_sha1 = compute_sha1_from_file(tmp_file.name)
chunk_size = st.session_state['chunk_size']
chunk_overlap = st.session_state['chunk_overlap']
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
documents = text_splitter.split_documents(documents)
# Add the document sha1 as metadata to each document
docs_with_metadata = [Document(page_content=doc.page_content, metadata={"file_sha1": file_sha1}) for doc in documents]
vector_store.add_documents(docs_with_metadata)
return

View File

@ -1,34 +1,5 @@
import tempfile
import streamlit as st
from .common import process_file
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from utils import compute_sha1_from_file
from langchain.schema import Document
def process_csv(vector_store, file):
documents = []
file_sha = ""
with tempfile.NamedTemporaryFile(delete=True, suffix=".csv") as tmp_file:
tmp_file.write(file.getvalue())
tmp_file.flush()
loader = CSVLoader(tmp_file.name)
documents = loader.load()
file_sha1 = compute_sha1_from_file(tmp_file.name)
chunk_size = st.session_state['chunk_size']
chunk_overlap = st.session_state['chunk_overlap']
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
documents = text_splitter.split_documents(documents)
# Add the document sha1 as metadata to each document
docs_with_metadata = [Document(page_content=doc.page_content, metadata={"file_sha1": file_sha1}) for doc in documents]
# We're using the default `documents` table here. You can modify this by passing in a `table_name` argument to the `from_documents` method.
vector_store.add_documents(docs_with_metadata)
return
return process_file(vector_store, file, CSVLoader, ".csv")

View File

@ -1,36 +1,5 @@
import tempfile
import streamlit as st
from .common import process_file
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from utils import compute_sha1_from_file
def process_markdown(vector_store, file):
documents = []
file_sha = ""
with tempfile.NamedTemporaryFile(delete=True, suffix=".md") as tmp_file:
tmp_file.write(file.getvalue())
tmp_file.flush()
loader = UnstructuredMarkdownLoader(tmp_file.name)
documents = loader.load()
file_sha1 = compute_sha1_from_file(tmp_file.name)
## Load chunk size and overlap from sidebar
chunk_size = st.session_state['chunk_size']
chunk_overlap = st.session_state['chunk_overlap']
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
documents = text_splitter.split_documents(documents)
# Add the document sha1 as metadata to each document
docs_with_metadata = [Document(page_content=doc.page_content, metadata={"file_sha1": file_sha1}) for doc in documents]
# We're using the default `documents` table here. You can modify this by passing in a `table_name` argument to the `from_documents` method.
vector_store.add_documents(docs_with_metadata)
return
return process_file(vector_store, file, UnstructuredMarkdownLoader, ".md")

View File

@ -1,35 +1,5 @@
import tempfile
import streamlit as st
from .common import process_file
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
def process_txt(vector_store, file):
documents = []
file_sha = ""
with tempfile.NamedTemporaryFile(delete=True, suffix=".txt") as tmp_file:
tmp_file.write(file.getvalue())
tmp_file.flush()
loader = TextLoader(tmp_file.name)
documents = loader.load()
file_sha1 = compute_sha1_from_file(tmp_file.name)
## Load chunk size and overlap from sidebar
chunk_size = st.session_state['chunk_size']
chunk_overlap = st.session_state['chunk_overlap']
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs = text_splitter.split_documents(documents)
# Add the document sha1 as metadata to each document
docs_with_metadata = [Document(page_content=doc.page_content, metadata={"file_sha1": file_sha1}) for doc in docs]
# We're using the default `documents` table here. You can modify this by passing in a `table_name` argument to the `from_documents` method.
vector_store.add_documents(docs_with_metadata)
return
return process_file(vector_store, file, TextLoader, ".txt")

View File

@ -13,7 +13,7 @@ from supabase import Client, create_client
# supabase_url = "https://fqgpcifsfmamprzldyiv.supabase.co"
supabase_url = st.secrets.supabase_url
supabase_key = st.secrets.supabase_key
supabase_key = st.secrets.supabase_service_key
openai_api_key = st.secrets.openai_api_key
supabase: Client = create_client(supabase_url, supabase_key)