fix: gpt4all (#595)

* fix: gpt4all

* fix: pyright

* Update backend/llm/openai.py

* fix: remove backend tag

* fix: typing

* feat: qa_base class

* fix: pyright

* fix: model_path not found
This commit is contained in:
Matt 2023-07-11 19:15:56 +01:00 committed by GitHub
parent f837a6e9b9
commit 8fbb4b2d91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 290 additions and 230 deletions

View File

@ -12,8 +12,6 @@ MAX_REQUESTS_NUMBER=200
#Private LLM Variables
PRIVATE=False
MODEL_PATH=./local_models/ggml-gpt4all-j-v1.3-groovy.bin
MODEL_N_CTX=1000
MODEL_N_BATCH=8
#RESEND
RESEND_API_KEY=
RESEND_API_KEY=

View File

@ -1,10 +1,12 @@
from .base import BaseBrainPicking
from .qa_base import QABaseBrainPicking
from .openai import OpenAIBrainPicking
from .openai_functions import OpenAIFunctionsBrainPicking
from .private_gpt4all import PrivateGPT4AllBrainPicking
__all__ = [
"BaseBrainPicking",
"QABaseBrainPicking",
"OpenAIBrainPicking",
"OpenAIFunctionsBrainPicking",
"PrivateGPT4AllBrainPicking",

View File

@ -1,31 +1,13 @@
import asyncio
import json
from typing import AsyncIterable, Awaitable
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from llm.qa_base import QABaseBrainPicking
from logger import get_logger
from models.chat import ChatHistory
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from supabase.client import Client, create_client
from vectorstore.supabase import (
CustomSupabaseVectorStore,
)
# Custom class for handling vector storage with Supabase
from .base import BaseBrainPicking
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
logger = get_logger(__name__)
class OpenAIBrainPicking(BaseBrainPicking):
class OpenAIBrainPicking(QABaseBrainPicking):
"""
Main class for the OpenAI Brain Picking functionality.
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
@ -64,182 +46,17 @@ class OpenAIBrainPicking(BaseBrainPicking):
openai_api_key=self.openai_api_key
) # pyright: ignore reportPrivateUsage=none
@property
def supabase_client(self) -> Client:
return create_client(
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
)
@property
def vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
)
@property
def question_llm(self) -> LLM:
return self._create_llm(model=self.model, streaming=False)
@property
def doc_llm(self) -> LLM:
return self._create_llm(
model=self.model, streaming=self.streaming, callbacks=self.callbacks
)
@property
def question_generator(self) -> LLMChain:
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT)
@property
def doc_chain(self) -> LLMChain:
return load_qa_chain(
llm=self.doc_llm, chain_type="stuff"
) # pyright: ignore reportPrivateUsage=none
@property
def qa(self) -> ConversationalRetrievalChain:
return ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(),
question_generator=self.question_generator,
combine_docs_chain=self.doc_chain, # pyright: ignore reportPrivateUsage=none
verbose=True,
)
def _create_llm(self, model, streaming=False, callbacks=None) -> LLM:
def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param private_model_args: Dictionary containing model_path, n_ctx and n_batch.
:param private: Boolean value to determine if private model is to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
return ChatOpenAI(
temperature=0,
temperature=self.temperature,
model=model,
streaming=streaming,
callbacks=callbacks,
) # pyright: ignore reportPrivateUsage=none
def _call_chain(self, chain, question, history):
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
return chain(
{
"question": question,
"chat_history": history,
}
)
def generate_answer(self, question: str) -> ChatHistory:
"""
Generate an answer to a given question by interacting with the language model.
:param question: The question
:return: The generated answer.
"""
transformed_history = []
# Get the history from the database
history = get_chat_history(self.chat_id)
# Format the chat history into a list of tuples (human, ai)
transformed_history = format_chat_history(history)
# Generate the model response using the QA chain
model_response = self._call_chain(self.qa, question, transformed_history)
answer = model_response["answer"]
# Update chat history
chat_answer = update_chat_history(
chat_id=self.chat_id,
user_message=question,
assistant=answer,
)
return chat_answer
async def _acall_chain(self, chain, question, history):
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
return chain.acall(
{
"question": question,
"chat_history": history,
}
)
async def generate_stream(self, question: str) -> AsyncIterable:
"""
Generate a streaming answer to a given question by interacting with the language model.
:param question: The question
:return: An async iterable which generates the answer.
"""
history = get_chat_history(self.chat_id)
callback = self.callbacks[0]
transformed_history = []
# Format the chat history into a list of tuples (human, ai)
transformed_history = format_chat_history(history)
# Initialize a list to hold the tokens
response_tokens = []
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()
task = asyncio.create_task(
wrap_done(
self.qa._acall_chain( # pyright: ignore reportPrivateUsage=none
self.qa, question, transformed_history
),
callback.done, # pyright: ignore reportPrivateUsage=none
)
)
streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
assistant="",
)
# Use the aiter method of the callback to stream the response with server-sent-events
async for token in callback.aiter(): # pyright: ignore reportPrivateUsage=none
logger.info("Token: %s", token)
# Add the token to the response_tokens list
response_tokens.append(token)
streamed_chat_history.assistant = token
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
await task
# Join the tokens to create the assistant's response
assistant = "".join(response_tokens)
update_message_by_id(
message_id=streamed_chat_history.message_id,
user_message=question,
assistant=assistant,
)

View File

@ -2,6 +2,8 @@ from typing import Any, Dict, List, Optional
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from llm.models.FunctionCall import FunctionCall
from llm.models.OpenAiAnswer import OpenAiAnswer
from logger import get_logger
from models.chat import ChatHistory
from repository.chat.get_chat_history import get_chat_history
@ -9,9 +11,6 @@ from repository.chat.update_chat_history import update_chat_history
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from llm.models.FunctionCall import FunctionCall
from llm.models.OpenAiAnswer import OpenAiAnswer
from .base import BaseBrainPicking
logger = get_logger(__name__)
@ -148,8 +147,8 @@ class OpenAIFunctionsBrainPicking(BaseBrainPicking):
{
"role": "system",
"content": """Your name is Quivr. You are an assistant that has access to a person's documents and that can answer questions about them.
A person will ask you a question and you will provide a helpful answer.
Write the answer in the same language as the question.
A person will ask you a question and you will provide a helpful answer.
Write the answer in the same language as the question.
You have access to functions to help you answer the question.
If you don't know the answer, just say that you don't know but be helpful and explain why you can't answer""",
}

View File

@ -1,60 +1,73 @@
from langchain.llms.base import LLM
from typing import Optional
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.llms.gpt4all import GPT4All
from llm.qa_base import QABaseBrainPicking
from logger import get_logger
from models.settings import LLMSettings
from .base import BaseBrainPicking
logger = get_logger(__name__)
class PrivateGPT4AllBrainPicking(BaseBrainPicking):
class PrivateGPT4AllBrainPicking(QABaseBrainPicking):
"""
This subclass of BrainPicking is used to specifically work with the private language model GPT4All.
"""
# Initialize class settings
llm_settings = LLMSettings()
# Define the default model path
model_path: str = "./local_models/ggml-gpt4all-j-v1.3-groovy.bin"
def __init__(
self,
chat_id: str,
brain_id: str,
user_openai_api_key: Optional[str],
streaming: bool,
) -> "PrivateGPT4AllBrainPicking": # pyright: ignore reportPrivateUsage=none
model_path: str,
) -> None:
"""
Initialize the PrivateBrainPicking class by calling the parent class's initializer.
:param brain_id: The brain_id in the DB.
:param chat_id: The id of the chat in the DB.
:param streaming: Whether to enable streaming of the model
:return: PrivateBrainPicking instance
:param model_path: The path to the model. If not provided, a default path is used.
"""
# set defaults
model = "gpt4all-j-1.3"
super().__init__(
model=model,
model="gpt4all-j-1.3",
brain_id=brain_id,
chat_id=chat_id,
user_openai_api_key=user_openai_api_key,
streaming=streaming,
)
def _create_llm(self) -> LLM:
# Set the model path
self.model_path = model_path
# TODO: Use private embeddings model. This involves some restructuring of how we store the embeddings.
@property
def embeddings(self) -> OpenAIEmbeddings:
return OpenAIEmbeddings(
openai_api_key=self.openai_api_key
) # pyright: ignore reportPrivateUsage=none
def _create_llm(
self,
model,
streaming=False,
callbacks=None,
) -> BaseLLM:
"""
Override the _create_llm method to enforce the use of a private model.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
model_path = self.llm_settings.model_path
model_n_ctx = self.llm_settings.model_n_ctx
model_n_batch = self.llm_settings.model_n_batch
model_path = self.model_path
logger.info("Using private model: %s", model_path)
logger.info("Using private model: %s", model)
logger.info("Streaming is set to %s", streaming)
return GPT4All(
model=model_path,
n_ctx=model_n_ctx,
n_batch=model_n_batch,
backend="gptj",
verbose=True,
) # pyright: ignore reportPrivateUsage=none

228
backend/llm/qa_base.py Normal file
View File

@ -0,0 +1,228 @@
import asyncio
import json
from abc import abstractmethod, abstractproperty
from typing import AsyncIterable, Awaitable
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from logger import get_logger
from models.chat import ChatHistory
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from .base import BaseBrainPicking
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
logger = get_logger(__name__)
class QABaseBrainPicking(BaseBrainPicking):
"""
Base class for the Brain Picking functionality using the Conversational Retrieval Chain (QA) from Langchain.
It is not designed to be used directly, but to be subclassed by other classes which use the QA chain.
"""
def __init__(
self,
model: str,
brain_id: str,
chat_id: str,
streaming: bool = False,
**kwargs,
) -> "QABaseBrainPicking": # pyright: ignore reportPrivateUsage=none
"""
Initialize the QA BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
:return: QABrainPicking instance
"""
super().__init__(
model=model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
**kwargs,
)
@abstractproperty
def embeddings(self) -> OpenAIEmbeddings:
raise NotImplementedError("This property should be overridden in a subclass.")
@property
def supabase_client(self) -> Client:
return create_client(
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
)
@property
def vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
)
@property
def question_llm(self):
return self._create_llm(model=self.model, streaming=False)
@property
def doc_llm(self):
return self._create_llm(
model=self.model, streaming=self.streaming, callbacks=self.callbacks
)
@property
def question_generator(self) -> LLMChain:
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT)
@property
def doc_chain(self) -> LLMChain:
return load_qa_chain(
llm=self.doc_llm, chain_type="stuff"
) # pyright: ignore reportPrivateUsage=none
@property
def qa(self) -> ConversationalRetrievalChain:
return ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(),
question_generator=self.question_generator,
combine_docs_chain=self.doc_chain, # pyright: ignore reportPrivateUsage=none
verbose=True,
)
@abstractmethod
def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
def _call_chain(self, chain, question, history):
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
return chain(
{
"question": question,
"chat_history": history,
}
)
def generate_answer(self, question: str) -> ChatHistory:
"""
Generate an answer to a given question by interacting with the language model.
:param question: The question
:return: The generated answer.
"""
transformed_history = []
# Get the history from the database
history = get_chat_history(self.chat_id)
# Format the chat history into a list of tuples (human, ai)
transformed_history = format_chat_history(history)
# Generate the model response using the QA chain
model_response = self._call_chain(self.qa, question, transformed_history)
answer = model_response["answer"]
# Update chat history
chat_answer = update_chat_history(
chat_id=self.chat_id,
user_message=question,
assistant=answer,
)
return chat_answer
async def _acall_chain(self, chain, question, history):
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
return chain.acall(
{
"question": question,
"chat_history": history,
}
)
async def generate_stream(self, question: str) -> AsyncIterable:
"""
Generate a streaming answer to a given question by interacting with the language model.
:param question: The question
:return: An async iterable which generates the answer.
"""
history = get_chat_history(self.chat_id)
callback = self.callbacks[0]
transformed_history = []
# Format the chat history into a list of tuples (human, ai)
transformed_history = format_chat_history(history)
# Initialize a list to hold the tokens
response_tokens = []
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()
task = asyncio.create_task(
wrap_done(
self.qa._acall_chain( # pyright: ignore reportPrivateUsage=none
self.qa, question, transformed_history
),
callback.done, # pyright: ignore reportPrivateUsage=none
)
)
streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
assistant="",
)
# Use the aiter method of the callback to stream the response with server-sent-events
async for token in callback.aiter(): # pyright: ignore reportPrivateUsage=none
logger.info("Token: %s", token)
# Add the token to the response_tokens list
response_tokens.append(token)
streamed_chat_history.assistant = token
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
await task
# Join the tokens to create the assistant's response
assistant = "".join(response_tokens)
update_message_by_id(
message_id=streamed_chat_history.message_id,
user_message=question,
assistant=assistant,
)

View File

@ -16,9 +16,7 @@ class BrainSettings(BaseSettings):
class LLMSettings(BaseSettings):
private: bool = False
model_path: str = "gpt2"
model_n_ctx: int = 1000
model_n_batch: int = 8
model_path: str = "./local_models/ggml-gpt4all-j-v1.3-groovy.bin"
def common_dependencies() -> dict:

View File

@ -1,5 +1,5 @@
pymupdf==1.22.3
langchain==0.0.207
langchain==0.0.228
Markdown==3.4.3
openai==0.27.6
pdf2image==1.16.3

View File

@ -172,7 +172,9 @@ async def create_question_handler(
gpt_answer_generator = PrivateGPT4AllBrainPicking(
chat_id=str(chat_id),
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key,
streaming=False,
model_path=llm_settings.model_path,
)
elif chat_question.model in openai_function_compatible_models:
@ -228,6 +230,7 @@ async def create_stream_question_handler(
try:
user_openai_api_key = request.headers.get("Openai-Api-Key")
streaming = True
check_user_limit(current_user)
llm_settings = LLMSettings()
@ -235,7 +238,9 @@ async def create_stream_question_handler(
gpt_answer_generator = PrivateGPT4AllBrainPicking(
chat_id=str(chat_id),
brain_id=str(brain_id),
streaming=False,
user_openai_api_key=user_openai_api_key,
streaming=streaming,
model_path=llm_settings.model_path,
)
else:
gpt_answer_generator = OpenAIBrainPicking(
@ -245,7 +250,7 @@ async def create_stream_question_handler(
temperature=chat_question.temperature,
brain_id=str(brain_id),
user_openai_api_key=user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
streaming=streaming,
)
return StreamingResponse(

View File

@ -4,20 +4,20 @@ sidebar_position: 1
# Private LLM
Quivr now has the capability to use a private LLM model powered by GPT4All (other open source models coming soon).
Quivr now has the capability to use a private LLM model powered by GPT4All (other open source models coming soon).
This is simular to the functionality provided by the PrivateGPT project.
This means that your data never leaves the server. The LLM is downloaded to the server and runs inference on your question locally.
## How to use
Set the 'private' flag to True in the /backend/.env file. You can also set other model parameters in the .env file.
Download the GPT4All model from [here](
https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin) and place it in the /backend/local_models folder.
Download the GPT4All model from [here](https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin) and place it in the /backend/local_models folder. Or you can download any model from their ecosystem on there [website](https://gpt4all.io/index.html).
## Future Plans
We are planning to add more models to the private LLM feature. We are also planning on using a local embedding model from Hugging Face to reduce our reliance on OpenAI's API.
We will also be adding the ability to use a private LLM model from in the frontend and api. Currently it is only available if you self host the backend.