test: unskip qa_headless.py linter tests (#1041)

This commit is contained in:
Mamadou DICKO 2023-08-25 14:03:57 +02:00 committed by GitHub
parent 66bafcf2c5
commit f1d6b7892c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 26 deletions

View File

@ -6,7 +6,7 @@ from uuid import UUID
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.llms.base import BaseLLM
from langchain.chat_models.base import BaseChatModel
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
@ -33,15 +33,15 @@ SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't k
class HeadlessQA(BaseModel):
model: str = None # type: ignore
model: str
temperature: float = 0.0
max_tokens: int = 256
user_openai_api_key: str = None # type: ignore
openai_api_key: str = None # type: ignore
user_openai_api_key: Optional[str] = None
openai_api_key: Optional[str] = None
streaming: bool = False
chat_id: str = None # type: ignore
callbacks: List[AsyncIteratorCallbackHandler] = None # type: ignore
prompt_id: Optional[UUID]
chat_id: str
callbacks: Optional[List[AsyncIteratorCallbackHandler]] = None
prompt_id: Optional[UUID] = None
def _determine_api_key(self, openai_api_key, user_openai_api_key):
"""If user provided an API key, use it."""
@ -50,31 +50,28 @@ class HeadlessQA(BaseModel):
else:
return openai_api_key
def _determine_streaming(self, model: str, streaming: bool) -> bool:
def _determine_streaming(self, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
return streaming
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
) -> List[AsyncIteratorCallbackHandler]:
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
]
return [AsyncIteratorCallbackHandler()]
else:
return []
def __init__(self, **data):
super().__init__(**data)
print("in HeadlessQA")
self.openai_api_key = self._determine_api_key(
self.openai_api_key, self.user_openai_api_key
)
self.streaming = self._determine_streaming(
self.model, self.streaming
) # pyright: ignore reportPrivateUsage=none
self.callbacks = self._determine_callback_array(
self.streaming
) # pyright: ignore reportPrivateUsage=none
self.streaming = self._determine_streaming(self.streaming)
self.callbacks = self._determine_callback_array(self.streaming)
@property
def prompt_to_use(self) -> Optional[Prompt]:
@ -86,7 +83,7 @@ class HeadlessQA(BaseModel):
def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None
) -> BaseLLM:
) -> BaseChatModel:
"""
Determine the language model to be used.
:param model: Language model name to be used.
@ -101,7 +98,7 @@ class HeadlessQA(BaseModel):
verbose=True,
callbacks=callbacks,
openai_api_key=self.openai_api_key,
) # pyright: ignore reportPrivateUsage=none
)
def _create_prompt_template(self):
messages = [
@ -124,9 +121,7 @@ class HeadlessQA(BaseModel):
answering_llm = self._create_llm(
model=self.model, streaming=False, callbacks=self.callbacks
)
model_prediction = answering_llm.predict_messages(
messages # pyright: ignore reportPrivateUsage=none
)
model_prediction = answering_llm.predict_messages(messages)
answer = model_prediction.content
new_chat = update_chat_history(

View File

@ -1,5 +1,6 @@
from typing import List, Tuple
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
def format_chat_history(history) -> List[Tuple[str, str]]:
@ -8,7 +9,9 @@ def format_chat_history(history) -> List[Tuple[str, str]]:
return [(chat.user_message, chat.assistant) for chat in history]
def format_history_to_openai_mesages(tuple_history: List[Tuple[str, str]], system_message: str, question: str) -> List[SystemMessage | HumanMessage | AIMessage]:
def format_history_to_openai_mesages(
tuple_history: List[Tuple[str, str]], system_message: str, question: str
) -> List[BaseMessage]:
"""Format the chat history into a list of Base Messages"""
messages = []
messages.append(SystemMessage(content=system_message))

View File

@ -220,7 +220,7 @@ async def create_question_handler(
model=chat_question.model,
temperature=chat_question.temperature,
max_tokens=chat_question.max_tokens,
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
user_openai_api_key=current_user.openai_api_key,
chat_id=str(chat_id),
prompt_id=chat_question.prompt_id,
)