fix: update max token logic (#1725)

Fix headless and Doc based brains max token overwriting
This commit is contained in:
Mamadou DICKO 2023-11-27 11:21:26 +01:00 committed by GitHub
parent 404a9f1573
commit b252aa1794
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 10 deletions

View File

@ -14,8 +14,6 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from models.chats import ChatQuestion
@ -32,6 +30,9 @@ from repository.chat import (
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
logger = get_logger(__name__)
@ -133,7 +134,7 @@ class QABaseBrainPicking(BaseModel):
)
def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None, max_tokens=256
self, model, temperature=0, streaming=False, callbacks=None
) -> BaseLLM:
"""
Determine the language model to be used.
@ -144,7 +145,7 @@ class QABaseBrainPicking(BaseModel):
"""
return ChatLiteLLM(
temperature=temperature,
max_tokens=max_tokens,
max_tokens=self.max_tokens,
model=model,
streaming=streaming,
verbose=False,
@ -179,7 +180,9 @@ class QABaseBrainPicking(BaseModel):
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(get_chat_history(self.chat_id))
answering_llm = self._create_llm(
model=self.model, streaming=False, callbacks=self.callbacks
model=self.model,
streaming=False,
callbacks=self.callbacks,
)
# The Chain that generates the answer to the question
@ -255,7 +258,6 @@ class QABaseBrainPicking(BaseModel):
model=self.model,
streaming=True,
callbacks=self.callbacks,
max_tokens=self.max_tokens,
)
# The Chain that generates the answer to the question

View File

@ -8,8 +8,6 @@ from langchain.chains import LLMChain
from langchain.chat_models import ChatLiteLLM
from langchain.chat_models.base import BaseChatModel
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
@ -24,6 +22,9 @@ from repository.chat import (
update_message_by_id,
)
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
logger = get_logger(__name__)
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer.When answering use markdown or any other techniques to display the content in a nice and aerated way."
@ -64,7 +65,11 @@ class HeadlessQA(BaseModel):
return get_prompt_to_use_id(None, self.prompt_id)
def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None
self,
model,
temperature=0,
streaming=False,
callbacks=None,
) -> BaseChatModel:
"""
Determine the language model to be used.
@ -79,6 +84,7 @@ class HeadlessQA(BaseModel):
streaming=streaming,
verbose=True,
callbacks=callbacks,
max_tokens=self.max_tokens,
)
def _create_prompt_template(self):
@ -100,7 +106,9 @@ class HeadlessQA(BaseModel):
transformed_history, prompt_content, question.question
)
answering_llm = self._create_llm(
model=self.model, streaming=False, callbacks=self.callbacks
model=self.model,
streaming=False,
callbacks=self.callbacks,
)
model_prediction = answering_llm.predict_messages(messages)
answer = model_prediction.content