Fix: ChromeDriver only supports characters in the BMP

Add set_cookies helper, Show last used model
This commit is contained in:
Heiner Lohaus 2024-01-26 12:49:52 +01:00
parent feb83c168b
commit 1eb7dc05e5
18 changed files with 75 additions and 63 deletions

View File

@ -1,4 +1,3 @@
from . import include
import unittest import unittest
from unittest.mock import MagicMock from unittest.mock import MagicMock
from .mocks import ProviderMock from .mocks import ProviderMock
@ -13,7 +12,7 @@ class TestBackendApi(unittest.TestCase):
def setUp(self): def setUp(self):
if not has_requirements: if not has_requirements:
self.skipTest('"flask" not installed') self.skipTest("gui is not installed")
self.app = MagicMock() self.app = MagicMock()
self.api = Backend_Api(self.app) self.api = Backend_Api(self.app)
@ -36,7 +35,7 @@ class TestUtilityFunctions(unittest.TestCase):
def setUp(self): def setUp(self):
if not has_requirements: if not has_requirements:
self.skipTest('"flask" not installed') self.skipTest("gui is not installed")
def test_get_error_message(self): def test_get_error_message(self):
g4f.debug.last_provider = ProviderMock g4f.debug.last_provider = ProviderMock

View File

@ -10,6 +10,7 @@ from .helper import get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages, Union from ..typing import CreateResult, AsyncResult, Messages, Union
from ..base_provider import BaseProvider from ..base_provider import BaseProvider
from ..errors import NestAsyncioError, ModelNotSupportedError from ..errors import NestAsyncioError, ModelNotSupportedError
from .. import debug
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
NoneType = type(None) NoneType = type(None)
@ -266,9 +267,10 @@ class ProviderModelMixin:
@classmethod @classmethod
def get_model(cls, model: str) -> str: def get_model(cls, model: str) -> str:
if not model: if not model:
return cls.default_model model = cls.default_model
elif model in cls.model_aliases: elif model in cls.model_aliases:
return cls.model_aliases[model] model = cls.model_aliases[model]
elif model not in cls.get_models(): elif model not in cls.get_models():
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
debug.last_model = model
return model return model

View File

@ -202,16 +202,15 @@ class CreateImagesBing:
Yields: Yields:
Generator[str, None, None]: The final output as markdown formatted string with images. Generator[str, None, None]: The final output as markdown formatted string with images.
""" """
try: cookies = self.cookies or get_cookies(".bing.com", False)
cookies = self.cookies or get_cookies(".bing.com")
except MissingRequirementsError as e:
raise MissingAccessToken(f'Missing "_U" cookie. {e}')
if "_U" not in cookies: if "_U" not in cookies:
login_url = os.environ.get("G4F_LOGIN_URL") login_url = os.environ.get("G4F_LOGIN_URL")
if login_url: if login_url:
yield f"Please login: [Bing]({login_url})\n\n" yield f"Please login: [Bing]({login_url})\n\n"
self.cookies = get_cookies_from_browser(self.proxy) try:
self.cookies = get_cookies_from_browser(self.proxy)
except MissingRequirementsError as e:
raise MissingAccessToken(f'Missing "_U" cookie. {e}')
yield asyncio.run(self.create_async(prompt)) yield asyncio.run(self.create_async(prompt))
async def create_async(self, prompt: str) -> ImageResponse: async def create_async(self, prompt: str) -> ImageResponse:
@ -224,10 +223,7 @@ class CreateImagesBing:
Returns: Returns:
str: Markdown formatted string with images. str: Markdown formatted string with images.
""" """
try: cookies = self.cookies or get_cookies(".bing.com", False)
cookies = self.cookies or get_cookies(".bing.com")
except MissingRequirementsError as e:
raise MissingAccessToken(f'Missing "_U" cookie. {e}')
if "_U" not in cookies: if "_U" not in cookies:
raise MissingAccessToken('Missing "_U" cookie') raise MissingAccessToken('Missing "_U" cookie')
proxy = os.environ.get("G4F_PROXY") proxy = os.environ.get("G4F_PROXY")

View File

