Fix issue with get_cookies from nodriver in OpenaiChat

This commit is contained in:
Heiner Lohaus 2024-12-18 02:28:26 +01:00
parent bbb858249b
commit af677717ee
7 changed files with 66 additions and 29 deletions

View File

@ -3,10 +3,16 @@ from __future__ import annotations
import asyncio
import json
try:
import nodriver
has_nodriver = True
except ImportError:
has_nodriver = False
from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies, DEFAULT_HEADERS
from ..errors import ResponseStatusError, MissingRequirementsError
from ..errors import ResponseStatusError
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
label = "Cloudflare AI"
@ -35,12 +41,15 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
def get_models(cls) -> str:
if not cls.models:
if cls._args is None:
if has_nodriver:
get_running_loop(check_nested=True)
args = get_args_from_nodriver(cls.url)
cls._args = asyncio.run(args)
else:
cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
with Session(**cls._args) as session:
response = session.get(cls.models_url)
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
cls._args["cookies"] = merge_cookies(cls._args["cookies"], response)
try:
raise_for_status(response)
except ResponseStatusError:
@ -62,10 +71,10 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
**kwargs
) -> AsyncResult:
if cls._args is None:
try:
if has_nodriver:
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
except MissingRequirementsError:
cls._args = {"headers": DEFAULT_HEADERS, cookies: {}}
else:
cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
model = cls.get_model(model)
data = {
"messages": messages,

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import os
import json
import asyncio
import base64
@ -76,10 +77,13 @@ class Copilot(AbstractProvider, ProviderModelMixin):
cls._access_token, cls._cookies = readHAR(cls.url)
except NoValidHarFileError as h:
debug.log(f"Copilot: {h}")
try:
if has_nodriver:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"[Login to {cls.label}]({login_url})\n\n"
get_running_loop(check_nested=True)
cls._access_token, cls._cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy))
except MissingRequirementsError:
else:
raise h
debug.log(f"Copilot: Access token: {cls._access_token[:7]}...{cls._access_token[-5:]}")
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}"

View File

