mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-23 04:17:48 +03:00
feat(gpt4): image generation (#2569)
This pull request adds a new feature to generate images using the OpenAI DALL-E model. The `ImageGeneratorTool` class is implemented to handle the image generation functionality.
This commit is contained in:
parent
854cf9ef7c
commit
4e5b0c0373
@ -1,8 +1,15 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Annotated, AsyncIterable, List, Sequence, TypedDict
|
||||
from typing import Annotated, AsyncIterable, List, Optional, Sequence, Type, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel as BaseModelV1
|
||||
from langchain.pydantic_v1 import Field as FieldV1
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_core.messages import BaseMessage, ToolMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
@ -15,6 +22,8 @@ from modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
@ -28,6 +37,56 @@ logger = get_logger(__name__)
|
||||
chat_service = ChatService()
|
||||
|
||||
|
||||
class ImageGenerationInput(BaseModelV1):
|
||||
query: str = FieldV1(
|
||||
...,
|
||||
title="description",
|
||||
description="A detailled prompt to generate the image from. Takes into account the history of the chat.",
|
||||
)
|
||||
|
||||
|
||||
class ImageGeneratorTool(BaseTool):
|
||||
name = "image-generator"
|
||||
description = "useful for when you need to answer questions about current events"
|
||||
args_schema: Type[BaseModel] = ImageGenerationInput
|
||||
return_direct = True
|
||||
|
||||
def _run(
|
||||
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
client = OpenAI()
|
||||
|
||||
response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=query,
|
||||
size="1024x1024",
|
||||
quality="standard",
|
||||
n=1,
|
||||
)
|
||||
logger.info(response.data[0])
|
||||
image_url = response.data[0].url
|
||||
revised_prompt = response.data[0].revised_prompt
|
||||
# Make the url a markdown image
|
||||
return f"{revised_prompt} \n ![Generated Image]({image_url}) "
|
||||
|
||||
async def _arun(
|
||||
self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
client = OpenAI()
|
||||
response = await run_manager.run_async(
|
||||
client.images.generate,
|
||||
model="dall-e-3",
|
||||
prompt=query,
|
||||
size="1024x1024",
|
||||
quality="standard",
|
||||
n=1,
|
||||
)
|
||||
image_url = response.data[0].url
|
||||
# Make the url a markdown image
|
||||
return f"![Generated Image]({image_url})"
|
||||
|
||||
|
||||
class GPT4Brain(KnowledgeBrainQA):
|
||||
"""This is the Notion brain class. it is a KnowledgeBrainQA has the data is stored locally.
|
||||
It is going to call the Data Store internally to get the data.
|
||||
@ -36,7 +95,7 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
KnowledgeBrainQA (_type_): A brain that store the knowledge internaly
|
||||
"""
|
||||
|
||||
tools: List[BaseTool] = [DuckDuckGoSearchResults()]
|
||||
tools: List[BaseTool] = [DuckDuckGoSearchResults(), ImageGeneratorTool()]
|
||||
tool_executor: ToolExecutor = ToolExecutor(tools)
|
||||
model_function: ChatOpenAI = None
|
||||
|
||||
@ -54,10 +113,16 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
def should_continue(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
# Make sure there is a previous message
|
||||
|
||||
if last_message.tool_calls:
|
||||
name = last_message.tool_calls[0]["name"]
|
||||
if name == "image-generator":
|
||||
return "final"
|
||||
# If there is no function call, then we finish
|
||||
if not last_message.tool_calls:
|
||||
return "end"
|
||||
# Otherwise if there is, we continue
|
||||
# Otherwise if there is, we check if it's suppose to return direct
|
||||
else:
|
||||
return "continue"
|
||||
|
||||
@ -76,6 +141,9 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
last_message = messages[-1]
|
||||
# We construct an ToolInvocation from the function_call
|
||||
tool_call = last_message.tool_calls[0]
|
||||
tool_name = tool_call["name"]
|
||||
arguments = tool_call["args"]
|
||||
|
||||
action = ToolInvocation(
|
||||
tool=tool_call["name"],
|
||||
tool_input=tool_call["args"],
|
||||
@ -96,6 +164,7 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
# Define the two nodes we will cycle between
|
||||
workflow.add_node("agent", self.call_model)
|
||||
workflow.add_node("action", self.call_tool)
|
||||
workflow.add_node("final", self.call_tool)
|
||||
|
||||
# Set the entrypoint as `agent`
|
||||
# This means that this node is the first one called
|
||||
@ -117,6 +186,8 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
{
|
||||
# If `tools`, then we call the tool node.
|
||||
"continue": "action",
|
||||
# Final call
|
||||
"final": "final",
|
||||
# Otherwise we finish.
|
||||
"end": END,
|
||||
},
|
||||
@ -125,6 +196,7 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
# We now add a normal edge from `tools` to `agent`.
|
||||
# This means that after `tools` is called, `agent` node is called next.
|
||||
workflow.add_edge("action", "agent")
|
||||
workflow.add_edge("final", END)
|
||||
|
||||
# Finally, we compile it!
|
||||
# This compiles it into a LangChain Runnable,
|
||||
@ -196,6 +268,18 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
print(f"Done tool: {event['name']}")
|
||||
print(f"Tool output was: {event['data'].get('output')}")
|
||||
print("--")
|
||||
elif kind == "on_chain_end":
|
||||
output = event["data"]["output"]
|
||||
final_output = [item for item in output if "final" in item]
|
||||
if final_output:
|
||||
if (
|
||||
final_output[0]["final"]["messages"][0].name
|
||||
== "image-generator"
|
||||
):
|
||||
final_message = final_output[0]["final"]["messages"][0].content
|
||||
response_tokens.append(final_message)
|
||||
streamed_chat_history.assistant = final_message
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
|
||||
|
||||
@ -203,11 +287,10 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
conversational_qa_chain = self.get_chain()
|
||||
transformed_history, streamed_chat_history = (
|
||||
self.initialize_streamed_chat_history(chat_id, question)
|
||||
transformed_history, _ = self.initialize_streamed_chat_history(
|
||||
chat_id, question
|
||||
)
|
||||
filtered_history = self.filter_history(transformed_history, 20, 2000)
|
||||
response_tokens = []
|
||||
config = {"metadata": {"conversation_id": str(chat_id)}}
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
|
@ -283,8 +283,8 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
conversational_qa_chain = self.knowledge_qa.get_chain()
|
||||
transformed_history, streamed_chat_history = (
|
||||
self.initialize_streamed_chat_history(chat_id, question)
|
||||
transformed_history, _ = self.initialize_streamed_chat_history(
|
||||
chat_id, question
|
||||
)
|
||||
metadata = self.metadata or {}
|
||||
citations = None
|
||||
|
Loading…
Reference in New Issue
Block a user