@ -21,7 +21,7 @@ try:
except ImportError: except ImportError:
has_browser_cookie3 = False has_browser_cookie3 = False
from ..typing import Dict, Messages, Optional from ..typing import Dict, Messages, Cookies, Optional
from ..errors import AiohttpSocksError, MissingRequirementsError from ..errors import AiohttpSocksError, MissingRequirementsError
from .. import debug from .. import debug
@ -48,6 +48,12 @@ def get_cookies(domain_name: str = '', raise_requirements_error: bool = True) ->
_cookies[domain_name] = cookies _cookies[domain_name] = cookies
return cookies return cookies
def set_cookies(domain_name: str, cookies: Cookies = None) -> None:
if cookies:
_cookies[domain_name] = cookies
else:
_cookies.pop(domain_name)
def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool = True) -> Dict[str, str]: def load_cookies_from_browsers(domain_name: str, raise_requirements_error: bool = True) -> Dict[str, str]:
""" """
Helper function to load cookies from various browsers. Helper function to load cookies from various browsers.

View File

@ -7,14 +7,14 @@ try:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.keys import Keys
except ImportError: except ImportError:
pass pass
from ...typing import CreateResult, Messages from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider from ..base_provider import AbstractProvider
from ..helper import format_prompt from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession from ...webdriver import WebDriver, WebDriverSession, element_send_text
class Bard(AbstractProvider): class Bard(AbstractProvider):
url = "https://bard.google.com" url = "https://bard.google.com"
@ -68,13 +68,7 @@ XMLHttpRequest.prototype.open = function(method, url) {
""" """
driver.execute_script(script) driver.execute_script(script)
textarea = driver.find_element(By.CSS_SELECTOR, "div.ql-editor.textarea") element_send_text(driver.find_element(By.CSS_SELECTOR, "div.ql-editor.textarea"), prompt)
lines = prompt.splitlines()
for idx, line in enumerate(lines):
textarea.send_keys(line)
if (len(lines) - 1 != idx):
textarea.send_keys(Keys.SHIFT + "\n")
textarea.send_keys(Keys.ENTER)
while True: while True:
chunk = driver.execute_script("return window._message;") chunk = driver.execute_script("return window._message;")

View File

@ -16,9 +16,8 @@ try:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
has_webdriver = True
except ImportError: except ImportError:
has_webdriver = False pass
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies from ..helper import format_prompt, get_cookies
@ -332,13 +331,14 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
cookies = cls._cookies or get_cookies("chat.openai.com", False) cookies = cls._cookies or get_cookies("chat.openai.com", False)
if not access_token and "access_token" in cookies: if not access_token and "access_token" in cookies:
access_token = cookies["access_token"] access_token = cookies["access_token"]
if not access_token and not has_webdriver:
raise MissingAccessToken(f'Missing "access_token"')
if not access_token: if not access_token:
login_url = os.environ.get("G4F_LOGIN_URL") login_url = os.environ.get("G4F_LOGIN_URL")
if login_url: if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n" yield f"Please login: [ChatGPT]({login_url})\n\n"
access_token, cookies = cls.browse_access_token(proxy) try:
access_token, cookies = cls.browse_access_token(proxy)
except MissingRequirementsError:
raise MissingAccessToken(f'Missing "access_token"')
cls._cookies = cookies cls._cookies = cookies
headers = {"Authorization": f"Bearer {access_token}"} headers = {"Authorization": f"Bearer {access_token}"}

View File

@ -5,7 +5,7 @@ import time
from ...typing import CreateResult, Messages from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider from ..base_provider import AbstractProvider
from ..helper import format_prompt from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession from ...webdriver import WebDriver, WebDriverSession, element_send_text
models = { models = {
"meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"}, "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"},
@ -89,7 +89,7 @@ class Poe(AbstractProvider):
else: else:
raise RuntimeError("Prompt textarea not found. You may not be logged in.") raise RuntimeError("Prompt textarea not found. You may not be logged in.")
driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']").send_keys(prompt) element_send_text(driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']"), prompt)
driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click() driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click()
script = """ script = """

View File