@ -81,7 +81,7 @@ class Gemini(AsyncGeneratorProvider, ProviderModelMixin):
browser = await get_nodriver(proxy=proxy, user_data_dir="gemini")
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [Google Gemini]({login_url})\n\n"
yield f"[Login to {cls.label}]({login_url})\n\n"
page = await browser.get(f"{cls.url}/app")
await page.select("div.ql-editor.textarea", 240)
cookies = {}

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import os
import re
import asyncio
import uuid
@ -8,6 +9,7 @@ import base64
import time
import requests
import random
from typing import AsyncIterator
from copy import copy
try:
@ -314,7 +316,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
RuntimeError: If an error occurs during processing.
"""
if cls.needs_auth:
await cls.login(proxy)
async for message in cls.login(proxy):
yield message
async with StreamSession(
proxy=proxy,
impersonate="chrome",
@ -504,7 +507,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
async def synthesize(cls, params: dict) -> AsyncIterator[bytes]:
await cls.login()
async for _ in cls.login():
pass
async with StreamSession(
impersonate="chrome",
timeout=0
@ -519,23 +523,27 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
yield chunk
@classmethod
async def login(cls, proxy: str = None):
async def login(cls, proxy: str = None) -> AsyncIterator[str]:
if cls._expires is not None and cls._expires < time.time():
cls._headers = cls._api_key = None
try:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
if RequestConfig.access_token is not None:
cls._set_api_key(RequestConfig.access_token)
except NoValidHarFileError:
if has_nodriver:
if cls._api_key is None:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"[Login to {cls.label}]({login_url})\n\n"
await cls.nodriver_auth(proxy)
else:
raise
@classmethod
async def nodriver_auth(cls, proxy: str = None):
browser = await get_nodriver(proxy=proxy, user_data_dir="chatgpt")
browser = await get_nodriver(proxy=proxy)
page = browser.main_tab
def on_request(event: nodriver.cdp.network.RequestWillBeSent):
if event.request.url == start_url or event.request.url.startswith(conversation_url):
@ -548,7 +556,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if "OpenAI-Sentinel-Turnstile-Token" in event.request.headers:
RequestConfig.turnstile_token = event.request.headers["OpenAI-Sentinel-Turnstile-Token"]
if "Authorization" in event.request.headers:
cls._set_api_key(event.request.headers["Authorization"].split()[-1])
cls._api_key = event.request.headers["Authorization"].split()[-1]
elif event.request.url == arkose_url:
RequestConfig.arkose_request = arkReq(
arkURL=event.request.url,
@ -569,7 +577,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if body:
match = re.search(r'"accessToken":"(.*?)"', body)
if match:
cls._set_api_key(match.group(1))
cls._api_key = match.group(1)
break
await asyncio.sleep(1)
while True:
@ -577,10 +585,11 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
break
await asyncio.sleep(1)
RequestConfig.data_build = await page.evaluate("document.documentElement.getAttribute('data-build')")
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
RequestConfig.cookies[c.name] = c.value
for c in await page.send(get_cookies([cls.url])):
RequestConfig.cookies[c["name"]] = c["value"]
await page.close()
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers, user_agent=user_agent)
cls._set_api_key(cls._api_key)
@staticmethod
def get_default_headers() -> dict:
@ -624,3 +633,16 @@ class Conversation(BaseConversation):
self.message_id = message_id
self.finish_reason = finish_reason
self.is_recipient = False
def get_cookies(
urls: list[str] = None
):
params = dict()
if urls is not None:
params['urls'] = [i for i in urls]
cmd_dict = {
'method': 'Network.getCookies',
'params': params,
}
json = yield cmd_dict
return json['cookies']

View File

@ -81,6 +81,7 @@ body:not(.white) a:visited{
transform: translate(-50%, -50%);
filter: blur(var(--blur)) opacity(var(--opacity));
animation: zoom_gradient 6s infinite alternate;
display: none;
}
@keyframes zoom_gradient {
@ -116,6 +117,8 @@ body:not(.white) a:visited{
font-weight: 500;
background-color: rgba(0, 0, 0, 0.5);
color: var(--colour-3);
border: var(--colour-1) 1px solid;
border-radius: var(--border-radius-1);
}
.white .new_version {
@ -174,10 +177,6 @@ body:not(.white) a:visited{
color: var(--user-input)
}
body.white .gradient{
display: none;
}
.conversations {
display: flex;
flex-direction: column;
@ -826,6 +825,9 @@ select:hover,
.count_total {
padding-left: 98px;
}
body:not(.white) .gradient{
display: block;
}
}
.input-box {

View File

@ -601,6 +601,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
api_key: api_key,
ignored: ignored,
}, files, message_id);
if (content_map.inner.dataset.timeout) clearTimeout(content_map.inner.dataset.timeout);
if (!error_storage[message_id]) {
html = markdown_render(message_storage[message_id]);
content_map.inner.innerHTML = html;
@ -629,10 +630,9 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
regenerate
);
await safe_load_conversation(window.conversation_id, message_index == -1);
} else {
}
let cursorDiv = message_el.querySelector(".cursor");
if (cursorDiv) cursorDiv.parentNode.removeChild(cursorDiv);
}
if (message_index == -1) {
await scroll_to_bottom();
}

View File

@ -181,7 +181,7 @@ class Api:
def handle_provider(self, provider_handler, model):
if isinstance(provider_handler, IterListProvider) and provider_handler.last_provider is not None:
provider_handler = provider_handler.last_provider
if hasattr(provider_handler, "last_model") and provider_handler.last_model is not None:
if not model and hasattr(provider_handler, "last_model") and provider_handler.last_model is not None:
model = provider_handler.last_model
return self._format_json("provider", {**provider_handler.get_dict(), "model": model})