diff --git a/files.py b/files.py index 3d4661485..ab999fd62 100644 --- a/files.py +++ b/files.py @@ -61,7 +61,7 @@ def filter_file(file, supabase, vector_store): return False def url_uploader(supabase, openai_key, vector_store): - url = st.text_area("## Add an url",placeholder="https://www.quivr.app") + url = st.text_area("**Add an url**",placeholder="https://www.quivr.app") button = st.button("Add the URL to the database") if button: html = get_html(url) diff --git a/question.py b/question.py index c3e13f96d..3aa270499 100644 --- a/question.py +++ b/question.py @@ -22,26 +22,53 @@ def count_tokens(question, model): def chat_with_doc(model, vector_store: SupabaseVectorStore): + + if 'chat_history' not in st.session_state: + st.session_state['chat_history'] = [] + + question = st.text_area("## Ask a question") - button = st.button("Ask") - count_button = st.button("Count Tokens", type='secondary') + columns = st.columns(3) + with columns[0]: + button = st.button("Ask") + with columns[1]: + count_button = st.button("Count Tokens", type='secondary') + with columns[2]: + clear_history = st.button("Clear History", type='secondary') + + for speaker, text in st.session_state['chat_history']: + st.markdown(f"**{speaker}:** {text}") + + if clear_history: + st.session_state['chat_history'] = [] + st.experimental_rerun() + if button: + qa = None if model.startswith("gpt"): logger.info('Using OpenAI model %s', model) qa = ConversationalRetrievalChain.from_llm( OpenAI( model_name=st.session_state['model'], openai_api_key=openai_api_key, temperature=st.session_state['temperature'], max_tokens=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True) - result = qa({"question": question}) - logger.info('Result: %s', result) - st.write(result["answer"]) elif anthropic_api_key and model.startswith("claude"): logger.info('Using Anthropics model %s', model) qa = ConversationalRetrievalChain.from_llm( ChatAnthropic( model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400) - result = qa({"question": question}) - logger.info('Result: %s', result) - st.write(result["answer"]) + + + st.session_state['chat_history'].append(("You", question)) + + # Generate model's response and add it to chat history + model_response = qa({"question": question}) + logger.info('Result: %s', model_response) + + st.session_state['chat_history'].append(("Quivr", model_response["answer"])) + + # Display chat history + for speaker, text in st.session_state['chat_history']: + st.empty() + st.markdown(f"**{speaker}:** {text}") if count_button: st.write(count_tokens(question, model))