mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-23 09:10:13 +03:00
Add websocket support in OpenaiChat
This commit is contained in:
parent
84812b9632
commit
ac86e576d2
@ -4,6 +4,8 @@ import asyncio
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from aiohttp import ClientWebSocketResponse
|
||||
|
||||
try:
|
||||
from py_arkose_generator.arkose import get_values_for_request
|
||||
@ -22,7 +24,7 @@ except ImportError:
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..helper import get_cookies
|
||||
from ...webdriver import get_browser
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImageType, Union
|
||||
from ...typing import AsyncResult, Messages, Cookies, ImageType, Union, AsyncIterator
|
||||
from ...requests import get_args_from_browser
|
||||
from ...requests.aiohttp import StreamSession
|
||||
from ...image import to_image, to_bytes, ImageResponse, ImageRequest
|
||||
@ -38,10 +40,14 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
supports_gpt_35_turbo = True
|
||||
supports_gpt_4 = True
|
||||
supports_message_history = True
|
||||
supports_system_message = True
|
||||
default_model = None
|
||||
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
|
||||
model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo"}
|
||||
_args: dict = None
|
||||
model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo"}
|
||||
_api_key: str = None
|
||||
_headers: dict = None
|
||||
_cookies: Cookies = None
|
||||
_last_message: int = 0
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
@ -299,6 +305,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
conversation_id: str = None,
|
||||
parent_id: str = None,
|
||||
image: ImageType = None,
|
||||
image_name: str = None,
|
||||
response_fields: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
@ -332,67 +339,64 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if not parent_id:
|
||||
parent_id = str(uuid.uuid4())
|
||||
|
||||
# Read api_key from args
|
||||
# Read api_key from arguments
|
||||
api_key = kwargs["access_token"] if "access_token" in kwargs else api_key
|
||||
# If no cached args
|
||||
if cls._args is None:
|
||||
if api_key is None:
|
||||
# Read api_key from cookies
|
||||
cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
|
||||
api_key = cookies["access_token"] if "access_token" in cookies else api_key
|
||||
cls._args = cls._create_request_args(cookies)
|
||||
else:
|
||||
# Read api_key from cache
|
||||
api_key = cls._args["headers"]["Authorization"] if "Authorization" in cls._args["headers"] else None
|
||||
|
||||
async with StreamSession(
|
||||
proxies={"https": proxy},
|
||||
impersonate="chrome",
|
||||
timeout=timeout
|
||||
) as session:
|
||||
# Read api_key from session cookies
|
||||
# Read api_key and cookies from cache / browser config
|
||||
if cls._headers is None:
|
||||
if api_key is None:
|
||||
# Read api_key from cookies
|
||||
cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
|
||||
api_key = cookies["access_token"] if "access_token" in cookies else api_key
|
||||
cls._create_request_args(cookies)
|
||||
else:
|
||||
api_key = cls._api_key if api_key is None else api_key
|
||||
# Read api_key with session cookies
|
||||
if api_key is None and cookies:
|
||||
api_key = await cls.fetch_access_token(session, cls._args["headers"])
|
||||
api_key = await cls.fetch_access_token(session, cls._headers)
|
||||
# Load default model
|
||||
if cls.default_model is None:
|
||||
if cls.default_model is None and api_key is not None:
|
||||
try:
|
||||
if cookies and not model and api_key is not None:
|
||||
cls._args["headers"]["Authorization"] = api_key
|
||||
cls.default_model = cls.get_model(await cls.get_default_model(session, cls._args["headers"]))
|
||||
elif api_key:
|
||||
cls.default_model = cls.get_model(model or "gpt-3.5-turbo")
|
||||
if not model:
|
||||
cls._set_api_key(api_key)
|
||||
cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers))
|
||||
else:
|
||||
cls.default_model = cls.get_model(model)
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print("OpenaiChat: Load default_model failed")
|
||||
print(f"{e.__class__.__name__}: {e}")
|
||||
# Browse api_key and update default model
|
||||
# Browse api_key and default model
|
||||
if api_key is None or cls.default_model is None:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"Please login: [ChatGPT]({login_url})\n\n"
|
||||
try:
|
||||
cls._args = cls.browse_access_token(proxy)
|
||||
cls.browse_access_token(proxy)
|
||||
except MissingRequirementsError:
|
||||
raise MissingAuthError(f'Missing "access_token". Add a "api_key" please')
|
||||
cls.default_model = cls.get_model(await cls.get_default_model(session, cls._args["headers"]))
|
||||
cls.default_model = cls.get_model(await cls.get_default_model(session, cls._headers))
|
||||
else:
|
||||
cls._args["headers"]["Authorization"] = api_key
|
||||
cls._set_api_key(api_key)
|
||||
|
||||
try:
|
||||
image_response = await cls.upload_image(
|
||||
session,
|
||||
cls._args["headers"],
|
||||
image,
|
||||
kwargs.get("image_name")
|
||||
) if image else None
|
||||
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
|
||||
except Exception as e:
|
||||
yield e
|
||||
if debug.logging:
|
||||
print("OpenaiChat: Upload image failed")
|
||||
print(f"{e.__class__.__name__}: {e}")
|
||||
|
||||
end_turn = EndTurn()
|
||||
model = cls.get_model(model)
|
||||
model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model
|
||||
while not end_turn.is_end:
|
||||
model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
|
||||
fields = ResponseFields()
|
||||
while fields.finish_reason is None:
|
||||
arkose_token = await cls.get_arkose_token(session)
|
||||
conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id
|
||||
parent_id = parent_id if fields.message_id is None else fields.message_id
|
||||
data = {
|
||||
"action": action,
|
||||
"arkose_token": arkose_token,
|
||||
@ -405,8 +409,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"history_and_training_disabled": history_disabled and not auto_continue,
|
||||
}
|
||||
if action != "continue":
|
||||
messages = messages if not conversation_id else [messages[-1]]
|
||||
data["messages"] = cls.create_messages(messages, image_response)
|
||||
messages = messages if conversation_id is None else [messages[-1]]
|
||||
data["messages"] = cls.create_messages(messages, image_request)
|
||||
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-api/conversation",
|
||||
@ -414,63 +418,88 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
headers={
|
||||
"Accept": "text/event-stream",
|
||||
"OpenAI-Sentinel-Arkose-Token": arkose_token,
|
||||
**cls._args["headers"]
|
||||
**cls._headers
|
||||
}
|
||||
) as response:
|
||||
cls._update_request_args(session)
|
||||
if not response.ok:
|
||||
message = f"{await response.text()} headers:\n{json.dumps(cls._args['headers'], indent=4)}"
|
||||
raise RuntimeError(f"Response {response.status}: {message}")
|
||||
last_message: int = 0
|
||||
async for line in response.iter_lines():
|
||||
if not line.startswith(b"data: "):
|
||||
continue
|
||||
elif line.startswith(b"data: [DONE]"):
|
||||
break
|
||||
try:
|
||||
line = json.loads(line[6:])
|
||||
except:
|
||||
continue
|
||||
if "message" not in line:
|
||||
continue
|
||||
if "error" in line and line["error"]:
|
||||
raise RuntimeError(line["error"])
|
||||
if "message_type" not in line["message"]["metadata"]:
|
||||
continue
|
||||
try:
|
||||
image_response = await cls.get_generated_image(session, cls._args["headers"], line)
|
||||
if image_response is not None:
|
||||
yield image_response
|
||||
except Exception as e:
|
||||
yield e
|
||||
if line["message"]["author"]["role"] != "assistant":
|
||||
continue
|
||||
if line["message"]["content"]["content_type"] != "text":
|
||||
continue
|
||||
if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
|
||||
continue
|
||||
conversation_id = line["conversation_id"]
|
||||
parent_id = line["message"]["id"]
|
||||
raise RuntimeError(f"Response {response.status}: {await response.text()}")
|
||||
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields):
|
||||
if response_fields:
|
||||
response_fields = False
|
||||
yield ResponseFields(conversation_id, parent_id, end_turn)
|
||||
if "parts" in line["message"]["content"]:
|
||||
new_message = line["message"]["content"]["parts"][0]
|
||||
if len(new_message) > last_message:
|
||||
yield new_message[last_message:]
|
||||
last_message = len(new_message)
|
||||
if "finish_details" in line["message"]["metadata"]:
|
||||
if line["message"]["metadata"]["finish_details"]["type"] == "stop":
|
||||
end_turn.end()
|
||||
yield fields
|
||||
yield chunk
|
||||
if not auto_continue:
|
||||
break
|
||||
action = "continue"
|
||||
await asyncio.sleep(5)
|
||||
if history_disabled and auto_continue:
|
||||
await cls.delete_conversation(session, cls._args["headers"], conversation_id)
|
||||
await cls.delete_conversation(session, cls._headers, conversation_id)
|
||||
|
||||
@staticmethod
|
||||
async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator:
|
||||
while True:
|
||||
yield base64.b64decode((await ws.receive_json())["body"])
|
||||
|
||||
@classmethod
|
||||
def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]:
|
||||
async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator:
|
||||
last_message: int = 0
|
||||
async for message in messages:
|
||||
if message.startswith(b'{"wss_url":'):
|
||||
async with session.ws_connect(json.loads(message)["wss_url"]) as ws:
|
||||
async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields):
|
||||
yield chunk
|
||||
break
|
||||
async for chunk in cls.iter_messages_line(session, message, fields):
|
||||
if fields.finish_reason is not None:
|
||||
break
|
||||
elif isinstance(chunk, str):
|
||||
if len(chunk) > last_message:
|
||||
yield chunk[last_message:]
|
||||
last_message = len(chunk)
|
||||
else:
|
||||
yield chunk
|
||||
if fields.finish_reason is not None:
|
||||
break
|
||||
|
||||
@classmethod
|
||||
async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: ResponseFields) -> AsyncIterator:
|
||||
if not line.startswith(b"data: "):
|
||||
return
|
||||
elif line.startswith(b"data: [DONE]"):
|
||||
return
|
||||
try:
|
||||
line = json.loads(line[6:])
|
||||
except:
|
||||
return
|
||||
if "message" not in line:
|
||||
return
|
||||
if "error" in line and line["error"]:
|
||||
raise RuntimeError(line["error"])
|
||||
if "message_type" not in line["message"]["metadata"]:
|
||||
return
|
||||
try:
|
||||
image_response = await cls.get_generated_image(session, cls._headers, line)
|
||||
if image_response is not None:
|
||||
yield image_response
|
||||
except Exception as e:
|
||||
yield e
|
||||
if line["message"]["author"]["role"] != "assistant":
|
||||
return
|
||||
if line["message"]["content"]["content_type"] != "text":
|
||||
return
|
||||
if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
|
||||
return
|
||||
if fields.conversation_id is None:
|
||||
fields.conversation_id = line["conversation_id"]
|
||||
fields.message_id = line["message"]["id"]
|
||||
if "parts" in line["message"]["content"]:
|
||||
yield line["message"]["content"]["parts"][0]
|
||||
if "finish_details" in line["message"]["metadata"]:
|
||||
fields.finish_reason = line["message"]["metadata"]["finish_details"]["type"]
|
||||
|
||||
@classmethod
|
||||
def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> None:
|
||||
"""
|
||||
Browse to obtain an access token.
|
||||
|
||||
@ -493,9 +522,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"return accessToken;"
|
||||
)
|
||||
args = get_args_from_browser(f"{cls.url}/", driver, do_bypass_cloudflare=False)
|
||||
args["headers"]["Authorization"] = f"Bearer {access_token}"
|
||||
args["headers"]["Cookie"] = cls._format_cookies(args["cookies"])
|
||||
return args
|
||||
cls._headers = args["headers"]
|
||||
cls._cookies = args["cookies"]
|
||||
cls._update_cookie_header()
|
||||
cls._set_api_key(access_token)
|
||||
finally:
|
||||
driver.close()
|
||||
|
||||
@ -546,16 +576,24 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
||||
@classmethod
|
||||
def _create_request_args(cls, cookies: Union[Cookies, None]):
|
||||
return {
|
||||
"headers": {} if cookies is None else {"Cookie": cls._format_cookies(cookies)},
|
||||
"cookies": {} if cookies is None else cookies
|
||||
}
|
||||
cls._headers = {}
|
||||
cls._cookies = {} if cookies is None else cookies
|
||||
cls._update_cookie_header()
|
||||
|
||||
@classmethod
|
||||
def _update_request_args(cls, session: StreamSession):
|
||||
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
|
||||
cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value
|
||||
cls._args["headers"]["Cookie"] = cls._format_cookies(cls._args["cookies"])
|
||||
cls._cookies[c.name if hasattr(c, "name") else c.key] = c.value
|
||||
cls._update_cookie_header()
|
||||
|
||||
@classmethod
|
||||
def _set_api_key(cls, api_key: str):
|
||||
cls._api_key = api_key
|
||||
cls._headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
@classmethod
|
||||
def _update_cookie_header(cls):
|
||||
cls._headers["Cookie"] = cls._format_cookies(cls._cookies)
|
||||
|
||||
class EndTurn:
|
||||
"""
|
||||
@ -571,10 +609,10 @@ class ResponseFields:
|
||||
"""
|
||||
Class to encapsulate response fields.
|
||||
"""
|
||||
def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
|
||||
def __init__(self, conversation_id: str = None, message_id: str = None, finish_reason: str = None):
|
||||
self.conversation_id = conversation_id
|
||||
self.message_id = message_id
|
||||
self._end_turn = end_turn
|
||||
self.finish_reason = finish_reason
|
||||
|
||||
class Response():
|
||||
"""
|
||||
@ -608,7 +646,7 @@ class Response():
|
||||
self._message = "".join(chunks)
|
||||
if not self._fields:
|
||||
raise RuntimeError("Missing response fields")
|
||||
self.is_end = self._fields._end_turn.is_end
|
||||
self.is_end = self._fields.end_turn
|
||||
|
||||
def __aiter__(self):
|
||||
return self.generator()
|
||||
|
@ -270,7 +270,7 @@ class ProviderModelMixin:
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str) -> str:
|
||||
if not model:
|
||||
if not model and cls.default_model is not None:
|
||||
model = cls.default_model
|
||||
elif model in cls.model_aliases:
|
||||
model = cls.model_aliases[model]
|
||||
|
@ -26,6 +26,7 @@ class BaseProvider(ABC):
|
||||
supports_gpt_35_turbo: bool = False
|
||||
supports_gpt_4: bool = False
|
||||
supports_message_history: bool = False
|
||||
supports_system_message: bool = False
|
||||
params: str
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user