mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-23 12:26:03 +03:00
refactor(backend): cleaning dead and unused code (#1432)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
1cd99ae234
commit
ca1ef8ccbd
@ -1,11 +1,4 @@
|
||||
from .base import BaseBrainPicking
|
||||
from .qa_base import QABaseBrainPicking
|
||||
from .openai import OpenAIBrainPicking
|
||||
from .qa_headless import HeadlessQA
|
||||
|
||||
__all__ = [
|
||||
"BaseBrainPicking",
|
||||
"QABaseBrainPicking",
|
||||
"OpenAIBrainPicking",
|
||||
"HeadlessQA"
|
||||
]
|
||||
__all__ = ["QABaseBrainPicking", "HeadlessQA"]
|
||||
|
@ -1,101 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncIterable, List
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from pydantic import BaseModel # For data validation and settings management
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseBrainPicking(BaseModel):
|
||||
"""
|
||||
Base Class for BrainPicking. Allows you to interact with LLMs (large language models)
|
||||
Use this class to define abstract methods and methods and properties common to all classes.
|
||||
"""
|
||||
|
||||
# Instantiate settings
|
||||
brain_settings = BrainSettings() # type: ignore other parameters are optional
|
||||
|
||||
# Default class attributes
|
||||
model: str = None # pyright: ignore reportPrivateUsage=none
|
||||
temperature: float = 0.1
|
||||
chat_id: str = None # pyright: ignore reportPrivateUsage=none
|
||||
brain_id: str = None # pyright: ignore reportPrivateUsage=none
|
||||
max_tokens: int = 256
|
||||
user_openai_api_key: str = None # pyright: ignore reportPrivateUsage=none
|
||||
streaming: bool = False
|
||||
|
||||
openai_api_key: str = None # pyright: ignore reportPrivateUsage=none
|
||||
callbacks: List[
|
||||
AsyncIteratorCallbackHandler
|
||||
] = None # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
def _determine_api_key(self, openai_api_key, user_openai_api_key):
|
||||
"""If user provided an API key, use it."""
|
||||
if user_openai_api_key is not None:
|
||||
return user_openai_api_key
|
||||
else:
|
||||
return openai_api_key
|
||||
|
||||
def _determine_streaming(self, model: str, streaming: bool) -> bool:
|
||||
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
|
||||
return streaming
|
||||
|
||||
def _determine_callback_array(
|
||||
self, streaming
|
||||
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
|
||||
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
|
||||
if streaming:
|
||||
return [
|
||||
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
|
||||
]
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
self.openai_api_key = self._determine_api_key(
|
||||
self.brain_settings.openai_api_key, self.user_openai_api_key
|
||||
)
|
||||
self.streaming = self._determine_streaming(
|
||||
self.model, self.streaming
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
self.callbacks = self._determine_callback_array(
|
||||
self.streaming
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
class Config:
|
||||
"""Configuration of the Pydantic Object"""
|
||||
|
||||
# Allowing arbitrary types for class validation
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# the below methods define the names, arguments and return types for the most useful functions for the child classes. These should be overwritten if they are used.
|
||||
@abstractmethod
|
||||
def generate_answer(self, question: str) -> str:
|
||||
"""
|
||||
Generate an answer to a given question using QA Chain.
|
||||
:param question: The question
|
||||
:return: The generated answer.
|
||||
|
||||
This function should also call: _create_qa, get_chat_history and format_chat_history.
|
||||
It should also update the chat_history in the DB.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(self, question: str) -> AsyncIterable:
|
||||
"""
|
||||
Generate a streaming answer to a given question using QA Chain.
|
||||
:param question: The question
|
||||
:return: An async iterable which generates the answer.
|
||||
|
||||
This function has to do some other things:
|
||||
- Update the chat history in the DB with the chat details(chat_id, question) -> Return a message_id and timestamp
|
||||
- Use the _acall_chain method inside create_task from asyncio to run the process on a child thread.
|
||||
- Append each token to the chat_history object from the db and yield it from the function
|
||||
- Append each token from the callback to an answer string -> Used to update chat history in DB (update_message_by_id)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Async generation not implemented for this BrainPicking Class."
|
||||
)
|
@ -1,12 +0,0 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class FunctionCall:
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
arguments: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.arguments = arguments
|
@ -1,13 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from .FunctionCall import FunctionCall
|
||||
|
||||
|
||||
class OpenAiAnswer:
|
||||
def __init__(
|
||||
self,
|
||||
content: Optional[str] = None,
|
||||
function_call: FunctionCall = None, # pyright: ignore reportPrivateUsage=none
|
||||
):
|
||||
self.content = content
|
||||
self.function_call = function_call
|
@ -1,50 +0,0 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIBrainPicking(QABaseBrainPicking):
|
||||
"""
|
||||
Main class for the OpenAI Brain Picking functionality.
|
||||
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
|
||||
"""
|
||||
|
||||
# Default class attributes
|
||||
model: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
brain_id: str,
|
||||
temperature: float,
|
||||
chat_id: str,
|
||||
max_tokens: int,
|
||||
user_openai_api_key: str,
|
||||
prompt_id: Optional[UUID],
|
||||
streaming: bool = False,
|
||||
) -> "OpenAIBrainPicking": # pyright: ignore reportPrivateUsage=none
|
||||
"""
|
||||
Initialize the BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
|
||||
:return: OpenAIBrainPicking instance
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
brain_id=brain_id,
|
||||
chat_id=chat_id,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> OpenAIEmbeddings:
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_key=self.openai_api_key
|
||||
) # pyright: ignore reportPrivateUsage=none
|
@ -1,12 +0,0 @@
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
prompt_template = """Your name is Quivr. You are a second brain. A person will ask you a question and you will provide a helpful answer. Write the answer in the same language as the question. If you don't know the answer, just say that you don't know. Don't try to make up an answer. Use the following context to answer the question:
|
||||
|
||||
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
QA_PROMPT = PromptTemplate(
|
||||
template=prompt_template, input_variables=["context", "question"]
|
||||
)
|
@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncIterable, Awaitable, Optional
|
||||
from typing import AsyncIterable, Awaitable, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
@ -16,8 +17,10 @@ from langchain.prompts.chat import (
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
from pydantic import BaseModel
|
||||
from repository.brain import get_brain_by_id
|
||||
from repository.chat import (
|
||||
GetChatHistoryOutput,
|
||||
@ -29,14 +32,13 @@ from repository.chat import (
|
||||
from supabase.client import Client, create_client
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
from .base import BaseBrainPicking
|
||||
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
|
||||
logger = get_logger(__name__)
|
||||
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
|
||||
|
||||
class QABaseBrainPicking(BaseBrainPicking):
|
||||
class QABaseBrainPicking(BaseModel):
|
||||
"""
|
||||
Main class for the Brain Picking functionality.
|
||||
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
|
||||
@ -46,6 +48,55 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
Each have the same prompt template, which is defined in the `prompt_template` property.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration of the Pydantic Object"""
|
||||
|
||||
# Allowing arbitrary types for class validation
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# Instantiate settings
|
||||
brain_settings = BrainSettings() # type: ignore other parameters are optional
|
||||
|
||||
# Default class attributes
|
||||
model: str = None # pyright: ignore reportPrivateUsage=none
|
||||
temperature: float = 0.1
|
||||
chat_id: str = None # pyright: ignore reportPrivateUsage=none
|
||||
brain_id: str = None # pyright: ignore reportPrivateUsage=none
|
||||
max_tokens: int = 256
|
||||
user_openai_api_key: str = None # pyright: ignore reportPrivateUsage=none
|
||||
streaming: bool = False
|
||||
|
||||
openai_api_key: str = None # pyright: ignore reportPrivateUsage=none
|
||||
callbacks: List[
|
||||
AsyncIteratorCallbackHandler
|
||||
] = None # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
def _determine_api_key(self, openai_api_key, user_openai_api_key):
|
||||
"""If user provided an API key, use it."""
|
||||
if user_openai_api_key is not None:
|
||||
return user_openai_api_key
|
||||
else:
|
||||
return openai_api_key
|
||||
|
||||
def _determine_streaming(self, model: str, streaming: bool) -> bool:
|
||||
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
|
||||
return streaming
|
||||
|
||||
def _determine_callback_array(
|
||||
self, streaming
|
||||
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
|
||||
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
|
||||
if streaming:
|
||||
return [
|
||||
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
|
||||
]
|
||||
|
||||
@property
|
||||
def embeddings(self) -> OpenAIEmbeddings:
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_key=self.openai_api_key
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
supabase_client: Optional[Client] = None
|
||||
vector_store: Optional[CustomSupabaseVectorStore] = None
|
||||
qa: Optional[ConversationalRetrievalChain] = None
|
||||
|
@ -1,13 +1,16 @@
|
||||
from .create_brain import create_brain
|
||||
from .get_brain_by_id import get_brain_by_id
|
||||
from .update_brain import update_brain_by_id
|
||||
from .get_user_brains import get_user_brains
|
||||
from .get_brain_details import get_brain_details
|
||||
from .create_brain_user import create_brain_user
|
||||
from .delete_brain_users import delete_brain_users
|
||||
from .get_brain_by_id import get_brain_by_id
|
||||
from .get_public_brains import get_public_brains
|
||||
from .get_brain_details import get_brain_details
|
||||
from .get_brain_for_user import get_brain_for_user
|
||||
from .get_brain_prompt_id import get_brain_prompt_id
|
||||
from .update_user_rights import update_brain_user_rights
|
||||
from .get_default_user_brain import get_user_default_brain
|
||||
from .set_as_default_brain_for_user import set_as_default_brain_for_user
|
||||
from .get_default_user_brain_or_create_new import get_default_user_brain_or_create_new
|
||||
from .get_default_user_brain_or_create_new import \
|
||||
get_default_user_brain_or_create_new
|
||||
from .get_question_context_from_brain import get_question_context_from_brain
|
||||
from .get_user_brains import get_user_brains
|
||||
from .set_as_default_brain_for_user import set_as_default_brain_for_user
|
||||
from .update_brain import update_brain_by_id
|
||||
from .update_user_rights import update_brain_user_rights
|
||||
|
@ -13,40 +13,29 @@ from models.databases.supabase.brains import (
|
||||
from repository.brain import (
|
||||
create_brain,
|
||||
create_brain_user,
|
||||
delete_brain_users,
|
||||
get_brain_details,
|
||||
get_default_user_brain_or_create_new,
|
||||
get_public_brains,
|
||||
get_question_context_from_brain,
|
||||
get_user_brains,
|
||||
get_user_default_brain,
|
||||
set_as_default_brain_for_user,
|
||||
update_brain_by_id,
|
||||
)
|
||||
from repository.brain.delete_brain_users import delete_brain_users
|
||||
from repository.brain.get_public_brains import get_public_brains
|
||||
from repository.prompt import delete_prompt_by_id, get_prompt_by_id
|
||||
|
||||
from routes.authorizations.brain_authorization import has_brain_authorization
|
||||
from routes.authorizations.types import RoleEnum
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
brain_router = APIRouter()
|
||||
|
||||
|
||||
# get all brains
|
||||
@brain_router.get("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
|
||||
async def brain_endpoint(
|
||||
async def retrieve_all_brains_for_user(
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve all brains for the current user.
|
||||
|
||||
- `current_user`: The current authenticated user.
|
||||
- Returns a list of all brains registered for the user.
|
||||
|
||||
This endpoint retrieves all the brains associated with the current authenticated user. It returns a list of brains objects
|
||||
containing the brain ID and brain name for each brain.
|
||||
"""
|
||||
"""Retrieve all brains for the current user."""
|
||||
brains = get_user_brains(current_user.id)
|
||||
return {"brains": brains}
|
||||
|
||||
@ -54,30 +43,18 @@ async def brain_endpoint(
|
||||
@brain_router.get(
|
||||
"/brains/public", dependencies=[Depends(AuthBearer())], tags=["Brain"]
|
||||
)
|
||||
async def public_brains_endpoint() -> list[PublicBrain]:
|
||||
"""
|
||||
Retrieve all Quivr public brains
|
||||
"""
|
||||
async def retrieve_public_brains() -> list[PublicBrain]:
|
||||
"""Retrieve all Quivr public brains."""
|
||||
return get_public_brains()
|
||||
|
||||
|
||||
# get default brain
|
||||
@brain_router.get(
|
||||
"/brains/default/", dependencies=[Depends(AuthBearer())], tags=["Brain"]
|
||||
)
|
||||
async def get_default_brain_endpoint(
|
||||
async def retrieve_default_brain(
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve the default brain for the current user. If the user doesnt have one, it creates one.
|
||||
|
||||
- `current_user`: The current authenticated user.
|
||||
- Returns the default brain for the user.
|
||||
|
||||
This endpoint retrieves the default brain associated with the current authenticated user.
|
||||
The default brain is defined as the brain marked as default in the brains_users table.
|
||||
"""
|
||||
|
||||
"""Retrieve or create the default brain for the current user."""
|
||||
brain = get_default_user_brain_or_create_new(current_user)
|
||||
return {"id": brain.brain_id, "name": brain.name, "rights": "Owner"}
|
||||
|
||||
@ -87,64 +64,35 @@ async def get_default_brain_endpoint(
|
||||
dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def get_brain_endpoint(
|
||||
brain_id: UUID,
|
||||
):
|
||||
"""
|
||||
Retrieve details of a specific brain by brain ID.
|
||||
|
||||
- `brain_id`: The ID of the brain to retrieve details for.
|
||||
- Returns the brain ID and its history.
|
||||
|
||||
This endpoint retrieves the details of a specific brain identified by the provided brain ID. It returns the brain ID and its
|
||||
history, which includes the brain messages exchanged in the brain.
|
||||
"""
|
||||
|
||||
async def retrieve_brain_by_id(brain_id: UUID):
|
||||
"""Retrieve details of a specific brain by its ID."""
|
||||
brain_details = get_brain_details(brain_id)
|
||||
if brain_details is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Brain details not found",
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Brain details not found")
|
||||
return brain_details
|
||||
|
||||
|
||||
# create new brain
|
||||
@brain_router.post("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
|
||||
async def create_brain_endpoint(
|
||||
brain: CreateBrainProperties,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
async def create_new_brain(
|
||||
brain: CreateBrainProperties, current_user: UserIdentity = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new brain with given
|
||||
name
|
||||
status
|
||||
model
|
||||
max_tokens
|
||||
temperature
|
||||
In the brains table & in the brains_users table and put the creator user as 'Owner'
|
||||
"""
|
||||
|
||||
"""Create a new brain for the user."""
|
||||
user_brains = get_user_brains(current_user.id)
|
||||
userDailyUsage = UserUsage(
|
||||
user_usage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
openai_api_key=current_user.openai_api_key,
|
||||
)
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
user_settings = user_usage.get_user_settings()
|
||||
|
||||
if len(user_brains) >= userSettings.get("max_brains", 5):
|
||||
if len(user_brains) >= user_settings.get("max_brains", 5):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Maximum number of brains reached ({userSettings.get('max_brains', 5)}).",
|
||||
detail=f"Maximum number of brains reached ({user_settings.get('max_brains', 5)}).",
|
||||
)
|
||||
|
||||
new_brain = create_brain(
|
||||
brain,
|
||||
)
|
||||
default_brain = get_user_default_brain(current_user.id)
|
||||
if default_brain:
|
||||
new_brain = create_brain(brain)
|
||||
if get_user_default_brain(current_user.id):
|
||||
logger.info(f"Default brain already exists for user {current_user.id}")
|
||||
create_brain_user(
|
||||
user_id=current_user.id,
|
||||
@ -153,9 +101,7 @@ async def create_brain_endpoint(
|
||||
is_default_brain=False,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Default brain does not exist for user {current_user.id}. It will be created."
|
||||
)
|
||||
logger.info(f"Creating default brain for user {current_user.id}.")
|
||||
create_brain_user(
|
||||
user_id=current_user.id,
|
||||
brain_id=new_brain.brain_id,
|
||||
@ -163,97 +109,56 @@ async def create_brain_endpoint(
|
||||
is_default_brain=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": new_brain.brain_id,
|
||||
"name": brain.name,
|
||||
"rights": "Owner",
|
||||
}
|
||||
return {"id": new_brain.brain_id, "name": brain.name, "rights": "Owner"}
|
||||
|
||||
|
||||
# update existing brain
|
||||
@brain_router.put(
|
||||
"/brains/{brain_id}/",
|
||||
dependencies=[
|
||||
Depends(
|
||||
AuthBearer(),
|
||||
),
|
||||
Depends(AuthBearer()),
|
||||
Depends(has_brain_authorization([RoleEnum.Editor, RoleEnum.Owner])),
|
||||
],
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def update_brain_endpoint(
|
||||
brain_id: UUID,
|
||||
brain_to_update: BrainUpdatableProperties,
|
||||
async def update_existing_brain(
|
||||
brain_id: UUID, brain_update_data: BrainUpdatableProperties
|
||||
):
|
||||
"""
|
||||
Update an existing brain with new brain configuration
|
||||
"""
|
||||
|
||||
# Remove prompt if it is private and no longer used by brain
|
||||
"""Update an existing brain's configuration."""
|
||||
existing_brain = get_brain_details(brain_id)
|
||||
if existing_brain is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Brain not found",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="Brain not found")
|
||||
|
||||
if brain_to_update.prompt_id is None:
|
||||
prompt_id = existing_brain.prompt_id
|
||||
if prompt_id is not None:
|
||||
prompt = get_prompt_by_id(prompt_id)
|
||||
if prompt is not None and prompt.status == "private":
|
||||
delete_prompt_by_id(prompt_id)
|
||||
if brain_update_data.prompt_id is None and existing_brain.prompt_id:
|
||||
prompt = get_prompt_by_id(existing_brain.prompt_id)
|
||||
if prompt and prompt.status == "private":
|
||||
delete_prompt_by_id(existing_brain.prompt_id)
|
||||
|
||||
if brain_to_update.status == "private" and existing_brain.status == "public":
|
||||
if brain_update_data.status == "private" and existing_brain.status == "public":
|
||||
delete_brain_users(brain_id)
|
||||
|
||||
update_brain_by_id(brain_id, brain_to_update)
|
||||
|
||||
update_brain_by_id(brain_id, brain_update_data)
|
||||
return {"message": f"Brain {brain_id} has been updated."}
|
||||
|
||||
|
||||
# set as default brain
|
||||
@brain_router.post(
|
||||
"/brains/{brain_id}/default",
|
||||
dependencies=[
|
||||
Depends(
|
||||
AuthBearer(),
|
||||
),
|
||||
Depends(has_brain_authorization()),
|
||||
],
|
||||
dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def set_as_default_brain_endpoint(
|
||||
brain_id: UUID,
|
||||
user: UserIdentity = Depends(get_current_user),
|
||||
async def set_brain_as_default(
|
||||
brain_id: UUID, user: UserIdentity = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Set a brain as default for the current user.
|
||||
"""
|
||||
|
||||
"""Set a brain as the default for the current user."""
|
||||
set_as_default_brain_for_user(user.id, brain_id)
|
||||
|
||||
return {"message": f"Brain {brain_id} has been set as default brain."}
|
||||
|
||||
|
||||
@brain_router.post(
|
||||
"/brains/{brain_id}/question_context",
|
||||
dependencies=[
|
||||
Depends(
|
||||
AuthBearer(),
|
||||
),
|
||||
Depends(has_brain_authorization()),
|
||||
],
|
||||
dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
|
||||
tags=["Brain"],
|
||||
)
|
||||
async def get_question_context_from_brain_endpoint(
|
||||
brain_id: UUID,
|
||||
request: BrainQuestionRequest,
|
||||
):
|
||||
"""
|
||||
Get question context from brain
|
||||
"""
|
||||
|
||||
async def get_question_context_for_brain(brain_id: UUID, request: BrainQuestionRequest):
|
||||
"""Retrieve the question context from a specific brain."""
|
||||
context = get_question_context_from_brain(brain_id, request.question)
|
||||
|
||||
return {"context": context}
|
||||
|
@ -6,7 +6,7 @@ from venv import logger
|
||||
from auth import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from llm.openai import OpenAIBrainPicking
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from models import (
|
||||
Brain,
|
||||
@ -36,7 +36,6 @@ from repository.chat.get_chat_history_with_notifications import (
|
||||
)
|
||||
from repository.notification.remove_chat_notifications import remove_chat_notifications
|
||||
from repository.user_identity import get_user_identity
|
||||
|
||||
from routes.authorizations.brain_authorization import validate_brain_authorization
|
||||
from routes.authorizations.types import RoleEnum
|
||||
|
||||
@ -241,9 +240,9 @@ async def create_question_handler(
|
||||
try:
|
||||
check_user_requests_limit(current_user)
|
||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
|
||||
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
||||
if brain_id:
|
||||
gpt_answer_generator = OpenAIBrainPicking(
|
||||
gpt_answer_generator = QABaseBrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
max_tokens=chat_question.max_tokens,
|
||||
@ -331,14 +330,14 @@ async def create_stream_question_handler(
|
||||
try:
|
||||
logger.info(f"Streaming request for {chat_question.model}")
|
||||
check_user_requests_limit(current_user)
|
||||
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
|
||||
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
||||
# TODO check if model is in the list of models available for the user
|
||||
|
||||
print(userSettings.get("models", ["gpt-3.5-turbo"])) # type: ignore
|
||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
|
||||
if brain_id:
|
||||
gpt_answer_generator = OpenAIBrainPicking(
|
||||
gpt_answer_generator = QABaseBrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
|
||||
|
Loading…
Reference in New Issue
Block a user