quivr/backend/modules/brain/integrations/GPT4/Brain.py
Stan Girard f5cfa2f6fa
feat(gpt4): Add search functionality (#2566)
This pull request adds search functionality to the application The
search functionality allows users to search the internet for
information.
2024-05-09 07:01:33 -07:00

240 lines
8.8 KiB
Python

import json
import operator
from typing import Annotated, AsyncIterable, List, Sequence, TypedDict
from uuid import UUID
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_core.messages import BaseMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolExecutor, ToolInvocation
from logger import get_logger
from modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
# Define the function that determines whether to continue or not
logger = get_logger(__name__)
chat_service = ChatService()
class GPT4Brain(KnowledgeBrainQA):
"""This is the Notion brain class. it is a KnowledgeBrainQA has the data is stored locally.
It is going to call the Data Store internally to get the data.
Args:
KnowledgeBrainQA (_type_): A brain that store the knowledge internaly
"""
tools: List[BaseTool] = [DuckDuckGoSearchResults()]
tool_executor: ToolExecutor = ToolExecutor(tools)
model_function: ChatOpenAI = None
def __init__(
self,
**kwargs,
):
super().__init__(
**kwargs,
)
def calculate_pricing(self):
return 3
def should_continue(self, state):
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
else:
return "continue"
# Define the function that calls the model
def call_model(self, state):
messages = state["messages"]
response = self.model_function.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
# Define the function to execute tools
def call_tool(self, state):
messages = state["messages"]
# Based on the continue condition
# we know the last message involves a function call
last_message = messages[-1]
# We construct an ToolInvocation from the function_call
tool_call = last_message.tool_calls[0]
action = ToolInvocation(
tool=tool_call["name"],
tool_input=tool_call["args"],
)
# We call the tool_executor and get back a response
response = self.tool_executor.invoke(action)
# We use the response to create a FunctionMessage
function_message = ToolMessage(
content=str(response), name=action.tool, tool_call_id=tool_call["id"]
)
# We return a list, because this will get added to the existing list
return {"messages": [function_message]}
def create_graph(self):
# Define a new graph
workflow = StateGraph(AgentState)
# Define the two nodes we will cycle between
workflow.add_node("agent", self.call_model)
workflow.add_node("action", self.call_tool)
# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")
# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
self.should_continue,
# Finally we pass in a mapping.
# The keys are strings, and the values are other nodes.
# END is a special node marking that the graph should finish.
# What will happen is we will call `should_continue`, and then the output of that
# will be matched against the keys in this mapping.
# Based on which one it matches, that node will then be called.
{
# If `tools`, then we call the tool node.
"continue": "action",
# Otherwise we finish.
"end": END,
},
)
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()
return app
def get_chain(self):
self.model_function = ChatOpenAI(
model="gpt-4-turbo", temperature=0, streaming=True
)
self.model_function = self.model_function.bind_tools(self.tools)
graph = self.create_graph()
return graph
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> AsyncIterable:
conversational_qa_chain = self.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
)
filtered_history = self.filter_history(transformed_history, 20, 2000)
response_tokens = []
config = {"metadata": {"conversation_id": str(chat_id)}}
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are GPT-4 powered by Quivr. You are an assistant. {custom_personality}",
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
prompt_formated = prompt.format_messages(
chat_history=filtered_history,
question=question.question,
custom_personality=(
self.prompt_to_use.content if self.prompt_to_use else None
),
)
async for event in conversational_qa_chain.astream_events(
{"messages": prompt_formated},
config=config,
version="v1",
):
kind = event["event"]
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
# Empty content in the context of OpenAI or Anthropic usually means
# that the model is asking for a tool to be invoked.
# So we only print non-empty content
response_tokens.append(content)
streamed_chat_history.assistant = content
yield f"data: {json.dumps(streamed_chat_history.dict())}"
elif kind == "on_tool_start":
print("--")
print(
f"Starting tool: {event['name']} with inputs: {event['data'].get('input')}"
)
elif kind == "on_tool_end":
print(f"Done tool: {event['name']}")
print(f"Tool output was: {event['data'].get('output')}")
print("--")
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
def generate_answer(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:
conversational_qa_chain = self.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
)
filtered_history = self.filter_history(transformed_history, 20, 2000)
response_tokens = []
config = {"metadata": {"conversation_id": str(chat_id)}}
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are GPT-4 powered by Quivr. You are an assistant. {custom_personality}",
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
prompt_formated = prompt.format_messages(
chat_history=filtered_history,
question=question.question,
custom_personality=(
self.prompt_to_use.content if self.prompt_to_use else None
),
)
model_response = conversational_qa_chain.invoke(
{"messages": prompt_formated},
config=config,
)
answer = model_response["messages"][-1].content
return self.save_non_streaming_answer(
chat_id=chat_id, question=question, answer=answer, metadata={}
)