mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-23 09:10:13 +03:00
Expire cache, Fix multiple websocket conversations in OpenaiChat
Map system messages to user messages in GeminiPro
This commit is contained in:
parent
eb48299195
commit
cfa45e7016
@ -26,38 +26,35 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
stream: bool = False,
|
||||
proxy: str = None,
|
||||
api_key: str = None,
|
||||
api_base: str = None,
|
||||
use_auth_header: bool = True,
|
||||
api_base: str = "https://generativelanguage.googleapis.com/v1beta",
|
||||
use_auth_header: bool = False,
|
||||
image: ImageType = None,
|
||||
connector: BaseConnector = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = "gemini-pro-vision" if not model and image else model
|
||||
model = "gemini-pro-vision" if model is None and image is not None else model
|
||||
model = cls.get_model(model)
|
||||
|
||||
if not api_key:
|
||||
raise MissingAuthError('Missing "api_key"')
|
||||
|
||||
headers = params = None
|
||||
if api_base and use_auth_header:
|
||||
if use_auth_header:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
else:
|
||||
params = {"key": api_key}
|
||||
|
||||
if not api_base:
|
||||
api_base = f"https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
method = "streamGenerateContent" if stream else "generateContent"
|
||||
url = f"{api_base.rstrip('/')}/models/{model}:{method}"
|
||||
async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
|
||||
contents = [
|
||||
{
|
||||
"role": "model" if message["role"] == "assistant" else message["role"],
|
||||
"role": "model" if message["role"] == "assistant" else "user",
|
||||
"parts": [{"text": message["content"]}]
|
||||
}
|
||||
for message in messages
|
||||
]
|
||||
if image:
|
||||
if image is not None:
|
||||
image = to_bytes(image)
|
||||
contents[-1]["parts"].append({
|
||||
"inline_data": {
|
||||
@ -87,7 +84,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
lines = [b"{\n"]
|
||||
elif chunk == b",\r\n" or chunk == b"]":
|
||||
try:
|
||||
data = json.loads(b"".join(lines))
|
||||
data = b"".join(lines)
|
||||
data = json.loads(data)
|
||||
yield data["candidates"][0]["content"]["parts"][0]["text"]
|
||||
except:
|
||||
data = data.decode() if isinstance(data, bytes) else data
|
||||
|
@ -5,6 +5,7 @@ import uuid
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
import time
|
||||
from aiohttp import ClientWebSocketResponse
|
||||
|
||||
try:
|
||||
@ -47,7 +48,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
_api_key: str = None
|
||||
_headers: dict = None
|
||||
_cookies: Cookies = None
|
||||
_last_message: int = 0
|
||||
_expires: int = None
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
@ -348,7 +349,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
timeout=timeout
|
||||
) as session:
|
||||
# Read api_key and cookies from cache / browser config
|
||||
if cls._headers is None:
|
||||
if cls._headers is None or time.time() > cls._expires:
|
||||
if api_key is None:
|
||||
# Read api_key from cookies
|
||||
cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
|
||||
@ -437,17 +438,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
await cls.delete_conversation(session, cls._headers, fields.conversation_id)
|
||||
|
||||
@staticmethod
|
||||
async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator:
|
||||
async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str) -> AsyncIterator:
|
||||
while True:
|
||||
yield base64.b64decode((await ws.receive_json())["body"])
|
||||
message = await ws.receive_json()
|
||||
if message["conversation_id"] == conversation_id:
|
||||
yield base64.b64decode(message["body"])
|
||||
|
||||
@classmethod
|
||||
async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator:
|
||||
last_message: int = 0
|
||||
async for message in messages:
|
||||
if message.startswith(b'{"wss_url":'):
|
||||
async with session.ws_connect(json.loads(message)["wss_url"]) as ws:
|
||||
async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields):
|
||||
message = json.loads(message)
|
||||
async with session.ws_connect(message["wss_url"]) as ws:
|
||||
async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws, message["conversation_id"]), session, fields):
|
||||
yield chunk
|
||||
break
|
||||
async for chunk in cls.iter_messages_line(session, message, fields):
|
||||
@ -589,6 +593,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
@classmethod
|
||||
def _set_api_key(cls, api_key: str):
|
||||
cls._api_key = api_key
|
||||
cls._expires = int(time.time()) + 60 * 60 * 4
|
||||
cls._headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user