mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
feat(brainpicking): simplified (#371)
* feat(functions): simplified * refactor(openai): changed to brainpicking * feat(functions): made them inherit from brainpicking * feat(privatebrainpicking): added new class * feat(history&context): added * Delete test_brainpicking.py * Delete __init__.py
This commit is contained in:
parent
572fc7e1b0
commit
5fc837b250
@ -0,0 +1,198 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from llm.brainpicking import BrainPicking
|
||||
from llm.BrainPickingOpenAIFunctions.models.OpenAiAnswer import OpenAiAnswer
|
||||
from logger import get_logger
|
||||
from models.settings import BrainSettings
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from supabase import Client, create_client
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
from .utils.format_answer import format_answer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BrainPickingOpenAIFunctions(BrainPicking):
|
||||
DEFAULT_MODEL = "gpt-3.5-turbo-0613"
|
||||
DEFAULT_TEMPERATURE = 0.0
|
||||
DEFAULT_MAX_TOKENS = 256
|
||||
|
||||
openai_client: ChatOpenAI = None
|
||||
user_email: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
chat_id: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_email: str,
|
||||
user_openai_api_key: str,
|
||||
) -> None:
|
||||
# Call the constructor of the parent class (BrainPicking)
|
||||
super().__init__(
|
||||
model=model,
|
||||
user_id=user_email,
|
||||
chat_id=chat_id,
|
||||
max_tokens=max_tokens,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
temperature=temperature,
|
||||
)
|
||||
self.openai_client = ChatOpenAI(openai_api_key=self.settings.openai_api_key)
|
||||
self.user_email = user_email
|
||||
|
||||
def _get_model_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve a model response given messages and functions
|
||||
"""
|
||||
logger.info("Getting model response")
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"model": self.llm_name,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
|
||||
if functions:
|
||||
logger.info("Adding functions to model response")
|
||||
kwargs["functions"] = functions
|
||||
|
||||
return self.openai_client.completion_with_retry(**kwargs)
|
||||
|
||||
def _get_chat_history(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Retrieves the chat history in a formatted list
|
||||
"""
|
||||
logger.info("Getting chat history")
|
||||
history = get_chat_history(self.chat_id)
|
||||
return [
|
||||
item
|
||||
for chat in history
|
||||
for item in [
|
||||
{"role": "user", "content": chat.user_message},
|
||||
{"role": "assistant", "content": chat.assistant},
|
||||
]
|
||||
]
|
||||
|
||||
def _get_context(self, question: str) -> str:
|
||||
"""
|
||||
Retrieve documents related to the question
|
||||
"""
|
||||
logger.info("Getting context")
|
||||
vector_store = CustomSupabaseVectorStore(
|
||||
self.supabase_client,
|
||||
self.embeddings,
|
||||
table_name="vectors",
|
||||
user_id=self.user_email,
|
||||
)
|
||||
|
||||
return vector_store.similarity_search(query=question)
|
||||
|
||||
def _construct_prompt(
|
||||
self, question: str, useContext: bool = False, useHistory: bool = False
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Constructs a prompt given a question, and optionally include context and history
|
||||
"""
|
||||
logger.info("Constructing prompt")
|
||||
system_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Your name is Quivr. You are a second brain. A person will ask you a question and you will provide a helpful answer. Write the answer in the same language as the question.If you don't know the answer, just say that you don't know. Don't try to make up an answer.our main goal is to answer questions about user uploaded documents. Unless basic questions or greetings, you should always refer to user uploaded documents by fetching them with the get_context function.",
|
||||
}
|
||||
]
|
||||
|
||||
if useHistory:
|
||||
logger.info("Adding chat history to prompt")
|
||||
history = self._get_chat_history()
|
||||
system_messages.append(
|
||||
{"role": "system", "content": "Previous messages are already in chat."}
|
||||
)
|
||||
system_messages.extend(history)
|
||||
|
||||
if useContext:
|
||||
logger.info("Adding chat context to prompt")
|
||||
chat_context = self._get_context(question)
|
||||
context_message = f"Here is chat context: {chat_context if chat_context else 'No document found'}"
|
||||
system_messages.append({"role": "user", "content": context_message})
|
||||
|
||||
system_messages.append({"role": "user", "content": question})
|
||||
|
||||
return system_messages
|
||||
|
||||
def generate_answer(self, question: str) -> str:
|
||||
"""
|
||||
Main function to get an answer for the given question
|
||||
"""
|
||||
logger.info("Getting answer")
|
||||
functions = [
|
||||
{
|
||||
"name": "get_history",
|
||||
"description": "Used to get the chat history between the user and the assistant",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "get_context",
|
||||
"description": "Used for retrieving documents related to the question to help answer the question",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "get_history_and_context",
|
||||
"description": "Used for retrieving documents related to the question to help answer the question and the previous chat history",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
# First, try to get an answer using just the question
|
||||
response = self._get_model_response(
|
||||
messages=self._construct_prompt(question), functions=functions
|
||||
)
|
||||
formatted_response = format_answer(response)
|
||||
|
||||
# If the model calls for history, try again with history included
|
||||
if (
|
||||
formatted_response.function_call
|
||||
and formatted_response.function_call.name == "get_history"
|
||||
):
|
||||
logger.info("Model called for history")
|
||||
response = self._get_model_response(
|
||||
messages=self._construct_prompt(question, useHistory=True),
|
||||
functions=functions,
|
||||
)
|
||||
formatted_response = format_answer(response)
|
||||
|
||||
# If the model calls for context, try again with context included
|
||||
if (
|
||||
formatted_response.function_call
|
||||
and formatted_response.function_call.name == "get_context"
|
||||
):
|
||||
logger.info("Model called for context")
|
||||
response = self._get_model_response(
|
||||
messages=self._construct_prompt(
|
||||
question, useContext=True, useHistory=False
|
||||
),
|
||||
functions=functions,
|
||||
)
|
||||
formatted_response = format_answer(response)
|
||||
|
||||
if (
|
||||
formatted_response.function_call
|
||||
and formatted_response.function_call.name == "get_history_and_context"
|
||||
):
|
||||
logger.info("Model called for history and context")
|
||||
response = self._get_model_response(
|
||||
messages=self._construct_prompt(
|
||||
question, useContext=True, useHistory=True
|
||||
),
|
||||
functions=functions,
|
||||
)
|
||||
formatted_response = format_answer(response)
|
||||
|
||||
return formatted_response.content or ""
|
@ -1,7 +1,8 @@
|
||||
from llm.OpenAiFunctionBasedAnswerGenerator.models.OpenAiAnswer import OpenAiAnswer
|
||||
from llm.OpenAiFunctionBasedAnswerGenerator.models.FunctionCall import FunctionCall
|
||||
from typing import Any, Dict # For type hinting
|
||||
|
||||
from llm.BrainPickingOpenAIFunctions.models.FunctionCall import FunctionCall
|
||||
from llm.BrainPickingOpenAIFunctions.models.OpenAiAnswer import OpenAiAnswer
|
||||
|
||||
|
||||
def format_answer(model_response: Dict[str, Any]) -> OpenAiAnswer:
|
||||
answer = model_response["choices"][0]["message"]
|
@ -1,245 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from typing import Any, Dict, List # For type hinting
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from .utils.format_answer import format_answer
|
||||
|
||||
|
||||
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
|
||||
from models.settings import BrainSettings # Importing settings related to the 'brain'
|
||||
from llm.OpenAiFunctionBasedAnswerGenerator.models.OpenAiAnswer import OpenAiAnswer
|
||||
|
||||
from supabase import Client, create_client # For interacting with Supabase database
|
||||
from vectorstore.supabase import (
|
||||
CustomSupabaseVectorStore,
|
||||
) # Custom class for handling vector storage with Supabase
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
get_context_function_name = "get_context"
|
||||
|
||||
prompt_template = """Your name is Quivr. You are a second brain.
|
||||
A person will ask you a question and you will provide a helpful answer.
|
||||
Write the answer in the same language as the question.
|
||||
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
||||
Your main goal is to answer questions about user uploaded documents. Unless basic questions or greetings, you should always refer to user uploaded documents by fetching them with the {} function.""".format(
|
||||
get_context_function_name
|
||||
)
|
||||
|
||||
get_history_schema = {
|
||||
"name": "get_history",
|
||||
"description": "Get current chat previous messages",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
get_context_schema = {
|
||||
"name": get_context_function_name,
|
||||
"description": "A function which returns user uploaded documents and which must be used when you don't now the answer to a question or when the question seems to refer to user uploaded documents",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAiFunctionBasedAnswerGenerator:
|
||||
# Default class attributes
|
||||
model: str = "gpt-3.5-turbo-0613"
|
||||
temperature: float = 0.0
|
||||
max_tokens: int = 256
|
||||
chat_id: str
|
||||
supabase_client: Client = None
|
||||
embeddings: OpenAIEmbeddings = None
|
||||
settings = BrainSettings()
|
||||
openai_client: ChatOpenAI = None
|
||||
user_email: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
chat_id: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_email: str,
|
||||
user_openai_api_key: str,
|
||||
) -> "OpenAiFunctionBasedAnswerGenerator":
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.chat_id = chat_id
|
||||
self.supabase_client = create_client(
|
||||
self.settings.supabase_url, self.settings.supabase_service_key
|
||||
)
|
||||
self.user_email = user_email
|
||||
if user_openai_api_key is not None:
|
||||
self.settings.openai_api_key = user_openai_api_key
|
||||
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key)
|
||||
self.openai_client = ChatOpenAI(openai_api_key=self.settings.openai_api_key)
|
||||
|
||||
def _get_model_response(
|
||||
self,
|
||||
messages: list[dict[str, str]] = [],
|
||||
functions: list[dict[str, Any]] = None,
|
||||
):
|
||||
if functions is not None:
|
||||
model_response = self.openai_client.completion_with_retry(
|
||||
functions=functions,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
else:
|
||||
model_response = self.openai_client.completion_with_retry(
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
def _get_formatted_history(self) -> List[Dict[str, str]]:
|
||||
formatted_history = []
|
||||
history = get_chat_history(self.chat_id)
|
||||
for chat in history:
|
||||
formatted_history.append({"role": "user", "content": chat.user_message})
|
||||
formatted_history.append({"role": "assistant", "content": chat.assistant})
|
||||
return formatted_history
|
||||
|
||||
def _get_formatted_prompt(
|
||||
self,
|
||||
question: Optional[str],
|
||||
useContext: Optional[bool] = False,
|
||||
useHistory: Optional[bool] = False,
|
||||
) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{"role": "system", "content": prompt_template},
|
||||
]
|
||||
|
||||
if not useHistory and not useContext:
|
||||
messages.append(
|
||||
{"role": "user", "content": question},
|
||||
)
|
||||
return messages
|
||||
|
||||
if useHistory:
|
||||
history = self._get_formatted_history()
|
||||
if len(history):
|
||||
messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Previous messages are already in chat.",
|
||||
},
|
||||
)
|
||||
messages.extend(history)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is the first message of the chat. There is no previous one",
|
||||
}
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nAnswer:",
|
||||
}
|
||||
)
|
||||
if useContext:
|
||||
chat_context = self._get_context(question)
|
||||
enhanced_question = f"Here is chat context: {chat_context if len(chat_context) else 'No document found'}"
|
||||
messages.append({"role": "user", "content": enhanced_question})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nAnswer:",
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
def _get_context(self, question: str) -> str:
|
||||
# retrieve 5 nearest documents
|
||||
vector_store = CustomSupabaseVectorStore(
|
||||
self.supabase_client,
|
||||
self.embeddings,
|
||||
table_name="vectors",
|
||||
user_id=self.user_email,
|
||||
)
|
||||
|
||||
return vector_store.similarity_search(
|
||||
query=question,
|
||||
)
|
||||
|
||||
def _get_answer_from_question(self, question: str) -> OpenAiAnswer:
|
||||
functions = [get_history_schema, get_context_schema]
|
||||
|
||||
model_response = self._get_model_response(
|
||||
messages=self._get_formatted_prompt(question=question),
|
||||
functions=functions,
|
||||
)
|
||||
|
||||
return format_answer(model_response)
|
||||
|
||||
def _get_answer_from_question_and_history(self, question: str) -> OpenAiAnswer:
|
||||
logger.info("Using chat history")
|
||||
|
||||
functions = [
|
||||
get_context_schema,
|
||||
]
|
||||
|
||||
model_response = self._get_model_response(
|
||||
messages=self._get_formatted_prompt(question=question, useHistory=True),
|
||||
functions=functions,
|
||||
)
|
||||
|
||||
return format_answer(model_response)
|
||||
|
||||
def _get_answer_from_question_and_context(self, question: str) -> OpenAiAnswer:
|
||||
logger.info("Using documents ")
|
||||
|
||||
functions = [
|
||||
get_history_schema,
|
||||
]
|
||||
model_response = self._get_model_response(
|
||||
messages=self._get_formatted_prompt(question=question, useContext=True),
|
||||
functions=functions,
|
||||
)
|
||||
|
||||
return format_answer(model_response)
|
||||
|
||||
def _get_answer_from_question_and_context_and_history(
|
||||
self, question: str
|
||||
) -> OpenAiAnswer:
|
||||
logger.info("Using context and history")
|
||||
model_response = self._get_model_response(
|
||||
messages=self._get_formatted_prompt(
|
||||
question, useContext=True, useHistory=True
|
||||
),
|
||||
)
|
||||
return format_answer(model_response)
|
||||
|
||||
def get_answer(self, question: str) -> str:
|
||||
response = self._get_answer_from_question(question)
|
||||
function_name = response.function_call.name if response.function_call else None
|
||||
|
||||
if function_name == "get_history":
|
||||
response = self._get_answer_from_question_and_history(question)
|
||||
elif function_name == "get_context":
|
||||
response = self._get_answer_from_question_and_context(question)
|
||||
|
||||
if response.function_call:
|
||||
response = self._get_answer_from_question_and_context_and_history(question)
|
||||
|
||||
return response.content or ""
|
80
backend/llm/PrivateBrainPicking.py
Normal file
80
backend/llm/PrivateBrainPicking.py
Normal file
@ -0,0 +1,80 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from llm.brainpicking import BrainPicking
|
||||
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
from logger import get_logger
|
||||
from models.settings import BrainSettings # Importing settings related to the 'brain'
|
||||
from models.settings import LLMSettings # For type hinting
|
||||
from pydantic import BaseModel # For data validation and settings management
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from supabase import Client # For interacting with Supabase database
|
||||
from supabase import create_client
|
||||
from vectorstore.supabase import (
|
||||
CustomSupabaseVectorStore,
|
||||
) # Custom class for handling vector storage with Supabase
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PrivateBrainPicking(BrainPicking):
|
||||
"""
|
||||
This subclass of BrainPicking is used to specifically work with a private language model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
user_openai_api_key: str,
|
||||
) -> "PrivateBrainPicking":
|
||||
"""
|
||||
Initialize the PrivateBrainPicking class by calling the parent class's initializer.
|
||||
:param model: Language model name to be used.
|
||||
:param user_id: The user id to be used for CustomSupabaseVectorStore.
|
||||
:return: PrivateBrainPicking instance
|
||||
"""
|
||||
# Call the parent class's initializer
|
||||
super().__init__(
|
||||
model=model,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
|
||||
def _determine_llm(
|
||||
self, private_model_args: dict, private: bool = True, model_name: str = None
|
||||
) -> LLM:
|
||||
"""
|
||||
Override the _determine_llm method to enforce the use of a private model.
|
||||
:param model_name: Language model name to be used.
|
||||
:param private_model_args: Dictionary containing model_path, n_ctx and n_batch.
|
||||
:param private: Boolean value to determine if private model is to be used. Defaulted to True.
|
||||
:return: Language model instance
|
||||
"""
|
||||
# Force the use of a private model by setting private to True.
|
||||
model_path = private_model_args["model_path"]
|
||||
model_n_ctx = private_model_args["n_ctx"]
|
||||
model_n_batch = private_model_args["n_batch"]
|
||||
|
||||
logger.info("Using private model: %s", model_path)
|
||||
|
||||
return GPT4All(
|
||||
model=model_path,
|
||||
n_ctx=model_n_ctx,
|
||||
n_batch=model_n_batch,
|
||||
backend="gptj",
|
||||
verbose=True,
|
||||
)
|
0
backend/llm/__init__.py
Normal file
0
backend/llm/__init__.py
Normal file
@ -1,5 +1,4 @@
|
||||
from typing import Any, Dict
|
||||
from models.settings import LLMSettings # For type hinting
|
||||
|
||||
# Importing various modules and classes from a custom library 'langchain' likely used for natural language processing
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
@ -9,17 +8,17 @@ from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms import GPT4All
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
from llm.prompt.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
from logger import get_logger
|
||||
from models.settings import BrainSettings # Importing settings related to the 'brain'
|
||||
from models.settings import LLMSettings # For type hinting
|
||||
from pydantic import BaseModel # For data validation and settings management
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from supabase import Client # For interacting with Supabase database
|
||||
from supabase import create_client
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
from vectorstore.supabase import (
|
||||
CustomSupabaseVectorStore,
|
||||
) # Custom class for handling vector storage with Supabase
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -58,6 +57,7 @@ class BrainPicking(BaseModel):
|
||||
|
||||
# Default class attributes
|
||||
llm_name: str = "gpt-3.5-turbo"
|
||||
temperature: float = 0.0
|
||||
settings = BrainSettings()
|
||||
llm_config = LLMSettings()
|
||||
embeddings: OpenAIEmbeddings = None
|
||||
@ -77,6 +77,7 @@ class BrainPicking(BaseModel):
|
||||
self,
|
||||
model: str,
|
||||
user_id: str,
|
||||
temperature: float,
|
||||
chat_id: str,
|
||||
max_tokens: int,
|
||||
user_openai_api_key: str,
|
||||
@ -92,16 +93,19 @@ class BrainPicking(BaseModel):
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
# If user provided an API key, update the settings
|
||||
if user_openai_api_key is not None:
|
||||
self.settings.openai_api_key = user_openai_api_key
|
||||
|
||||
self.temperature = temperature
|
||||
self.embeddings = OpenAIEmbeddings(openai_api_key=self.settings.openai_api_key)
|
||||
self.supabase_client = create_client(
|
||||
self.settings.supabase_url, self.settings.supabase_service_key
|
||||
)
|
||||
self.llm_name = model
|
||||
self.vector_store = CustomSupabaseVectorStore(
|
||||
self.supabase_client,
|
||||
self.embeddings,
|
||||
@ -135,22 +139,8 @@ class BrainPicking(BaseModel):
|
||||
:param private: Boolean value to determine if private model is to be used.
|
||||
:return: Language model instance
|
||||
"""
|
||||
if private:
|
||||
model_path = private_model_args["model_path"]
|
||||
model_n_ctx = private_model_args["n_ctx"]
|
||||
model_n_batch = private_model_args["n_batch"]
|
||||
|
||||
logger.info("Using private model: %s", model_path)
|
||||
|
||||
return GPT4All(
|
||||
model=model_path,
|
||||
n_ctx=model_n_ctx,
|
||||
n_batch=model_n_batch,
|
||||
backend="gptj",
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
return ChatOpenAI(temperature=0, model_name=model_name)
|
||||
return ChatOpenAI(temperature=0, model_name=model_name)
|
||||
|
||||
def _get_qa(
|
||||
self,
|
||||
|
0
backend/models/__init__.py
Normal file
0
backend/models/__init__.py
Normal file
@ -15,10 +15,10 @@ class BrainSettings(BaseSettings):
|
||||
|
||||
|
||||
class LLMSettings(BaseSettings):
|
||||
private: bool
|
||||
model_path: str
|
||||
model_n_ctx: int
|
||||
model_n_batch: int
|
||||
private: bool = False
|
||||
model_path: str = "gpt2"
|
||||
model_n_ctx: int = 1000
|
||||
model_n_batch: int = 8
|
||||
|
||||
|
||||
def common_dependencies() -> dict:
|
||||
|
@ -5,12 +5,13 @@ from uuid import UUID
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from llm.brainpicking import BrainPicking
|
||||
from llm.OpenAiFunctionBasedAnswerGenerator.OpenAiFunctionBasedAnswerGenerator import (
|
||||
OpenAiFunctionBasedAnswerGenerator,
|
||||
from llm.BrainPickingOpenAIFunctions.BrainPickingOpenAIFunctions import (
|
||||
BrainPickingOpenAIFunctions,
|
||||
)
|
||||
from llm.PrivateBrainPicking import PrivateBrainPicking
|
||||
from models.chat import Chat, ChatHistory
|
||||
from models.chats import ChatQuestion
|
||||
from models.settings import common_dependencies
|
||||
from models.settings import LLMSettings, common_dependencies
|
||||
from models.users import User
|
||||
from repository.chat.create_chat import CreateChatProperties, create_chat
|
||||
from repository.chat.get_chat_by_id import get_chat_by_id
|
||||
@ -157,27 +158,39 @@ async def create_question_handler(
|
||||
try:
|
||||
user_openai_api_key = request.headers.get("Openai-Api-Key")
|
||||
check_user_limit(current_user.email, user_openai_api_key)
|
||||
llm_settings = LLMSettings()
|
||||
openai_function_compatible_models = [
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-4-0613",
|
||||
]
|
||||
if chat_question.model in openai_function_compatible_models:
|
||||
# TODO: RBAC with current_user
|
||||
gpt_answer_generator = OpenAiFunctionBasedAnswerGenerator(
|
||||
if llm_settings.private:
|
||||
gpt_answer_generator = PrivateBrainPicking(
|
||||
model=chat_question.model,
|
||||
chat_id=chat_id,
|
||||
chat_id=str(chat_id),
|
||||
temperature=chat_question.temperature,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
user_id=current_user.email,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
answer = gpt_answer_generator.generate_answer(chat_question.question)
|
||||
elif chat_question.model in openai_function_compatible_models:
|
||||
# TODO: RBAC with current_user
|
||||
gpt_answer_generator = BrainPickingOpenAIFunctions(
|
||||
model=chat_question.model,
|
||||
chat_id=str(chat_id),
|
||||
temperature=chat_question.temperature,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
# TODO: use user_id in vectors table instead of email
|
||||
user_email=current_user.email,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
answer = gpt_answer_generator.get_answer(chat_question.question)
|
||||
answer = gpt_answer_generator.generate_answer(chat_question.question)
|
||||
else:
|
||||
brainPicking = BrainPicking(
|
||||
chat_id=str(chat_id),
|
||||
model=chat_question.model,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
temperature=chat_question.temperature,
|
||||
user_id=current_user.email,
|
||||
user_openai_api_key=user_openai_api_key,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user