@ -5,7 +5,7 @@ import time
from ...typing import CreateResult, Messages from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider from ..base_provider import AbstractProvider
from ..helper import format_prompt from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession from ...webdriver import WebDriver, WebDriverSession, element_send_text
models = { models = {
"theb-ai": "TheB.AI", "theb-ai": "TheB.AI",
@ -118,8 +118,7 @@ window._last_message = "";
# Submit prompt # Submit prompt
wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize"))) wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
driver.find_element(By.ID, "textareaAutosize").send_keys(prompt) element_send_text(driver.find_element(By.ID, "textareaAutosize"), prompt)
driver.find_element(By.ID, "textareaAutosize").send_keys(Keys.ENTER)
# Read response with reader # Read response with reader
script = """ script = """

View File

@ -6,7 +6,7 @@ import random
from ...typing import CreateResult, Messages from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider from ..base_provider import AbstractProvider
from ..helper import format_prompt, get_random_string from ..helper import format_prompt, get_random_string
from ...webdriver import WebDriver, WebDriverSession from ...webdriver import WebDriver, WebDriverSession, element_send_text
from ... import debug from ... import debug
class AItianhuSpace(AbstractProvider): class AItianhuSpace(AbstractProvider):
@ -91,8 +91,7 @@ XMLHttpRequest.prototype.open = function(method, url) {
driver.execute_script(script) driver.execute_script(script)
# Submit prompt # Submit prompt
driver.find_element(By.CSS_SELECTOR, "textarea.n-input__textarea-el").send_keys(prompt) element_send_text(driver.find_element(By.CSS_SELECTOR, "textarea.n-input__textarea-el"), prompt)
driver.find_element(By.CSS_SELECTOR, "button.n-button.n-button--primary-type.n-button--medium-type").click()
# Read response # Read response
while True: while True:

View File

@ -6,14 +6,13 @@ try:
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.keys import Keys
except ImportError: except ImportError:
pass pass
from ...typing import CreateResult, Messages from ...typing import CreateResult, Messages
from ..base_provider import AbstractProvider from ..base_provider import AbstractProvider
from ..helper import format_prompt from ..helper import format_prompt
from ...webdriver import WebDriver, WebDriverSession from ...webdriver import WebDriver, WebDriverSession, element_send_text
class PerplexityAi(AbstractProvider): class PerplexityAi(AbstractProvider):
url = "https://www.perplexity.ai" url = "https://www.perplexity.ai"
@ -83,8 +82,7 @@ WebSocket.prototype.send = function(...args) {
raise RuntimeError("You need a account for copilot") raise RuntimeError("You need a account for copilot")
# Submit prompt # Submit prompt
driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(prompt) element_send_text(driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']"), prompt)
driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(Keys.ENTER)
# Stream response # Stream response
script = """ script = """

View File

@ -42,7 +42,7 @@ class TalkAi(AbstractProvider):
"content": message["content"] "content": message["content"]
} for message in messages], } for message in messages],
"model": model if model else "gpt-3.5-turbo", "model": model if model else "gpt-3.5-turbo",
"max_tokens": 256, "max_tokens": 2048,
"temperature": 1, "temperature": 1,
"top_p": 1, "top_p": 1,
"presence_penalty": 0, "presence_penalty": 0,
@ -67,16 +67,15 @@ window._reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
while True: while True:
chunk = driver.execute_script(""" chunk = driver.execute_script("""
chunk = await window._reader.read(); chunk = await window._reader.read();
if (chunk["done"]) { if (chunk.done) {
return null; return null;
} }
content = ""; content = "";
lines = chunk["value"].split("\\n") for (line of chunk.value.split("\\n")) {
lines.forEach((line, index) => {
if (line.startsWith('data: ')) { if (line.startsWith('data: ')) {
content += line.substring('data: '.length); content += line.substring('data: '.length);
} }
}); }
return content; return content;
""") """)
if chunk: if chunk:

View File

@ -3,11 +3,13 @@ from __future__ import annotations
import os import os
from .errors import * from .errors import *
from .models import Model, ModelUtils, _all_models from .models import Model, ModelUtils
from .Provider import AsyncGeneratorProvider, ProviderUtils from .Provider import AsyncGeneratorProvider, ProviderUtils
from .typing import Messages, CreateResult, AsyncResult, Union from .typing import Messages, CreateResult, AsyncResult, Union
from . import debug, version from . import debug, version
from .base_provider import BaseRetryProvider, ProviderType from .base_provider import BaseRetryProvider, ProviderType
from .Provider.helper import get_cookies, set_cookies
from .Provider.base_provider import ProviderModelMixin
def get_model_and_provider(model : Union[Model, str], def get_model_and_provider(model : Union[Model, str],
provider : Union[ProviderType, str, None], provider : Union[ProviderType, str, None],
@ -76,6 +78,7 @@ def get_model_and_provider(model : Union[Model, str],
print(f'Using {provider.__name__} provider') print(f'Using {provider.__name__} provider')
debug.last_provider = provider debug.last_provider = provider
debug.last_model = model
return model, provider return model, provider
@ -227,5 +230,10 @@ def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, st
if isinstance(last, BaseRetryProvider): if isinstance(last, BaseRetryProvider):
last = last.last_provider last = last.last_provider
if last and as_dict: if last and as_dict:
return {"name": last.__name__, "url": last.url} return {
"name": last.__name__,
"url": last.url,
"model": debug.last_model,
"models": last.models if isinstance(last, ProviderModelMixin) else []
}
return last return last

View File

@ -2,4 +2,5 @@ from .base_provider import ProviderType
logging: bool = False logging: bool = False
version_check: bool = True version_check: bool = True
last_provider: ProviderType = None last_provider: ProviderType = None
last_model: str = None

View File

@ -31,7 +31,7 @@ class NestAsyncioError(Exception):
class ModelNotSupportedError(Exception): class ModelNotSupportedError(Exception):
pass pass
class MissingRequirementsError(Exception): class MissingRequirementsError(ImportError):
pass pass
class AiohttpSocksError(MissingRequirementsError): class AiohttpSocksError(MissingRequirementsError):

View File

@ -1,12 +1,12 @@
from ..errors import MissingRequirementsError
try: try:
from .server.app import app from .server.app import app
from .server.website import Website from .server.website import Website
from .server.backend import Backend_Api from .server.backend import Backend_Api
except ImportError: except ImportError:
from g4f.errors import MissingRequirementsError raise MissingRequirementsError('Install "flask" package for the gui')
raise MissingRequirementsError('Install "flask" and "werkzeug" package for gui')
def run_gui(host: str = '0.0.0.0', port: int = 80, debug: bool = False) -> None: def run_gui(host: str = '0.0.0.0', port: int = 8080, debug: bool = False) -> None:
config = { config = {
'host' : host, 'host' : host,
'port' : port, 'port' : port,

View File

@ -164,12 +164,16 @@ const ask_gpt = async () => {
for (const line of value.split("\n")) { for (const line of value.split("\n")) {
if (!line) continue; if (!line) continue;
const message = JSON.parse(line); const message = JSON.parse(line);
if (message["type"] == "content") { if (message.type == "content") {
text += message["content"]; text += message.content;
} else if (message["type"] == "provider") { } else if (message["type"] == "provider") {
provider = message["provider"]; provider = message.provider
content.querySelector('.provider').innerHTML = content.querySelector('.provider').innerHTML = `
'<a href="' + provider.url + '" target="_blank">' + provider.name + "</a>" <a href="${provider.url}" target="_blank">
${provider.name}
</a>
${provider.model ? ' with ' + provider.model : ''}
`
} else if (message["type"] == "error") { } else if (message["type"] == "error") {
error = message["error"]; error = message["error"];
} else if (message["type"] == "message") { } else if (message["type"] == "message") {

View File

@ -3,7 +3,7 @@ import json
from flask import request, Flask from flask import request, Flask
from typing import Generator from typing import Generator
from g4f import version, models from g4f import version, models
from g4f import _all_models, get_last_provider, ChatCompletion from g4f import get_last_provider, ChatCompletion
from g4f.image import is_allowed_extension, to_image from g4f.image import is_allowed_extension, to_image
from g4f.errors import VersionNotFoundError from g4f.errors import VersionNotFoundError
from g4f.Provider import __providers__ from g4f.Provider import __providers__
@ -76,7 +76,7 @@ class Backend_Api:
Returns: Returns:
List[str]: A list of model names. List[str]: A list of model names.
""" """
return _all_models return models._all_models
def get_providers(self): def get_providers(self):
""" """

View File

@ -3,13 +3,15 @@ from __future__ import annotations
try: try:
from platformdirs import user_config_dir from platformdirs import user_config_dir
from selenium.webdriver.remote.webdriver import WebDriver from selenium.webdriver.remote.webdriver import WebDriver
from selenium.webdriver.remote.webelement import WebElement
from undetected_chromedriver import Chrome, ChromeOptions from undetected_chromedriver import Chrome, ChromeOptions
from selenium.webdriver.common.by import By from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.keys import Keys
has_requirements = True has_requirements = True
except ImportError: except ImportError:
WebDriver = type from typing import Type as WebDriver
has_requirements = False has_requirements = False
from os import path from os import path
@ -197,4 +199,9 @@ class WebDriverSession:
print(f"Error closing WebDriver: {e}") print(f"Error closing WebDriver: {e}")
self.default_driver.quit() self.default_driver.quit()
if self.virtual_display: if self.virtual_display:
self.virtual_display.stop() self.virtual_display.stop()
def element_send_text(element: WebElement, text: str) -> None:
script = "arguments[0].innerText = arguments[1]"
element.parent.execute_script(script, element, text)
element.send_keys(Keys.ENTER)