mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-30 01:32:52 +03:00
feat(dead-code): removed composite & api (#2902)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
cf5c56d74c
commit
a2721d3926
@ -14,7 +14,6 @@ from quivr_api.modules.api_key.controller import api_key_router
|
||||
from quivr_api.modules.assistant.controller import assistant_router
|
||||
from quivr_api.modules.brain.controller import brain_router
|
||||
from quivr_api.modules.chat.controller import chat_router
|
||||
from quivr_api.modules.contact_support.controller import contact_router
|
||||
from quivr_api.modules.knowledge.controller import knowledge_router
|
||||
from quivr_api.modules.misc.controller import misc_router
|
||||
from quivr_api.modules.onboarding.controller import onboarding_router
|
||||
@ -86,7 +85,6 @@ app.include_router(api_key_router)
|
||||
app.include_router(subscription_router)
|
||||
app.include_router(prompt_router)
|
||||
app.include_router(knowledge_router)
|
||||
app.include_router(contact_router)
|
||||
|
||||
PROFILING = os.getenv("PROFILING", "false").lower() == "true"
|
||||
|
||||
|
@ -20,6 +20,12 @@ class BrainRateLimiting(BaseSettings):
|
||||
max_brain_per_user: int = 5
|
||||
|
||||
|
||||
class SendEmailSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
resend_contact_sales_from: str = "null"
|
||||
resend_contact_sales_to: str = "null"
|
||||
|
||||
|
||||
# The `PostHogSettings` class is used to initialize and interact with the PostHog analytics service.
|
||||
class PostHogSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
|
@ -29,7 +29,6 @@ async def list_assistants(
|
||||
summary = summary_inputs()
|
||||
# difference = difference_inputs()
|
||||
# crawler = crawler_inputs()
|
||||
# audio_transcript = audio_transcript_inputs()
|
||||
return [summary]
|
||||
|
||||
|
||||
|
@ -1,88 +0,0 @@
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from openai import OpenAI
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.assistant.dto.outputs import (
|
||||
AssistantOutput,
|
||||
InputFile,
|
||||
Inputs,
|
||||
OutputBrain,
|
||||
OutputEmail,
|
||||
Outputs,
|
||||
)
|
||||
from quivr_api.modules.assistant.ito.ito import ITO
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AudioTranscriptAssistant(ITO):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def process_assistant(self):
|
||||
client = OpenAI()
|
||||
|
||||
logger.info(f"Processing audio file {self.uploadFile.filename}")
|
||||
|
||||
# Extract the original filename and create a temporary file with the same name
|
||||
filename = os.path.basename(self.uploadFile.filename)
|
||||
temp_file = NamedTemporaryFile(delete=False, suffix=filename)
|
||||
|
||||
# Write the uploaded file's data to the temporary file
|
||||
data = await self.uploadFile.read()
|
||||
temp_file.write(data)
|
||||
temp_file.close()
|
||||
|
||||
# Open the temporary file and pass it to the OpenAI API
|
||||
with open(temp_file.name, "rb") as file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model="whisper-1", file=file, response_format="text"
|
||||
)
|
||||
logger.info(f"Transcription: {transcription}")
|
||||
|
||||
# Delete the temporary file
|
||||
os.remove(temp_file.name)
|
||||
|
||||
return await self.create_and_upload_processed_file(
|
||||
transcription, self.uploadFile.filename, "Audio Transcript"
|
||||
)
|
||||
|
||||
|
||||
def audio_transcript_inputs():
|
||||
output = AssistantOutput(
|
||||
name="Audio Transcript",
|
||||
description="Transcribes an audio file",
|
||||
tags=["new"],
|
||||
input_description="One audio file to transcribe",
|
||||
output_description="Transcription of the audio file",
|
||||
inputs=Inputs(
|
||||
files=[
|
||||
InputFile(
|
||||
key="audio_file",
|
||||
allowed_extensions=["mp3", "wav", "ogg", "m4a"],
|
||||
required=True,
|
||||
description="The audio file to transcribe",
|
||||
)
|
||||
]
|
||||
),
|
||||
outputs=Outputs(
|
||||
brain=OutputBrain(
|
||||
required=True,
|
||||
description="The brain to which to upload the document",
|
||||
type="uuid",
|
||||
),
|
||||
email=OutputEmail(
|
||||
required=True,
|
||||
description="Send the document by email",
|
||||
type="str",
|
||||
),
|
||||
),
|
||||
)
|
||||
return output
|
@ -10,10 +10,10 @@ from typing import List, Optional
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import SendEmailSettings
|
||||
from quivr_api.modules.assistant.dto.inputs import InputAssistant
|
||||
from quivr_api.modules.assistant.ito.utils.pdf_generator import PDFGenerator, PDFModel
|
||||
from quivr_api.modules.chat.controller.chat.utils import update_user_usage
|
||||
from quivr_api.modules.contact_support.controller.settings import ContactsSettings
|
||||
from quivr_api.modules.upload.controller.upload_routes import upload_file
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
@ -80,7 +80,7 @@ class ITO(BaseModel):
|
||||
custom_message: str,
|
||||
brain_id: str = None,
|
||||
):
|
||||
settings = ContactsSettings()
|
||||
settings = SendEmailSettings()
|
||||
file = await self.uploadfile_to_file(file)
|
||||
domain_quivr = os.getenv("QUIVR_DOMAIN", "https://chat.quivr.app/")
|
||||
|
||||
|
@ -1,499 +0,0 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import jq
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from litellm import completion
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.brain.qa_interface import QAInterface
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.brain.service.call_brain_api import call_brain_api
|
||||
from quivr_api.modules.brain.service.get_api_brain_definition_as_json_schema import (
|
||||
get_api_brain_definition_as_json_schema,
|
||||
)
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
|
||||
brain_service = BrainService()
|
||||
chat_service = get_service(ChatService)()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UUIDEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, UUID):
|
||||
# if the object is uuid, we simply return the value of uuid
|
||||
return str(obj)
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
class APIBrainQA(KnowledgeBrainQA, QAInterface):
|
||||
user_id: UUID
|
||||
raw: bool = False
|
||||
jq_instructions: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
brain_id: str,
|
||||
chat_id: str,
|
||||
streaming: bool = False,
|
||||
prompt_id: Optional[UUID] = None,
|
||||
raw: bool = False,
|
||||
jq_instructions: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="Cannot find user id")
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
brain_id=brain_id,
|
||||
chat_id=chat_id,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.user_id = user_id
|
||||
self.raw = raw
|
||||
self.jq_instructions = jq_instructions
|
||||
|
||||
def get_api_call_response_as_text(
|
||||
self, method, api_url, params, search_params, secrets
|
||||
) -> str:
|
||||
headers = {}
|
||||
|
||||
api_url_with_search_params = api_url
|
||||
if search_params:
|
||||
api_url_with_search_params += "?"
|
||||
for search_param in search_params:
|
||||
api_url_with_search_params += (
|
||||
f"{search_param}={search_params[search_param]}&"
|
||||
)
|
||||
|
||||
for secret in secrets:
|
||||
headers[secret] = secrets[secret]
|
||||
|
||||
try:
|
||||
if method in ["GET", "DELETE"]:
|
||||
response = requests.request(
|
||||
method,
|
||||
url=api_url_with_search_params,
|
||||
params=params or None,
|
||||
headers=headers or None,
|
||||
)
|
||||
elif method in ["POST", "PUT", "PATCH"]:
|
||||
response = requests.request(
|
||||
method,
|
||||
url=api_url_with_search_params,
|
||||
json=params or None,
|
||||
headers=headers or None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling API: {e}")
|
||||
return None
|
||||
|
||||
def log_steps(self, message: str, type: str):
|
||||
if "api" not in self.metadata:
|
||||
self.metadata["api"] = {}
|
||||
if "steps" not in self.metadata["api"]:
|
||||
self.metadata["api"]["steps"] = []
|
||||
self.metadata["api"]["steps"].append(
|
||||
{
|
||||
"number": len(self.metadata["api"]["steps"]),
|
||||
"type": type,
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
|
||||
async def make_completion(
|
||||
self,
|
||||
messages,
|
||||
functions,
|
||||
brain_id: UUID,
|
||||
recursive_count=0,
|
||||
should_log_steps=True,
|
||||
) -> str | None:
|
||||
if recursive_count > 5:
|
||||
self.log_steps(
|
||||
"The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction.",
|
||||
"error",
|
||||
)
|
||||
return
|
||||
|
||||
if "api" not in self.metadata:
|
||||
self.metadata["api"] = {}
|
||||
if "raw" not in self.metadata["api"]:
|
||||
self.metadata["api"]["raw_enabled"] = self.raw
|
||||
|
||||
response = completion(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
stream=True,
|
||||
function_call="auto",
|
||||
)
|
||||
|
||||
function_call = {
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
for chunk in response:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
if finish_reason == "stop":
|
||||
self.log_steps("Quivr has finished", "info")
|
||||
break
|
||||
if (
|
||||
"function_call" in chunk.choices[0].delta
|
||||
and chunk.choices[0].delta["function_call"]
|
||||
):
|
||||
if chunk.choices[0].delta["function_call"].name:
|
||||
function_call["name"] = chunk.choices[0].delta["function_call"].name
|
||||
if chunk.choices[0].delta["function_call"].arguments:
|
||||
function_call["arguments"] += (
|
||||
chunk.choices[0].delta["function_call"].arguments
|
||||
)
|
||||
|
||||
elif finish_reason == "function_call":
|
||||
try:
|
||||
arguments = json.loads(function_call["arguments"])
|
||||
|
||||
except Exception:
|
||||
self.log_steps(f"Issues with {arguments}", "error")
|
||||
arguments = {}
|
||||
|
||||
self.log_steps(f"Calling {brain_id} with arguments {arguments}", "info")
|
||||
|
||||
try:
|
||||
api_call_response = call_brain_api(
|
||||
brain_id=brain_id,
|
||||
user_id=self.user_id,
|
||||
arguments=arguments,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"Error while calling API: {e}")
|
||||
api_call_response = f"Error while calling API: {e}"
|
||||
function_name = function_call["name"]
|
||||
self.log_steps("Quivr has called the API", "info")
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_call["name"],
|
||||
"content": f"The function {function_name} was called and gave The following answer:(data from function) {api_call_response} (end of data from function). Don't call this function again unless there was an error or extremely necessary and asked specifically by the user. If an error, display it to the user in raw.",
|
||||
}
|
||||
)
|
||||
|
||||
self.metadata["api"]["raw_response"] = json.loads(api_call_response)
|
||||
if self.raw:
|
||||
# Yield the raw response in a format that can then be catched by the generate_stream function
|
||||
response_to_yield = f"````raw_response: {api_call_response}````"
|
||||
|
||||
yield response_to_yield
|
||||
return
|
||||
|
||||
async for value in self.make_completion(
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
brain_id=brain_id,
|
||||
recursive_count=recursive_count + 1,
|
||||
should_log_steps=should_log_steps,
|
||||
):
|
||||
yield value
|
||||
|
||||
else:
|
||||
if (
|
||||
hasattr(chunk.choices[0], "delta")
|
||||
and chunk.choices[0].delta
|
||||
and hasattr(chunk.choices[0].delta, "content")
|
||||
):
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
else: # pragma: no cover
|
||||
yield "**...**"
|
||||
break
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool = True,
|
||||
should_log_steps: Optional[bool] = True,
|
||||
):
|
||||
brain = brain_service.get_brain_by_id(self.brain_id)
|
||||
|
||||
if not brain:
|
||||
raise HTTPException(status_code=404, detail="Brain not found")
|
||||
|
||||
prompt_content = "You are a helpful assistant that can access functions to help answer questions. If there are information missing in the question, you can ask follow up questions to get more information to the user. Once all the information is available, you can call the function to get the answer."
|
||||
|
||||
if self.prompt_to_use:
|
||||
prompt_content += self.prompt_to_use.content
|
||||
|
||||
messages = [{"role": "system", "content": prompt_content}]
|
||||
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
for message in history:
|
||||
formatted_message = [
|
||||
{"role": "user", "content": message.user_message},
|
||||
{"role": "assistant", "content": message.assistant},
|
||||
]
|
||||
messages.extend(formatted_message)
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
if save_answer:
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": self.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name if brain else None,
|
||||
"brain_id": str(self.brain_id),
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
)
|
||||
else:
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": None,
|
||||
"message_time": None,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name if brain else None,
|
||||
"brain_id": str(self.brain_id),
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
)
|
||||
response_tokens = []
|
||||
async for value in self.make_completion(
|
||||
messages=messages,
|
||||
functions=[get_api_brain_definition_as_json_schema(brain)],
|
||||
brain_id=self.brain_id,
|
||||
should_log_steps=should_log_steps,
|
||||
):
|
||||
# Look if the value is a raw response
|
||||
if value.startswith("````raw_response:"):
|
||||
raw_value_cleaned = value.replace("````raw_response: ", "").replace(
|
||||
"````", ""
|
||||
)
|
||||
logger.info(f"Raw response: {raw_value_cleaned}")
|
||||
if self.jq_instructions:
|
||||
json_raw_value_cleaned = json.loads(raw_value_cleaned)
|
||||
raw_value_cleaned = (
|
||||
jq.compile(self.jq_instructions)
|
||||
.input_value(json_raw_value_cleaned)
|
||||
.first()
|
||||
)
|
||||
streamed_chat_history.assistant = raw_value_cleaned
|
||||
response_tokens.append(raw_value_cleaned)
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
else:
|
||||
streamed_chat_history.assistant = value
|
||||
response_tokens.append(value)
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(str(token) for token in response_tokens),
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
def make_completion_without_streaming(
|
||||
self,
|
||||
messages,
|
||||
functions,
|
||||
brain_id: UUID,
|
||||
recursive_count=0,
|
||||
should_log_steps=False,
|
||||
):
|
||||
if recursive_count > 5:
|
||||
print(
|
||||
"The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction."
|
||||
)
|
||||
return
|
||||
|
||||
if should_log_steps:
|
||||
print("🧠<Deciding what to do>🧠")
|
||||
|
||||
response = completion(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
stream=False,
|
||||
function_call="auto",
|
||||
)
|
||||
|
||||
response_message = response.choices[0].message
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
if finish_reason == "function_call":
|
||||
function_call = response_message.function_call
|
||||
try:
|
||||
arguments = json.loads(function_call.arguments)
|
||||
|
||||
except Exception:
|
||||
arguments = {}
|
||||
|
||||
if should_log_steps:
|
||||
self.log_steps(f"Calling {brain_id} with arguments {arguments}", "info")
|
||||
|
||||
try:
|
||||
api_call_response = call_brain_api(
|
||||
brain_id=brain_id,
|
||||
user_id=self.user_id,
|
||||
arguments=arguments,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Error while calling API: {e}",
|
||||
)
|
||||
|
||||
function_name = function_call.name
|
||||
messages.append(
|
||||
{
|
||||
"role": "function",
|
||||
"name": function_call.name,
|
||||
"content": f"The function {function_name} was called and gave The following answer:(data from function) {api_call_response} (end of data from function). Don't call this function again unless there was an error or extremely necessary and asked specifically by the user.",
|
||||
}
|
||||
)
|
||||
|
||||
return self.make_completion_without_streaming(
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
brain_id=brain_id,
|
||||
recursive_count=recursive_count + 1,
|
||||
should_log_steps=should_log_steps,
|
||||
)
|
||||
|
||||
if finish_reason == "stop":
|
||||
return response_message
|
||||
|
||||
else:
|
||||
print("Never ending completion")
|
||||
|
||||
def generate_answer(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool = True,
|
||||
raw: bool = True,
|
||||
):
|
||||
if not self.brain_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No brain id provided in the question"
|
||||
)
|
||||
|
||||
brain = brain_service.get_brain_by_id(self.brain_id)
|
||||
|
||||
if not brain:
|
||||
raise HTTPException(status_code=404, detail="Brain not found")
|
||||
|
||||
prompt_content = "You are a helpful assistant that can access functions to help answer questions. If there are information missing in the question, you can ask follow up questions to get more information to the user. Once all the information is available, you can call the function to get the answer."
|
||||
|
||||
if self.prompt_to_use:
|
||||
prompt_content += self.prompt_to_use.content
|
||||
|
||||
messages = [{"role": "system", "content": prompt_content}]
|
||||
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
for message in history:
|
||||
formatted_message = [
|
||||
{"role": "user", "content": message.user_message},
|
||||
{"role": "assistant", "content": message.assistant},
|
||||
]
|
||||
messages.extend(formatted_message)
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
response = self.make_completion_without_streaming(
|
||||
messages=messages,
|
||||
functions=[get_api_brain_definition_as_json_schema(brain)],
|
||||
brain_id=self.brain_id,
|
||||
should_log_steps=False,
|
||||
raw=raw,
|
||||
)
|
||||
|
||||
answer = response.content
|
||||
if save_answer:
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": self.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name if brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
"metadata": self.metadata,
|
||||
"brain_id": str(self.brain_id),
|
||||
}
|
||||
)
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": "123",
|
||||
"prompt_title": None,
|
||||
"brain_name": brain.name,
|
||||
"message_id": None,
|
||||
"metadata": self.metadata,
|
||||
"brain_id": str(self.brain_id),
|
||||
}
|
||||
)
|
@ -1,592 +0,0 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from litellm import completion
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.api_brain_qa import APIBrainQA
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity, BrainType
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.brain.qa_headless import HeadlessQA
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
|
||||
from quivr_api.modules.chat.dto.outputs import (
|
||||
BrainCompletionOutput,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
GetChatHistoryOutput,
|
||||
)
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
|
||||
brain_service = BrainService()
|
||||
chat_service = get_service(ChatService)()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def format_brain_to_tool(brain):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(brain.id),
|
||||
"description": brain.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask the brain",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CompositeBrainQA(
|
||||
KnowledgeBrainQA,
|
||||
):
|
||||
user_id: UUID
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
brain_id: str,
|
||||
chat_id: str,
|
||||
streaming: bool = False,
|
||||
prompt_id: Optional[UUID] = None,
|
||||
**kwargs,
|
||||
):
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="Cannot find user id")
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
brain_id=brain_id,
|
||||
chat_id=chat_id,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.user_id = user_id
|
||||
|
||||
def get_answer_generator_from_brain_type(self, brain: BrainEntity):
|
||||
if brain.brain_type == BrainType.composite:
|
||||
return self.generate_answer
|
||||
elif brain.brain_type == BrainType.api:
|
||||
return APIBrainQA(
|
||||
brain_id=str(brain.id),
|
||||
chat_id=self.chat_id,
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
streaming=self.streaming,
|
||||
prompt_id=self.prompt_id,
|
||||
user_id=str(self.user_id),
|
||||
raw=brain.raw,
|
||||
jq_instructions=brain.jq_instructions,
|
||||
).generate_answer
|
||||
elif brain.brain_type == BrainType.doc:
|
||||
return KnowledgeBrainQA(
|
||||
brain_id=str(brain.id),
|
||||
chat_id=self.chat_id,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
streaming=self.streaming,
|
||||
prompt_id=self.prompt_id,
|
||||
).generate_answer
|
||||
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool
|
||||
) -> str:
|
||||
brain = brain_service.get_brain_by_id(question.brain_id)
|
||||
|
||||
connected_brains = brain_service.get_connected_brains(self.brain_id)
|
||||
|
||||
if not connected_brains:
|
||||
response = HeadlessQA(
|
||||
chat_id=chat_id,
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
streaming=self.streaming,
|
||||
prompt_id=self.prompt_id,
|
||||
).generate_answer(chat_id, question, save_answer=False)
|
||||
if save_answer:
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": response.assistant,
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": response.assistant,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name,
|
||||
"message_id": new_chat.message_id,
|
||||
"brain_id": str(brain.id),
|
||||
}
|
||||
)
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": response.assistant,
|
||||
"message_time": None,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name,
|
||||
"message_id": None,
|
||||
"brain_id": str(brain.id),
|
||||
}
|
||||
)
|
||||
|
||||
tools = []
|
||||
available_functions = {}
|
||||
|
||||
connected_brains_details = {}
|
||||
for connected_brain_id in connected_brains:
|
||||
connected_brain = brain_service.get_brain_by_id(connected_brain_id)
|
||||
if connected_brain is None:
|
||||
continue
|
||||
|
||||
tools.append(format_brain_to_tool(connected_brain))
|
||||
|
||||
available_functions[connected_brain_id] = (
|
||||
self.get_answer_generator_from_brain_type(connected_brain)
|
||||
)
|
||||
|
||||
connected_brains_details[str(connected_brain.id)] = connected_brain
|
||||
|
||||
CHOOSE_BRAIN_FROM_TOOLS_PROMPT = (
|
||||
"Based on the provided user content, find the most appropriate tools to answer"
|
||||
+ "If you can't find any tool to answer and only then, and if you can answer without using any tool. In that case, let the user know that you are not using any particular brain (i.e tool) "
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": CHOOSE_BRAIN_FROM_TOOLS_PROMPT}]
|
||||
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
for message in history:
|
||||
formatted_message = [
|
||||
{"role": "user", "content": message.user_message},
|
||||
{"role": "assistant", "content": message.assistant},
|
||||
]
|
||||
messages.extend(formatted_message)
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
response = completion(
|
||||
model="gpt-3.5-turbo-0125",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
brain_completion_output = self.make_recursive_tool_calls(
|
||||
messages,
|
||||
question,
|
||||
chat_id,
|
||||
tools,
|
||||
available_functions,
|
||||
recursive_count=0,
|
||||
last_completion_response=response.choices[0],
|
||||
)
|
||||
|
||||
if brain_completion_output:
|
||||
answer = brain_completion_output.response.message.content
|
||||
new_chat = None
|
||||
if save_answer:
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": brain_completion_output.response.message.content,
|
||||
"message_time": new_chat.message_time if new_chat else None,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name if brain else None,
|
||||
"message_id": new_chat.message_id if new_chat else None,
|
||||
"brain_id": str(brain.id) if brain else None,
|
||||
}
|
||||
)
|
||||
|
||||
def make_recursive_tool_calls(
|
||||
self,
|
||||
messages,
|
||||
question,
|
||||
chat_id,
|
||||
tools=[],
|
||||
available_functions={},
|
||||
recursive_count=0,
|
||||
last_completion_response: CompletionResponse = None,
|
||||
):
|
||||
if recursive_count > 5:
|
||||
print(
|
||||
"The assistant is having issues and took more than 5 calls to the tools. Please try again later or an other instruction."
|
||||
)
|
||||
return None
|
||||
|
||||
finish_reason = last_completion_response.finish_reason
|
||||
if finish_reason == "stop":
|
||||
messages.append(last_completion_response.message)
|
||||
return BrainCompletionOutput(
|
||||
**{
|
||||
"messages": messages,
|
||||
"question": question.question,
|
||||
"response": last_completion_response,
|
||||
}
|
||||
)
|
||||
|
||||
if finish_reason == "tool_calls":
|
||||
response_message: CompletionMessage = last_completion_response.message
|
||||
tool_calls = response_message.tool_calls
|
||||
|
||||
messages.append(response_message)
|
||||
|
||||
if (
|
||||
len(tool_calls) == 0
|
||||
or tool_calls is None
|
||||
or len(available_functions) == 0
|
||||
):
|
||||
return
|
||||
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
question = ChatQuestion(
|
||||
question=function_args["question"], brain_id=function_name
|
||||
)
|
||||
|
||||
# TODO: extract chat_id from generate_answer function of XBrainQA
|
||||
function_response = function_to_call(
|
||||
chat_id=chat_id,
|
||||
question=question,
|
||||
save_answer=False,
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response.assistant,
|
||||
}
|
||||
)
|
||||
|
||||
PROMPT_2 = "If initial question can be answered by our conversation messages, then give an answer and end the conversation."
|
||||
|
||||
messages.append({"role": "system", "content": PROMPT_2})
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
logger.info(
|
||||
f"Message {idx}: Role - {msg['role']}, Content - {msg['content']}"
|
||||
)
|
||||
|
||||
response_after_tools_answers = completion(
|
||||
model="gpt-3.5-turbo-0125",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
return self.make_recursive_tool_calls(
|
||||
messages,
|
||||
question,
|
||||
chat_id,
|
||||
tools,
|
||||
available_functions,
|
||||
recursive_count=recursive_count + 1,
|
||||
last_completion_response=response_after_tools_answers.choices[0],
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool,
|
||||
should_log_steps: Optional[bool] = True,
|
||||
):
|
||||
brain = brain_service.get_brain_by_id(question.brain_id)
|
||||
if save_answer:
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": question.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name if brain else None,
|
||||
"brain_id": str(brain.id) if brain else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": None,
|
||||
"message_time": None,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name if brain else None,
|
||||
"brain_id": str(brain.id) if brain else None,
|
||||
}
|
||||
)
|
||||
|
||||
connected_brains = brain_service.get_connected_brains(self.brain_id)
|
||||
|
||||
if not connected_brains:
|
||||
headlesss_answer = HeadlessQA(
|
||||
chat_id=chat_id,
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
streaming=self.streaming,
|
||||
prompt_id=self.prompt_id,
|
||||
).generate_stream(chat_id, question)
|
||||
|
||||
response_tokens = []
|
||||
async for value in headlesss_answer:
|
||||
streamed_chat_history.assistant = value
|
||||
response_tokens.append(value)
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(response_tokens),
|
||||
)
|
||||
|
||||
tools = []
|
||||
available_functions = {}
|
||||
|
||||
connected_brains_details = {}
|
||||
for brain_id in connected_brains:
|
||||
brain = brain_service.get_brain_by_id(brain_id)
|
||||
if brain == None:
|
||||
continue
|
||||
|
||||
tools.append(format_brain_to_tool(brain))
|
||||
|
||||
available_functions[brain_id] = self.get_answer_generator_from_brain_type(
|
||||
brain
|
||||
)
|
||||
|
||||
connected_brains_details[str(brain.id)] = brain
|
||||
|
||||
CHOOSE_BRAIN_FROM_TOOLS_PROMPT = (
|
||||
"Based on the provided user content, find the most appropriate tools to answer"
|
||||
+ "If you can't find any tool to answer and only then, and if you can answer without using any tool. In that case, let the user know that you are not using any particular brain (i.e tool) "
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": CHOOSE_BRAIN_FROM_TOOLS_PROMPT}]
|
||||
|
||||
history = chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
for message in history:
|
||||
formatted_message = [
|
||||
{"role": "user", "content": message.user_message},
|
||||
{"role": "assistant", "content": message.assistant},
|
||||
]
|
||||
if message.assistant is None:
|
||||
print(message)
|
||||
messages.extend(formatted_message)
|
||||
|
||||
messages.append({"role": "user", "content": question.question})
|
||||
|
||||
initial_response = completion(
|
||||
model="gpt-3.5-turbo-0125",
|
||||
stream=True,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
response_tokens = []
|
||||
tool_calls_aggregate = []
|
||||
for chunk in initial_response:
|
||||
content = chunk.choices[0].delta.content
|
||||
if content is not None:
|
||||
# Need to store it ?
|
||||
streamed_chat_history.assistant = content
|
||||
response_tokens.append(chunk.choices[0].delta.content)
|
||||
|
||||
if save_answer:
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
else:
|
||||
yield f"🧠<' {chunk.choices[0].delta.content}"
|
||||
|
||||
if (
|
||||
"tool_calls" in chunk.choices[0].delta
|
||||
and chunk.choices[0].delta.tool_calls is not None
|
||||
):
|
||||
tool_calls = chunk.choices[0].delta.tool_calls
|
||||
for tool_call in tool_calls:
|
||||
id = tool_call.id
|
||||
name = tool_call.function.name
|
||||
if id and name:
|
||||
tool_calls_aggregate += [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"function": {
|
||||
"arguments": tool_call.function.arguments,
|
||||
"name": tool_call.function.name,
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
]
|
||||
|
||||
else:
|
||||
try:
|
||||
tool_calls_aggregate[tool_call.index]["function"][
|
||||
"arguments"
|
||||
] += tool_call.function.arguments
|
||||
except IndexError:
|
||||
print("TOOL_CALL_INDEX error", tool_call.index)
|
||||
print("TOOL_CALLS_AGGREGATE error", tool_calls_aggregate)
|
||||
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
if finish_reason == "stop":
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(
|
||||
[
|
||||
token
|
||||
for token in response_tokens
|
||||
if not token.startswith("🧠<")
|
||||
]
|
||||
),
|
||||
)
|
||||
break
|
||||
|
||||
if finish_reason == "tool_calls":
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": tool_calls_aggregate,
|
||||
"content": None,
|
||||
}
|
||||
)
|
||||
for tool_call in tool_calls_aggregate:
|
||||
function_name = tool_call["function"]["name"]
|
||||
queried_brain = connected_brains_details[function_name]
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
print("function_args", function_args["question"])
|
||||
question = ChatQuestion(
|
||||
question=function_args["question"], brain_id=queried_brain.id
|
||||
)
|
||||
|
||||
# yield f"🧠< Querying the brain {queried_brain.name} with the following arguments: {function_args} >🧠",
|
||||
|
||||
print(
|
||||
f"🧠< Querying the brain {queried_brain.name} with the following arguments: {function_args}",
|
||||
)
|
||||
function_response = function_to_call(
|
||||
chat_id=chat_id,
|
||||
question=question,
|
||||
save_answer=False,
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response.assistant,
|
||||
}
|
||||
)
|
||||
|
||||
print("messages", messages)
|
||||
|
||||
PROMPT_2 = "If the last user's question can be answered by our conversation messages since then, then give an answer and end the conversation. If you need to ask question to the user to gather more information and give a more accurate answer, then ask the question and wait for the user's answer."
|
||||
# Otherwise, ask a new question to the assistant and choose brains you would like to ask questions."
|
||||
|
||||
messages.append({"role": "system", "content": PROMPT_2})
|
||||
|
||||
response_after_tools_answers = completion(
|
||||
model="gpt-3.5-turbo-0125",
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
response_tokens = []
|
||||
for chunk in response_after_tools_answers:
|
||||
print("chunk_response_after_tools_answers", chunk)
|
||||
content = chunk.choices[0].delta.content
|
||||
if content:
|
||||
streamed_chat_history.assistant = content
|
||||
response_tokens.append(chunk.choices[0].delta.content)
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
if finish_reason == "stop":
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant="".join(
|
||||
[
|
||||
token
|
||||
for token in response_tokens
|
||||
if not token.startswith("🧠<")
|
||||
]
|
||||
),
|
||||
)
|
||||
break
|
||||
elif finish_reason is not None:
|
||||
# TODO: recursively call with tools (update prompt + create intermediary function )
|
||||
print("NO STOP")
|
||||
print(chunk.choices[0])
|
@ -1,4 +1,3 @@
|
||||
from typing import Dict
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
@ -9,7 +8,7 @@ from quivr_api.modules.brain.dto.inputs import (
|
||||
BrainUpdatableProperties,
|
||||
CreateBrainProperties,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.brain_entity import PublicBrain, RoleEnum
|
||||
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
|
||||
from quivr_api.modules.brain.entity.integration_brain import (
|
||||
IntegrationDescriptionEntity,
|
||||
)
|
||||
@ -44,7 +43,8 @@ integration_brain_description_service = IntegrationBrainDescriptionService()
|
||||
)
|
||||
async def get_integration_brain_description() -> list[IntegrationDescriptionEntity]:
|
||||
"""Retrieve the integration brain description."""
|
||||
return integration_brain_description_service.get_all_integration_descriptions()
|
||||
# TODO: Deprecated, remove this endpoint
|
||||
return []
|
||||
|
||||
|
||||
@brain_router.get("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
|
||||
@ -56,14 +56,6 @@ async def retrieve_all_brains_for_user(
|
||||
return {"brains": brains}
|
||||
|
||||
|
||||
@brain_router.get(
|
||||
"/brains/public", dependencies=[Depends(AuthBearer())], tags=["Brain"]
|
||||
)
|
||||
async def retrieve_public_brains() -> list[PublicBrain]:
|
||||
"""Retrieve all Quivr public brains."""
|
||||
return brain_service.get_public_brains()
|
||||
|
||||
|
||||
@brain_router.get(
|
||||
"/brains/{brain_id}/",
|
||||
dependencies=[
|
||||
@ -156,67 +148,6 @@ async def update_existing_brain(
|
||||
return {"message": f"Brain {brain_id} has been updated."}
|
||||
|
||||
|
||||
@brain_router.put(
|
||||
"/brains/{brain_id}/secrets-values",
|
||||
dependencies=[
|
||||
Depends(AuthBearer()),
|
||||
],
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def update_existing_brain_secrets(
|
||||
brain_id: UUID,
|
||||
secrets: Dict[str, str],
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""Update an existing brain's secrets."""
|
||||
|
||||
existing_brain = brain_service.get_brain_details(brain_id, None)
|
||||
|
||||
if existing_brain is None:
|
||||
raise HTTPException(status_code=404, detail="Brain not found")
|
||||
|
||||
if (
|
||||
existing_brain.brain_definition is None
|
||||
or existing_brain.brain_definition.secrets is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This brain does not support secrets.",
|
||||
)
|
||||
|
||||
is_brain_user = (
|
||||
brain_user_service.get_brain_for_user(
|
||||
user_id=current_user.id,
|
||||
brain_id=brain_id,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
if not is_brain_user:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You are not authorized to update this brain.",
|
||||
)
|
||||
|
||||
secrets_names = [secret.name for secret in existing_brain.brain_definition.secrets]
|
||||
|
||||
for key, value in secrets.items():
|
||||
if key not in secrets_names:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Secret {key} is not a valid secret.",
|
||||
)
|
||||
if value:
|
||||
brain_service.update_secret_value(
|
||||
user_id=current_user.id,
|
||||
brain_id=brain_id,
|
||||
secret_name=key,
|
||||
secret_value=value,
|
||||
)
|
||||
|
||||
return {"message": f"Brain {brain_id} has been updated."}
|
||||
|
||||
|
||||
@brain_router.post(
|
||||
"/brains/{brain_id}/documents",
|
||||
dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
|
||||
|
@ -3,28 +3,12 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.entity.api_brain_definition_entity import (
|
||||
ApiBrainAllowedMethods,
|
||||
ApiBrainDefinitionEntity,
|
||||
ApiBrainDefinitionSchema,
|
||||
ApiBrainDefinitionSecret,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainType
|
||||
from quivr_api.modules.brain.entity.integration_brain import IntegrationType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CreateApiBrainDefinition(BaseModel, extra="ignore"):
|
||||
method: ApiBrainAllowedMethods
|
||||
url: str
|
||||
params: Optional[ApiBrainDefinitionSchema] = ApiBrainDefinitionSchema()
|
||||
search_params: ApiBrainDefinitionSchema = ApiBrainDefinitionSchema()
|
||||
secrets: Optional[list[ApiBrainDefinitionSecret]] = []
|
||||
raw: Optional[bool] = False
|
||||
jq_instructions: Optional[str] = None
|
||||
|
||||
|
||||
class CreateIntegrationBrain(BaseModel, extra="ignore"):
|
||||
integration_name: str
|
||||
integration_logo_url: str
|
||||
@ -52,9 +36,6 @@ class CreateBrainProperties(BaseModel, extra="ignore"):
|
||||
max_tokens: Optional[int] = 2000
|
||||
prompt_id: Optional[UUID] = None
|
||||
brain_type: Optional[BrainType] = BrainType.doc
|
||||
brain_definition: Optional[CreateApiBrainDefinition] = None
|
||||
brain_secrets_values: Optional[dict] = {}
|
||||
connected_brains_ids: Optional[list[UUID]] = []
|
||||
integration: Optional[BrainIntegrationSettings] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
@ -72,8 +53,6 @@ class BrainUpdatableProperties(BaseModel, extra="ignore"):
|
||||
max_tokens: Optional[int] = None
|
||||
status: Optional[str] = None
|
||||
prompt_id: Optional[UUID] = None
|
||||
brain_definition: Optional[ApiBrainDefinitionEntity] = None
|
||||
connected_brains_ids: Optional[list[UUID]] = []
|
||||
integration: Optional[BrainIntegrationUpdateSettings] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
|
@ -1 +0,0 @@
|
||||
from .api_brain_definition_entity import ApiBrainDefinitionEntity
|
@ -1,47 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
|
||||
class ApiBrainDefinitionSchemaProperty(BaseModel, extra=Extra.forbid):
|
||||
type: str
|
||||
description: str
|
||||
enum: Optional[list] = None
|
||||
name: str
|
||||
|
||||
def dict(self, **kwargs):
|
||||
result = super().dict(**kwargs)
|
||||
if "enum" in result and result["enum"] is None:
|
||||
del result["enum"]
|
||||
return result
|
||||
|
||||
|
||||
class ApiBrainDefinitionSchema(BaseModel, extra=Extra.forbid):
|
||||
properties: list[ApiBrainDefinitionSchemaProperty] = []
|
||||
required: list[str] = []
|
||||
|
||||
|
||||
class ApiBrainDefinitionSecret(BaseModel, extra=Extra.forbid):
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ApiBrainAllowedMethods(str, Enum):
|
||||
GET = "GET"
|
||||
POST = "POST"
|
||||
PUT = "PUT"
|
||||
DELETE = "DELETE"
|
||||
|
||||
|
||||
class ApiBrainDefinitionEntity(BaseModel, extra=Extra.forbid):
|
||||
brain_id: UUID
|
||||
method: ApiBrainAllowedMethods
|
||||
url: str
|
||||
params: ApiBrainDefinitionSchema
|
||||
search_params: ApiBrainDefinitionSchema
|
||||
secrets: list[ApiBrainDefinitionSecret]
|
||||
raw: bool = False
|
||||
jq_instructions: Optional[str] = None
|
@ -4,9 +4,6 @@ from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from quivr_api.modules.brain.entity.api_brain_definition_entity import (
|
||||
ApiBrainDefinitionEntity,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.integration_brain import (
|
||||
IntegrationDescriptionEntity,
|
||||
IntegrationEntity,
|
||||
@ -82,10 +79,6 @@ class BrainEntity(BaseModel):
|
||||
prompt_id: Optional[UUID] = None
|
||||
last_update: datetime
|
||||
brain_type: BrainType
|
||||
brain_definition: Optional[ApiBrainDefinitionEntity] = None
|
||||
connected_brains_ids: Optional[List[UUID]] = None
|
||||
raw: Optional[bool] = None
|
||||
jq_instructions: Optional[str] = None
|
||||
integration: Optional[IntegrationEntity] = None
|
||||
integration_description: Optional[IntegrationDescriptionEntity] = None
|
||||
|
||||
@ -101,16 +94,6 @@ class BrainEntity(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class PublicBrain(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
number_of_subscribers: int = 0
|
||||
last_update: str
|
||||
brain_type: BrainType
|
||||
brain_definition: Optional[ApiBrainDefinitionEntity] = None
|
||||
|
||||
|
||||
class RoleEnum(str, Enum):
|
||||
Viewer = "Viewer"
|
||||
Editor = "Editor"
|
||||
|
@ -1,8 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CompositeBrainConnectionEntity(BaseModel):
|
||||
composite_brain_id: UUID
|
||||
connected_brain_id: UUID
|
@ -19,12 +19,6 @@ logger = get_logger(__name__)
|
||||
class cited_answer(BaseModelV1):
|
||||
"""Answer the user question based only on the given sources, and cite the sources used."""
|
||||
|
||||
thoughts: str = FieldV1(
|
||||
...,
|
||||
description="""Description of the thought process, based only on the given sources.
|
||||
Cite the text as much as possible and give the document name it appears in. In the format : 'Doc_name states : cited_text'. Be the most
|
||||
procedural as possible.""",
|
||||
)
|
||||
answer: str = FieldV1(
|
||||
...,
|
||||
description="The answer to the user question, which is based only on the given sources.",
|
||||
@ -34,10 +28,6 @@ class cited_answer(BaseModelV1):
|
||||
description="The integer IDs of the SPECIFIC sources which justify the answer.",
|
||||
)
|
||||
|
||||
thoughts: str = FieldV1(
|
||||
...,
|
||||
description="Explain shortly what you did to find the answer and what you used by citing the sources by their name.",
|
||||
)
|
||||
followup_questions: List[str] = FieldV1(
|
||||
...,
|
||||
description="Generate up to 3 follow-up questions that could be asked based on the answer given or context provided.",
|
||||
|
@ -1,514 +0,0 @@
|
||||
import json
|
||||
from typing import AsyncIterable, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import BrainSettings
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity
|
||||
from quivr_api.modules.brain.qa_interface import (
|
||||
QAInterface,
|
||||
model_compatible_with_function_calling,
|
||||
)
|
||||
from quivr_api.modules.brain.rags.quivr_rag import QuivrRAG
|
||||
from quivr_api.modules.brain.rags.rag_interface import RAGInterface
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.brain.service.utils.format_chat_history import (
|
||||
format_chat_history,
|
||||
)
|
||||
from quivr_api.modules.brain.service.utils.get_prompt_to_use_id import (
|
||||
get_prompt_to_use_id,
|
||||
)
|
||||
from quivr_api.modules.chat.controller.chat.utils import (
|
||||
find_model_and_generate_metadata,
|
||||
update_user_usage,
|
||||
)
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion, Sources
|
||||
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.prompt.service.get_prompt_to_use import get_prompt_to_use
|
||||
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
||||
generate_file_signed_url,
|
||||
)
|
||||
from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
|
||||
brain_service = BrainService()
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_to_test, version=4):
|
||||
try:
|
||||
uuid_obj = UUID(uuid_to_test, version=version)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
return str(uuid_obj) == uuid_to_test
|
||||
|
||||
|
||||
def generate_source(
|
||||
source_documents,
|
||||
brain_id: UUID,
|
||||
citations: List[int] | None = None,
|
||||
):
|
||||
"""
|
||||
Generate the sources list for the answer
|
||||
It takes in a list of sources documents and citations that points to the docs index that was used in the answer
|
||||
"""
|
||||
# Initialize an empty list for sources
|
||||
sources_list: List[Sources] = []
|
||||
|
||||
# Initialize a dictionary for storing generated URLs
|
||||
generated_urls = {}
|
||||
|
||||
# remove duplicate sources with same name and create a list of unique sources
|
||||
sources_url_cache = {}
|
||||
|
||||
# Get source documents from the result, default to an empty list if not found
|
||||
|
||||
# If source documents exist
|
||||
if source_documents:
|
||||
logger.info(f"Citations {citations}")
|
||||
# Iterate over each document
|
||||
for doc, index in zip(source_documents, range(len(source_documents))):
|
||||
logger.info(f"Processing source document {doc.metadata['file_name']}")
|
||||
if citations is not None:
|
||||
if index not in citations:
|
||||
logger.info(f"Skipping source document {doc.metadata['file_name']}")
|
||||
continue
|
||||
# Check if 'url' is in the document metadata
|
||||
is_url = (
|
||||
"original_file_name" in doc.metadata
|
||||
and doc.metadata["original_file_name"] is not None
|
||||
and doc.metadata["original_file_name"].startswith("http")
|
||||
)
|
||||
|
||||
# Determine the name based on whether it's a URL or a file
|
||||
name = (
|
||||
doc.metadata["original_file_name"]
|
||||
if is_url
|
||||
else doc.metadata["file_name"]
|
||||
)
|
||||
|
||||
# Determine the type based on whether it's a URL or a file
|
||||
type_ = "url" if is_url else "file"
|
||||
|
||||
# Determine the source URL based on whether it's a URL or a file
|
||||
if is_url:
|
||||
source_url = doc.metadata["original_file_name"]
|
||||
else:
|
||||
file_path = f"{brain_id}/{doc.metadata['file_name']}"
|
||||
# Check if the URL has already been generated
|
||||
if file_path in generated_urls:
|
||||
source_url = generated_urls[file_path]
|
||||
else:
|
||||
# Generate the URL
|
||||
if file_path in sources_url_cache:
|
||||
source_url = sources_url_cache[file_path]
|
||||
else:
|
||||
generated_url = generate_file_signed_url(file_path)
|
||||
if generated_url is not None:
|
||||
source_url = generated_url.get("signedURL", "")
|
||||
else:
|
||||
source_url = ""
|
||||
# Store the generated URL
|
||||
generated_urls[file_path] = source_url
|
||||
|
||||
# Append a new Sources object to the list
|
||||
sources_list.append(
|
||||
Sources(
|
||||
name=name,
|
||||
type=type_,
|
||||
source_url=source_url,
|
||||
original_file_name=name,
|
||||
citation=doc.page_content,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info("No source documents found or source_documents is not a list.")
|
||||
return sources_list
|
||||
|
||||
|
||||
class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
"""
|
||||
Main class for the Brain Picking functionality.
|
||||
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
|
||||
It has two main methods: `generate_question` and `generate_stream`.
|
||||
One is for generating questions in a single request, the other is for generating questions in a streaming fashion.
|
||||
Both are the same, except that the streaming version streams the last message as a stream.
|
||||
Each have the same prompt template, which is defined in the `prompt_template` property.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# Instantiate settings
|
||||
brain_settings: BaseSettings = BrainSettings()
|
||||
|
||||
# TODO: remove this !!!!! Only added for compatibility
|
||||
chat_service: ChatService
|
||||
|
||||
# Default class attributes
|
||||
model: str = "gpt-3.5-turbo-0125" # pyright: ignore reportPrivateUsage=none
|
||||
temperature: float = 0.1
|
||||
chat_id: str = None # pyright: ignore reportPrivateUsage=none
|
||||
brain_id: str = None # pyright: ignore reportPrivateUsage=none
|
||||
max_tokens: int = 2000
|
||||
max_input: int = 2000
|
||||
streaming: bool = False
|
||||
knowledge_qa: Optional[RAGInterface] = None
|
||||
brain: Optional[BrainEntity] = None
|
||||
user_id: str = None
|
||||
user_email: str = None
|
||||
user_usage: Optional[UserUsage] = None
|
||||
user_settings: Optional[dict] = None
|
||||
models_settings: Optional[List[dict]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
callbacks: List[AsyncIteratorCallbackHandler] = (
|
||||
None # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
|
||||
prompt_id: Optional[UUID] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
brain_id: str,
|
||||
chat_id: str,
|
||||
chat_service: ChatService,
|
||||
user_id: str = None,
|
||||
user_email: str = None,
|
||||
streaming: bool = False,
|
||||
prompt_id: Optional[UUID] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
cost: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
brain_id=brain_id,
|
||||
chat_id=chat_id,
|
||||
chat_service=chat_service,
|
||||
streaming=streaming,
|
||||
**kwargs,
|
||||
)
|
||||
self.chat_service = chat_service
|
||||
self.prompt_id = prompt_id
|
||||
self.user_id = user_id
|
||||
self.user_email = user_email
|
||||
self.user_usage = UserUsage(id=user_id, email=user_email)
|
||||
# TODO: we already have a brain before !!!
|
||||
self.brain = brain_service.get_brain_by_id(brain_id)
|
||||
self.user_settings = self.user_usage.get_user_settings()
|
||||
|
||||
# Get Model settings for the user
|
||||
self.models_settings = self.user_usage.get_models()
|
||||
self.increase_usage_user()
|
||||
self.knowledge_qa = QuivrRAG(
|
||||
model=self.brain.model if self.brain.model else self.model,
|
||||
brain_id=brain_id,
|
||||
chat_id=chat_id,
|
||||
streaming=streaming,
|
||||
max_input=self.max_input,
|
||||
max_tokens=self.max_tokens,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
|
||||
@property
|
||||
def prompt_to_use(self):
|
||||
if self.brain_id and is_valid_uuid(self.brain_id):
|
||||
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def prompt_to_use_id(self) -> Optional[UUID]:
|
||||
# TODO: move to prompt service or instruction or something
|
||||
if self.brain_id and is_valid_uuid(self.brain_id):
|
||||
return get_prompt_to_use_id(UUID(self.brain_id), self.prompt_id)
|
||||
else:
|
||||
return None
|
||||
|
||||
def filter_history(
|
||||
self, chat_history, max_history: int = 10, max_tokens: int = 2000
|
||||
):
|
||||
"""
|
||||
Filter out the chat history to only include the messages that are relevant to the current question
|
||||
|
||||
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, sous la direction de son supérieur hiérarchique, Stanislas Girard."), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
|
||||
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
||||
a token is 4 characters
|
||||
"""
|
||||
chat_history = chat_history[::-1]
|
||||
total_tokens = 0
|
||||
total_pairs = 0
|
||||
filtered_chat_history = []
|
||||
for i in range(0, len(chat_history), 2):
|
||||
if i + 1 < len(chat_history):
|
||||
human_message = chat_history[i]
|
||||
ai_message = chat_history[i + 1]
|
||||
message_tokens = (
|
||||
len(human_message.content) + len(ai_message.content)
|
||||
) // 4
|
||||
if (
|
||||
total_tokens + message_tokens > max_tokens
|
||||
or total_pairs >= max_history
|
||||
):
|
||||
break
|
||||
filtered_chat_history.append(human_message)
|
||||
filtered_chat_history.append(ai_message)
|
||||
total_tokens += message_tokens
|
||||
total_pairs += 1
|
||||
chat_history = filtered_chat_history[::-1]
|
||||
|
||||
return chat_history
|
||||
|
||||
def increase_usage_user(self):
|
||||
# Raises an error if the user has consumed all of of his credits
|
||||
|
||||
update_user_usage(
|
||||
usage=self.user_usage,
|
||||
user_settings=self.user_settings,
|
||||
cost=self.calculate_pricing(),
|
||||
)
|
||||
|
||||
def calculate_pricing(self):
|
||||
model_to_use = find_model_and_generate_metadata(
|
||||
self.brain.model,
|
||||
self.user_settings,
|
||||
self.models_settings,
|
||||
)
|
||||
self.model = model_to_use.name
|
||||
self.max_input = model_to_use.max_input
|
||||
self.max_tokens = model_to_use.max_output
|
||||
user_choosen_model_price = 1000
|
||||
|
||||
for model_setting in self.models_settings:
|
||||
if model_setting["name"] == self.model:
|
||||
user_choosen_model_price = model_setting["price"]
|
||||
|
||||
return user_choosen_model_price
|
||||
|
||||
# TODO: deprecated
|
||||
async def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
conversational_qa_chain = self.knowledge_qa.get_chain()
|
||||
transformed_history, _ = await self.initialize_streamed_chat_history(
|
||||
chat_id, question
|
||||
)
|
||||
metadata = self.metadata or {}
|
||||
citations = None
|
||||
answer = ""
|
||||
config = {"metadata": {"conversation_id": str(chat_id)}}
|
||||
|
||||
model_response = conversational_qa_chain.invoke(
|
||||
{
|
||||
"question": question.question,
|
||||
"chat_history": transformed_history,
|
||||
"custom_personality": (
|
||||
self.prompt_to_use.content if self.prompt_to_use else None
|
||||
),
|
||||
},
|
||||
config=config,
|
||||
)
|
||||
|
||||
if model_compatible_with_function_calling(model=self.model):
|
||||
if model_response["answer"].tool_calls:
|
||||
citations = model_response["answer"].tool_calls[-1]["args"]["citations"]
|
||||
followup_questions = model_response["answer"].tool_calls[-1]["args"][
|
||||
"followup_questions"
|
||||
]
|
||||
thoughts = model_response["answer"].tool_calls[-1]["args"]["thoughts"]
|
||||
if citations:
|
||||
citations = citations
|
||||
if followup_questions:
|
||||
metadata["followup_questions"] = followup_questions
|
||||
if thoughts:
|
||||
metadata["thoughts"] = thoughts
|
||||
answer = model_response["answer"].tool_calls[-1]["args"]["answer"]
|
||||
else:
|
||||
answer = model_response["answer"].content
|
||||
|
||||
sources = model_response["docs"] or []
|
||||
|
||||
if len(sources) > 0:
|
||||
sources_list = generate_source(sources, self.brain_id, citations=citations)
|
||||
serialized_sources_list = [source.dict() for source in sources_list]
|
||||
metadata["sources"] = serialized_sources_list
|
||||
|
||||
return self.save_non_streaming_answer(
|
||||
chat_id=chat_id, question=question, answer=answer, metadata=metadata
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> AsyncIterable:
|
||||
if hasattr(self, "get_chain") and callable(self.get_chain):
|
||||
conversational_qa_chain = self.get_chain()
|
||||
else:
|
||||
conversational_qa_chain = self.knowledge_qa.get_chain()
|
||||
(
|
||||
transformed_history,
|
||||
streamed_chat_history,
|
||||
) = await self.initialize_streamed_chat_history(chat_id, question)
|
||||
response_tokens = ""
|
||||
sources = []
|
||||
citations = []
|
||||
first = True
|
||||
config = {"metadata": {"conversation_id": str(chat_id)}}
|
||||
|
||||
async for chunk in conversational_qa_chain.astream(
|
||||
{
|
||||
"question": question.question,
|
||||
"chat_history": transformed_history,
|
||||
"custom_personality": (
|
||||
self.prompt_to_use.content if self.prompt_to_use else None
|
||||
),
|
||||
},
|
||||
config=config,
|
||||
):
|
||||
if not streamed_chat_history.metadata:
|
||||
streamed_chat_history.metadata = {}
|
||||
if model_compatible_with_function_calling(model=self.model):
|
||||
if chunk.get("answer"):
|
||||
if first:
|
||||
gathered = chunk["answer"]
|
||||
first = False
|
||||
else:
|
||||
gathered = gathered + chunk["answer"]
|
||||
if (
|
||||
gathered.tool_calls
|
||||
and gathered.tool_calls[-1].get("args")
|
||||
and "answer" in gathered.tool_calls[-1]["args"]
|
||||
):
|
||||
# Only send the difference between answer and response_tokens which was the previous answer
|
||||
answer = gathered.tool_calls[-1]["args"]["answer"]
|
||||
difference = answer[len(response_tokens) :]
|
||||
streamed_chat_history.assistant = difference
|
||||
response_tokens = answer
|
||||
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
if (
|
||||
gathered.tool_calls
|
||||
and gathered.tool_calls[-1].get("args")
|
||||
and "citations" in gathered.tool_calls[-1]["args"]
|
||||
):
|
||||
citations = gathered.tool_calls[-1]["args"]["citations"]
|
||||
if (
|
||||
gathered.tool_calls
|
||||
and gathered.tool_calls[-1].get("args")
|
||||
and "followup_questions" in gathered.tool_calls[-1]["args"]
|
||||
):
|
||||
followup_questions = gathered.tool_calls[-1]["args"][
|
||||
"followup_questions"
|
||||
]
|
||||
streamed_chat_history.metadata["followup_questions"] = (
|
||||
followup_questions
|
||||
)
|
||||
if (
|
||||
gathered.tool_calls
|
||||
and gathered.tool_calls[-1].get("args")
|
||||
and "thoughts" in gathered.tool_calls[-1]["args"]
|
||||
):
|
||||
thoughts = gathered.tool_calls[-1]["args"]["thoughts"]
|
||||
streamed_chat_history.metadata["thoughts"] = thoughts
|
||||
else:
|
||||
if chunk.get("answer"):
|
||||
response_tokens += chunk["answer"].content
|
||||
streamed_chat_history.assistant = chunk["answer"].content
|
||||
yield f"data: {streamed_chat_history.model_dump_json()}"
|
||||
|
||||
if chunk.get("docs"):
|
||||
sources = chunk["docs"]
|
||||
|
||||
sources_list = generate_source(sources, self.brain_id, citations)
|
||||
|
||||
# Serialize the sources list
|
||||
serialized_sources_list = [source.dict() for source in sources_list]
|
||||
streamed_chat_history.metadata["sources"] = serialized_sources_list
|
||||
yield f"data: {streamed_chat_history.model_dump_json()}"
|
||||
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
|
||||
|
||||
async def initialize_streamed_chat_history(self, chat_id, question):
|
||||
history = await self.chat_service.get_chat_history(self.chat_id)
|
||||
transformed_history = format_chat_history(history)
|
||||
brain = brain_service.get_brain_by_id(self.brain_id)
|
||||
|
||||
streamed_chat_history = self.chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": brain.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name if brain else None,
|
||||
"brain_id": str(brain.brain_id) if brain else None,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
)
|
||||
|
||||
return transformed_history, streamed_chat_history
|
||||
|
||||
def save_answer(
|
||||
self, question, response_tokens, streamed_chat_history, save_answer
|
||||
):
|
||||
assistant = "".join(response_tokens)
|
||||
|
||||
try:
|
||||
if save_answer:
|
||||
self.chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
metadata=streamed_chat_history.metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error updating message by ID: %s", e)
|
||||
|
||||
def save_non_streaming_answer(self, chat_id, question, answer, metadata):
|
||||
new_chat = self.chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": self.brain.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": self.brain.name if self.brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
@ -1,267 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, Awaitable, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import (
|
||||
BrainSettings,
|
||||
) # Importing settings related to the 'brain'
|
||||
from quivr_api.modules.brain.qa_interface import QAInterface
|
||||
from quivr_api.modules.brain.service.utils.format_chat_history import (
|
||||
format_chat_history,
|
||||
format_history_to_openai_mesages,
|
||||
)
|
||||
from quivr_api.modules.brain.service.utils.get_prompt_to_use_id import (
|
||||
get_prompt_to_use_id,
|
||||
)
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.prompt.service.get_prompt_to_use import get_prompt_to_use
|
||||
|
||||
logger = get_logger(__name__)
|
||||
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer.When answering use markdown or any other techniques to display the content in a nice and aerated way."
|
||||
chat_service = get_service(ChatService)()
|
||||
|
||||
|
||||
class HeadlessQA(BaseModel, QAInterface):
|
||||
brain_settings = BrainSettings()
|
||||
model: str
|
||||
temperature: float = 0.0
|
||||
max_tokens: int = 2000
|
||||
streaming: bool = False
|
||||
chat_id: str
|
||||
callbacks: Optional[List[AsyncIteratorCallbackHandler]] = None
|
||||
prompt_id: Optional[UUID] = None
|
||||
|
||||
def _determine_streaming(self, streaming: bool) -> bool:
|
||||
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
|
||||
return streaming
|
||||
|
||||
def _determine_callback_array(
|
||||
self, streaming
|
||||
) -> List[AsyncIteratorCallbackHandler]:
|
||||
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
|
||||
if streaming:
|
||||
return [AsyncIteratorCallbackHandler()]
|
||||
else:
|
||||
return []
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.streaming = self._determine_streaming(self.streaming)
|
||||
self.callbacks = self._determine_callback_array(self.streaming)
|
||||
|
||||
@property
|
||||
def prompt_to_use(self) -> str:
|
||||
return get_prompt_to_use(None, self.prompt_id)
|
||||
|
||||
@property
|
||||
def prompt_to_use_id(self) -> Optional[UUID]:
|
||||
return get_prompt_to_use_id(None, self.prompt_id)
|
||||
|
||||
def _create_llm(
|
||||
self,
|
||||
model,
|
||||
temperature=0,
|
||||
streaming=False,
|
||||
callbacks=None,
|
||||
) -> BaseChatModel:
|
||||
"""
|
||||
Determine the language model to be used.
|
||||
:param model: Language model name to be used.
|
||||
:param streaming: Whether to enable streaming of the model
|
||||
:param callbacks: Callbacks to be used for streaming
|
||||
:return: Language model instance
|
||||
"""
|
||||
api_base = None
|
||||
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
|
||||
api_base = self.brain_settings.ollama_api_base_url
|
||||
|
||||
return ChatLiteLLM(
|
||||
temperature=temperature,
|
||||
model=model,
|
||||
streaming=streaming,
|
||||
verbose=True,
|
||||
callbacks=callbacks,
|
||||
max_tokens=self.max_tokens,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
def _create_prompt_template(self):
|
||||
messages = [
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
return CHAT_PROMPT
|
||||
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
# Move format_chat_history to chat service ?
|
||||
transformed_history = format_chat_history(
|
||||
chat_service.get_chat_history(self.chat_id)
|
||||
)
|
||||
prompt_content = (
|
||||
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
|
||||
)
|
||||
|
||||
messages = format_history_to_openai_mesages(
|
||||
transformed_history, prompt_content, question.question
|
||||
)
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model,
|
||||
streaming=False,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
model_prediction = answering_llm.predict_messages(messages)
|
||||
answer = model_prediction.content
|
||||
if save_answer:
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": None,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": None,
|
||||
"message_id": new_chat.message_id,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": None,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": None,
|
||||
"message_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> AsyncIterable:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
self.callbacks = [callback]
|
||||
|
||||
transformed_history = format_chat_history(
|
||||
chat_service.get_chat_history(self.chat_id)
|
||||
)
|
||||
prompt_content = (
|
||||
self.prompt_to_use.content if self.prompt_to_use else SYSTEM_MESSAGE
|
||||
)
|
||||
|
||||
messages = format_history_to_openai_mesages(
|
||||
transformed_history, prompt_content, question.question
|
||||
)
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
headlessChain = LLMChain(llm=answering_llm, prompt=CHAT_PROMPT)
|
||||
|
||||
response_tokens = []
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
try:
|
||||
await fn
|
||||
except Exception as e:
|
||||
logger.error(f"Caught exception: {e}")
|
||||
finally:
|
||||
event.set()
|
||||
|
||||
run = asyncio.create_task(
|
||||
wrap_done(
|
||||
headlessChain.acall({}),
|
||||
callback.done,
|
||||
),
|
||||
)
|
||||
|
||||
if save_answer:
|
||||
streamed_chat_history = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"brain_id": None,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": streamed_chat_history.message_id,
|
||||
"message_time": streamed_chat_history.message_time,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
streamed_chat_history = GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": str(chat_id),
|
||||
"message_id": None,
|
||||
"message_time": None,
|
||||
"user_message": question.question,
|
||||
"assistant": "",
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": None,
|
||||
}
|
||||
)
|
||||
|
||||
async for token in callback.aiter():
|
||||
response_tokens.append(token)
|
||||
streamed_chat_history.assistant = token
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
await run
|
||||
assistant = "".join(response_tokens)
|
||||
|
||||
if save_answer:
|
||||
chat_service.update_message_by_id(
|
||||
message_id=str(streamed_chat_history.message_id),
|
||||
user_message=question.question,
|
||||
assistant=assistant,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
@ -1,58 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.chat.dto.chats import ChatQuestion
|
||||
|
||||
|
||||
def model_compatible_with_function_calling(model: str):
|
||||
return model in [
|
||||
"gpt-4o",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-0613",
|
||||
]
|
||||
|
||||
|
||||
class QAInterface(ABC):
|
||||
"""
|
||||
Abstract class for all QA interfaces.
|
||||
This can be used to implement custom answer generation logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate_pricing(self):
|
||||
raise NotImplementedError(
|
||||
"calculate_pricing is an abstract method and must be implemented"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def generate_answer(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool,
|
||||
*custom_params: tuple,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"generate_answer is an abstract method and must be implemented"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
save_answer: bool,
|
||||
*custom_params: tuple,
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"generate_stream is an abstract method and must be implemented"
|
||||
)
|
@ -41,12 +41,6 @@ logger = get_logger(__name__)
|
||||
class cited_answer(BaseModelV1):
|
||||
"""Answer the user question based only on the given sources, and cite the sources used."""
|
||||
|
||||
thoughts: str = FieldV1(
|
||||
...,
|
||||
description="""Description of the thought process, based only on the given sources.
|
||||
Cite the text as much as possible and give the document name it appears in. In the format : 'Doc_name states : cited_text'. Be the most
|
||||
procedural as possible. Write all the steps needed to find the answer until you find it.""",
|
||||
)
|
||||
answer: str = FieldV1(
|
||||
...,
|
||||
description="The answer to the user question, which is based only on the given sources.",
|
||||
@ -56,10 +50,6 @@ class cited_answer(BaseModelV1):
|
||||
description="The integer IDs of the SPECIFIC sources which justify the answer.",
|
||||
)
|
||||
|
||||
thoughts: str = FieldV1(
|
||||
...,
|
||||
description="Explain shortly what you did to find the answer and what you used by citing the sources by their name.",
|
||||
)
|
||||
followup_questions: List[str] = FieldV1(
|
||||
...,
|
||||
description="Generate up to 3 follow-up questions that could be asked based on the answer given or context provided.",
|
||||
|
@ -1,58 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.models.settings import get_supabase_client
|
||||
from quivr_api.modules.brain.dto.inputs import CreateApiBrainDefinition
|
||||
from quivr_api.modules.brain.entity.api_brain_definition_entity import (
|
||||
ApiBrainDefinitionEntity,
|
||||
)
|
||||
from quivr_api.modules.brain.repository.interfaces import ApiBrainDefinitionsInterface
|
||||
|
||||
|
||||
class ApiBrainDefinitions(ApiBrainDefinitionsInterface):
|
||||
def __init__(self):
|
||||
self.db = get_supabase_client()
|
||||
|
||||
def get_api_brain_definition(
|
||||
self, brain_id: UUID
|
||||
) -> Optional[ApiBrainDefinitionEntity]:
|
||||
response = (
|
||||
self.db.table("api_brain_definition")
|
||||
.select("*")
|
||||
.filter("brain_id", "eq", brain_id)
|
||||
.execute()
|
||||
)
|
||||
if len(response.data) == 0:
|
||||
return None
|
||||
|
||||
return ApiBrainDefinitionEntity(**response.data[0])
|
||||
|
||||
def add_api_brain_definition(
|
||||
self, brain_id: UUID, api_brain_definition: CreateApiBrainDefinition
|
||||
) -> Optional[ApiBrainDefinitionEntity]:
|
||||
response = (
|
||||
self.db.table("api_brain_definition")
|
||||
.insert([{"brain_id": str(brain_id), **api_brain_definition.dict()}])
|
||||
.execute()
|
||||
)
|
||||
if len(response.data) == 0:
|
||||
return None
|
||||
return ApiBrainDefinitionEntity(**response.data[0])
|
||||
|
||||
def update_api_brain_definition(
|
||||
self, brain_id: UUID, api_brain_definition: ApiBrainDefinitionEntity
|
||||
) -> Optional[ApiBrainDefinitionEntity]:
|
||||
response = (
|
||||
self.db.table("api_brain_definition")
|
||||
.update(api_brain_definition.dict(exclude={"brain_id"}))
|
||||
.filter("brain_id", "eq", str(brain_id))
|
||||
.execute()
|
||||
)
|
||||
if len(response.data) == 0:
|
||||
return None
|
||||
return ApiBrainDefinitionEntity(**response.data[0])
|
||||
|
||||
def delete_api_brain_definition(self, brain_id: UUID) -> None:
|
||||
self.db.table("api_brain_definition").delete().filter(
|
||||
"brain_id", "eq", str(brain_id)
|
||||
).execute()
|
@ -7,7 +7,7 @@ from quivr_api.models.settings import (
|
||||
get_supabase_client,
|
||||
)
|
||||
from quivr_api.modules.brain.dto.inputs import BrainUpdatableProperties
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity, PublicBrain
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity
|
||||
from quivr_api.modules.brain.repository.interfaces.brains_interface import (
|
||||
BrainsInterface,
|
||||
)
|
||||
@ -29,9 +29,6 @@ class Brains(BrainsInterface):
|
||||
brain_meaning = embeddings.embed_query(string_to_embed)
|
||||
brain_dict = brain.dict(
|
||||
exclude={
|
||||
"brain_definition",
|
||||
"brain_secrets_values",
|
||||
"connected_brains_ids",
|
||||
"integration",
|
||||
}
|
||||
)
|
||||
@ -40,27 +37,6 @@ class Brains(BrainsInterface):
|
||||
|
||||
return BrainEntity(**response.data[0])
|
||||
|
||||
def get_public_brains(self):
|
||||
response = (
|
||||
self.db.from_("brains")
|
||||
.select(
|
||||
"id:brain_id, name, description, last_update, brain_type, brain_definition: api_brain_definition(*), number_of_subscribers:brains_users(count)"
|
||||
)
|
||||
.filter("status", "eq", "public")
|
||||
.execute()
|
||||
)
|
||||
public_brains: list[PublicBrain] = []
|
||||
|
||||
for item in response.data:
|
||||
item["number_of_subscribers"] = item["number_of_subscribers"][0]["count"]
|
||||
if not item["brain_definition"]:
|
||||
del item["brain_definition"]
|
||||
else:
|
||||
item["brain_definition"]["secrets"] = []
|
||||
|
||||
public_brains.append(PublicBrain(**item))
|
||||
return public_brains
|
||||
|
||||
def update_brain_last_update_time(self, brain_id):
|
||||
try:
|
||||
with self.pg_engine.begin() as connection:
|
||||
|
@ -1,63 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import get_supabase_client
|
||||
from quivr_api.modules.brain.entity.composite_brain_connection_entity import (
|
||||
CompositeBrainConnectionEntity,
|
||||
)
|
||||
from quivr_api.modules.brain.repository.interfaces import (
|
||||
CompositeBrainsConnectionsInterface,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CompositeBrainsConnections(CompositeBrainsConnectionsInterface):
|
||||
def __init__(self):
|
||||
self.db = get_supabase_client()
|
||||
|
||||
def connect_brain(
|
||||
self, composite_brain_id: UUID, connected_brain_id: UUID
|
||||
) -> CompositeBrainConnectionEntity:
|
||||
response = (
|
||||
self.db.table("composite_brain_connections")
|
||||
.insert(
|
||||
{
|
||||
"composite_brain_id": str(composite_brain_id),
|
||||
"connected_brain_id": str(connected_brain_id),
|
||||
}
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return CompositeBrainConnectionEntity(**response.data[0])
|
||||
|
||||
def get_connected_brains(self, composite_brain_id: UUID) -> list[UUID]:
|
||||
response = (
|
||||
self.db.from_("composite_brain_connections")
|
||||
.select("connected_brain_id")
|
||||
.filter("composite_brain_id", "eq", str(composite_brain_id))
|
||||
.execute()
|
||||
)
|
||||
|
||||
return [item["connected_brain_id"] for item in response.data]
|
||||
|
||||
def disconnect_brain(
|
||||
self, composite_brain_id: UUID, connected_brain_id: UUID
|
||||
) -> None:
|
||||
self.db.table("composite_brain_connections").delete().match(
|
||||
{
|
||||
"composite_brain_id": composite_brain_id,
|
||||
"connected_brain_id": connected_brain_id,
|
||||
}
|
||||
).execute()
|
||||
|
||||
def is_connected_brain(self, brain_id: UUID) -> bool:
|
||||
response = (
|
||||
self.db.from_("composite_brain_connections")
|
||||
.select("connected_brain_id")
|
||||
.filter("connected_brain_id", "eq", str(brain_id))
|
||||
.execute()
|
||||
)
|
||||
|
||||
return len(response.data) > 0
|
@ -1,60 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.models.settings import get_supabase_client
|
||||
from quivr_api.modules.brain.repository.interfaces.external_api_secrets_interface import (
|
||||
ExternalApiSecretsInterface,
|
||||
)
|
||||
|
||||
|
||||
def build_secret_unique_name(user_id: UUID, brain_id: UUID, secret_name: str):
|
||||
return f"{user_id}-{brain_id}-{secret_name}"
|
||||
|
||||
|
||||
class ExternalApiSecrets(ExternalApiSecretsInterface):
|
||||
def __init__(self):
|
||||
supabase_client = get_supabase_client()
|
||||
self.db = supabase_client
|
||||
|
||||
def create_secret(
|
||||
self, user_id: UUID, brain_id: UUID, secret_name: str, secret_value
|
||||
) -> UUID | None:
|
||||
response = self.db.rpc(
|
||||
"insert_secret",
|
||||
{
|
||||
"name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
"secret": secret_value,
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
||||
|
||||
def read_secret(
|
||||
self,
|
||||
user_id: UUID,
|
||||
brain_id: UUID,
|
||||
secret_name: str,
|
||||
) -> UUID | None:
|
||||
response = self.db.rpc(
|
||||
"read_secret",
|
||||
{
|
||||
"secret_name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
||||
|
||||
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
|
||||
response = self.db.rpc(
|
||||
"delete_secret",
|
||||
{
|
||||
"secret_name": build_secret_unique_name(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret_name
|
||||
),
|
||||
},
|
||||
).execute()
|
||||
|
||||
return response.data
|
@ -1,9 +1,6 @@
|
||||
from .api_brain_definitions_interface import ApiBrainDefinitionsInterface
|
||||
# noqa
|
||||
from .brains_interface import BrainsInterface
|
||||
from .brains_users_interface import BrainsUsersInterface
|
||||
from .brains_vectors_interface import BrainsVectorsInterface
|
||||
from .composite_brains_connections_interface import \
|
||||
CompositeBrainsConnectionsInterface
|
||||
from .external_api_secrets_interface import ExternalApiSecretsInterface
|
||||
from .integration_brains_interface import (IntegrationBrainInterface,
|
||||
IntegrationDescriptionInterface)
|
||||
IntegrationDescriptionInterface)
|
||||
|
@ -1,38 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.brain.dto.inputs import CreateApiBrainDefinition
|
||||
from quivr_api.modules.brain.entity.api_brain_definition_entity import (
|
||||
ApiBrainDefinitionEntity,
|
||||
)
|
||||
|
||||
|
||||
class ApiBrainDefinitionsInterface(ABC):
|
||||
@abstractmethod
|
||||
def get_api_brain_definition(
|
||||
self, brain_id: UUID
|
||||
) -> Optional[ApiBrainDefinitionEntity]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_api_brain_definition(
|
||||
self, brain_id: UUID, api_brain_definition: CreateApiBrainDefinition
|
||||
) -> Optional[ApiBrainDefinitionEntity]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_api_brain_definition(
|
||||
self, brain_id: UUID, api_brain_definition: ApiBrainDefinitionEntity
|
||||
) -> Optional[ApiBrainDefinitionEntity]:
|
||||
"""
|
||||
Get all public brains
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_api_brain_definition(self, brain_id: UUID) -> None:
|
||||
"""
|
||||
Update the last update time of the brain
|
||||
"""
|
||||
pass
|
@ -5,7 +5,7 @@ from quivr_api.modules.brain.dto.inputs import (
|
||||
BrainUpdatableProperties,
|
||||
CreateBrainProperties,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity, PublicBrain
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity
|
||||
|
||||
|
||||
class BrainsInterface(ABC):
|
||||
@ -16,13 +16,6 @@ class BrainsInterface(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_public_brains(self) -> list[PublicBrain]:
|
||||
"""
|
||||
Get all public brains
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_brain_details(self, brain_id: UUID, user_id: UUID) -> BrainEntity | None:
|
||||
"""
|
||||
|
@ -1,40 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.brain.entity.composite_brain_connection_entity import (
|
||||
CompositeBrainConnectionEntity,
|
||||
)
|
||||
|
||||
|
||||
class CompositeBrainsConnectionsInterface(ABC):
|
||||
@abstractmethod
|
||||
def connect_brain(
|
||||
self, composite_brain_id: UUID, connected_brain_id: UUID
|
||||
) -> CompositeBrainConnectionEntity:
|
||||
"""
|
||||
Connect a brain to a composite brain in the composite_brain_connections table
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_connected_brains(self, composite_brain_id: UUID) -> list[UUID]:
|
||||
"""
|
||||
Get all brains connected to a composite brain
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def disconnect_brain(
|
||||
self, composite_brain_id: UUID, connected_brain_id: UUID
|
||||
) -> None:
|
||||
"""
|
||||
Disconnect a brain from a composite brain
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_connected_brain(self, brain_id: UUID) -> bool:
|
||||
"""
|
||||
Check if a brain is connected to any composite brain
|
||||
"""
|
||||
pass
|
@ -1,29 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class ExternalApiSecretsInterface(ABC):
|
||||
@abstractmethod
|
||||
def create_secret(
|
||||
self, user_id: UUID, brain_id: UUID, secret_name: str, secret_value
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Create a new secret for the API Request in given brain
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_secret(
|
||||
self, user_id: UUID, brain_id: UUID, secret_name: str
|
||||
) -> UUID | None:
|
||||
"""
|
||||
Read a secret for the API Request in given brain
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_secret(self, user_id: UUID, brain_id: UUID, secret_name: str) -> bool:
|
||||
"""
|
||||
Delete a secret from a brain
|
||||
"""
|
||||
pass
|
@ -1,36 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.brain.dto.inputs import CreateApiBrainDefinition
|
||||
from quivr_api.modules.brain.entity.api_brain_definition_entity import (
|
||||
ApiBrainDefinitionEntity,
|
||||
)
|
||||
from quivr_api.modules.brain.repository.api_brain_definitions import ApiBrainDefinitions
|
||||
from quivr_api.modules.brain.repository.interfaces import ApiBrainDefinitionsInterface
|
||||
|
||||
|
||||
class ApiBrainDefinitionService:
|
||||
repository: ApiBrainDefinitionsInterface
|
||||
|
||||
def __init__(self):
|
||||
self.repository = ApiBrainDefinitions()
|
||||
|
||||
def add_api_brain_definition(
|
||||
self, brain_id: UUID, api_brain_definition: CreateApiBrainDefinition
|
||||
) -> None:
|
||||
self.repository.add_api_brain_definition(brain_id, api_brain_definition)
|
||||
|
||||
def delete_api_brain_definition(self, brain_id: UUID) -> None:
|
||||
self.repository.delete_api_brain_definition(brain_id)
|
||||
|
||||
def get_api_brain_definition(
|
||||
self, brain_id: UUID
|
||||
) -> Optional[ApiBrainDefinitionEntity]:
|
||||
return self.repository.get_api_brain_definition(brain_id)
|
||||
|
||||
def update_api_brain_definition(
|
||||
self, brain_id: UUID, api_brain_definition: ApiBrainDefinitionEntity
|
||||
) -> Optional[ApiBrainDefinitionEntity]:
|
||||
return self.repository.update_api_brain_definition(
|
||||
brain_id, api_brain_definition
|
||||
)
|
@ -8,11 +8,7 @@ from quivr_api.modules.brain.dto.inputs import (
|
||||
BrainUpdatableProperties,
|
||||
CreateBrainProperties,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.brain_entity import (
|
||||
BrainEntity,
|
||||
BrainType,
|
||||
PublicBrain,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity, BrainType
|
||||
from quivr_api.modules.brain.entity.integration_brain import IntegrationEntity
|
||||
from quivr_api.modules.brain.repository import (
|
||||
Brains,
|
||||
@ -21,24 +17,18 @@ from quivr_api.modules.brain.repository import (
|
||||
IntegrationBrain,
|
||||
IntegrationDescription,
|
||||
)
|
||||
from quivr_api.modules.brain.service.api_brain_definition_service import (
|
||||
ApiBrainDefinitionService,
|
||||
)
|
||||
from quivr_api.modules.brain.service.utils.validate_brain import validate_api_brain
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
knowledge_service = KnowledgeService()
|
||||
api_brain_definition_service = ApiBrainDefinitionService()
|
||||
|
||||
|
||||
class BrainService:
|
||||
# brain_repository: BrainsInterface
|
||||
# brain_user_repository: BrainsUsersInterface
|
||||
# brain_vector_repository: BrainsVectorsInterface
|
||||
# external_api_secrets_repository: ExternalApiSecretsInterface
|
||||
# integration_brains_repository: IntegrationBrainInterface
|
||||
# integration_description_repository: IntegrationDescriptionInterface
|
||||
|
||||
@ -123,14 +113,7 @@ class BrainService:
|
||||
brain: Optional[CreateBrainProperties],
|
||||
) -> BrainEntity:
|
||||
if brain is None:
|
||||
brain = CreateBrainProperties() # type: ignore model and brain_definition
|
||||
|
||||
if brain.brain_type == BrainType.api:
|
||||
validate_api_brain(brain)
|
||||
return self.create_brain_api(user_id, brain)
|
||||
|
||||
if brain.brain_type == BrainType.composite:
|
||||
return self.create_brain_composite(brain)
|
||||
brain = CreateBrainProperties()
|
||||
|
||||
if brain.brain_type == BrainType.integration:
|
||||
return self.create_brain_integration(user_id, brain)
|
||||
@ -138,46 +121,6 @@ class BrainService:
|
||||
created_brain = self.brain_repository.create_brain(brain)
|
||||
return created_brain
|
||||
|
||||
def create_brain_api(
|
||||
self,
|
||||
user_id: UUID,
|
||||
brain: CreateBrainProperties,
|
||||
) -> BrainEntity:
|
||||
created_brain = self.brain_repository.create_brain(brain)
|
||||
|
||||
if brain.brain_definition is not None:
|
||||
api_brain_definition_service.add_api_brain_definition(
|
||||
brain_id=created_brain.brain_id,
|
||||
api_brain_definition=brain.brain_definition,
|
||||
)
|
||||
|
||||
secrets_values = brain.brain_secrets_values
|
||||
|
||||
for secret_name in secrets_values:
|
||||
self.external_api_secrets_repository.create_secret(
|
||||
user_id=user_id,
|
||||
brain_id=created_brain.brain_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secrets_values[secret_name],
|
||||
)
|
||||
|
||||
return created_brain
|
||||
|
||||
def create_brain_composite(
|
||||
self,
|
||||
brain: CreateBrainProperties,
|
||||
) -> BrainEntity:
|
||||
created_brain = self.brain_repository.create_brain(brain)
|
||||
|
||||
if brain.connected_brains_ids is not None:
|
||||
for connected_brain_id in brain.connected_brains_ids:
|
||||
self.composite_brains_connections_repository.connect_brain(
|
||||
composite_brain_id=created_brain.brain_id,
|
||||
connected_brain_id=connected_brain_id,
|
||||
)
|
||||
|
||||
return created_brain
|
||||
|
||||
def create_brain_integration(
|
||||
self,
|
||||
user_id: UUID,
|
||||
@ -203,37 +146,11 @@ class BrainService:
|
||||
)
|
||||
return created_brain
|
||||
|
||||
def delete_brain_secrets_values(self, brain_id: UUID) -> None:
|
||||
brain_definition = api_brain_definition_service.get_api_brain_definition(
|
||||
brain_id=brain_id
|
||||
)
|
||||
|
||||
if brain_definition is None:
|
||||
raise HTTPException(status_code=404, detail="Brain definition not found.")
|
||||
|
||||
secrets = brain_definition.secrets
|
||||
|
||||
if len(secrets) > 0:
|
||||
brain_users = self.brain_user_repository.get_brain_users(brain_id=brain_id)
|
||||
for user in brain_users:
|
||||
for secret in secrets:
|
||||
self.external_api_secrets_repository.delete_secret(
|
||||
user_id=user.user_id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret.name,
|
||||
)
|
||||
|
||||
def delete_brain(self, brain_id: UUID) -> dict[str, str]:
|
||||
brain_to_delete = self.get_brain_by_id(brain_id=brain_id)
|
||||
if brain_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="Brain not found.")
|
||||
|
||||
if brain_to_delete.brain_type == BrainType.api:
|
||||
self.delete_brain_secrets_values(
|
||||
brain_id=brain_id,
|
||||
)
|
||||
api_brain_definition_service.delete_api_brain_definition(brain_id=brain_id)
|
||||
else:
|
||||
knowledge_service.remove_brain_all_knowledge(brain_id)
|
||||
|
||||
self.brain_vector.delete_brain_vector(str(brain_id))
|
||||
@ -263,9 +180,7 @@ class BrainService:
|
||||
brain_update_answer = self.brain_repository.update_brain_by_id(
|
||||
brain_id,
|
||||
brain=BrainUpdatableProperties(
|
||||
**brain_new_values.dict(
|
||||
exclude={"brain_definition", "connected_brains_ids", "integration"}
|
||||
)
|
||||
**brain_new_values.dict(exclude={"integration"})
|
||||
),
|
||||
)
|
||||
|
||||
@ -275,35 +190,6 @@ class BrainService:
|
||||
detail=f"Brain with id {brain_id} not found",
|
||||
)
|
||||
|
||||
if (
|
||||
brain_update_answer.brain_type == BrainType.api
|
||||
and brain_new_values.brain_definition
|
||||
):
|
||||
existing_brain_secrets_definition = (
|
||||
existing_brain.brain_definition.secrets
|
||||
if existing_brain.brain_definition
|
||||
else None
|
||||
)
|
||||
brain_new_values_secrets_definition = (
|
||||
brain_new_values.brain_definition.secrets
|
||||
if brain_new_values.brain_definition
|
||||
else None
|
||||
)
|
||||
should_remove_existing_secrets_values = (
|
||||
existing_brain_secrets_definition
|
||||
and brain_new_values_secrets_definition
|
||||
and existing_brain_secrets_definition
|
||||
!= brain_new_values_secrets_definition
|
||||
)
|
||||
|
||||
if should_remove_existing_secrets_values:
|
||||
self.delete_brain_secrets_values(brain_id=brain_id)
|
||||
|
||||
api_brain_definition_service.update_api_brain_definition(
|
||||
brain_id,
|
||||
api_brain_definition=brain_new_values.brain_definition,
|
||||
)
|
||||
|
||||
if brain_update_answer is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@ -345,9 +231,6 @@ class BrainService:
|
||||
brain_id
|
||||
)
|
||||
|
||||
def get_public_brains(self) -> list[PublicBrain]:
|
||||
return self.brain_repository.get_public_brains()
|
||||
|
||||
def update_secret_value(
|
||||
self,
|
||||
user_id: UUID,
|
||||
|
@ -5,43 +5,32 @@ from fastapi import HTTPException
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.entity.brain_entity import (
|
||||
BrainEntity,
|
||||
BrainType,
|
||||
BrainUser,
|
||||
MinimalUserBrainEntity,
|
||||
RoleEnum,
|
||||
)
|
||||
from quivr_api.modules.brain.repository.brains import Brains
|
||||
from quivr_api.modules.brain.repository.brains_users import BrainsUsers
|
||||
from quivr_api.modules.brain.repository.external_api_secrets import ExternalApiSecrets
|
||||
from quivr_api.modules.brain.repository.interfaces.brains_interface import (
|
||||
BrainsInterface,
|
||||
)
|
||||
from quivr_api.modules.brain.repository.interfaces.brains_users_interface import (
|
||||
BrainsUsersInterface,
|
||||
)
|
||||
from quivr_api.modules.brain.repository.interfaces.external_api_secrets_interface import (
|
||||
ExternalApiSecretsInterface,
|
||||
)
|
||||
from quivr_api.modules.brain.service.api_brain_definition_service import (
|
||||
ApiBrainDefinitionService,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
brain_service = BrainService()
|
||||
api_brain_definition_service = ApiBrainDefinitionService()
|
||||
|
||||
|
||||
class BrainUserService:
|
||||
brain_repository: BrainsInterface
|
||||
brain_user_repository: BrainsUsersInterface
|
||||
external_api_secrets_repository: ExternalApiSecretsInterface
|
||||
|
||||
def __init__(self):
|
||||
self.brain_repository = Brains()
|
||||
self.brain_user_repository = BrainsUsers()
|
||||
self.external_api_secrets_repository = ExternalApiSecrets()
|
||||
|
||||
def get_user_default_brain(self, user_id: UUID) -> BrainEntity | None:
|
||||
brain_id = self.brain_user_repository.get_user_default_brain_id(user_id)
|
||||
@ -56,22 +45,6 @@ class BrainUserService:
|
||||
if brain_to_delete_user_from is None:
|
||||
raise HTTPException(status_code=404, detail="Brain not found.")
|
||||
|
||||
if brain_to_delete_user_from.brain_type == BrainType.api:
|
||||
brain_definition = api_brain_definition_service.get_api_brain_definition(
|
||||
brain_id=brain_id
|
||||
)
|
||||
if brain_definition is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Brain definition not found."
|
||||
)
|
||||
secrets = brain_definition.secrets
|
||||
for secret in secrets:
|
||||
self.external_api_secrets_repository.delete_secret(
|
||||
user_id=user_id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret.name,
|
||||
)
|
||||
|
||||
self.brain_user_repository.delete_brain_user_by_id(
|
||||
user_id=user_id,
|
||||
brain_id=brain_id,
|
||||
|
@ -1,116 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from quivr_api.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from fastapi import HTTPException
|
||||
from quivr_api.modules.brain.entity.api_brain_definition_entity import (
|
||||
ApiBrainDefinitionSchema,
|
||||
)
|
||||
from quivr_api.modules.brain.service.api_brain_definition_service import (
|
||||
ApiBrainDefinitionService,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
|
||||
brain_service = BrainService()
|
||||
api_brain_definition_service = ApiBrainDefinitionService()
|
||||
|
||||
|
||||
def get_api_call_response_as_text(
|
||||
method, api_url, params, search_params, secrets
|
||||
) -> str:
|
||||
headers = {}
|
||||
|
||||
api_url_with_search_params = api_url
|
||||
if search_params:
|
||||
api_url_with_search_params += "?"
|
||||
for search_param in search_params:
|
||||
api_url_with_search_params += (
|
||||
f"{search_param}={search_params[search_param]}&"
|
||||
)
|
||||
|
||||
for secret in secrets:
|
||||
headers[secret] = secrets[secret]
|
||||
|
||||
try:
|
||||
if method in ["GET", "DELETE"]:
|
||||
response = requests.request(
|
||||
method,
|
||||
url=api_url_with_search_params,
|
||||
params=params or None,
|
||||
headers=headers or None,
|
||||
)
|
||||
elif method in ["POST", "PUT", "PATCH"]:
|
||||
response = requests.request(
|
||||
method,
|
||||
url=api_url_with_search_params,
|
||||
json=params or None,
|
||||
headers=headers or None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling API: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_api_brain_definition_values_from_llm_output(
|
||||
brain_schema: ApiBrainDefinitionSchema, arguments: dict
|
||||
) -> dict:
|
||||
params_values = {}
|
||||
properties = brain_schema.properties
|
||||
required_values = brain_schema.required
|
||||
for property in properties:
|
||||
if property.name in arguments:
|
||||
if property.type == "number":
|
||||
params_values[property.name] = float(arguments[property.name])
|
||||
else:
|
||||
params_values[property.name] = arguments[property.name]
|
||||
continue
|
||||
|
||||
if property.name in required_values:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Required parameter {property.name} not found in arguments",
|
||||
)
|
||||
|
||||
return params_values
|
||||
|
||||
|
||||
def call_brain_api(brain_id: UUID, user_id: UUID, arguments: dict) -> str:
|
||||
brain_definition = api_brain_definition_service.get_api_brain_definition(brain_id)
|
||||
|
||||
if brain_definition is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Brain definition {brain_id} not found"
|
||||
)
|
||||
|
||||
brain_params_values = extract_api_brain_definition_values_from_llm_output(
|
||||
brain_definition.params, arguments
|
||||
)
|
||||
|
||||
brain_search_params_values = extract_api_brain_definition_values_from_llm_output(
|
||||
brain_definition.search_params, arguments
|
||||
)
|
||||
|
||||
secrets = brain_definition.secrets
|
||||
secrets_values = {}
|
||||
|
||||
for secret in secrets:
|
||||
secret_value = brain_service.external_api_secrets_repository.read_secret(
|
||||
user_id=user_id, brain_id=brain_id, secret_name=secret.name
|
||||
)
|
||||
secrets_values[secret.name] = secret_value
|
||||
|
||||
return get_api_call_response_as_text(
|
||||
api_url=brain_definition.url,
|
||||
params=brain_params_values,
|
||||
search_params=brain_search_params_values,
|
||||
secrets=secrets_values,
|
||||
method=brain_definition.method,
|
||||
)
|
@ -1,64 +0,0 @@
|
||||
import re
|
||||
|
||||
from fastapi import HTTPException
|
||||
from quivr_api.modules.brain.entity.api_brain_definition_entity import (
|
||||
ApiBrainDefinitionSchemaProperty,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity
|
||||
from quivr_api.modules.brain.service.api_brain_definition_service import (
|
||||
ApiBrainDefinitionService,
|
||||
)
|
||||
|
||||
api_brain_definition_service = ApiBrainDefinitionService()
|
||||
|
||||
|
||||
def sanitize_function_name(string):
|
||||
sanitized_string = re.sub(r"[^a-zA-Z0-9_-]", "", string)
|
||||
|
||||
return sanitized_string
|
||||
|
||||
|
||||
def format_api_brain_property(property: ApiBrainDefinitionSchemaProperty):
|
||||
property_data: dict = {
|
||||
"type": property.type,
|
||||
"description": property.description,
|
||||
}
|
||||
if property.enum:
|
||||
property_data["enum"] = property.enum
|
||||
return property_data
|
||||
|
||||
|
||||
def get_api_brain_definition_as_json_schema(brain: BrainEntity):
|
||||
api_brain_definition = api_brain_definition_service.get_api_brain_definition(
|
||||
brain.id
|
||||
)
|
||||
if not api_brain_definition:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Brain definition {brain.id} not found"
|
||||
)
|
||||
|
||||
required = []
|
||||
required.extend(api_brain_definition.params.required)
|
||||
required.extend(api_brain_definition.search_params.required)
|
||||
properties = {}
|
||||
|
||||
api_properties = (
|
||||
api_brain_definition.params.properties
|
||||
+ api_brain_definition.search_params.properties
|
||||
)
|
||||
|
||||
for property in api_properties:
|
||||
properties[property.name] = format_api_brain_property(property)
|
||||
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
schema = {
|
||||
"name": sanitize_function_name(brain.name),
|
||||
"description": brain.description,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
return schema
|
@ -1 +0,0 @@
|
||||
from .validate_brain import validate_api_brain
|
@ -1,13 +1,2 @@
|
||||
from fastapi import HTTPException
|
||||
from quivr_api.modules.brain.dto.inputs import CreateBrainProperties
|
||||
|
||||
|
||||
def validate_api_brain(brain: CreateBrainProperties):
|
||||
if brain.brain_definition is None:
|
||||
raise HTTPException(status_code=404, detail="Brain definition not found")
|
||||
|
||||
if brain.brain_definition.url is None:
|
||||
raise HTTPException(status_code=404, detail="Brain url not found")
|
||||
|
||||
if brain.brain_definition.method is None:
|
||||
raise HTTPException(status_code=404, detail="Brain method not found")
|
||||
|
@ -1,111 +0,0 @@
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainType, RoleEnum
|
||||
from quivr_api.modules.brain.integrations.Big.Brain import BigBrain
|
||||
from quivr_api.modules.brain.integrations.GPT4.Brain import GPT4Brain
|
||||
from quivr_api.modules.brain.integrations.Multi_Contract.Brain import MultiContractBrain
|
||||
from quivr_api.modules.brain.integrations.Notion.Brain import NotionBrain
|
||||
from quivr_api.modules.brain.integrations.Proxy.Brain import ProxyBrain
|
||||
from quivr_api.modules.brain.integrations.Self.Brain import SelfBrain
|
||||
from quivr_api.modules.brain.integrations.SQL.Brain import SQLBrain
|
||||
from quivr_api.modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from quivr_api.modules.brain.service.api_brain_definition_service import (
|
||||
ApiBrainDefinitionService,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_authorization_service import (
|
||||
validate_brain_authorization,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.brain.service.integration_brain_service import (
|
||||
IntegrationBrainDescriptionService,
|
||||
)
|
||||
from quivr_api.modules.chat.controller.chat.interface import ChatInterface
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
|
||||
chat_service = get_service(ChatService)()
|
||||
api_brain_definition_service = ApiBrainDefinitionService()
|
||||
integration_brain_description_service = IntegrationBrainDescriptionService()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
models_supporting_function_calls = [
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0613",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4o",
|
||||
]
|
||||
|
||||
|
||||
integration_list = {
|
||||
"notion": NotionBrain,
|
||||
"gpt4": GPT4Brain,
|
||||
"sql": SQLBrain,
|
||||
"big": BigBrain,
|
||||
"doc": KnowledgeBrainQA,
|
||||
"proxy": ProxyBrain,
|
||||
"self": SelfBrain,
|
||||
"multi-contract": MultiContractBrain,
|
||||
}
|
||||
|
||||
brain_service = BrainService()
|
||||
|
||||
|
||||
def validate_authorization(user_id, brain_id):
|
||||
if brain_id:
|
||||
validate_brain_authorization(
|
||||
brain_id=brain_id,
|
||||
user_id=user_id,
|
||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
||||
)
|
||||
|
||||
|
||||
# TODO: redo this
|
||||
class BrainfulChat(ChatInterface):
|
||||
def get_answer_generator(
|
||||
self,
|
||||
brain,
|
||||
chat_id,
|
||||
chat_service,
|
||||
model,
|
||||
temperature,
|
||||
streaming,
|
||||
prompt_id,
|
||||
user_id,
|
||||
user_email,
|
||||
):
|
||||
if brain and brain.brain_type == BrainType.doc:
|
||||
return KnowledgeBrainQA(
|
||||
chat_service=chat_service,
|
||||
chat_id=chat_id,
|
||||
brain_id=str(brain.brain_id),
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
)
|
||||
|
||||
if brain.brain_type == BrainType.integration:
|
||||
integration_brain = integration_brain_description_service.get_integration_description_by_user_brain_id(
|
||||
brain.brain_id, user_id
|
||||
)
|
||||
|
||||
integration_class = integration_list.get(
|
||||
integration_brain.integration_name.lower()
|
||||
)
|
||||
if integration_class:
|
||||
return integration_class(
|
||||
chat_service=chat_service,
|
||||
chat_id=chat_id,
|
||||
temperature=temperature,
|
||||
brain_id=str(brain.brain_id),
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
)
|
@ -1,26 +0,0 @@
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from quivr_api.modules.chat.controller.chat.interface import ChatInterface
|
||||
|
||||
|
||||
class BrainlessChat(ChatInterface):
|
||||
def validate_authorization(self, user_id, brain_id):
|
||||
pass
|
||||
|
||||
def get_answer_generator(
|
||||
self,
|
||||
chat_id,
|
||||
model,
|
||||
max_tokens,
|
||||
temperature,
|
||||
streaming,
|
||||
prompt_id,
|
||||
user_id,
|
||||
):
|
||||
return HeadlessQA(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
)
|
@ -1,11 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from .brainful_chat import BrainfulChat
|
||||
from .brainless_chat import BrainlessChat
|
||||
|
||||
|
||||
def get_chat_strategy(brain_id: UUID | None = None):
|
||||
if brain_id:
|
||||
return BrainfulChat()
|
||||
else:
|
||||
return BrainlessChat()
|
@ -5,25 +5,26 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.models.settings import get_embedding_client, get_supabase_client
|
||||
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
|
||||
from quivr_api.modules.brain.service.brain_authorization_service import (
|
||||
validate_brain_authorization,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.chat.controller.chat.brainful_chat import (
|
||||
BrainfulChat, validate_authorization)
|
||||
from quivr_api.modules.chat.dto.chats import ChatItem, ChatQuestion
|
||||
from quivr_api.modules.chat.dto.inputs import (ChatMessageProperties,
|
||||
ChatUpdatableProperties,
|
||||
CreateChatProperties,
|
||||
QuestionAndAnswer)
|
||||
from quivr_api.modules.chat.dto.inputs import (
|
||||
ChatMessageProperties,
|
||||
ChatUpdatableProperties,
|
||||
CreateChatProperties,
|
||||
QuestionAndAnswer,
|
||||
)
|
||||
from quivr_api.modules.chat.entity.chat import Chat
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.knowledge.repository.knowledges import \
|
||||
KnowledgeRepository
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.prompt.service.prompt_service import PromptService
|
||||
from quivr_api.modules.rag_service import RAGService
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.packages.utils.telemetry import maybe_send_telemetry
|
||||
from quivr_api.vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -37,52 +38,13 @@ ChatServiceDep = Annotated[ChatService, Depends(get_service(ChatService))]
|
||||
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
|
||||
|
||||
|
||||
def init_vector_store(user_id: UUID) -> CustomSupabaseVectorStore:
|
||||
"""
|
||||
Initialize the vector store
|
||||
"""
|
||||
supabase_client = get_supabase_client()
|
||||
embedding_service = get_embedding_client()
|
||||
vector_store = CustomSupabaseVectorStore(
|
||||
supabase_client, embedding_service, table_name="vectors", user_id=user_id
|
||||
)
|
||||
|
||||
return vector_store
|
||||
|
||||
|
||||
async def get_answer_generator(
|
||||
chat_id: UUID,
|
||||
chat_question: ChatQuestion,
|
||||
chat_service: ChatService,
|
||||
brain_id: UUID | None,
|
||||
current_user: UserIdentity,
|
||||
):
|
||||
chat_instance = BrainfulChat()
|
||||
vector_store = init_vector_store(user_id=current_user.id)
|
||||
|
||||
# Get History only if needed
|
||||
if not brain_id:
|
||||
history = await chat_service.get_chat_history(chat_id)
|
||||
else:
|
||||
history = []
|
||||
|
||||
# TODO(@aminediro) : NOT USED anymore
|
||||
brain, metadata_brain = brain_service.find_brain_from_question(
|
||||
brain_id, chat_question.question, current_user, chat_id, history, vector_store
|
||||
)
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
brain=brain,
|
||||
chat_id=str(chat_id),
|
||||
chat_service=chat_service,
|
||||
model=brain.model,
|
||||
temperature=0.1,
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
user_id=current_user.id,
|
||||
user_email=current_user.email,
|
||||
)
|
||||
|
||||
return gpt_answer_generator
|
||||
def validate_authorization(user_id, brain_id):
|
||||
if brain_id:
|
||||
validate_brain_authorization(
|
||||
brain_id=brain_id,
|
||||
user_id=user_id,
|
||||
required_roles=[RoleEnum.Viewer, RoleEnum.Editor, RoleEnum.Owner],
|
||||
)
|
||||
|
||||
|
||||
@chat_router.get("/chat/healthz", tags=["Health"])
|
||||
|
@ -14,7 +14,6 @@ class ChatMessage(BaseModel):
|
||||
history: List[Tuple[str, str]]
|
||||
temperature: float = 0.0
|
||||
max_tokens: int = 256
|
||||
use_summarization: bool = False
|
||||
chat_id: Optional[UUID] = None
|
||||
chat_name: Optional[str] = None
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
from .contact_routes import contact_router
|
@ -1,42 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.contact_support.controller.settings import ContactsSettings
|
||||
from quivr_api.packages.emails.send_email import send_email
|
||||
|
||||
|
||||
class ContactMessage(BaseModel):
|
||||
customer_email: str
|
||||
content: str
|
||||
|
||||
|
||||
contact_router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def resend_contact_sales_email(customer_email: str, content: str):
|
||||
settings = ContactsSettings()
|
||||
mail_from = settings.resend_contact_sales_from
|
||||
mail_to = settings.resend_contact_sales_to
|
||||
body = f"""
|
||||
<p>Customer email: {customer_email}</p>
|
||||
<p>{content}</p>
|
||||
"""
|
||||
params = {
|
||||
"from": mail_from,
|
||||
"to": mail_to,
|
||||
"subject": "Contact sales",
|
||||
"reply_to": customer_email,
|
||||
"html": body,
|
||||
}
|
||||
|
||||
return send_email(params)
|
||||
|
||||
|
||||
@contact_router.post("/contact")
|
||||
def post_contact(message: ContactMessage):
|
||||
try:
|
||||
resend_contact_sales_email(message.customer_email, message.content)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return {"error": "There was an error sending the email"}
|
@ -1,7 +0,0 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class ContactsSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(validate_default=False)
|
||||
resend_contact_sales_from: str = "null"
|
||||
resend_contact_sales_to: str = "null"
|
@ -11,8 +11,7 @@ from langchain_community.document_loaders import PlaywrightURLLoader
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import BrainSettings
|
||||
from quivr_api.modules.contact_support.controller.settings import ContactsSettings
|
||||
from quivr_api.models.settings import BrainSettings, SendEmailSettings
|
||||
from quivr_api.packages.emails.send_email import send_email
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -32,7 +31,7 @@ class EmailSenderTool(BaseTool):
|
||||
description = "useful for when you need to send an email."
|
||||
args_schema: Type[BaseModel] = EmailInput
|
||||
brain_settings: BrainSettings = BrainSettings()
|
||||
contact_settings: ContactsSettings = ContactsSettings()
|
||||
contact_settings: SendEmailSettings = SendEmailSettings()
|
||||
|
||||
def _run(
|
||||
self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
|
@ -33,11 +33,6 @@ notification_service = NotificationService()
|
||||
knowledge_service = KnowledgeService()
|
||||
|
||||
|
||||
@upload_router.get("/upload/healthz", tags=["Health"])
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@upload_router.post("/upload", dependencies=[Depends(AuthBearer())], tags=["Upload"])
|
||||
async def upload_file(
|
||||
uploadFile: UploadFile,
|
||||
|
@ -41,7 +41,6 @@ class Users(UsersInterface):
|
||||
|
||||
user_identity = response.data[0]
|
||||
|
||||
print("USER_IDENTITY", user_identity)
|
||||
return UserIdentity(id=user_id)
|
||||
|
||||
def get_user_identity(self, user_id):
|
||||
|
@ -1,202 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import msal
|
||||
import requests
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
CLIENT_ID = "511dce23-02f3-4724-8684-05da226df5f3"
|
||||
AUTHORITY = "https://login.microsoftonline.com/common"
|
||||
REDIRECT_URI = "http://localhost:8000/oauth2callback"
|
||||
SCOPE = [
|
||||
"https://graph.microsoft.com/Files.Read",
|
||||
"https://graph.microsoft.com/User.Read",
|
||||
"https://graph.microsoft.com/Sites.Read.All",
|
||||
]
|
||||
|
||||
client = msal.PublicClientApplication(CLIENT_ID, authority=AUTHORITY)
|
||||
|
||||
|
||||
def get_token_data():
|
||||
if not os.path.exists("azure_token.json"):
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
with open("azure_token.json", "r") as token_file:
|
||||
token_data = json.load(token_file)
|
||||
if "access_token" not in token_data:
|
||||
raise HTTPException(status_code=401, detail="Invalid token data")
|
||||
return token_data
|
||||
|
||||
|
||||
def refresh_token():
|
||||
if not os.path.exists("azure_token.json"):
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
with open("azure_token.json", "r") as token_file:
|
||||
token_data = json.load(token_file)
|
||||
if "refresh_token" not in token_data:
|
||||
raise HTTPException(status_code=401, detail="No refresh token available")
|
||||
|
||||
result = client.acquire_token_by_refresh_token(
|
||||
token_data["refresh_token"], scopes=SCOPE
|
||||
)
|
||||
if "access_token" not in result:
|
||||
raise HTTPException(status_code=400, detail="Failed to refresh token")
|
||||
|
||||
with open("azure_token.json", "w") as token:
|
||||
json.dump(result, token)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_headers(token_data):
|
||||
return {
|
||||
"Authorization": f"Bearer {token_data['access_token']}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/authorize")
|
||||
def authorize():
|
||||
authorization_url = client.get_authorization_request_url(
|
||||
scopes=SCOPE, redirect_uri=REDIRECT_URI
|
||||
)
|
||||
return JSONResponse(content={"authorization_url": authorization_url})
|
||||
|
||||
|
||||
@app.get("/oauth2callback")
|
||||
def oauth2callback(request: Request):
|
||||
code = request.query_params.get("code")
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Authorization code not found")
|
||||
|
||||
result = client.acquire_token_by_authorization_code(
|
||||
code, scopes=SCOPE, redirect_uri=REDIRECT_URI
|
||||
)
|
||||
if "access_token" not in result:
|
||||
print(f"Token acquisition failed: {result}")
|
||||
raise HTTPException(status_code=400, detail="Failed to acquire token")
|
||||
|
||||
with open("azure_token.json", "w") as token:
|
||||
json.dump(result, token)
|
||||
|
||||
return JSONResponse(content={"message": "Authentication successful"})
|
||||
|
||||
|
||||
@app.get("/list_sites")
|
||||
def list_sites(token_data: dict = Depends(get_token_data)):
|
||||
headers = get_headers(token_data)
|
||||
endpoint = "https://graph.microsoft.com/v1.0/sites?search=*"
|
||||
response = requests.get(endpoint, headers=headers)
|
||||
if response.status_code == 401:
|
||||
token_data = refresh_token()
|
||||
headers = get_headers(token_data)
|
||||
response = requests.get(endpoint, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
sites = response.json().get("value", [])
|
||||
return JSONResponse(content={"sites": sites})
|
||||
|
||||
|
||||
def extract_files_and_folders(items, headers, page_size):
|
||||
result = []
|
||||
for item in items:
|
||||
entry = {
|
||||
"name": item.get("name"),
|
||||
"id": item.get("id"),
|
||||
"parentReference": item.get("parentReference"),
|
||||
"lastModifiedDateTime": item.get("lastModifiedDateTime"),
|
||||
"webUrl": item.get("webUrl"),
|
||||
"size": item.get("size"),
|
||||
"fileSystemInfo": item.get("fileSystemInfo"),
|
||||
"folder": item.get("folder"),
|
||||
"file": item.get("file"),
|
||||
}
|
||||
if "folder" in item:
|
||||
folder_endpoint = f"https://graph.microsoft.com/v1.0/me/drive/items/{item['id']}/children?$top={page_size}"
|
||||
children = []
|
||||
while folder_endpoint:
|
||||
folder_response = requests.get(folder_endpoint, headers=headers)
|
||||
if folder_response.status_code == 200:
|
||||
children_page = folder_response.json().get("value", [])
|
||||
children.extend(children_page)
|
||||
folder_endpoint = folder_response.json().get(
|
||||
"@odata.nextLink", None
|
||||
)
|
||||
else:
|
||||
break
|
||||
entry["children"] = extract_files_and_folders(children, headers, page_size)
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
|
||||
def fetch_all_files(headers, page_size):
|
||||
endpoint = (
|
||||
f"https://graph.microsoft.com/v1.0/me/drive/root/children?$top={page_size}"
|
||||
)
|
||||
all_files = []
|
||||
while endpoint:
|
||||
response = requests.get(endpoint, headers=headers)
|
||||
if response.status_code == 401:
|
||||
token_data = refresh_token()
|
||||
headers = get_headers(token_data)
|
||||
response = requests.get(endpoint, headers=headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
files = response.json().get("value", [])
|
||||
all_files.extend(files)
|
||||
endpoint = response.json().get("@odata.nextLink", None)
|
||||
return all_files
|
||||
|
||||
|
||||
@app.get("/list_files")
|
||||
def list_files(page_size: int = 1, token_data: dict = Depends(get_token_data)):
|
||||
headers = get_headers(token_data)
|
||||
all_files = fetch_all_files(headers, page_size)
|
||||
structured_files = extract_files_and_folders(all_files, headers, page_size)
|
||||
return JSONResponse(content={"files": structured_files})
|
||||
|
||||
|
||||
@app.get("/download_file/{file_id}")
|
||||
def download_file(file_id: str, token_data: dict = Depends(get_token_data)):
|
||||
headers = get_headers(token_data)
|
||||
metadata_endpoint = f"https://graph.microsoft.com/v1.0/me/drive/items/{file_id}"
|
||||
metadata_response = requests.get(metadata_endpoint, headers=headers)
|
||||
if metadata_response.status_code == 401:
|
||||
token_data = refresh_token()
|
||||
headers = get_headers(token_data)
|
||||
metadata_response = requests.get(metadata_endpoint, headers=headers)
|
||||
if metadata_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=metadata_response.status_code, detail=metadata_response.text
|
||||
)
|
||||
metadata = metadata_response.json()
|
||||
if "folder" in metadata:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="The specified ID is a folder, not a file"
|
||||
)
|
||||
download_endpoint = (
|
||||
f"https://graph.microsoft.com/v1.0/me/drive/items/{file_id}/content"
|
||||
)
|
||||
download_response = requests.get(download_endpoint, headers=headers, stream=True)
|
||||
if download_response.status_code == 401:
|
||||
token_data = refresh_token()
|
||||
headers = get_headers(token_data)
|
||||
download_response = requests.get(
|
||||
download_endpoint, headers=headers, stream=True
|
||||
)
|
||||
if download_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=download_response.status_code, detail=download_response.text
|
||||
)
|
||||
return StreamingResponse(
|
||||
download_response.iter_content(chunk_size=1024),
|
||||
headers={"Content-Disposition": f"attachment; filename={metadata.get('name')}"},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
@ -1,91 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from google.auth.transport.requests import Request as GoogleRequest
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
CLIENT_SECRETS_FILE = "credentials.json"
|
||||
REDIRECT_URI = "http://localhost:8000/oauth2callback"
|
||||
|
||||
# Disable OAuthlib's HTTPS verification when running locally.
|
||||
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
|
||||
|
||||
|
||||
@app.get("/authorize")
|
||||
def authorize():
|
||||
flow = Flow.from_client_secrets_file(
|
||||
CLIENT_SECRETS_FILE, scopes=SCOPES, redirect_uri=REDIRECT_URI
|
||||
)
|
||||
authorization_url, state = flow.authorization_url(
|
||||
access_type="offline", include_granted_scopes="true"
|
||||
)
|
||||
# Store the state in session to validate the callback later
|
||||
with open("state.json", "w") as state_file:
|
||||
json.dump({"state": state}, state_file)
|
||||
return JSONResponse(content={"authorization_url": authorization_url})
|
||||
|
||||
|
||||
@app.get("/oauth2callback")
|
||||
def oauth2callback(request: Request):
|
||||
state = request.query_params.get("state")
|
||||
with open("state.json", "r") as state_file:
|
||||
saved_state = json.load(state_file)["state"]
|
||||
|
||||
if state != saved_state:
|
||||
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
||||
|
||||
flow = Flow.from_client_secrets_file(
|
||||
CLIENT_SECRETS_FILE, scopes=SCOPES, state=state, redirect_uri=REDIRECT_URI
|
||||
)
|
||||
flow.fetch_token(authorization_response=str(request.url))
|
||||
creds = flow.credentials
|
||||
|
||||
# Save the credentials for future use
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
|
||||
return JSONResponse(content={"message": "Authentication successful"})
|
||||
|
||||
|
||||
@app.get("/list_files")
|
||||
def list_files():
|
||||
creds = None
|
||||
if os.path.exists("token.json"):
|
||||
creds = Credentials.from_authorized_user_file("token.json", SCOPES)
|
||||
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
creds.refresh(GoogleRequest())
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail="Credentials are not valid")
|
||||
|
||||
try:
|
||||
service = build("drive", "v3", credentials=creds)
|
||||
results = (
|
||||
service.files()
|
||||
.list(pageSize=10, fields="nextPageToken, files(id, name)")
|
||||
.execute()
|
||||
)
|
||||
items = results.get("files", [])
|
||||
|
||||
if not items:
|
||||
return JSONResponse(content={"files": "No files found."})
|
||||
|
||||
files = [{"name": item["name"], "id": item["id"]} for item in items]
|
||||
return JSONResponse(content={"files": files})
|
||||
except HttpError as error:
|
||||
raise HTTPException(status_code=500, detail=f"An error occurred: {error}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
@ -8,9 +8,6 @@ from quivr_api.middlewares.auth.auth_bearer import AuthBearer, get_current_user
|
||||
from quivr_api.models.brains_subscription_invitations import BrainSubscription
|
||||
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
|
||||
from quivr_api.modules.brain.repository import IntegrationBrain
|
||||
from quivr_api.modules.brain.service.api_brain_definition_service import (
|
||||
ApiBrainDefinitionService,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_authorization_service import (
|
||||
has_brain_authorization,
|
||||
validate_brain_authorization,
|
||||
@ -35,7 +32,6 @@ user_service = UserService()
|
||||
prompt_service = PromptService()
|
||||
brain_user_service = BrainUserService()
|
||||
brain_service = BrainService()
|
||||
api_brain_definition_service = ApiBrainDefinitionService()
|
||||
integration_brains_repository = IntegrationBrain()
|
||||
|
||||
|
||||
@ -425,26 +421,6 @@ async def subscribe_to_brain_handler(
|
||||
status_code=403,
|
||||
detail="You are already subscribed to this brain",
|
||||
)
|
||||
if brain.brain_type == "api":
|
||||
brain_definition = api_brain_definition_service.get_api_brain_definition(
|
||||
brain_id
|
||||
)
|
||||
brain_secrets = brain_definition.secrets if brain_definition != None else []
|
||||
|
||||
for secret in brain_secrets:
|
||||
if not secrets[secret.name]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Please provide the secret {secret}",
|
||||
)
|
||||
|
||||
for secret in brain_secrets:
|
||||
brain_service.external_api_secrets_repository.create_secret(
|
||||
user_id=current_user.id,
|
||||
brain_id=brain_id,
|
||||
secret_name=secret.name,
|
||||
secret_value=secrets[secret.name],
|
||||
)
|
||||
|
||||
try:
|
||||
brain_user_service.create_brain_user(
|
||||
|
@ -142,7 +142,6 @@ def generate_replies(
|
||||
contexts = []
|
||||
test_questions = test_data.question.tolist()
|
||||
test_groundtruths = test_data.ground_truth.tolist()
|
||||
thoughts = []
|
||||
|
||||
for question in test_questions:
|
||||
response = brain_chain.invoke({"question": question, "chat_history": []})
|
||||
@ -151,14 +150,12 @@ def generate_replies(
|
||||
]["arguments"]
|
||||
cited_answer_obj = json.loads(cited_answer_data)
|
||||
answers.append(cited_answer_obj["answer"])
|
||||
thoughts.append(cited_answer_obj["thoughts"])
|
||||
contexts.append([context.page_content for context in response["docs"]])
|
||||
|
||||
return Dataset.from_dict(
|
||||
{
|
||||
"question": test_questions,
|
||||
"answer": answers,
|
||||
"thoughs": thoughts,
|
||||
"contexts": contexts,
|
||||
"ground_truth": test_groundtruths,
|
||||
}
|
||||
|
@ -17,12 +17,6 @@ class cited_answer(BaseModelV1):
|
||||
...,
|
||||
description="The answer to the user question, which is based only on the given sources.",
|
||||
)
|
||||
thoughts: str = FieldV1(
|
||||
...,
|
||||
description="""Description of the thought process, based only on the given sources.
|
||||
Cite the text as much as possible and give the document name it appears in. In the format : 'Doc_name states : cited_text'. Be the most
|
||||
procedural as possible. Write all the steps needed to find the answer until you find it.""",
|
||||
)
|
||||
citations: list[int] = FieldV1(
|
||||
...,
|
||||
description="The integer IDs of the SPECIFIC sources which justify the answer.",
|
||||
@ -63,7 +57,6 @@ class RawRAGResponse(TypedDict):
|
||||
|
||||
class RAGResponseMetadata(BaseModel):
|
||||
citations: list[int] | None = None
|
||||
thoughts: str | list[str] | None = None
|
||||
followup_questions: list[str] | None = None
|
||||
sources: list[Any] | None = None
|
||||
|
||||
|
@ -4,7 +4,6 @@ from typing import Any, List, Tuple, no_type_check
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.prompts import format_document
|
||||
|
||||
from quivr_core.models import (
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
@ -71,10 +70,6 @@ def get_chunk_metadata(
|
||||
followup_questions = gathered_args["followup_questions"]
|
||||
metadata["followup_questions"] = followup_questions
|
||||
|
||||
if "thoughts" in gathered_args:
|
||||
thoughts = gathered_args["thoughts"]
|
||||
metadata["thoughts"] = thoughts
|
||||
|
||||
return RAGResponseMetadata(**metadata)
|
||||
|
||||
|
||||
@ -127,12 +122,8 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe
|
||||
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
|
||||
"followup_questions"
|
||||
]
|
||||
thoughts = raw_response["answer"].tool_calls[-1]["args"]["thoughts"]
|
||||
if followup_questions:
|
||||
metadata["followup_questions"] = followup_questions
|
||||
if thoughts:
|
||||
metadata["thoughts"] = thoughts
|
||||
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
|
||||
|
||||
parsed_response = ParsedRAGResponse(
|
||||
answer=answer, metadata=RAGResponseMetadata(**metadata)
|
||||
|
@ -1,7 +1,6 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import LLMEndpointConfig, RAGConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
@ -63,7 +62,6 @@ async def test_quivrqarag(
|
||||
# TODO(@aminediro) : test responses with sources
|
||||
assert last_response.metadata.sources == []
|
||||
assert last_response.metadata.citations == []
|
||||
assert last_response.metadata.thoughts and len(last_response.metadata.thoughts) > 0
|
||||
|
||||
# Assert whole response makes sense
|
||||
assert "".join([r.answer for r in stream_responses]) == full_response
|
||||
|
Loading…
Reference in New Issue
Block a user