mirror of
https://github.com/OpenBMB/ChatDev.git
synced 2024-10-06 15:48:26 +03:00
Resolve #284: update openai api to 1.3 while retaining support for the older api & remove embodied agent
This commit is contained in:
parent
82e155fe62
commit
452df57cca
@ -17,7 +17,6 @@ from .task_agent import TaskPlannerAgent, TaskSpecifyAgent
|
||||
from .critic_agent import CriticAgent
|
||||
from .tool_agents.base import BaseToolAgent
|
||||
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
|
||||
from .embodied_agent import EmbodiedAgent
|
||||
from .role_playing import RolePlaying
|
||||
|
||||
__all__ = [
|
||||
@ -28,6 +27,5 @@ __all__ = [
|
||||
'CriticAgent',
|
||||
'BaseToolAgent',
|
||||
'HuggingFaceToolAgent',
|
||||
'EmbodiedAgent',
|
||||
'RolePlaying',
|
||||
]
|
||||
|
@ -29,7 +29,12 @@ from camel.utils import (
|
||||
openai_api_key_required,
|
||||
)
|
||||
|
||||
from openai.types.chat import ChatCompletion
|
||||
try:
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
openai_new_api = True # new openai api version
|
||||
except ImportError:
|
||||
openai_new_api = False # old openai api version
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -191,19 +196,34 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
if num_tokens < self.model_token_limit:
|
||||
response = self.model_backend.run(messages=openai_messages)
|
||||
if not isinstance(response, ChatCompletion):
|
||||
raise RuntimeError("OpenAI returned unexpected struct")
|
||||
output_messages = [
|
||||
ChatMessage(role_name=self.role_name, role_type=self.role_type,
|
||||
meta_dict=dict(), **dict(choice.message))
|
||||
for choice in response.choices
|
||||
]
|
||||
info = self.get_info(
|
||||
response.id,
|
||||
response.usage,
|
||||
[str(choice.finish_reason) for choice in response.choices],
|
||||
num_tokens,
|
||||
)
|
||||
if openai_new_api:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
raise RuntimeError("OpenAI returned unexpected struct")
|
||||
output_messages = [
|
||||
ChatMessage(role_name=self.role_name, role_type=self.role_type,
|
||||
meta_dict=dict(), **dict(choice.message))
|
||||
for choice in response.choices
|
||||
]
|
||||
info = self.get_info(
|
||||
response.id,
|
||||
response.usage,
|
||||
[str(choice.finish_reason) for choice in response.choices],
|
||||
num_tokens,
|
||||
)
|
||||
else:
|
||||
if not isinstance(response, dict):
|
||||
raise RuntimeError("OpenAI returned unexpected struct")
|
||||
output_messages = [
|
||||
ChatMessage(role_name=self.role_name, role_type=self.role_type,
|
||||
meta_dict=dict(), **dict(choice["message"]))
|
||||
for choice in response["choices"]
|
||||
]
|
||||
info = self.get_info(
|
||||
response["id"],
|
||||
response["usage"],
|
||||
[str(choice["finish_reason"]) for choice in response["choices"]],
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
# TODO strict <INFO> check, only in the beginning of the line
|
||||
# if "<INFO>" in output_messages[0].content:
|
||||
|
@ -1,132 +0,0 @@
|
||||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
||||
# Licensed under the Apache License, Version 2.0 (the “License”);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an “AS IS” BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from camel.agents import BaseToolAgent, ChatAgent, HuggingFaceToolAgent
|
||||
from camel.messages import ChatMessage, SystemMessage
|
||||
from camel.typing import ModelType
|
||||
from camel.utils import print_text_animated
|
||||
|
||||
|
||||
class EmbodiedAgent(ChatAgent):
|
||||
r"""Class for managing conversations of CAMEL Embodied Agents.
|
||||
|
||||
Args:
|
||||
system_message (SystemMessage): The system message for the chat agent.
|
||||
model (ModelType, optional): The LLM model to use for generating
|
||||
responses. (default :obj:`ModelType.GPT_4_TURBO`)
|
||||
model_config (Any, optional): Configuration options for the LLM model.
|
||||
(default: :obj:`None`)
|
||||
message_window_size (int, optional): The maximum number of previous
|
||||
messages to include in the context window. If `None`, no windowing
|
||||
is performed. (default: :obj:`None`)
|
||||
action_space (List[Any], optional): The action space for the embodied
|
||||
agent. (default: :obj:`None`)
|
||||
verbose (bool, optional): Whether to print the critic's messages.
|
||||
logger_color (Any): The color of the logger displayed to the user.
|
||||
(default: :obj:`Fore.MAGENTA`)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_message: SystemMessage,
|
||||
model: ModelType = ModelType.GPT_4_TURBO,
|
||||
model_config: Optional[Any] = None,
|
||||
message_window_size: Optional[int] = None,
|
||||
action_space: Optional[List[BaseToolAgent]] = None,
|
||||
verbose: bool = False,
|
||||
logger_color: Any = Fore.MAGENTA,
|
||||
) -> None:
|
||||
default_action_space = [
|
||||
HuggingFaceToolAgent('hugging_face_tool_agent', model=model.value),
|
||||
]
|
||||
self.action_space = action_space or default_action_space
|
||||
action_space_prompt = self.get_action_space_prompt()
|
||||
system_message.content = system_message.content.format(
|
||||
action_space=action_space_prompt)
|
||||
self.verbose = verbose
|
||||
self.logger_color = logger_color
|
||||
super().__init__(
|
||||
system_message=system_message,
|
||||
model=model,
|
||||
model_config=model_config,
|
||||
message_window_size=message_window_size,
|
||||
)
|
||||
|
||||
def get_action_space_prompt(self) -> str:
|
||||
r"""Returns the action space prompt.
|
||||
|
||||
Returns:
|
||||
str: The action space prompt.
|
||||
"""
|
||||
return "\n".join([
|
||||
f"*** {action.name} ***:\n {action.description}"
|
||||
for action in self.action_space
|
||||
])
|
||||
|
||||
def step(
|
||||
self,
|
||||
input_message: ChatMessage,
|
||||
) -> Tuple[ChatMessage, bool, Dict[str, Any]]:
|
||||
r"""Performs a step in the conversation.
|
||||
|
||||
Args:
|
||||
input_message (ChatMessage): The input message.
|
||||
|
||||
Returns:
|
||||
Tuple[ChatMessage, bool, Dict[str, Any]]: A tuple
|
||||
containing the output messages, termination status, and
|
||||
additional information.
|
||||
"""
|
||||
response = super().step(input_message)
|
||||
|
||||
if response.msgs is None or len(response.msgs) == 0:
|
||||
raise RuntimeError("Got None output messages.")
|
||||
if response.terminated:
|
||||
raise RuntimeError(f"{self.__class__.__name__} step failed.")
|
||||
|
||||
# NOTE: Only single output messages are supported
|
||||
explanations, codes = response.msg.extract_text_and_code_prompts()
|
||||
|
||||
if self.verbose:
|
||||
for explanation, code in zip(explanations, codes):
|
||||
print_text_animated(self.logger_color +
|
||||
f"> Explanation:\n{explanation}")
|
||||
print_text_animated(self.logger_color + f"> Code:\n{code}")
|
||||
|
||||
if len(explanations) > len(codes):
|
||||
print_text_animated(self.logger_color +
|
||||
f"> Explanation:\n{explanations}")
|
||||
|
||||
content = response.msg.content
|
||||
|
||||
if codes is not None:
|
||||
content = "\n> Executed Results:"
|
||||
global_vars = {action.name: action for action in self.action_space}
|
||||
for code in codes:
|
||||
executed_outputs = code.execute(global_vars)
|
||||
content += (
|
||||
f"- Python standard output:\n{executed_outputs[0]}\n"
|
||||
f"- Local variables:\n{executed_outputs[1]}\n")
|
||||
content += "*" * 50 + "\n"
|
||||
|
||||
# TODO: Handle errors
|
||||
content = input_message.content + (Fore.RESET +
|
||||
f"\n> Embodied Actions:\n{content}")
|
||||
message = ChatMessage(input_message.role_name, input_message.role_type,
|
||||
input_message.meta_dict, input_message.role,
|
||||
content)
|
||||
return message, response.terminated, response.info
|
@ -24,8 +24,13 @@ from camel.messages import (
|
||||
from camel.prompts import CodePrompt, TextPrompt
|
||||
from camel.typing import ModelType, RoleType
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
try:
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
openai_new_api = True # new openai api version
|
||||
except ImportError:
|
||||
openai_new_api = False # old openai api version
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -47,8 +52,9 @@ class BaseMessage:
|
||||
meta_dict: Optional[Dict[str, str]]
|
||||
role: str
|
||||
content: str
|
||||
function_call: Optional[FunctionCall] = None
|
||||
tool_calls: Optional[ChatCompletionMessageToolCall] = None
|
||||
if openai_new_api:
|
||||
function_call: Optional[FunctionCall] = None
|
||||
tool_calls: Optional[ChatCompletionMessageToolCall] = None
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
r"""Get attribute override to delegate string methods to the
|
||||
|
@ -17,8 +17,14 @@ from typing import Dict, Optional
|
||||
from camel.messages import BaseMessage
|
||||
from camel.typing import RoleType
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
try:
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
openai_new_api = True # new openai api version
|
||||
except ImportError:
|
||||
openai_new_api = False # old openai api version
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage(BaseMessage):
|
||||
@ -38,8 +44,9 @@ class ChatMessage(BaseMessage):
|
||||
meta_dict: Optional[Dict[str, str]]
|
||||
role: str
|
||||
content: str = ""
|
||||
function_call: Optional[FunctionCall] = None
|
||||
tool_calls: Optional[ChatCompletionMessageToolCall] = None
|
||||
if openai_new_api:
|
||||
function_call: Optional[FunctionCall] = None
|
||||
tool_calls: Optional[ChatCompletionMessageToolCall] = None
|
||||
|
||||
def set_user_role_at_backend(self: BaseMessage):
|
||||
return self.__class__(
|
||||
|
@ -21,7 +21,20 @@ from camel.typing import ModelType
|
||||
from chatdev.statistics import prompt_cost
|
||||
from chatdev.utils import log_visualize
|
||||
|
||||
from openai.types.chat import ChatCompletion
|
||||
try:
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
openai_new_api = True # new openai api version
|
||||
except ImportError:
|
||||
openai_new_api = False # old openai api version
|
||||
|
||||
import os
|
||||
|
||||
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
|
||||
if 'BASE_URL' in os.environ:
|
||||
BASE_URL = os.environ['BASE_URL']
|
||||
else:
|
||||
BASE_URL = None
|
||||
|
||||
|
||||
class ModelBackend(ABC):
|
||||
@ -29,7 +42,7 @@ class ModelBackend(ABC):
|
||||
May be OpenAI API, a local LLM, a stub for unit tests, etc."""
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs) -> ChatCompletion:
|
||||
def run(self, *args, **kwargs):
|
||||
r"""Runs the query to the backend model.
|
||||
|
||||
Raises:
|
||||
@ -49,47 +62,87 @@ class OpenAIModel(ModelBackend):
|
||||
super().__init__()
|
||||
self.model_type = model_type
|
||||
self.model_config_dict = model_config_dict
|
||||
|
||||
def run(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
string = "\n".join([message["content"] for message in kwargs["messages"]])
|
||||
encoding = tiktoken.encoding_for_model(self.model_type.value)
|
||||
num_prompt_tokens = len(encoding.encode(string))
|
||||
gap_between_send_receive = 15 * len(kwargs["messages"])
|
||||
num_prompt_tokens += gap_between_send_receive
|
||||
|
||||
num_max_token_map = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k-0613": 16384,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-1106-preview": 4096,
|
||||
"gpt-4-1106-vision-preview": 4096,
|
||||
}
|
||||
num_max_token = num_max_token_map[self.model_type.value]
|
||||
num_max_completion_tokens = num_max_token - num_prompt_tokens
|
||||
self.model_config_dict['max_tokens'] = num_max_completion_tokens
|
||||
if openai_new_api:
|
||||
# Experimental, add base_url
|
||||
if BASE_URL:
|
||||
client = openai.OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=BASE_URL,
|
||||
)
|
||||
else:
|
||||
client = openai.OpenAI(
|
||||
api_key=OPENAI_API_KEY
|
||||
)
|
||||
|
||||
try:
|
||||
response = openai.chat.completions.create(*args, **kwargs, model=self.model_type.value, **self.model_config_dict)
|
||||
except AttributeError:
|
||||
response = openai.chat.completions.create(*args, **kwargs, model=self.model_type.value, **self.model_config_dict)
|
||||
num_max_token_map = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k-0613": 16384,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-1106-preview": 4096,
|
||||
"gpt-4-1106-vision-preview": 4096,
|
||||
}
|
||||
num_max_token = num_max_token_map[self.model_type.value]
|
||||
num_max_completion_tokens = num_max_token - num_prompt_tokens
|
||||
self.model_config_dict['max_tokens'] = num_max_completion_tokens
|
||||
|
||||
cost = prompt_cost(
|
||||
self.model_type.value,
|
||||
num_prompt_tokens=response.usage.prompt_tokens,
|
||||
response = client.chat.completions.create(*args, **kwargs, model=self.model_type.value,
|
||||
**self.model_config_dict)
|
||||
|
||||
cost = prompt_cost(
|
||||
self.model_type.value,
|
||||
num_prompt_tokens=response.usage.prompt_tokens,
|
||||
num_completion_tokens=response.usage.completion_tokens
|
||||
)
|
||||
)
|
||||
|
||||
log_visualize(
|
||||
"**[OpenAI_Usage_Info Receive]**\nprompt_tokens: {}\ncompletion_tokens: {}\ntotal_tokens: {}\ncost: ${:.6f}\n".format(
|
||||
response.usage.prompt_tokens, response.usage.completion_tokens,
|
||||
response.usage.total_tokens, cost))
|
||||
if not isinstance(response, ChatCompletion):
|
||||
raise RuntimeError("Unexpected return from OpenAI API")
|
||||
return response
|
||||
log_visualize(
|
||||
"**[OpenAI_Usage_Info Receive]**\nprompt_tokens: {}\ncompletion_tokens: {}\ntotal_tokens: {}\ncost: ${:.6f}\n".format(
|
||||
response.usage.prompt_tokens, response.usage.completion_tokens,
|
||||
response.usage.total_tokens, cost))
|
||||
if not isinstance(response, ChatCompletion):
|
||||
raise RuntimeError("Unexpected return from OpenAI API")
|
||||
return response
|
||||
else:
|
||||
num_max_token_map = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k-0613": 16384,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
}
|
||||
num_max_token = num_max_token_map[self.model_type.value]
|
||||
num_max_completion_tokens = num_max_token - num_prompt_tokens
|
||||
self.model_config_dict['max_tokens'] = num_max_completion_tokens
|
||||
|
||||
response = openai.ChatCompletion.create(*args, **kwargs, model=self.model_type.value,
|
||||
**self.model_config_dict)
|
||||
|
||||
cost = prompt_cost(
|
||||
self.model_type.value,
|
||||
num_prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
num_completion_tokens=response["usage"]["completion_tokens"]
|
||||
)
|
||||
|
||||
log_visualize(
|
||||
"**[OpenAI_Usage_Info Receive]**\nprompt_tokens: {}\ncompletion_tokens: {}\ntotal_tokens: {}\ncost: ${:.6f}\n".format(
|
||||
response["usage"]["prompt_tokens"], response["usage"]["completion_tokens"],
|
||||
response["usage"]["total_tokens"], cost))
|
||||
if not isinstance(response, Dict):
|
||||
raise RuntimeError("Unexpected return from OpenAI API")
|
||||
return response
|
||||
|
||||
|
||||
class StubModel(ModelBackend):
|
||||
@ -123,7 +176,12 @@ class ModelFactory:
|
||||
default_model_type = ModelType.GPT_3_5_TURBO
|
||||
|
||||
if model_type in {
|
||||
ModelType.GPT_3_5_TURBO, ModelType.GPT_4, ModelType.GPT_4_32k, ModelType.GPT_4_TURBO,
|
||||
ModelType.GPT_3_5_TURBO,
|
||||
ModelType.GPT_3_5_TURBO_NEW,
|
||||
ModelType.GPT_4,
|
||||
ModelType.GPT_4_32k,
|
||||
ModelType.GPT_4_TURBO,
|
||||
ModelType.GPT_4_TURBO_V,
|
||||
None
|
||||
}:
|
||||
model_class = OpenAIModel
|
||||
|
@ -45,9 +45,12 @@ class RoleType(Enum):
|
||||
|
||||
class ModelType(Enum):
|
||||
GPT_3_5_TURBO = "gpt-3.5-turbo-16k-0613"
|
||||
GPT_3_5_TURBO_NEW = "gpt-3.5-turbo-16k"
|
||||
GPT_4 = "gpt-4"
|
||||
GPT_4_32k = "gpt-4-32k"
|
||||
GPT_4_TURBO = "gpt-4-1106-preview"
|
||||
GPT_4_TURBO_V = "gpt-4-1106-vision-preview"
|
||||
|
||||
STUB = "stub"
|
||||
|
||||
@property
|
||||
|
@ -83,7 +83,12 @@ def num_tokens_from_messages(
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
if model in {
|
||||
ModelType.GPT_3_5_TURBO, ModelType.GPT_4, ModelType.GPT_4_32k, ModelType.GPT_4_TURBO,
|
||||
ModelType.GPT_3_5_TURBO,
|
||||
ModelType.GPT_3_5_TURBO_NEW,
|
||||
ModelType.GPT_4,
|
||||
ModelType.GPT_4_32k,
|
||||
ModelType.GPT_4_TURBO,
|
||||
ModelType.GPT_4_TURBO_V,
|
||||
ModelType.STUB
|
||||
}:
|
||||
return count_tokens_openai_chat_models(messages, encoding)
|
||||
@ -109,6 +114,8 @@ def get_model_token_limit(model: ModelType) -> int:
|
||||
"""
|
||||
if model == ModelType.GPT_3_5_TURBO:
|
||||
return 16384
|
||||
elif model == ModelType.GPT_3_5_TURBO_NEW:
|
||||
return 16384
|
||||
elif model == ModelType.GPT_4:
|
||||
return 8192
|
||||
elif model == ModelType.GPT_4_32k:
|
||||
|
@ -14,6 +14,14 @@ from chatdev.documents import Documents
|
||||
from chatdev.roster import Roster
|
||||
from chatdev.utils import log_visualize
|
||||
|
||||
try:
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
openai_new_api = True # new openai api version
|
||||
except ImportError:
|
||||
openai_new_api = False # old openai api version
|
||||
|
||||
|
||||
class ChatEnvConfig:
|
||||
def __init__(self, clear_structure,
|
||||
@ -216,12 +224,20 @@ class ChatEnv:
|
||||
if desc.endswith(".png"):
|
||||
desc = desc.replace(".png", "")
|
||||
print("{}: {}".format(filename, desc))
|
||||
response = openai.images.generate(
|
||||
prompt=desc,
|
||||
n=1,
|
||||
size="256x256"
|
||||
)
|
||||
image_url = response.data[0].url
|
||||
if openai_new_api:
|
||||
response = openai.images.generate(
|
||||
prompt=desc,
|
||||
n=1,
|
||||
size="256x256"
|
||||
)
|
||||
image_url = response.data[0].url
|
||||
else:
|
||||
response = openai.Image.create(
|
||||
prompt=desc,
|
||||
n=1,
|
||||
size="256x256"
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
download(image_url, filename)
|
||||
|
||||
def get_proposed_images_from_message(self, messages):
|
||||
@ -258,12 +274,22 @@ class ChatEnv:
|
||||
if desc.endswith(".png"):
|
||||
desc = desc.replace(".png", "")
|
||||
print("{}: {}".format(filename, desc))
|
||||
response = openai.Image.create(
|
||||
prompt=desc,
|
||||
n=1,
|
||||
size="256x256"
|
||||
)
|
||||
image_url = response.data[0].url
|
||||
|
||||
if openai_new_api:
|
||||
response = openai.images.generate(
|
||||
prompt=desc,
|
||||
n=1,
|
||||
size="256x256"
|
||||
)
|
||||
image_url = response.data[0].url
|
||||
else:
|
||||
response = openai.Image.create(
|
||||
prompt=desc,
|
||||
n=1,
|
||||
size="256x256"
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
|
||||
download(image_url, filename)
|
||||
|
||||
return images
|
||||
|
23
run.py
23
run.py
@ -23,6 +23,18 @@ sys.path.append(root)
|
||||
|
||||
from chatdev.chat_chain import ChatChain
|
||||
|
||||
try:
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
openai_new_api = True # new openai api version
|
||||
except ImportError:
|
||||
openai_new_api = False # old openai api version
|
||||
print(
|
||||
"Warning: Your OpenAI version is outdated. \n "
|
||||
"Please update as specified in requirement.txt. \n "
|
||||
"The old API interface is deprecated and will no longer be supported.")
|
||||
|
||||
|
||||
def get_config(company):
|
||||
"""
|
||||
@ -78,8 +90,15 @@ args = parser.parse_args()
|
||||
# Init ChatChain
|
||||
# ----------------------------------------
|
||||
config_path, config_phase_path, config_role_path = get_config(args.config)
|
||||
args2type = {'GPT_3_5_TURBO': ModelType.GPT_3_5_TURBO, 'GPT_4': ModelType.GPT_4, \
|
||||
'GPT_4_32K': ModelType.GPT_4_32k, 'GPT_4_TURBO': ModelType.GPT_4_TURBO}
|
||||
args2type = {'GPT_3_5_TURBO': ModelType.GPT_3_5_TURBO,
|
||||
'GPT_4': ModelType.GPT_4,
|
||||
'GPT_4_32K': ModelType.GPT_4_32k,
|
||||
'GPT_4_TURBO': ModelType.GPT_4_TURBO,
|
||||
'GPT_4_TURBO_V': ModelType.GPT_4_TURBO_V
|
||||
}
|
||||
if openai_new_api:
|
||||
args2type['GPT_3_5_TURBO'] = ModelType.GPT_3_5_TURBO_NEW
|
||||
|
||||
chat_chain = ChatChain(config_path=config_path,
|
||||
config_phase_path=config_phase_path,
|
||||
config_role_path=config_role_path,
|
||||
|
Loading…
Reference in New Issue
Block a user