feat(gpt4): image generation (#2569)

This pull request adds a new feature to generate images using the OpenAI
DALL-E model. The `ImageGeneratorTool` class is implemented to handle
the image generation functionality.
This commit is contained in:
Stan Girard 2024-05-09 19:00:51 +02:00 committed by GitHub
parent 854cf9ef7c
commit 4e5b0c0373
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 8 deletions

View File

@ -1,8 +1,15 @@
import json
import operator
from typing import Annotated, AsyncIterable, List, Sequence, TypedDict
from typing import Annotated, AsyncIterable, List, Optional, Sequence, Type, TypedDict
from uuid import UUID
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel as BaseModelV1
from langchain.pydantic_v1 import Field as FieldV1
from langchain.tools import BaseTool
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_core.messages import BaseMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
@ -15,6 +22,8 @@ from modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from openai import OpenAI
from pydantic import BaseModel
class AgentState(TypedDict):
@ -28,6 +37,56 @@ logger = get_logger(__name__)
chat_service = ChatService()
class ImageGenerationInput(BaseModelV1):
query: str = FieldV1(
...,
title="description",
description="A detailled prompt to generate the image from. Takes into account the history of the chat.",
)
class ImageGeneratorTool(BaseTool):
name = "image-generator"
description = "useful for when you need to answer questions about current events"
args_schema: Type[BaseModel] = ImageGenerationInput
return_direct = True
def _run(
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
client = OpenAI()
response = client.images.generate(
model="dall-e-3",
prompt=query,
size="1024x1024",
quality="standard",
n=1,
)
logger.info(response.data[0])
image_url = response.data[0].url
revised_prompt = response.data[0].revised_prompt
# Make the url a markdown image
return f"{revised_prompt} \n ![Generated Image]({image_url}) "
async def _arun(
self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""Use the tool asynchronously."""
client = OpenAI()
response = await run_manager.run_async(
client.images.generate,
model="dall-e-3",
prompt=query,
size="1024x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
# Make the url a markdown image
return f"![Generated Image]({image_url})"
class GPT4Brain(KnowledgeBrainQA):
"""This is the Notion brain class. it is a KnowledgeBrainQA has the data is stored locally.
It is going to call the Data Store internally to get the data.
@ -36,7 +95,7 @@ class GPT4Brain(KnowledgeBrainQA):
KnowledgeBrainQA (_type_): A brain that store the knowledge internaly
"""
tools: List[BaseTool] = [DuckDuckGoSearchResults()]
tools: List[BaseTool] = [DuckDuckGoSearchResults(), ImageGeneratorTool()]
tool_executor: ToolExecutor = ToolExecutor(tools)
model_function: ChatOpenAI = None
@ -54,10 +113,16 @@ class GPT4Brain(KnowledgeBrainQA):
def should_continue(self, state):
messages = state["messages"]
last_message = messages[-1]
# Make sure there is a previous message
if last_message.tool_calls:
name = last_message.tool_calls[0]["name"]
if name == "image-generator":
return "final"
# If there is no function call, then we finish
if not last_message.tool_calls:
return "end"
# Otherwise if there is, we continue
# Otherwise if there is, we check if it's suppose to return direct
else:
return "continue"
@ -76,6 +141,9 @@ class GPT4Brain(KnowledgeBrainQA):
last_message = messages[-1]
# We construct an ToolInvocation from the function_call
tool_call = last_message.tool_calls[0]
tool_name = tool_call["name"]
arguments = tool_call["args"]
action = ToolInvocation(
tool=tool_call["name"],
tool_input=tool_call["args"],
@ -96,6 +164,7 @@ class GPT4Brain(KnowledgeBrainQA):
# Define the two nodes we will cycle between
workflow.add_node("agent", self.call_model)
workflow.add_node("action", self.call_tool)
workflow.add_node("final", self.call_tool)
# Set the entrypoint as `agent`
# This means that this node is the first one called
@ -117,6 +186,8 @@ class GPT4Brain(KnowledgeBrainQA):
{
# If `tools`, then we call the tool node.
"continue": "action",
# Final call
"final": "final",
# Otherwise we finish.
"end": END,
},
@ -125,6 +196,7 @@ class GPT4Brain(KnowledgeBrainQA):
# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
workflow.add_edge("final", END)
# Finally, we compile it!
# This compiles it into a LangChain Runnable,
@ -196,6 +268,18 @@ class GPT4Brain(KnowledgeBrainQA):
print(f"Done tool: {event['name']}")
print(f"Tool output was: {event['data'].get('output')}")
print("--")
elif kind == "on_chain_end":
output = event["data"]["output"]
final_output = [item for item in output if "final" in item]
if final_output:
if (
final_output[0]["final"]["messages"][0].name
== "image-generator"
):
final_message = final_output[0]["final"]["messages"][0].content
response_tokens.append(final_message)
streamed_chat_history.assistant = final_message
yield f"data: {json.dumps(streamed_chat_history.dict())}"
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
@ -203,11 +287,10 @@ class GPT4Brain(KnowledgeBrainQA):
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:
conversational_qa_chain = self.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
transformed_history, _ = self.initialize_streamed_chat_history(
chat_id, question
)
filtered_history = self.filter_history(transformed_history, 20, 2000)
response_tokens = []
config = {"metadata": {"conversation_id": str(chat_id)}}
prompt = ChatPromptTemplate.from_messages(

View File

@ -283,8 +283,8 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:
conversational_qa_chain = self.knowledge_qa.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
transformed_history, _ = self.initialize_streamed_chat_history(
chat_id, question
)
metadata = self.metadata or {}
citations = None