mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
fix: update max token logic (#1725)
Fix headless and Doc based brains max token overwriting
This commit is contained in:
parent
404a9f1573
commit
b252aa1794
@ -14,8 +14,6 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from models.chats import ChatQuestion
|
||||
@ -32,6 +30,9 @@ from repository.chat import (
|
||||
from supabase.client import Client, create_client
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
|
||||
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -133,7 +134,7 @@ class QABaseBrainPicking(BaseModel):
|
||||
)
|
||||
|
||||
def _create_llm(
|
||||
self, model, temperature=0, streaming=False, callbacks=None, max_tokens=256
|
||||
self, model, temperature=0, streaming=False, callbacks=None
|
||||
) -> BaseLLM:
|
||||
"""
|
||||
Determine the language model to be used.
|
||||
@ -144,7 +145,7 @@ class QABaseBrainPicking(BaseModel):
|
||||
"""
|
||||
return ChatLiteLLM(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
max_tokens=self.max_tokens,
|
||||
model=model,
|
||||
streaming=streaming,
|
||||
verbose=False,
|
||||
@ -179,7 +180,9 @@ class QABaseBrainPicking(BaseModel):
|
||||
) -> GetChatHistoryOutput:
|
||||
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model, streaming=False, callbacks=self.callbacks
|
||||
model=self.model,
|
||||
streaming=False,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
# The Chain that generates the answer to the question
|
||||
@ -255,7 +258,6 @@ class QABaseBrainPicking(BaseModel):
|
||||
model=self.model,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
# The Chain that generates the answer to the question
|
||||
|
@ -8,8 +8,6 @@ from langchain.chains import LLMChain
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
@ -24,6 +22,9 @@ from repository.chat import (
|
||||
update_message_by_id,
|
||||
)
|
||||
|
||||
from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer.When answering use markdown or any other techniques to display the content in a nice and aerated way."
|
||||
|
||||
@ -64,7 +65,11 @@ class HeadlessQA(BaseModel):
|
||||
return get_prompt_to_use_id(None, self.prompt_id)
|
||||
|
||||
def _create_llm(
|
||||
self, model, temperature=0, streaming=False, callbacks=None
|
||||
self,
|
||||
model,
|
||||
temperature=0,
|
||||
streaming=False,
|
||||
callbacks=None,
|
||||
) -> BaseChatModel:
|
||||
"""
|
||||
Determine the language model to be used.
|
||||
@ -79,6 +84,7 @@ class HeadlessQA(BaseModel):
|
||||
streaming=streaming,
|
||||
verbose=True,
|
||||
callbacks=callbacks,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
def _create_prompt_template(self):
|
||||
@ -100,7 +106,9 @@ class HeadlessQA(BaseModel):
|
||||
transformed_history, prompt_content, question.question
|
||||
)
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model, streaming=False, callbacks=self.callbacks
|
||||
model=self.model,
|
||||
streaming=False,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
model_prediction = answering_llm.predict_messages(messages)
|
||||
answer = model_prediction.content
|
||||
|
Loading…
Reference in New Issue
Block a user