mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
test: unskip qa_headless.py linter tests (#1041)
This commit is contained in:
parent
66bafcf2c5
commit
f1d6b7892c
@ -6,7 +6,7 @@ from uuid import UUID
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
@ -33,15 +33,15 @@ SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't k
|
||||
|
||||
|
||||
class HeadlessQA(BaseModel):
|
||||
model: str = None # type: ignore
|
||||
model: str
|
||||
temperature: float = 0.0
|
||||
max_tokens: int = 256
|
||||
user_openai_api_key: str = None # type: ignore
|
||||
openai_api_key: str = None # type: ignore
|
||||
user_openai_api_key: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
streaming: bool = False
|
||||
chat_id: str = None # type: ignore
|
||||
callbacks: List[AsyncIteratorCallbackHandler] = None # type: ignore
|
||||
prompt_id: Optional[UUID]
|
||||
chat_id: str
|
||||
callbacks: Optional[List[AsyncIteratorCallbackHandler]] = None
|
||||
prompt_id: Optional[UUID] = None
|
||||
|
||||
def _determine_api_key(self, openai_api_key, user_openai_api_key):
|
||||
"""If user provided an API key, use it."""
|
||||
@ -50,31 +50,28 @@ class HeadlessQA(BaseModel):
|
||||
else:
|
||||
return openai_api_key
|
||||
|
||||
def _determine_streaming(self, model: str, streaming: bool) -> bool:
|
||||
def _determine_streaming(self, streaming: bool) -> bool:
|
||||
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
|
||||
return streaming
|
||||
|
||||
def _determine_callback_array(
|
||||
self, streaming
|
||||
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
|
||||
) -> List[AsyncIteratorCallbackHandler]:
|
||||
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
|
||||
if streaming:
|
||||
return [
|
||||
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
|
||||
]
|
||||
return [AsyncIteratorCallbackHandler()]
|
||||
else:
|
||||
return []
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
print("in HeadlessQA")
|
||||
|
||||
self.openai_api_key = self._determine_api_key(
|
||||
self.openai_api_key, self.user_openai_api_key
|
||||
)
|
||||
self.streaming = self._determine_streaming(
|
||||
self.model, self.streaming
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
self.callbacks = self._determine_callback_array(
|
||||
self.streaming
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
self.streaming = self._determine_streaming(self.streaming)
|
||||
self.callbacks = self._determine_callback_array(self.streaming)
|
||||
|
||||
@property
|
||||
def prompt_to_use(self) -> Optional[Prompt]:
|
||||
@ -86,7 +83,7 @@ class HeadlessQA(BaseModel):
|
||||
|
||||
def _create_llm(
|
||||
self, model, temperature=0, streaming=False, callbacks=None
|
||||
) -> BaseLLM:
|
||||
) -> BaseChatModel:
|
||||
"""
|
||||
Determine the language model to be used.
|
||||
:param model: Language model name to be used.
|
||||
@ -101,7 +98,7 @@ class HeadlessQA(BaseModel):
|
||||
verbose=True,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=self.openai_api_key,
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
|
||||
def _create_prompt_template(self):
|
||||
messages = [
|
||||
@ -124,9 +121,7 @@ class HeadlessQA(BaseModel):
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model, streaming=False, callbacks=self.callbacks
|
||||
)
|
||||
model_prediction = answering_llm.predict_messages(
|
||||
messages # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
model_prediction = answering_llm.predict_messages(messages)
|
||||
answer = model_prediction.content
|
||||
|
||||
new_chat = update_chat_history(
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import List, Tuple
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
def format_chat_history(history) -> List[Tuple[str, str]]:
|
||||
@ -8,7 +9,9 @@ def format_chat_history(history) -> List[Tuple[str, str]]:
|
||||
return [(chat.user_message, chat.assistant) for chat in history]
|
||||
|
||||
|
||||
def format_history_to_openai_mesages(tuple_history: List[Tuple[str, str]], system_message: str, question: str) -> List[SystemMessage | HumanMessage | AIMessage]:
|
||||
def format_history_to_openai_mesages(
|
||||
tuple_history: List[Tuple[str, str]], system_message: str, question: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Format the chat history into a list of Base Messages"""
|
||||
messages = []
|
||||
messages.append(SystemMessage(content=system_message))
|
||||
|
@ -220,7 +220,7 @@ async def create_question_handler(
|
||||
model=chat_question.model,
|
||||
temperature=chat_question.temperature,
|
||||
max_tokens=chat_question.max_tokens,
|
||||
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
|
||||
user_openai_api_key=current_user.openai_api_key,
|
||||
chat_id=str(chat_id),
|
||||
prompt_id=chat_question.prompt_id,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user