mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
feat(refacto): removed duplicate lines
This commit is contained in:
parent
921d7e2502
commit
4174f0693a
40
files.py
40
files.py
@ -1,29 +1,34 @@
|
||||
import streamlit as st
|
||||
from loaders.audio import process_audio
|
||||
from loaders.audio import process_audio
|
||||
from loaders.txt import process_txt
|
||||
from loaders.csv import process_csv
|
||||
from loaders.markdown import process_markdown
|
||||
from utils import compute_sha1_from_content
|
||||
|
||||
def file_uploader(supabase, openai_key, vector_store):
|
||||
files = st.file_uploader("Upload a file", accept_multiple_files=True, type=["txt", "csv", "md", "m4a", "mp3", "webm", "mp4", "mpga", "wav", "mpeg"])
|
||||
file_processors = {
|
||||
".txt": process_txt,
|
||||
".csv": process_csv,
|
||||
".md": process_markdown,
|
||||
".m4a": process_audio,
|
||||
".mp3": process_audio,
|
||||
".webm": process_audio,
|
||||
".mp4": process_audio,
|
||||
".mpga": process_audio,
|
||||
".wav": process_audio,
|
||||
".mpeg": process_audio,
|
||||
}
|
||||
|
||||
files = st.file_uploader("Upload a file", accept_multiple_files=True, type=list(file_processors.keys()))
|
||||
if st.button("Add to Database"):
|
||||
if files is not None:
|
||||
for file in files:
|
||||
for file in files:
|
||||
if file_already_exists(supabase, file):
|
||||
st.write(f"😎 {file.name} is already in the database.")
|
||||
else:
|
||||
if file.name.endswith(".txt"):
|
||||
process_txt(vector_store, file)
|
||||
st.write(f"✅ {file.name} ")
|
||||
elif file.name.endswith((".m4a", ".mp3", ".webm", ".mp4", ".mpga", ".wav", ".mpeg")):
|
||||
process_audio(openai_key,vector_store, file)
|
||||
st.write(f"✅ {file.name} ")
|
||||
elif file.name.endswith(".csv"):
|
||||
process_csv(vector_store, file)
|
||||
st.write(f"✅ {file.name} ")
|
||||
elif file.name.endswith(".md"):
|
||||
process_markdown(vector_store, file)
|
||||
else:
|
||||
file_extension = os.path.splitext(file.name)[-1]
|
||||
if file_extension in file_processors:
|
||||
file_processors[file_extension](vector_store, file)
|
||||
st.write(f"✅ {file.name} ")
|
||||
else:
|
||||
st.write(f"❌ {file.name} is not a valid file type.")
|
||||
@ -31,7 +36,4 @@ def file_uploader(supabase, openai_key, vector_store):
|
||||
def file_already_exists(supabase, file):
|
||||
file_sha1 = compute_sha1_from_content(file.getvalue())
|
||||
response = supabase.table("documents").select("id").eq("metadata->>file_sha1", file_sha1).execute()
|
||||
if len(response.data) > 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return len(response.data) > 0
|
28
loaders/common.py
Normal file
28
loaders/common.py
Normal file
@ -0,0 +1,28 @@
|
||||
import tempfile
|
||||
from utils import compute_sha1_from_file
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
def process_file(vector_store, file, loader_class, file_suffix):
|
||||
documents = []
|
||||
file_sha = ""
|
||||
with tempfile.NamedTemporaryFile(delete=True, suffix=file_suffix) as tmp_file:
|
||||
tmp_file.write(file.getvalue())
|
||||
tmp_file.flush()
|
||||
|
||||
loader = loader_class(tmp_file.name)
|
||||
documents = loader.load()
|
||||
file_sha1 = compute_sha1_from_file(tmp_file.name)
|
||||
|
||||
chunk_size = st.session_state['chunk_size']
|
||||
chunk_overlap = st.session_state['chunk_overlap']
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
|
||||
documents = text_splitter.split_documents(documents)
|
||||
|
||||
# Add the document sha1 as metadata to each document
|
||||
docs_with_metadata = [Document(page_content=doc.page_content, metadata={"file_sha1": file_sha1}) for doc in documents]
|
||||
|
||||
vector_store.add_documents(docs_with_metadata)
|
||||
return
|
@ -1,34 +1,5 @@
|
||||
import tempfile
|
||||
|
||||
import streamlit as st
|
||||
from .common import process_file
|
||||
from langchain.document_loaders.csv_loader import CSVLoader
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from utils import compute_sha1_from_file
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def process_csv(vector_store, file):
|
||||
documents = []
|
||||
file_sha = ""
|
||||
with tempfile.NamedTemporaryFile(delete=True, suffix=".csv") as tmp_file:
|
||||
tmp_file.write(file.getvalue())
|
||||
tmp_file.flush()
|
||||
|
||||
loader = CSVLoader(tmp_file.name)
|
||||
documents = loader.load()
|
||||
file_sha1 = compute_sha1_from_file(tmp_file.name)
|
||||
|
||||
chunk_size = st.session_state['chunk_size']
|
||||
chunk_overlap = st.session_state['chunk_overlap']
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
|
||||
documents = text_splitter.split_documents(documents)
|
||||
# Add the document sha1 as metadata to each document
|
||||
docs_with_metadata = [Document(page_content=doc.page_content, metadata={"file_sha1": file_sha1}) for doc in documents]
|
||||
|
||||
|
||||
# We're using the default `documents` table here. You can modify this by passing in a `table_name` argument to the `from_documents` method.
|
||||
vector_store.add_documents(docs_with_metadata)
|
||||
return
|
||||
return process_file(vector_store, file, CSVLoader, ".csv")
|
@ -1,36 +1,5 @@
|
||||
import tempfile
|
||||
import streamlit as st
|
||||
from .common import process_file
|
||||
from langchain.document_loaders import UnstructuredMarkdownLoader
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from utils import compute_sha1_from_file
|
||||
|
||||
|
||||
def process_markdown(vector_store, file):
|
||||
documents = []
|
||||
file_sha = ""
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=True, suffix=".md") as tmp_file:
|
||||
tmp_file.write(file.getvalue())
|
||||
tmp_file.flush()
|
||||
|
||||
loader = UnstructuredMarkdownLoader(tmp_file.name)
|
||||
|
||||
documents = loader.load()
|
||||
file_sha1 = compute_sha1_from_file(tmp_file.name)
|
||||
|
||||
## Load chunk size and overlap from sidebar
|
||||
chunk_size = st.session_state['chunk_size']
|
||||
chunk_overlap = st.session_state['chunk_overlap']
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
|
||||
|
||||
documents = text_splitter.split_documents(documents)
|
||||
# Add the document sha1 as metadata to each document
|
||||
docs_with_metadata = [Document(page_content=doc.page_content, metadata={"file_sha1": file_sha1}) for doc in documents]
|
||||
|
||||
# We're using the default `documents` table here. You can modify this by passing in a `table_name` argument to the `from_documents` method.
|
||||
vector_store.add_documents(docs_with_metadata)
|
||||
return
|
||||
return process_file(vector_store, file, UnstructuredMarkdownLoader, ".md")
|
||||
|
@ -1,35 +1,5 @@
|
||||
import tempfile
|
||||
import streamlit as st
|
||||
from .common import process_file
|
||||
from langchain.document_loaders import TextLoader
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
|
||||
def process_txt(vector_store, file):
|
||||
documents = []
|
||||
file_sha = ""
|
||||
with tempfile.NamedTemporaryFile(delete=True, suffix=".txt") as tmp_file:
|
||||
tmp_file.write(file.getvalue())
|
||||
tmp_file.flush()
|
||||
|
||||
loader = TextLoader(tmp_file.name)
|
||||
documents = loader.load()
|
||||
file_sha1 = compute_sha1_from_file(tmp_file.name)
|
||||
|
||||
## Load chunk size and overlap from sidebar
|
||||
chunk_size = st.session_state['chunk_size']
|
||||
chunk_overlap = st.session_state['chunk_overlap']
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
# Add the document sha1 as metadata to each document
|
||||
docs_with_metadata = [Document(page_content=doc.page_content, metadata={"file_sha1": file_sha1}) for doc in docs]
|
||||
|
||||
|
||||
# We're using the default `documents` table here. You can modify this by passing in a `table_name` argument to the `from_documents` method.
|
||||
vector_store.add_documents(docs_with_metadata)
|
||||
return
|
||||
return process_file(vector_store, file, TextLoader, ".txt")
|
2
main.py
2
main.py
@ -13,7 +13,7 @@ from supabase import Client, create_client
|
||||
|
||||
# supabase_url = "https://fqgpcifsfmamprzldyiv.supabase.co"
|
||||
supabase_url = st.secrets.supabase_url
|
||||
supabase_key = st.secrets.supabase_key
|
||||
supabase_key = st.secrets.supabase_service_key
|
||||
openai_api_key = st.secrets.openai_api_key
|
||||
supabase: Client = create_client(supabase_url, supabase_key)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user