Merge pull request #1468 from hlohaus/upp

Refactor code with AI
This commit is contained in:
H Lohaus 2024-01-14 15:32:51 +01:00 committed by GitHub
commit 1ca80ed48b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1839 additions and 629 deletions

19
.github/workflows/unittest.yml vendored Normal file
View File

@ -0,0 +1,19 @@
name: Unittest
on: [push]
jobs:
build:
name: Build unittest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.x"
cache: 'pip'
- name: Install requirements
- run: pip install -r requirements.txt
- name: Run tests
run: python -m etc.unittest.main

73
etc/unittest/main.py Normal file
View File

@ -0,0 +1,73 @@
import sys
import pathlib
import unittest
from unittest.mock import MagicMock
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
import g4f
from g4f import ChatCompletion, get_last_provider
from g4f.gui.server.backend import Backend_Api, get_error_message
from g4f.base_provider import BaseProvider
g4f.debug.logging = False
class MockProvider(BaseProvider):
working = True
def create_completion(
model, messages, stream, **kwargs
):
yield "Mock"
async def create_async(
model, messages, **kwargs
):
return "Mock"
class TestBackendApi(unittest.TestCase):
def setUp(self):
self.app = MagicMock()
self.api = Backend_Api(self.app)
def test_version(self):
response = self.api.get_version()
self.assertIn("version", response)
self.assertIn("latest_version", response)
class TestChatCompletion(unittest.TestCase):
def test_create(self):
messages = [{'role': 'user', 'content': 'Hello'}]
result = ChatCompletion.create(g4f.models.default, messages)
self.assertTrue("Hello" in result or "Good" in result)
def test_get_last_provider(self):
messages = [{'role': 'user', 'content': 'Hello'}]
ChatCompletion.create(g4f.models.default, messages, MockProvider)
self.assertEqual(get_last_provider(), MockProvider)
def test_bing_provider(self):
messages = [{'role': 'user', 'content': 'Hello'}]
provider = g4f.Provider.Bing
result = ChatCompletion.create(g4f.models.default, messages, provider)
self.assertTrue("Bing" in result)
class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):
async def test_async(self):
messages = [{'role': 'user', 'content': 'Hello'}]
result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider)
self.assertTrue("Mock" in result)
class TestUtilityFunctions(unittest.TestCase):
def test_get_error_message(self):
g4f.debug.last_provider = g4f.Provider.Bing
exception = Exception("Message")
result = get_error_message(exception)
self.assertEqual("Bing: Exception: Message", result)
if __name__ == '__main__':
unittest.main()

View File

@ -15,12 +15,18 @@ from .bing.upload_image import upload_image
from .bing.create_images import create_images
from .bing.conversation import Conversation, create_conversation, delete_conversation
class Tones():
class Tones:
"""
Defines the different tone options for the Bing provider.
"""
creative = "Creative"
balanced = "Balanced"
precise = "Precise"
class Bing(AsyncGeneratorProvider):
"""
Bing provider for generating responses using the Bing API.
"""
url = "https://bing.com/chat"
working = True
supports_message_history = True
@ -38,6 +44,19 @@ class Bing(AsyncGeneratorProvider):
web_search: bool = False,
**kwargs
) -> AsyncResult:
"""
Creates an asynchronous generator for producing responses from Bing.
:param model: The model to use.
:param messages: Messages to process.
:param proxy: Proxy to use for requests.
:param timeout: Timeout for requests.
:param cookies: Cookies for the session.
:param tone: The tone of the response.
:param image: The image type to be used.
:param web_search: Flag to enable or disable web search.
:return: An asynchronous result object.
"""
if len(messages) < 2:
prompt = messages[0]["content"]
context = None
@ -56,65 +75,48 @@ class Bing(AsyncGeneratorProvider):
return stream_generate(prompt, tone, image, context, proxy, cookies, web_search, gpt4_turbo, timeout)
def create_context(messages: Messages):
def create_context(messages: Messages) -> str:
"""
Creates a context string from a list of messages.
:param messages: A list of message dictionaries.
:return: A string representing the context created from the messages.
"""
return "".join(
f"[{message['role']}]" + ("(#message)" if message['role']!="system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
for message in messages
)
class Defaults:
"""
Default settings and configurations for the Bing provider.
"""
delimiter = "\x1e"
ip_address = f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
# List of allowed message types for Bing responses
allowedMessageTypes = [
"ActionRequest",
"Chat",
"Context",
# "Disengaged", unwanted
"Progress",
# "AdsQuery", unwanted
"SemanticSerp",
"GenerateContentQuery",
"SearchQuery",
# The following message types should not be added so that it does not flood with
# useless messages (such as "Analyzing images" or "Searching the web") while it's retrieving the AI response
# "InternalSearchQuery",
# "InternalSearchResult",
"RenderCardRequest",
# "RenderContentRequest"
"ActionRequest", "Chat", "Context", "Progress", "SemanticSerp",
"GenerateContentQuery", "SearchQuery", "RenderCardRequest"
]
sliceIds = [
'abv2',
'srdicton',
'convcssclick',
'stylewv2',
'contctxp2tf',
'802fluxv1pc_a',
'806log2sphs0',
'727savemem',
'277teditgnds0',
'207hlthgrds0',
'abv2', 'srdicton', 'convcssclick', 'stylewv2', 'contctxp2tf',
'802fluxv1pc_a', '806log2sphs0', '727savemem', '277teditgnds0', '207hlthgrds0'
]
# Default location settings
location = {
"locale": "en-US",
"market": "en-US",
"region": "US",
"locationHints": [
{
"country": "United States",
"state": "California",
"city": "Los Angeles",
"timezoneoffset": 8,
"countryConfidence": 8,
"locale": "en-US", "market": "en-US", "region": "US",
"locationHints": [{
"country": "United States", "state": "California", "city": "Los Angeles",
"timezoneoffset": 8, "countryConfidence": 8,
"Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
"RegionType": 2,
"SourceType": 1,
}
],
"RegionType": 2, "SourceType": 1
}],
}
# Default headers for requests
headers = {
'accept': '*/*',
'accept-language': 'en-US,en;q=0.9',
@ -139,23 +141,13 @@ class Defaults:
}
optionsSets = [
'nlu_direct_response_filter',
'deepleo',
'disable_emoji_spoken_text',
'responsible_ai_policy_235',
'enablemm',
'iyxapbing',
'iycapbing',
'gencontentv3',
'fluxsrtrunc',
'fluxtrunc',
'fluxv1',
'rai278',
'replaceurl',
'eredirecturl',
'nojbfedge'
'nlu_direct_response_filter', 'deepleo', 'disable_emoji_spoken_text',
'responsible_ai_policy_235', 'enablemm', 'iyxapbing', 'iycapbing',
'gencontentv3', 'fluxsrtrunc', 'fluxtrunc', 'fluxv1', 'rai278',
'replaceurl', 'eredirecturl', 'nojbfedge'
]
# Default cookies
cookies = {
'SRCHD' : 'AF=NOFORM',
'PPLState' : '1',
@ -166,6 +158,12 @@ class Defaults:
}
def format_message(msg: dict) -> str:
"""
Formats a message dictionary into a JSON string with a delimiter.
:param msg: The message dictionary to format.
:return: A formatted string representation of the message.
"""
return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter
def create_message(
@ -177,7 +175,20 @@ def create_message(
web_search: bool = False,
gpt4_turbo: bool = False
) -> str:
"""
Creates a message for the Bing API with specified parameters.
:param conversation: The current conversation object.
:param prompt: The user's input prompt.
:param tone: The desired tone for the response.
:param context: Additional context for the prompt.
:param image_response: The response if an image is involved.
:param web_search: Flag to enable web search.
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
:return: A formatted string message for the Bing API.
"""
options_sets = Defaults.optionsSets
# Append tone-specific options
if tone == Tones.creative:
options_sets.append("h3imaginative")
elif tone == Tones.precise:
@ -187,53 +198,48 @@ def create_message(
else:
options_sets.append("harmonyv3")
# Additional configurations based on parameters
if not web_search:
options_sets.append("nosearchall")
if gpt4_turbo:
options_sets.append("dlgpt4t")
request_id = str(uuid.uuid4())
struct = {
'arguments': [
{
'source': 'cib',
'optionsSets': options_sets,
'arguments': [{
'source': 'cib', 'optionsSets': options_sets,
'allowedMessageTypes': Defaults.allowedMessageTypes,
'sliceIds': Defaults.sliceIds,
'traceId': os.urandom(16).hex(),
'isStartOfSession': True,
'traceId': os.urandom(16).hex(), 'isStartOfSession': True,
'requestId': request_id,
'message': {**Defaults.location, **{
'message': {
**Defaults.location,
'author': 'user',
'inputMethod': 'Keyboard',
'text': prompt,
'messageType': 'Chat',
'requestId': request_id,
'messageId': request_id,
}},
'messageId': request_id
},
"verbosity": "verbose",
"scenario": "SERP",
"plugins":[
{"id":"c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}
] if web_search else [],
"plugins": [{"id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}] if web_search else [],
'tone': tone,
'spokenTextMode': 'None',
'conversationId': conversation.conversationId,
'participant': {
'id': conversation.clientId
},
}
],
'participant': {'id': conversation.clientId},
}],
'invocationId': '1',
'target': 'chat',
'type': 4
}
if image_response.get('imageUrl') and image_response.get('originalImageUrl'):
if image_response and image_response.get('imageUrl') and image_response.get('originalImageUrl'):
struct['arguments'][0]['message']['originalImageUrl'] = image_response.get('originalImageUrl')
struct['arguments'][0]['message']['imageUrl'] = image_response.get('imageUrl')
struct['arguments'][0]['experienceType'] = None
struct['arguments'][0]['attachedFileInfo'] = {"fileName": None, "fileType": None}
if context:
struct['arguments'][0]['previousMessages'] = [{
"author": "user",
@ -242,6 +248,7 @@ def create_message(
"messageType": "Context",
"messageId": "discover-web--page-ping-mriduna-----"
}]
return format_message(struct)
async def stream_generate(
@ -254,18 +261,33 @@ async def stream_generate(
web_search: bool = False,
gpt4_turbo: bool = False,
timeout: int = 900
):
):
"""
Asynchronously streams generated responses from the Bing API.
:param prompt: The user's input prompt.
:param tone: The desired tone for the response.
:param image: The image type involved in the response.
:param context: Additional context for the prompt.
:param proxy: Proxy settings for the request.
:param cookies: Cookies for the session.
:param web_search: Flag to enable web search.
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
:param timeout: Timeout for the request.
:return: An asynchronous generator yielding responses.
"""
headers = Defaults.headers
if cookies:
headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
async with ClientSession(
timeout=ClientTimeout(total=timeout),
headers=headers
timeout=ClientTimeout(total=timeout), headers=headers
) as session:
conversation = await create_conversation(session, proxy)
image_response = await upload_image(session, image, tone, proxy) if image else None
if image_response:
yield image_response
try:
async with session.ws_connect(
'wss://sydney.bing.com/sydney/ChatHub',
@ -289,7 +311,7 @@ async def stream_generate(
if obj is None or not obj:
continue
response = json.loads(obj)
if response.get('type') == 1 and response['arguments'][0].get('messages'):
if response and response.get('type') == 1 and response['arguments'][0].get('messages'):
message = response['arguments'][0]['messages'][0]
image_response = None
if (message['contentOrigin'] != 'Apology'):

View File

@ -1,16 +1,20 @@
from __future__ import annotations
import json
import json, random
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider
models = {
"claude-v2": "claude-2.0",
"claude-v2.1":"claude-2.1",
"gemini-pro": "google-gemini-pro"
}
urls = [
"https://free.chatgpt.org.uk",
"https://ai.chatgpt.org.uk"
]
class FreeChatgpt(AsyncGeneratorProvider):
url = "https://free.chatgpt.org.uk"
@ -31,6 +35,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
model = models[model]
elif not model:
model = "gpt-3.5-turbo"
url = random.choice(urls)
headers = {
"Accept": "application/json, text/event-stream",
"Content-Type":"application/json",
@ -55,7 +60,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
"top_p":1,
**kwargs
}
async with session.post(f'{cls.url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
async with session.post(f'{url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
response.raise_for_status()
started = False
async for line in response.content:

View File

@ -59,12 +59,16 @@ class Phind(AsyncGeneratorProvider):
"rewrittenQuestion": prompt,
"challenge": 0.21132115912208504
}
async with session.post(f"{cls.url}/api/infer/followup/answer", headers=headers, json=data) as response:
async with session.post(f"https://https.api.phind.com/infer/", headers=headers, json=data) as response:
new_line = False
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
if chunk.startswith(b'<PHIND_DONE/>'):
break
if chunk.startswith(b'<PHIND_WEBRESULTS>') or chunk.startswith(b'<PHIND_FOLLOWUP>'):
pass
elif chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
pass
elif chunk:
yield chunk.decode()

View File

@ -1,5 +1,4 @@
from __future__ import annotations
import sys
import asyncio
from asyncio import AbstractEventLoop
@ -15,14 +14,16 @@ if sys.version_info < (3, 10):
else:
from types import NoneType
# Change event loop policy on windows for curl_cffi
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32':
if isinstance(
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
):
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
"""
@classmethod
async def create_async(
cls,
@ -33,62 +34,67 @@ class AbstractProvider(BaseProvider):
executor: ThreadPoolExecutor = None,
**kwargs
) -> str:
if not loop:
loop = get_event_loop()
"""
Asynchronously creates a result based on the given model and messages.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
"""
loop = loop or get_event_loop()
def create_func() -> str:
return "".join(cls.create_completion(
model,
messages,
False,
**kwargs
))
return "".join(cls.create_completion(model, messages, False, **kwargs))
return await asyncio.wait_for(
loop.run_in_executor(
executor,
create_func
),
loop.run_in_executor(executor, create_func),
timeout=kwargs.get("timeout", 0)
)
@classmethod
@property
def params(cls) -> str:
if issubclass(cls, AsyncGeneratorProvider):
sig = signature(cls.create_async_generator)
elif issubclass(cls, AsyncProvider):
sig = signature(cls.create_async)
else:
sig = signature(cls.create_completion)
"""
Returns the parameters supported by the provider.
Args:
cls (type): The class on which this property is called.
Returns:
str: A string listing the supported parameters.
"""
sig = signature(
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
cls.create_async if issubclass(cls, AsyncProvider) else
cls.create_completion
)
def get_type_name(annotation: type) -> str:
if hasattr(annotation, "__name__"):
annotation = annotation.__name__
elif isinstance(annotation, NoneType):
annotation = "None"
return str(annotation)
return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
args = ""
for name, param in sig.parameters.items():
if name in ("self", "kwargs"):
if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
continue
if name == "stream" and not cls.supports_stream:
continue
if args:
args += ", "
args += "\n " + name
if name != "model" and param.annotation is not Parameter.empty:
args += f": {get_type_name(param.annotation)}"
if param.default == "":
args += ' = ""'
elif param.default is not Parameter.empty:
args += f" = {param.default}"
args += f"\n {name}"
args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
class AsyncProvider(AbstractProvider):
"""
Provides asynchronous functionality for creating completions.
"""
@classmethod
def create_completion(
cls,
@ -99,8 +105,21 @@ class AsyncProvider(AbstractProvider):
loop: AbstractEventLoop = None,
**kwargs
) -> CreateResult:
if not loop:
loop = get_event_loop()
"""
Creates a completion result synchronously.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to False.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
loop = loop or get_event_loop()
coro = cls.create_async(model, messages, **kwargs)
yield loop.run_until_complete(coro)
@ -111,10 +130,27 @@ class AsyncProvider(AbstractProvider):
messages: Messages,
**kwargs
) -> str:
"""
Abstract method for creating asynchronous results.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Raises:
NotImplementedError: If this method is not overridden in derived classes.
Returns:
str: The created result as a string.
"""
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
"""
Provides asynchronous generator functionality for streaming results.
"""
supports_stream = True
@classmethod
@ -127,15 +163,24 @@ class AsyncGeneratorProvider(AsyncProvider):
loop: AbstractEventLoop = None,
**kwargs
) -> CreateResult:
if not loop:
loop = get_event_loop()
generator = cls.create_async_generator(
model,
messages,
stream=stream,
**kwargs
)
"""
Creates a streaming completion result synchronously.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to True.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the streaming completion creation.
"""
loop = loop or get_event_loop()
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__()
while True:
try:
yield loop.run_until_complete(gen.__anext__())
@ -149,21 +194,44 @@ class AsyncGeneratorProvider(AsyncProvider):
messages: Messages,
**kwargs
) -> str:
"""
Asynchronously creates a result from a generator.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
"""
return "".join([
chunk async for chunk in cls.create_async_generator(
model,
messages,
stream=False,
**kwargs
) if not isinstance(chunk, Exception)
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, Exception)
])
@staticmethod
@abstractmethod
def create_async_generator(
async def create_async_generator(
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> AsyncResult:
"""
Abstract method for creating an asynchronous generator.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to True.
**kwargs: Additional keyword arguments.
Raises:
NotImplementedError: If this method is not overridden in derived classes.
Returns:
AsyncResult: An asynchronous generator yielding results.
"""
raise NotImplementedError()

View File

@ -1,13 +1,33 @@
from aiohttp import ClientSession
class Conversation():
class Conversation:
"""
Represents a conversation with specific attributes.
"""
def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None:
"""
Initialize a new conversation instance.
Args:
conversationId (str): Unique identifier for the conversation.
clientId (str): Client identifier.
conversationSignature (str): Signature for the conversation.
"""
self.conversationId = conversationId
self.clientId = clientId
self.conversationSignature = conversationSignature
async def create_conversation(session: ClientSession, proxy: str = None) -> Conversation:
"""
Create a new conversation asynchronously.
Args:
session (ClientSession): An instance of aiohttp's ClientSession.
proxy (str, optional): Proxy URL. Defaults to None.
Returns:
Conversation: An instance representing the created conversation.
"""
url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1199.4'
async with session.get(url, proxy=proxy) as response:
try:
@ -24,12 +44,32 @@ async def create_conversation(session: ClientSession, proxy: str = None) -> Conv
return Conversation(conversationId, clientId, conversationSignature)
async def list_conversations(session: ClientSession) -> list:
"""
List all conversations asynchronously.
Args:
session (ClientSession): An instance of aiohttp's ClientSession.
Returns:
list: A list of conversations.
"""
url = "https://www.bing.com/turing/conversation/chats"
async with session.get(url) as response:
response = await response.json()
return response["chats"]
async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> bool:
"""
Delete a conversation asynchronously.
Args:
session (ClientSession): An instance of aiohttp's ClientSession.
conversation (Conversation): The conversation to delete.
proxy (str, optional): Proxy URL. Defaults to None.
Returns:
bool: True if deletion was successful, False otherwise.
"""
url = "https://sydney.bing.com/sydney/DeleteSingleConversation"
json = {
"conversationId": conversation.conversationId,

View File

@ -1,9 +1,16 @@
"""
This module provides functionalities for creating and managing images using Bing's service.
It includes functions for user login, session creation, image creation, and processing.
"""
import asyncio
import time, json, os
import time
import json
import os
from aiohttp import ClientSession
from bs4 import BeautifulSoup
from urllib.parse import quote
from typing import Generator
from typing import Generator, List, Dict
from ..create_images import CreateImagesProvider
from ..helper import get_cookies, get_event_loop
@ -12,23 +19,47 @@ from ...base_provider import ProviderType
from ...image import format_images_markdown
BING_URL = "https://www.bing.com"
TIMEOUT_LOGIN = 1200
TIMEOUT_IMAGE_CREATION = 300
ERRORS = [
"this prompt is being reviewed",
"this prompt has been blocked",
"we're working hard to offer image creator in more languages",
"we can't create your images right now"
]
BAD_IMAGES = [
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
]
def wait_for_login(driver: WebDriver, timeout: int = 1200) -> None:
def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None:
"""
Waits for the user to log in within a given timeout period.
Args:
driver (WebDriver): Webdriver for browser automation.
timeout (int): Maximum waiting time in seconds.
Raises:
RuntimeError: If the login process exceeds the timeout.
"""
driver.get(f"{BING_URL}/")
value = driver.get_cookie("_U")
if value:
return
start_time = time.time()
while True:
while not driver.get_cookie("_U"):
if time.time() - start_time > timeout:
raise RuntimeError("Timeout error")
value = driver.get_cookie("_U")
if value:
time.sleep(1)
return
time.sleep(0.5)
def create_session(cookies: dict) -> ClientSession:
def create_session(cookies: Dict[str, str]) -> ClientSession:
"""
Creates a new client session with specified cookies and headers.
Args:
cookies (Dict[str, str]): Cookies to be used for the session.
Returns:
ClientSession: The created client session.
"""
headers = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"accept-encoding": "gzip, deflate, br",
@ -47,28 +78,32 @@ def create_session(cookies: dict) -> ClientSession:
"upgrade-insecure-requests": "1",
}
if cookies:
headers["cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
return ClientSession(headers=headers)
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = 300) -> list:
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = TIMEOUT_IMAGE_CREATION) -> List[str]:
"""
Creates images based on a given prompt using Bing's service.
Args:
session (ClientSession): Active client session.
prompt (str): Prompt to generate images.
proxy (str, optional): Proxy configuration.
timeout (int): Timeout for the request.
Returns:
List[str]: A list of URLs to the created images.
Raises:
RuntimeError: If image creation fails or times out.
"""
url_encoded_prompt = quote(prompt)
payload = f"q={url_encoded_prompt}&rt=4&FORM=GENCRE"
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
async with session.post(
url,
allow_redirects=False,
data=payload,
timeout=timeout,
) as response:
async with session.post(url, allow_redirects=False, data=payload, timeout=timeout) as response:
response.raise_for_status()
errors = [
"this prompt is being reviewed",
"this prompt has been blocked",
"we're working hard to offer image creator in more languages",
"we can't create your images right now"
]
text = (await response.text()).lower()
for error in errors:
for error in ERRORS:
if error in text:
raise RuntimeError(f"Create images failed: {error}")
if response.status != 302:
@ -107,54 +142,109 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
raise RuntimeError(error)
return read_images(text)
def read_images(text: str) -> list:
html_soup = BeautifulSoup(text, "html.parser")
tags = html_soup.find_all("img")
image_links = [img["src"] for img in tags if "mimg" in img["class"]]
images = [link.split("?w=")[0] for link in image_links]
bad_images = [
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
]
if any(im in bad_images for im in images):
def read_images(html_content: str) -> List[str]:
"""
Extracts image URLs from the HTML content.
Args:
html_content (str): HTML content containing image URLs.
Returns:
List[str]: A list of image URLs.
"""
soup = BeautifulSoup(html_content, "html.parser")
tags = soup.find_all("img", class_="mimg")
images = [img["src"].split("?w=")[0] for img in tags]
if any(im in BAD_IMAGES for im in images):
raise RuntimeError("Bad images found")
if not images:
raise RuntimeError("No images found")
return images
async def create_images_markdown(cookies: dict, prompt: str, proxy: str = None) -> str:
session = create_session(cookies)
try:
async def create_images_markdown(cookies: Dict[str, str], prompt: str, proxy: str = None) -> str:
"""
Creates markdown formatted string with images based on the prompt.
Args:
cookies (Dict[str, str]): Cookies to be used for the session.
prompt (str): Prompt to generate images.
proxy (str, optional): Proxy configuration.
Returns:
str: Markdown formatted string with images.
"""
async with create_session(cookies) as session:
images = await create_images(session, prompt, proxy)
return format_images_markdown(images, prompt)
finally:
await session.close()
def get_cookies_from_browser(proxy: str = None) -> dict:
driver = get_browser(proxy=proxy)
try:
def get_cookies_from_browser(proxy: str = None) -> Dict[str, str]:
"""
Retrieves cookies from the browser using webdriver.
Args:
proxy (str, optional): Proxy configuration.
Returns:
Dict[str, str]: Retrieved cookies.
"""
with get_browser(proxy=proxy) as driver:
wait_for_login(driver)
time.sleep(1)
return get_driver_cookies(driver)
finally:
driver.quit()
def create_completion(prompt: str, cookies: dict = None, proxy: str = None) -> Generator:
class CreateImagesBing:
"""A class for creating images using Bing."""
_cookies: Dict[str, str] = {}
@classmethod
def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str, None, None]:
"""
Generator for creating imagecompletion based on a prompt.
Args:
prompt (str): Prompt to generate images.
cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
proxy (str, optional): Proxy configuration.
Yields:
Generator[str, None, None]: The final output as markdown formatted string with images.
"""
loop = get_event_loop()
if not cookies:
cookies = get_cookies(".bing.com")
cookies = cookies or cls._cookies or get_cookies(".bing.com")
if "_U" not in cookies:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [Bing]({login_url})\n\n"
cookies = get_cookies_from_browser(proxy)
cls._cookies = cookies = get_cookies_from_browser(proxy)
yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
async def create_async(prompt: str, cookies: dict = None, proxy: str = None) -> str:
if not cookies:
cookies = get_cookies(".bing.com")
@classmethod
async def create_async(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> str:
"""
Asynchronously creates a markdown formatted string with images based on the prompt.
Args:
prompt (str): Prompt to generate images.
cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
proxy (str, optional): Proxy configuration.
Returns:
str: Markdown formatted string with images.
"""
cookies = cookies or cls._cookies or get_cookies(".bing.com")
if "_U" not in cookies:
cookies = get_cookies_from_browser(proxy)
cls._cookies = cookies = get_cookies_from_browser(proxy)
return await create_images_markdown(cookies, prompt, proxy)
def patch_provider(provider: ProviderType) -> CreateImagesProvider:
return CreateImagesProvider(provider, create_completion, create_async)
"""
Patches a provider to include image creation capabilities.
Args:
provider (ProviderType): The provider to be patched.
Returns:
CreateImagesProvider: The patched provider with image creation capabilities.
"""
return CreateImagesProvider(provider, CreateImagesBing.create_completion, CreateImagesBing.create_async)

View File

@ -1,64 +1,107 @@
from __future__ import annotations
"""
Module to handle image uploading and processing for Bing AI integrations.
"""
from __future__ import annotations
import string
import random
import json
import math
from ...typing import ImageType
from aiohttp import ClientSession
from PIL import Image
from ...typing import ImageType, Tuple
from ...image import to_image, process_image, to_base64, ImageResponse
image_config = {
IMAGE_CONFIG = {
"maxImagePixels": 360000,
"imageCompressionRate": 0.7,
"enableFaceBlurDebug": 0,
"enableFaceBlurDebug": False,
}
async def upload_image(
session: ClientSession,
image: ImageType,
image_data: ImageType,
tone: str,
proxy: str = None
) -> ImageResponse:
image = to_image(image)
width, height = image.size
max_image_pixels = image_config['maxImagePixels']
if max_image_pixels / (width * height) < 1:
new_width = int(width * math.sqrt(max_image_pixels / (width * height)))
new_height = int(height * math.sqrt(max_image_pixels / (width * height)))
else:
new_width = width
new_height = height
new_img = process_image(image, new_width, new_height)
new_img_binary_data = to_base64(new_img, image_config['imageCompressionRate'])
data, boundary = build_image_upload_api_payload(new_img_binary_data, tone)
headers = session.headers.copy()
headers["content-type"] = f'multipart/form-data; boundary={boundary}'
headers["referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
headers["origin"] = 'https://www.bing.com'
"""
Uploads an image to Bing's AI service and returns the image response.
Args:
session (ClientSession): The active session.
image_data (bytes): The image data to be uploaded.
tone (str): The tone of the conversation.
proxy (str, optional): Proxy if any. Defaults to None.
Raises:
RuntimeError: If the image upload fails.
Returns:
ImageResponse: The response from the image upload.
"""
image = to_image(image_data)
new_width, new_height = calculate_new_dimensions(image)
processed_img = process_image(image, new_width, new_height)
img_binary_data = to_base64(processed_img, IMAGE_CONFIG['imageCompressionRate'])
data, boundary = build_image_upload_payload(img_binary_data, tone)
headers = prepare_headers(session, boundary)
async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response:
if response.status != 200:
raise RuntimeError("Failed to upload image.")
image_info = await response.json()
if not image_info.get('blobId'):
raise RuntimeError("Failed to parse image info.")
result = {'bcid': image_info.get('blobId', "")}
result['blurredBcid'] = image_info.get('processedBlobId', "")
if result['blurredBcid'] != "":
result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['blurredBcid']
elif result['bcid'] != "":
result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['bcid']
result['originalImageUrl'] = (
"https://www.bing.com/images/blob?bcid="
+ result['blurredBcid']
if image_config["enableFaceBlurDebug"]
else "https://www.bing.com/images/blob?bcid="
+ result['bcid']
)
return ImageResponse(result["imageUrl"], "", result)
return parse_image_response(await response.json())
def build_image_upload_api_payload(image_bin: str, tone: str):
payload = {
def calculate_new_dimensions(image: Image.Image) -> Tuple[int, int]:
"""
Calculates the new dimensions for the image based on the maximum allowed pixels.
Args:
image (Image): The PIL Image object.
Returns:
Tuple[int, int]: The new width and height for the image.
"""
width, height = image.size
max_image_pixels = IMAGE_CONFIG['maxImagePixels']
if max_image_pixels / (width * height) < 1:
scale_factor = math.sqrt(max_image_pixels / (width * height))
return int(width * scale_factor), int(height * scale_factor)
return width, height
def build_image_upload_payload(image_bin: str, tone: str) -> Tuple[str, str]:
"""
Builds the payload for image uploading.
Args:
image_bin (str): Base64 encoded image binary data.
tone (str): The tone of the conversation.
Returns:
Tuple[str, str]: The data and boundary for the payload.
"""
boundary = "----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
data = f"--{boundary}\r\n" \
f"Content-Disposition: form-data; name=\"knowledgeRequest\"\r\n\r\n" \
f"{json.dumps(build_knowledge_request(tone), ensure_ascii=False)}\r\n" \
f"--{boundary}\r\n" \
f"Content-Disposition: form-data; name=\"imageBase64\"\r\n\r\n" \
f"{image_bin}\r\n" \
f"--{boundary}--\r\n"
return data, boundary
def build_knowledge_request(tone: str) -> dict:
"""
Builds the knowledge request payload.
Args:
tone (str): The tone of the conversation.
Returns:
dict: The knowledge request payload.
"""
return {
'invokedSkills': ["ImageById"],
'subscriptionId': "Bing.Chat.Multimodal",
'invokedSkillsRequestData': {
@ -69,21 +112,46 @@ def build_image_upload_api_payload(image_bin: str, tone: str):
'convotone': tone
}
}
knowledge_request = {
'imageInfo': {},
'knowledgeRequest': payload
}
boundary="----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
data = (
f'--{boundary}'
+ '\r\nContent-Disposition: form-data; name="knowledgeRequest"\r\n\r\n'
+ json.dumps(knowledge_request, ensure_ascii=False)
+ "\r\n--"
+ boundary
+ '\r\nContent-Disposition: form-data; name="imageBase64"\r\n\r\n'
+ image_bin
+ "\r\n--"
+ boundary
+ "--\r\n"
def prepare_headers(session: ClientSession, boundary: str) -> dict:
"""
Prepares the headers for the image upload request.
Args:
session (ClientSession): The active session.
boundary (str): The boundary string for the multipart/form-data.
Returns:
dict: The headers for the request.
"""
headers = session.headers.copy()
headers["Content-Type"] = f'multipart/form-data; boundary={boundary}'
headers["Referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
headers["Origin"] = 'https://www.bing.com'
return headers
def parse_image_response(response: dict) -> ImageResponse:
"""
Parses the response from the image upload.
Args:
response (dict): The response dictionary.
Raises:
RuntimeError: If parsing the image info fails.
Returns:
ImageResponse: The parsed image response.
"""
if not response.get('blobId'):
raise RuntimeError("Failed to parse image info.")
result = {'bcid': response.get('blobId', ""), 'blurredBcid': response.get('processedBlobId', "")}
result["imageUrl"] = f"https://www.bing.com/images/blob?bcid={result['blurredBcid'] or result['bcid']}"
result['originalImageUrl'] = (
f"https://www.bing.com/images/blob?bcid={result['blurredBcid']}"
if IMAGE_CONFIG["enableFaceBlurDebug"] else
f"https://www.bing.com/images/blob?bcid={result['bcid']}"
)
return data, boundary
return ImageResponse(result["imageUrl"], "", result)

View File

@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType
system_message = """
You can generate custom images with the DALL-E 3 image generator.
To generate a image with a prompt, do this:
To generate an image with a prompt, do this:
<img data-prompt=\"keywords for the image\">
Don't use images with data uri. It is important to use a prompt instead.
<img data-prompt=\"image caption\">
"""
class CreateImagesProvider(BaseProvider):
"""
Provider class for creating images based on text prompts.
This provider handles image creation requests embedded within message content,
using provided image creation functions.
Attributes:
provider (ProviderType): The underlying provider to handle non-image related tasks.
create_images (callable): A function to create images synchronously.
create_images_async (callable): A function to create images asynchronously.
system_message (str): A message that explains the image creation capability.
include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
__name__ (str): Name of the provider.
url (str): URL of the provider.
working (bool): Indicates if the provider is operational.
supports_stream (bool): Indicates if the provider supports streaming.
"""
def __init__(
self,
provider: ProviderType,
@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider):
system_message: str = system_message,
include_placeholder: bool = True
) -> None:
"""
Initializes the CreateImagesProvider.
Args:
provider (ProviderType): The underlying provider.
create_images (callable): Function to create images synchronously.
create_async (callable): Function to create images asynchronously.
system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
"""
self.provider = provider
self.create_images = create_images
self.create_images_async = create_async
@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
"""
Creates a completion result, processing any image creation prompts found within the messages.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process, which may contain image prompts.
stream (bool, optional): Indicates whether to stream the results. Defaults to False.
**kwargs: Additional keywordarguments for the provider.
Yields:
CreateResult: Yields chunks of the processed messages, including image data if applicable.
Note:
This method processes messages to detect image creation prompts. When such a prompt is found,
it calls the synchronous image creation function and includes the resulting image in the output.
"""
messages.insert(0, {"role": "system", "content": self.system_message})
buffer = ""
for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider):
messages: Messages,
**kwargs
) -> str:
"""
Asynchronously creates a response, processing any image creation prompts found within the messages.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process, which may contain image prompts.
**kwargs: Additional keyword arguments for the provider.
Returns:
str: The processed response string, including asynchronously generated image data if applicable.
Note:
This method processes messages to detect image creation prompts. When such a prompt is found,
it calls the asynchronous image creation function and includes the resulting image in the output.
"""
messages.insert(0, {"role": "system", "content": self.system_message})
response = await self.provider.create_async(model, messages, **kwargs)
matches = re.findall(r'(<img data-prompt="(.*?)">)', response)

View File

@ -1,36 +1,31 @@
from __future__ import annotations
import asyncio
import webbrowser
import random
import string
import secrets
import os
from os import path
import random
import secrets
import string
from asyncio import AbstractEventLoop, BaseEventLoop
from platformdirs import user_config_dir
from browser_cookie3 import (
chrome,
chromium,
opera,
opera_gx,
brave,
edge,
vivaldi,
firefox,
_LinuxPasswordManager
chrome, chromium, opera, opera_gx,
brave, edge, vivaldi, firefox,
_LinuxPasswordManager, BrowserCookieError
)
from ..typing import Dict, Messages
from .. import debug
# Local Cookie Storage
# Global variable to store cookies
_cookies: Dict[str, Dict[str, str]] = {}
# If loop closed or not set, create new event loop.
# If event loop is already running, handle nested event loops.
# If "nest_asyncio" is installed, patch the event loop.
def get_event_loop() -> AbstractEventLoop:
"""
Get the current asyncio event loop. If the loop is closed or not set, create a new event loop.
If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed.
Returns:
AbstractEventLoop: The current or new event loop.
"""
try:
loop = asyncio.get_event_loop()
if isinstance(loop, BaseEventLoop):
@ -39,61 +34,50 @@ def get_event_loop() -> AbstractEventLoop:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Is running event loop
asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
import nest_asyncio
nest_asyncio.apply(loop)
except RuntimeError:
# No running event loop
pass
except ImportError:
raise RuntimeError(
'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.'
)
return loop
def init_cookies():
urls = [
'https://chat-gpt.org',
'https://www.aitianhu.com',
'https://chatgptfree.ai',
'https://gptchatly.com',
'https://bard.google.com',
'https://huggingface.co/chat',
'https://open-assistant.io/chat'
]
browsers = ['google-chrome', 'chrome', 'firefox', 'safari']
def open_urls_in_browser(browser):
b = webbrowser.get(browser)
for url in urls:
b.open(url, new=0, autoraise=True)
for browser in browsers:
try:
open_urls_in_browser(browser)
break
except webbrowser.Error:
continue
# Check for broken dbus address in docker image
if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
# Load cookies for a domain from all supported browsers.
# Cache the results in the "_cookies" variable.
def get_cookies(domain_name=''):
def get_cookies(domain_name: str = '') -> Dict[str, str]:
"""
Load cookies for a given domain from all supported browsers and cache the results.
Args:
domain_name (str): The domain for which to load cookies.
Returns:
Dict[str, str]: A dictionary of cookie names and values.
"""
if domain_name in _cookies:
return _cookies[domain_name]
def g4f(domain_name):
user_data_dir = user_config_dir("g4f")
cookie_file = path.join(user_data_dir, "Default", "Cookies")
return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name)
cookies = _load_cookies_from_browsers(domain_name)
_cookies[domain_name] = cookies
return cookies
def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]:
"""
Helper function to load cookies from various browsers.
Args:
domain_name (str): The domain for which to load cookies.
Returns:
Dict[str, str]: A dictionary of cookie names and values.
"""
cookies = {}
for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
try:
cookie_jar = cookie_fn(domain_name=domain_name)
if len(cookie_jar) and debug.logging:
@ -101,13 +85,38 @@ def get_cookies(domain_name=''):
for cookie in cookie_jar:
if cookie.name not in cookies:
cookies[cookie.name] = cookie.value
except:
except BrowserCookieError:
pass
_cookies[domain_name] = cookies
return _cookies[domain_name]
except Exception as e:
if debug.logging:
print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
return cookies
def _g4f(domain_name: str) -> list:
"""
Load cookies from the 'g4f' browser (if exists).
Args:
domain_name (str): The domain for which to load cookies.
Returns:
list: List of cookies.
"""
user_data_dir = user_config_dir("g4f")
cookie_file = os.path.join(user_data_dir, "Default", "Cookies")
return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name)
def format_prompt(messages: Messages, add_special_tokens=False) -> str:
"""
Format a series of messages into a single string, optionally adding special tokens.
Args:
messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
add_special_tokens (bool): Whether to add special formatting tokens.
Returns:
str: A formatted string containing all messages.
"""
if not add_special_tokens and len(messages) <= 1:
return messages[0]["content"]
formatted = "\n".join([
@ -116,12 +125,26 @@ def format_prompt(messages: Messages, add_special_tokens=False) -> str:
])
return f"{formatted}\nAssistant:"
def get_random_string(length: int = 10) -> str:
"""
Generate a random string of specified length, containing lowercase letters and digits.
Args:
length (int, optional): Length of the random string to generate. Defaults to 10.
Returns:
str: A random string of the specified length.
"""
return ''.join(
random.choice(string.ascii_lowercase + string.digits)
for _ in range(length)
)
def get_random_hex() -> str:
"""
Generate a random hexadecimal string of a fixed length.
Returns:
str: A random hexadecimal string of 32 characters (16 bytes).
"""
return secrets.token_hex(16).zfill(32)

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import asyncio
import uuid
import json
import os
import uuid, json, asyncio, os
from py_arkose_generator.arkose import get_values_for_request
from async_property import async_cached_property
from selenium.webdriver.common.by import By
@ -14,7 +17,8 @@ from ...typing import AsyncResult, Messages
from ...requests import StreamSession
from ...image import to_image, to_bytes, ImageType, ImageResponse
models = {
# Aliases for model names
MODELS = {
"gpt-3.5": "text-davinci-002-render-sha",
"gpt-3.5-turbo": "text-davinci-002-render-sha",
"gpt-4": "gpt-4",
@ -22,6 +26,8 @@ models = {
}
class OpenaiChat(AsyncGeneratorProvider):
"""A class for creating and managing conversations with OpenAI chat service"""
url = "https://chat.openai.com"
working = True
needs_auth = True
@ -43,6 +49,23 @@ class OpenaiChat(AsyncGeneratorProvider):
image: ImageType = None,
**kwargs
) -> Response:
"""Create a new conversation or continue an existing one
Args:
prompt: The user input to start or continue the conversation
model: The name of the model to use for generating responses
messages: The list of previous messages in the conversation
history_disabled: A flag indicating if the history and training should be disabled
action: The type of action to perform, either "next", "continue", or "variant"
conversation_id: The ID of the existing conversation, if any
parent_id: The ID of the parent message, if any
image: The image to include in the user input, if any
**kwargs: Additional keyword arguments to pass to the generator
Returns:
A Response object that contains the generator, action, messages, and options
"""
# Add the user input to the messages list
if prompt:
messages.append({
"role": "user",
@ -67,20 +90,33 @@ class OpenaiChat(AsyncGeneratorProvider):
)
@classmethod
async def upload_image(
async def _upload_image(
cls,
session: StreamSession,
headers: dict,
image: ImageType
) -> ImageResponse:
"""Upload an image to the service and get the download URL
Args:
session: The StreamSession object to use for requests
headers: The headers to include in the requests
image: The image to upload, either a PIL Image object or a bytes object
Returns:
An ImageResponse object that contains the download URL, file name, and other data
"""
# Convert the image to a PIL Image object and get the extension
image = to_image(image)
extension = image.format.lower()
# Convert the image to a bytes object and get the size
data_bytes = to_bytes(image)
data = {
"file_name": f"{image.width}x{image.height}.{extension}",
"file_size": len(data_bytes),
"use_case": "multimodal"
}
# Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
response.raise_for_status()
image_data = {
@ -91,6 +127,7 @@ class OpenaiChat(AsyncGeneratorProvider):
"height": image.height,
"width": image.width
}
# Put the image bytes to the upload URL and check the status
async with session.put(
image_data["upload_url"],
data=data_bytes,
@ -100,6 +137,7 @@ class OpenaiChat(AsyncGeneratorProvider):
}
) as response:
response.raise_for_status()
# Post the file ID to the service and get the download URL
async with session.post(
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
json={},
@ -110,24 +148,45 @@ class OpenaiChat(AsyncGeneratorProvider):
return ImageResponse(download_url, image_data["file_name"], image_data)
@classmethod
async def get_default_model(cls, session: StreamSession, headers: dict):
async def _get_default_model(cls, session: StreamSession, headers: dict):
"""Get the default model name from the service
Args:
session: The StreamSession object to use for requests
headers: The headers to include in the requests
Returns:
The default model name as a string
"""
# Check the cache for the default model
if cls._default_model:
model = cls._default_model
else:
return cls._default_model
# Get the models data from the service
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
data = await response.json()
if "categories" in data:
model = data["categories"][-1]["default_model"]
cls._default_model = data["categories"][-1]["default_model"]
else:
RuntimeError(f"Response: {data}")
cls._default_model = model
return model
raise RuntimeError(f"Response: {data}")
return cls._default_model
@classmethod
def create_messages(cls, prompt: str, image_response: ImageResponse = None):
def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
"""Create a list of messages for the user input
Args:
prompt: The user input as a string
image_response: The image response object, if any
Returns:
A list of messages with the user input and the image, if any
"""
# Check if there is an image response
if not image_response:
# Create a content object with the text type and the prompt
content = {"content_type": "text", "parts": [prompt]}
else:
# Create a content object with the multimodal text type and the image and the prompt
content = {
"content_type": "multimodal_text",
"parts": [{
@ -137,12 +196,15 @@ class OpenaiChat(AsyncGeneratorProvider):
"width": image_response.get("width"),
}, prompt]
}
# Create a message object with the user role and the content
messages = [{
"id": str(uuid.uuid4()),
"author": {"role": "user"},
"content": content,
}]
# Check if there is an image response
if image_response:
# Add the metadata object with the attachments
messages[0]["metadata"] = {
"attachments": [{
"height": image_response.get("height"),
@ -156,19 +218,38 @@ class OpenaiChat(AsyncGeneratorProvider):
return messages
@classmethod
async def get_image_response(cls, session: StreamSession, headers: dict, line: dict):
if "parts" in line["message"]["content"]:
part = line["message"]["content"]["parts"][0]
if "asset_pointer" in part and part["metadata"]:
file_id = part["asset_pointer"].split("file-service://", 1)[1]
prompt = part["metadata"]["dalle"]["prompt"]
async with session.get(
f"{cls.url}/backend-api/files/{file_id}/download",
headers=headers
) as response:
async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
"""
Retrieves the image response based on the message content.
:param session: The StreamSession object.
:param headers: HTTP headers for the request.
:param line: The line of response containing image information.
:return: An ImageResponse object with the image details.
"""
if "parts" not in line["message"]["content"]:
return
first_part = line["message"]["content"]["parts"][0]
if "asset_pointer" not in first_part or "metadata" not in first_part:
return
file_id = first_part["asset_pointer"].split("file-service://", 1)[1]
prompt = first_part["metadata"]["dalle"]["prompt"]
try:
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
response.raise_for_status()
download_url = (await response.json())["download_url"]
return ImageResponse(download_url, prompt)
except Exception as e:
raise RuntimeError(f"Error in downloading image: {e}")
@classmethod
async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
async with session.patch(
f"{cls.url}/backend-api/conversation/{conversation_id}",
json={"is_visible": False},
headers=headers
) as response:
response.raise_for_status()
@classmethod
async def create_async_generator(
@ -188,26 +269,47 @@ class OpenaiChat(AsyncGeneratorProvider):
response_fields: bool = False,
**kwargs
) -> AsyncResult:
if model in models:
model = models[model]
"""
Create an asynchronous generator for the conversation.
Args:
model (str): The model name.
messages (Messages): The list of previous messages.
proxy (str): Proxy to use for requests.
timeout (int): Timeout for requests.
access_token (str): Access token for authentication.
cookies (dict): Cookies to use for authentication.
auto_continue (bool): Flag to automatically continue the conversation.
history_disabled (bool): Flag to disable history and training.
action (str): Type of action ('next', 'continue', 'variant').
conversation_id (str): ID of the conversation.
parent_id (str): ID of the parent message.
image (ImageType): Image to include in the conversation.
response_fields (bool): Flag to include response fields in the output.
**kwargs: Additional keyword arguments.
Yields:
AsyncResult: Asynchronous results from the generator.
Raises:
RuntimeError: If an error occurs during processing.
"""
model = MODELS.get(model, model)
if not parent_id:
parent_id = str(uuid.uuid4())
if not cookies:
cookies = cls._cookies
if not access_token:
if not cookies:
cls._cookies = cookies = get_cookies("chat.openai.com")
if "access_token" in cookies:
cookies = cls._cookies or get_cookies("chat.openai.com")
if not access_token and "access_token" in cookies:
access_token = cookies["access_token"]
if not access_token:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n"
access_token, cookies = cls.browse_access_token(proxy)
access_token, cookies = cls._browse_access_token(proxy)
cls._cookies = cookies
headers = {
"Authorization": f"Bearer {access_token}",
}
headers = {"Authorization": f"Bearer {access_token}"}
async with StreamSession(
proxies={"https": proxy},
impersonate="chrome110",
@ -215,11 +317,11 @@ class OpenaiChat(AsyncGeneratorProvider):
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
) as session:
if not model:
model = await cls.get_default_model(session, headers)
model = await cls._get_default_model(session, headers)
try:
image_response = None
if image:
image_response = await cls.upload_image(session, headers, image)
image_response = await cls._upload_image(session, headers, image)
yield image_response
except Exception as e:
yield e
@ -227,7 +329,7 @@ class OpenaiChat(AsyncGeneratorProvider):
while not end_turn.is_end:
data = {
"action": action,
"arkose_token": await cls.get_arkose_token(session),
"arkose_token": await cls._get_arkose_token(session),
"conversation_id": conversation_id,
"parent_message_id": parent_id,
"model": model,
@ -235,7 +337,7 @@ class OpenaiChat(AsyncGeneratorProvider):
}
if action != "continue":
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
data["messages"] = cls.create_messages(prompt, image_response)
data["messages"] = cls._create_messages(prompt, image_response)
async with session.post(
f"{cls.url}/backend-api/conversation",
json=data,
@ -261,14 +363,17 @@ class OpenaiChat(AsyncGeneratorProvider):
if "message_type" not in line["message"]["metadata"]:
continue
try:
image_response = await cls.get_image_response(session, headers, line)
image_response = await cls._get_generated_image(session, headers, line)
if image_response:
yield image_response
except Exception as e:
yield e
if line["message"]["author"]["role"] != "assistant":
continue
if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"):
if line["message"]["content"]["content_type"] != "text":
continue
if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
continue
conversation_id = line["conversation_id"]
parent_id = line["message"]["id"]
if response_fields:
@ -282,41 +387,56 @@ class OpenaiChat(AsyncGeneratorProvider):
if "finish_details" in line["message"]["metadata"]:
if line["message"]["metadata"]["finish_details"]["type"] == "stop":
end_turn.end()
break
except Exception as e:
yield e
raise e
if not auto_continue:
break
action = "continue"
await asyncio.sleep(5)
if history_disabled:
async with session.patch(
f"{cls.url}/backend-api/conversation/{conversation_id}",
json={"is_visible": False},
headers=headers
) as response:
response.raise_for_status()
if history_disabled and auto_continue:
await cls._delete_conversation(session, headers, conversation_id)
@classmethod
def browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
"""
Browse to obtain an access token.
Args:
proxy (str): Proxy to use for browsing.
Returns:
tuple[str, dict]: A tuple containing the access token and cookies.
"""
driver = get_browser(proxy=proxy)
try:
driver.get(f"{cls.url}/")
WebDriverWait(driver, 1200).until(
EC.presence_of_element_located((By.ID, "prompt-textarea"))
WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
access_token = driver.execute_script(
"let session = await fetch('/api/auth/session');"
"let data = await session.json();"
"let accessToken = data['accessToken'];"
"let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7);"
"document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
"return accessToken;"
)
javascript = """
access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
return access_token;
"""
return driver.execute_script(javascript), get_driver_cookies(driver)
return access_token, get_driver_cookies(driver)
finally:
driver.quit()
@classmethod
async def get_arkose_token(cls, session: StreamSession) -> str:
async def _get_arkose_token(cls, session: StreamSession) -> str:
"""
Obtain an Arkose token for the session.
Args:
session (StreamSession): The session object.
Returns:
str: The Arkose token.
Raises:
RuntimeError: If unable to retrieve the token.
"""
config = {
"pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
"surl": "https://tcr9i.chat.openai.com",
@ -333,25 +453,29 @@ return access_token;
return decoded_json["token"]
raise RuntimeError(f"Response: {decoded_json}")
class EndTurn():
class EndTurn:
"""
Class to represent the end of a conversation turn.
"""
def __init__(self):
self.is_end = False
def end(self):
self.is_end = True
class ResponseFields():
def __init__(
self,
conversation_id: str,
message_id: str,
end_turn: EndTurn
):
class ResponseFields:
"""
Class to encapsulate response fields.
"""
def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
self.conversation_id = conversation_id
self.message_id = message_id
self._end_turn = end_turn
class Response():
"""
Class to encapsulate a response from the chat service.
"""
def __init__(
self,
generator: AsyncResult,
@ -360,8 +484,8 @@ class Response():
options: dict
):
self._generator = generator
self.action: str = action
self.is_end: bool = False
self.action = action
self.is_end = False
self._message = None
self._messages = messages
self._options = options
@ -387,15 +511,12 @@ class Response():
@async_cached_property
async def message(self) -> str:
[_ async for _ in self.generator()]
await self.generator()
return self._message
async def get_fields(self):
[_ async for _ in self.generator()]
return {
"conversation_id": self._fields.conversation_id,
"parent_id": self._fields.message_id,
}
await self.generator()
return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id}
async def next(self, prompt: str, **kwargs) -> Response:
return await OpenaiChat.create(
@ -433,7 +554,5 @@ class Response():
@async_cached_property
async def messages(self):
messages = self._messages
messages.append({
"role": "assistant", "content": await self.message
})
messages.append({"role": "assistant", "content": await self.message})
return messages

View File

@ -7,8 +7,17 @@ from ..base_provider import BaseRetryProvider
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
class RetryProvider(BaseRetryProvider):
"""
A provider class to handle retries for creating completions with different providers.
Attributes:
providers (list): A list of provider instances.
shuffle (bool): A flag indicating whether to shuffle providers before use.
exceptions (dict): A dictionary to store exceptions encountered during retries.
last_provider (BaseProvider): The last provider that was used.
"""
def create_completion(
self,
model: str,
@ -16,10 +25,21 @@ class RetryProvider(BaseRetryProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
if stream:
providers = [provider for provider in self.providers if provider.supports_stream]
else:
providers = self.providers
"""
Create a completion using available providers, with an option to stream the response.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
Yields:
CreateResult: Tokens or results from the completion.
Raises:
Exception: Any exception encountered during the completion process.
"""
providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
if self.shuffle:
random.shuffle(providers)
@ -50,6 +70,19 @@ class RetryProvider(BaseRetryProvider):
messages: Messages,
**kwargs
) -> str:
"""
Asynchronously create a completion using available providers.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
Returns:
str: The result of the asynchronous completion.
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
providers = self.providers
if self.shuffle:
random.shuffle(providers)
@ -70,6 +103,13 @@ class RetryProvider(BaseRetryProvider):
self.raise_exceptions()
def raise_exceptions(self) -> None:
"""
Raise a combined exception if any occurred during retries.
Raises:
RetryProviderError: If any provider encountered an exception.
RetryNoProviderError: If no provider is found.
"""
if self.exceptions:
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()

View File

@ -15,6 +15,26 @@ def get_model_and_provider(model : Union[Model, str],
ignored : list[str] = None,
ignore_working: bool = False,
ignore_stream: bool = False) -> tuple[str, ProviderType]:
"""
Retrieves the model and provider based on input parameters.
Args:
model (Union[Model, str]): The model to use, either as an object or a string identifier.
provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
stream (bool): Indicates if the operation should be performed as a stream.
ignored (list[str], optional): List of provider names to be ignored.
ignore_working (bool, optional): If True, ignores the working status of the provider.
ignore_stream (bool, optional): If True, ignores the streaming capability of the provider.
Returns:
tuple[str, ProviderType]: A tuple containing the model name and the provider type.
Raises:
ProviderNotFoundError: If the provider is not found.
ModelNotFoundError: If the model is not found.
ProviderNotWorkingError: If the provider is not working.
StreamNotSupportedError: If streaming is not supported by the provider.
"""
if debug.version_check:
debug.version_check = False
version.utils.check_version()
@ -70,7 +90,30 @@ class ChatCompletion:
ignore_stream_and_auth: bool = False,
patch_provider: callable = None,
**kwargs) -> Union[CreateResult, str]:
"""
Creates a chat completion using the specified model, provider, and messages.
Args:
model (Union[Model, str]): The model to use, either as an object or a string identifier.
messages (Messages): The messages for which the completion is to be created.
provider (Union[ProviderType, str, None], optional): The provider to use, either as an object, a string identifier, or None.
stream (bool, optional): Indicates if the operation should be performed as a stream.
auth (Union[str, None], optional): Authentication token or credentials, if required.
ignored (list[str], optional): List of provider names to be ignored.
ignore_working (bool, optional): If True, ignores the working status of the provider.
ignore_stream_and_auth (bool, optional): If True, ignores the stream and authentication requirement checks.
patch_provider (callable, optional): Function to modify the provider.
**kwargs: Additional keyword arguments.
Returns:
Union[CreateResult, str]: The result of the chat completion operation.
Raises:
AuthenticationRequiredError: If authentication is required but not provided.
ProviderNotFoundError, ModelNotFoundError: If the specified provider or model is not found.
ProviderNotWorkingError: If the provider is not operational.
StreamNotSupportedError: If streaming is requested but not supported by the provider.
"""
model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth)
if not ignore_stream_and_auth and provider.needs_auth and not auth:
@ -98,7 +141,24 @@ class ChatCompletion:
ignored : list[str] = None,
patch_provider: callable = None,
**kwargs) -> Union[AsyncResult, str]:
"""
Asynchronously creates a completion using the specified model and provider.
Args:
model (Union[Model, str]): The model to use, either as an object or a string identifier.
messages (Messages): Messages to be processed.
provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
stream (bool): Indicates if the operation should be performed as a stream.
ignored (list[str], optional): List of provider names to be ignored.
patch_provider (callable, optional): Function to modify the provider.
**kwargs: Additional keyword arguments.
Returns:
Union[AsyncResult, str]: The result of the asynchronous chat completion operation.
Raises:
StreamNotSupportedError: If streaming is requested but not supported by the provider.
"""
model, provider = get_model_and_provider(model, provider, False, ignored)
if stream:
@ -118,7 +178,23 @@ class Completion:
provider : Union[ProviderType, None] = None,
stream : bool = False,
ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]:
"""
Creates a completion based on the provided model, prompt, and provider.
Args:
model (Union[Model, str]): The model to use, either as an object or a string identifier.
prompt (str): The prompt text for which the completion is to be created.
provider (Union[ProviderType, None], optional): The provider to use, either as an object or None.
stream (bool, optional): Indicates if the operation should be performed as a stream.
ignored (list[str], optional): List of provider names to be ignored.
**kwargs: Additional keyword arguments.
Returns:
Union[CreateResult, str]: The result of the completion operation.
Raises:
ModelNotAllowedError: If the specified model is not allowed for use with this method.
"""
allowed_models = [
'code-davinci-002',
'text-ada-001',
@ -137,6 +213,15 @@ class Completion:
return result if stream else ''.join(result)
def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str]]:
"""
Retrieves the last used provider.
Args:
as_dict (bool, optional): If True, returns the provider information as a dictionary.
Returns:
Union[ProviderType, dict[str, str]]: The last used provider, either as an object or a dictionary.
"""
last = debug.last_provider
if isinstance(last, BaseRetryProvider):
last = last.last_provider

View File

@ -1,7 +1,22 @@
from abc import ABC, abstractmethod
from .typing import Messages, CreateResult, Union
from typing import Union, List, Dict, Type
from .typing import Messages, CreateResult
class BaseProvider(ABC):
"""
Abstract base class for a provider.
Attributes:
url (str): URL of the provider.
working (bool): Indicates if the provider is currently working.
needs_auth (bool): Indicates if the provider needs authentication.
supports_stream (bool): Indicates if the provider supports streaming.
supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
supports_message_history (bool): Indicates if the provider supports message history.
params (str): List parameters for the provider.
"""
url: str = None
working: bool = False
needs_auth: bool = False
@ -20,6 +35,18 @@ class BaseProvider(ABC):
stream: bool,
**kwargs
) -> CreateResult:
"""
Create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
stream (bool): Whether to use streaming.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the creation process.
"""
raise NotImplementedError()
@classmethod
@ -30,25 +57,59 @@ class BaseProvider(ABC):
messages: Messages,
**kwargs
) -> str:
"""
Asynchronously create a completion with the given parameters.
Args:
model (str): The model to use.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Returns:
str: The result of the creation process.
"""
raise NotImplementedError()
@classmethod
def get_dict(cls):
def get_dict(cls) -> Dict[str, str]:
"""
Get a dictionary representation of the provider.
Returns:
Dict[str, str]: A dictionary with provider's details.
"""
return {'name': cls.__name__, 'url': cls.url}
class BaseRetryProvider(BaseProvider):
"""
Base class for a provider that implements retry logic.
Attributes:
providers (List[Type[BaseProvider]]): List of providers to use for retries.
shuffle (bool): Whether to shuffle the providers list.
exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
last_provider (Type[BaseProvider]): The last provider used.
"""
__name__: str = "RetryProvider"
supports_stream: bool = True
def __init__(
self,
providers: list[type[BaseProvider]],
providers: List[Type[BaseProvider]],
shuffle: bool = True
) -> None:
self.providers: list[type[BaseProvider]] = providers
self.shuffle: bool = shuffle
self.working: bool = True
self.exceptions: dict[str, Exception] = {}
self.last_provider: type[BaseProvider] = None
"""
Initialize the BaseRetryProvider.
ProviderType = Union[type[BaseProvider], BaseRetryProvider]
Args:
providers (List[Type[BaseProvider]]): List of providers to use.
shuffle (bool): Whether to shuffle the providers list.
"""
self.providers = providers
self.shuffle = shuffle
self.working = True
self.exceptions: Dict[str, Exception] = {}
self.last_provider: Type[BaseProvider] = None
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]

View File

@ -404,7 +404,7 @@ body {
display: none;
}
#image {
#image, #file {
display: none;
}
@ -412,13 +412,22 @@ label[for="image"]:has(> input:valid){
color: var(--accent);
}
label[for="image"] {
label[for="file"]:has(> input:valid){
color: var(--accent);
}
label[for="image"], label[for="file"] {
cursor: pointer;
position: absolute;
top: 10px;
left: 10px;
}
label[for="file"] {
top: 32px;
left: 10px;
}
.buttons input[type="checkbox"] {
height: 0;
width: 0;

View File

@ -118,6 +118,10 @@
<input type="file" id="image" name="image" accept="image/png, image/gif, image/jpeg" required/>
<i class="fa-regular fa-image"></i>
</label>
<label for="file">
<input type="file" id="file" name="file" accept="text/plain, text/html, text/xml, application/json, text/javascript, .sh, .py, .php, .css, .yaml, .sql, .svg, .log, .csv, .twig, .md" required/>
<i class="fa-solid fa-paperclip"></i>
</label>
<div id="send-button">
<i class="fa-solid fa-paper-plane-top"></i>
</div>
@ -125,7 +129,14 @@
</div>
<div class="buttons">
<div class="field">
<select name="model" id="model"></select>
<select name="model" id="model">
<option value="">Model: Default</option>
<option value="gpt-4">gpt-4</option>
<option value="gpt-3.5-turbo">gpt-3.5-turbo</option>
<option value="llama2-70b">llama2-70b</option>
<option value="gemini-pro">gemini-pro</option>
<option value="">----</option>
</select>
</div>
<div class="field">
<select name="jailbreak" id="jailbreak" style="display: none;">
@ -138,7 +149,16 @@
<option value="gpt-evil-1.0">evil 1.0</option>
</select>
<div class="field">
<select name="provider" id="provider"></select>
<select name="provider" id="provider">
<option value="">Provider: Auto</option>
<option value="Bing">Bing</option>
<option value="OpenaiChat">OpenaiChat</option>
<option value="HuggingChat">HuggingChat</option>
<option value="Bard">Bard</option>
<option value="Liaobots">Liaobots</option>
<option value="Phind">Phind</option>
<option value="">----</option>
</select>
</div>
</div>
<div class="field">

View File

@ -7,7 +7,9 @@ const spinner = box_conversations.querySelector(".spinner");
const stop_generating = document.querySelector(`.stop_generating`);
const regenerate = document.querySelector(`.regenerate`);
const send_button = document.querySelector(`#send-button`);
const imageInput = document.querySelector('#image') ;
const imageInput = document.querySelector('#image');
const fileInput = document.querySelector('#file');
let prompt_lock = false;
hljs.addPlugin(new CopyButtonPlugin());
@ -42,6 +44,11 @@ const handle_ask = async () => {
if (message.length > 0) {
message_input.value = '';
await add_conversation(window.conversation_id, message);
if ("text" in fileInput.dataset) {
message += '\n```' + fileInput.dataset.type + '\n';
message += fileInput.dataset.text;
message += '\n```'
}
await add_message(window.conversation_id, "user", message);
window.token = message_id();
message_box.innerHTML += `
@ -55,6 +62,9 @@ const handle_ask = async () => {
</div>
</div>
`;
document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
await ask_gpt();
}
};
@ -171,17 +181,30 @@ const ask_gpt = async () => {
content_inner.innerHTML += "<p>An error occured, please try again, if the problem persists, please use a other model or provider.</p>";
} else {
html = markdown_render(text);
html = html.substring(0, html.lastIndexOf('</p>')) + '<span id="cursor"></span></p>';
let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</li>\n</ol>']) {
const index = html.lastIndexOf(element)
if (index > lastIndex) {
lastElement = element;
lastIndex = index;
}
}
if (lastIndex) {
html = html.substring(0, lastIndex) + '<span id="cursor"></span>' + lastElement;
}
content_inner.innerHTML = html;
document.querySelectorAll('code').forEach((el) => {
document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
}
window.scrollTo(0, 0);
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
}
}
if (!error && imageInput) imageInput.value = "";
if (!error && fileInput) fileInput.value = "";
} catch (e) {
console.error(e);
@ -305,7 +328,7 @@ const load_conversation = async (conversation_id) => {
`;
}
document.querySelectorAll(`code`).forEach((el) => {
document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
@ -400,7 +423,7 @@ const load_conversations = async (limit, offset, loader) => {
`;
}
document.querySelectorAll(`code`).forEach((el) => {
document.querySelectorAll('code:not(.hljs').forEach((el) => {
hljs.highlightElement(el);
});
};
@ -602,14 +625,7 @@ observer.observe(message_input, { attributes: true });
(async () => {
response = await fetch('/backend-api/v2/models')
models = await response.json()
let select = document.getElementById('model');
select.textContent = '';
let auto = document.createElement('option');
auto.value = '';
auto.text = 'Model: Default';
select.appendChild(auto);
for (model of models) {
let option = document.createElement('option');
@ -619,14 +635,7 @@ observer.observe(message_input, { attributes: true });
response = await fetch('/backend-api/v2/providers')
providers = await response.json()
select = document.getElementById('provider');
select.textContent = '';
auto = document.createElement('option');
auto.value = '';
auto.text = 'Provider: Auto';
select.appendChild(auto);
for (provider of providers) {
let option = document.createElement('option');
@ -643,11 +652,34 @@ observer.observe(message_input, { attributes: true });
document.title = 'g4f - gui - ' + versions["version"];
text = "version ~ "
if (versions["version"] != versions["lastet_version"]) {
release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["lastet_version"];
text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["lastet_version"] +'">' + versions["version"] + ' 🆕</a>';
if (versions["version"] != versions["latest_version"]) {
release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"];
text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["latest_version"] +'">' + versions["version"] + ' 🆕</a>';
} else {
text += versions["version"];
}
document.getElementById("version_text").innerHTML = text
})()
fileInput.addEventListener('change', async (event) => {
if (fileInput.files.length) {
type = fileInput.files[0].type;
if (type && type.indexOf('/')) {
type = type.split('/').pop().replace('x-', '')
type = type.replace('plain', 'plaintext')
.replace('shellscript', 'sh')
.replace('svg+xml', 'svg')
.replace('vnd.trolltech.linguist', 'ts')
} else {
type = fileInput.files[0].name.split('.').pop()
}
fileInput.dataset.type = type
const reader = new FileReader();
reader.addEventListener('load', (event) => {
fileInput.dataset.text = event.target.result;
});
reader.readAsText(fileInput.files[0]);
} else {
delete fileInput.dataset.text;
}
});

View File

@ -1,6 +1,7 @@
import logging
import json
from flask import request, Flask
from typing import Generator
from g4f import debug, version, models
from g4f import _all_models, get_last_provider, ChatCompletion
from g4f.image import is_allowed_extension, to_image
@ -11,60 +12,123 @@ from .internet import get_search_message
debug.logging = True
class Backend_Api:
"""
Handles various endpoints in a Flask application for backend operations.
This class provides methods to interact with models, providers, and to handle
various functionalities like conversations, error handling, and version management.
Attributes:
app (Flask): A Flask application instance.
routes (dict): A dictionary mapping API endpoints to their respective handlers.
"""
def __init__(self, app: Flask) -> None:
"""
Initialize the backend API with the given Flask application.
Args:
app (Flask): Flask application instance to attach routes to.
"""
self.app: Flask = app
self.routes = {
'/backend-api/v2/models': {
'function': self.models,
'methods' : ['GET']
'function': self.get_models,
'methods': ['GET']
},
'/backend-api/v2/providers': {
'function': self.providers,
'methods' : ['GET']
'function': self.get_providers,
'methods': ['GET']
},
'/backend-api/v2/version': {
'function': self.version,
'methods' : ['GET']
'function': self.get_version,
'methods': ['GET']
},
'/backend-api/v2/conversation': {
'function': self._conversation,
'function': self.handle_conversation,
'methods': ['POST']
},
'/backend-api/v2/gen.set.summarize:title': {
'function': self._gen_title,
'function': self.generate_title,
'methods': ['POST']
},
'/backend-api/v2/error': {
'function': self.error,
'function': self.handle_error,
'methods': ['POST']
}
}
def error(self):
print(request.json)
def handle_error(self):
"""
Initialize the backend API with the given Flask application.
Args:
app (Flask): Flask application instance to attach routes to.
"""
print(request.json)
return 'ok', 200
def models(self):
def get_models(self):
"""
Return a list of all models.
Fetches and returns a list of all available models in the system.
Returns:
List[str]: A list of model names.
"""
return _all_models
def providers(self):
return [
provider.__name__ for provider in __providers__ if provider.working
]
def get_providers(self):
"""
Return a list of all working providers.
"""
return [provider.__name__ for provider in __providers__ if provider.working]
def version(self):
def get_version(self):
"""
Returns the current and latest version of the application.
Returns:
dict: A dictionary containing the current and latest version.
"""
return {
"version": version.utils.current_version,
"lastet_version": version.get_latest_version(),
"latest_version": version.get_latest_version(),
}
def _gen_title(self):
return {
'title': ''
}
def generate_title(self):
"""
Generates and returns a title based on the request data.
def _conversation(self):
Returns:
dict: A dictionary with the generated title.
"""
return {'title': ''}
def handle_conversation(self):
"""
Handles conversation requests and streams responses back.
Returns:
Response: A Flask response object for streaming.
"""
kwargs = self._prepare_conversation_kwargs()
return self.app.response_class(
self._create_response_stream(kwargs),
mimetype='text/event-stream'
)
def _prepare_conversation_kwargs(self):
"""
Prepares arguments for chat completion based on the request data.
Reads the request and prepares the necessary arguments for handling
a chat completion request.
Returns:
dict: Arguments prepared for chat completion.
"""
kwargs = {}
if 'image' in request.files:
file = request.files['image']
@ -87,47 +151,70 @@ class Backend_Api:
messages[-1]["content"] = get_search_message(messages[-1]["content"])
model = json_data.get('model')
model = model if model else models.default
provider = json_data.get('provider', '').replace('g4f.Provider.', '')
provider = provider if provider and provider != "Auto" else None
patch = patch_provider if json_data.get('patch_provider') else None
def try_response():
return {
"model": model,
"provider": provider,
"messages": messages,
"stream": True,
"ignore_stream_and_auth": True,
"patch_provider": patch,
**kwargs
}
def _create_response_stream(self, kwargs) -> Generator[str, None, None]:
"""
Creates and returns a streaming response for the conversation.
Args:
kwargs (dict): Arguments for creating the chat completion.
Yields:
str: JSON formatted response chunks for the stream.
Raises:
Exception: If an error occurs during the streaming process.
"""
try:
first = True
for chunk in ChatCompletion.create(
model=model,
provider=provider,
messages=messages,
stream=True,
ignore_stream_and_auth=True,
patch_provider=patch,
**kwargs
):
for chunk in ChatCompletion.create(**kwargs):
if first:
first = False
yield json.dumps({
'type' : 'provider',
'provider': get_last_provider(True)
}) + "\n"
yield self._format_json('provider', get_last_provider(True))
if isinstance(chunk, Exception):
logging.exception(chunk)
yield json.dumps({
'type' : 'message',
'message': get_error_message(chunk),
}) + "\n"
yield self._format_json('message', get_error_message(chunk))
else:
yield json.dumps({
'type' : 'content',
'content': str(chunk),
}) + "\n"
yield self._format_json('content', str(chunk))
except Exception as e:
logging.exception(e)
yield json.dumps({
'type' : 'error',
'error': get_error_message(e)
})
yield self._format_json('error', get_error_message(e))
return self.app.response_class(try_response(), mimetype='text/event-stream')
def _format_json(self, response_type: str, content) -> str:
"""
Formats and returns a JSON response.
Args:
response_type (str): The type of the response.
content: The content to be included in the response.
Returns:
str: A JSON formatted string.
"""
return json.dumps({
'type': response_type,
response_type: content
}) + "\n"
def get_error_message(exception: Exception) -> str:
"""
Generates a formatted error message from an exception.
Args:
exception (Exception): The exception to format.
Returns:
str: A formatted error message string.
"""
return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}"

View File

@ -4,9 +4,18 @@ import base64
from .typing import ImageType, Union
from PIL import Image
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
def to_image(image: ImageType) -> Image.Image:
"""
Converts the input image to a PIL Image object.
Args:
image (Union[str, bytes, Image.Image]): The input image.
Returns:
Image.Image: The converted PIL Image object.
"""
if isinstance(image, str):
is_data_uri_an_image(image)
image = extract_data_uri(image)
@ -20,11 +29,29 @@ def to_image(image: ImageType) -> Image.Image:
image = copy
return image
def is_allowed_extension(filename) -> bool:
def is_allowed_extension(filename: str) -> bool:
"""
Checks if the given filename has an allowed extension.
Args:
filename (str): The filename to check.
Returns:
bool: True if the extension is allowed, False otherwise.
"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_data_uri_an_image(data_uri: str) -> bool:
"""
Checks if the given data URI represents an image.
Args:
data_uri (str): The data URI to check.
Raises:
ValueError: If the data URI is invalid or the image format is not allowed.
"""
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
if not re.match(r'data:image/(\w+);base64,', data_uri):
raise ValueError("Invalid data URI image.")
@ -35,6 +62,15 @@ def is_data_uri_an_image(data_uri: str) -> bool:
raise ValueError("Invalid image format (from mime file type).")
def is_accepted_format(binary_data: bytes) -> bool:
"""
Checks if the given binary data represents an image with an accepted format.
Args:
binary_data (bytes): The binary data to check.
Raises:
ValueError: If the image format is not allowed.
"""
if binary_data.startswith(b'\xFF\xD8\xFF'):
pass # It's a JPEG image
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
@ -51,11 +87,29 @@ def is_accepted_format(binary_data: bytes) -> bool:
raise ValueError("Invalid image format (from magic code).")
def extract_data_uri(data_uri: str) -> bytes:
"""
Extracts the binary data from the given data URI.
Args:
data_uri (str): The data URI.
Returns:
bytes: The extracted binary data.
"""
data = data_uri.split(",")[1]
data = base64.b64decode(data)
return data
def get_orientation(image: Image.Image) -> int:
"""
Gets the orientation of the given image.
Args:
image (Image.Image): The image.
Returns:
int: The orientation value.
"""
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
if exif_data is not None:
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
@ -63,6 +117,17 @@ def get_orientation(image: Image.Image) -> int:
return orientation
def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
"""
Processes the given image by adjusting its orientation and resizing it.
Args:
img (Image.Image): The image to process.
new_width (int): The new width of the image.
new_height (int): The new height of the image.
Returns:
Image.Image: The processed image.
"""
orientation = get_orientation(img)
if orientation:
if orientation > 4:
@ -77,11 +142,32 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
return img
def to_base64(image: Image.Image, compression_rate: float) -> str:
"""
Converts the given image to a base64-encoded string.
Args:
image (Image.Image): The image to convert.
compression_rate (float): The compression rate (0.0 to 1.0).
Returns:
str: The base64-encoded image.
"""
output_buffer = BytesIO()
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode()
def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
"""
Formats the given images as a markdown string.
Args:
images: The images to format.
prompt (str): The prompt for the images.
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
Returns:
str: The formatted markdown string.
"""
if isinstance(images, list):
images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
images = "\n".join(images)
@ -92,6 +178,15 @@ def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=20
return f"\n{start_flag}{images}\n{end_flag}\n"
def to_bytes(image: Image.Image) -> bytes:
"""
Converts the given image to bytes.
Args:
image (Image.Image): The image to convert.
Returns:
bytes: The image as bytes.
"""
bytes_io = BytesIO()
image.save(bytes_io, image.format)
image.seek(0)

View File

@ -31,12 +31,21 @@ from .Provider import (
@dataclass(unsafe_hash=True)
class Model:
"""
Represents a machine learning model configuration.
Attributes:
name (str): Name of the model.
base_provider (str): Default provider for the model.
best_provider (ProviderType): The preferred provider for the model, typically with retry logic.
"""
name: str
base_provider: str
best_provider: ProviderType = None
@staticmethod
def __all__() -> list[str]:
"""Returns a list of all model names."""
return _all_models
default = Model(
@ -298,6 +307,12 @@ pi = Model(
)
class ModelUtils:
"""
Utility class for mapping string identifiers to Model instances.
Attributes:
convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances.
"""
convert: dict[str, Model] = {
# gpt-3.5
'gpt-3.5-turbo' : gpt_35_turbo,

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import json
from contextlib import asynccontextmanager
from functools import partialmethod
from typing import AsyncGenerator
from urllib.parse import urlparse
@ -9,27 +8,41 @@ from curl_cffi.requests import AsyncSession, Session, Response
from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies
class StreamResponse:
"""
A wrapper class for handling asynchronous streaming responses.
Attributes:
inner (Response): The original Response object.
"""
def __init__(self, inner: Response) -> None:
"""Initialize the StreamResponse with the provided Response object."""
self.inner: Response = inner
async def text(self) -> str:
"""Asynchronously get the response text."""
return await self.inner.atext()
def raise_for_status(self) -> None:
"""Raise an HTTPError if one occurred."""
self.inner.raise_for_status()
async def json(self, **kwargs) -> dict:
"""Asynchronously parse the JSON response content."""
return json.loads(await self.inner.acontent(), **kwargs)
async def iter_lines(self) -> AsyncGenerator[bytes, None]:
"""Asynchronously iterate over the lines of the response."""
async for line in self.inner.aiter_lines():
yield line
async def iter_content(self) -> AsyncGenerator[bytes, None]:
"""Asynchronously iterate over the response content."""
async for chunk in self.inner.aiter_content():
yield chunk
async def __aenter__(self):
"""Asynchronously enter the runtime context for the response object."""
inner: Response = await self.inner
self.inner = inner
self.request = inner.request
@ -41,14 +54,24 @@ class StreamResponse:
return self
async def __aexit__(self, *args):
"""Asynchronously exit the runtime context for the response object."""
await self.inner.aclose()
class StreamSession(AsyncSession):
"""
An asynchronous session class for handling HTTP requests with streaming.
Inherits from AsyncSession.
"""
def request(
self, method: str, url: str, **kwargs
) -> StreamResponse:
"""Create and return a StreamResponse object for the given HTTP request."""
return StreamResponse(super().request(method, url, stream=True, **kwargs))
# Defining HTTP methods as partial methods of the request method.
head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
post = partialmethod(request, "POST")
@ -56,7 +79,20 @@ class StreamSession(AsyncSession):
patch = partialmethod(request, "PATCH")
delete = partialmethod(request, "DELETE")
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120):
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session:
"""
Create a Session object using a WebDriver to handle cookies and headers.
Args:
url (str): The URL to navigate to using the WebDriver.
webdriver (WebDriver, optional): The WebDriver instance to use.
proxy (str, optional): Proxy server to use for the Session.
timeout (int, optional): Timeout in seconds for the WebDriver.
Returns:
Session: A Session object configured with cookies and headers from the WebDriver.
"""
with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=True) as driver:
bypass_cloudflare(driver, url, timeout)
cookies = get_driver_cookies(driver)

View File

@ -5,41 +5,116 @@ from importlib.metadata import version as get_package_version, PackageNotFoundEr
from subprocess import check_output, CalledProcessError, PIPE
from .errors import VersionNotFoundError
def get_latest_version() -> str:
try:
get_package_version("g4f")
response = requests.get("https://pypi.org/pypi/g4f/json").json()
return response["info"]["version"]
except PackageNotFoundError:
url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest"
response = requests.get(url).json()
return response["tag_name"]
def get_pypi_version(package_name: str) -> str:
"""
Retrieves the latest version of a package from PyPI.
class VersionUtils():
Args:
package_name (str): The name of the package for which to retrieve the version.
Returns:
str: The latest version of the specified package from PyPI.
Raises:
VersionNotFoundError: If there is an error in fetching the version from PyPI.
"""
try:
response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
return response["info"]["version"]
except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get PyPI version: {e}")
def get_github_version(repo: str) -> str:
"""
Retrieves the latest release version from a GitHub repository.
Args:
repo (str): The name of the GitHub repository.
Returns:
str: The latest release version from the specified GitHub repository.
Raises:
VersionNotFoundError: If there is an error in fetching the version from GitHub.
"""
try:
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
return response["tag_name"]
except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
def get_latest_version() -> str:
"""
Retrieves the latest release version of the 'g4f' package from PyPI or GitHub.
Returns:
str: The latest release version of 'g4f'.
Note:
The function first tries to fetch the version from PyPI. If the package is not found,
it retrieves the version from the GitHub repository.
"""
try:
# Is installed via package manager?
get_package_version("g4f")
return get_pypi_version("g4f")
except PackageNotFoundError:
# Else use Github version:
return get_github_version("xtekky/gpt4free")
class VersionUtils:
"""
Utility class for managing and comparing package versions of 'g4f'.
"""
@cached_property
def current_version(self) -> str:
"""
Retrieves the current version of the 'g4f' package.
Returns:
str: The current version of 'g4f'.
Raises:
VersionNotFoundError: If the version cannot be determined from the package manager,
Docker environment, or git repository.
"""
# Read from package manager
try:
return get_package_version("g4f")
except PackageNotFoundError:
pass
# Read from docker environment
version = environ.get("G4F_VERSION")
if version:
return version
# Read from git repository
try:
command = ["git", "describe", "--tags", "--abbrev=0"]
return check_output(command, text=True, stderr=PIPE).strip()
except CalledProcessError:
pass
raise VersionNotFoundError("Version not found")
@cached_property
def latest_version(self) -> str:
"""
Retrieves the latest version of the 'g4f' package.
Returns:
str: The latest version of 'g4f'.
"""
return get_latest_version()
def check_version(self) -> None:
"""
Checks if the current version of 'g4f' is up to date with the latest version.
Note:
If a newer version is available, it prints a message with the new version and update instructions.
"""
try:
if self.current_version != self.latest_version:
print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f')

View File

@ -1,5 +1,4 @@
from __future__ import annotations
from platformdirs import user_config_dir
from selenium.webdriver.remote.webdriver import WebDriver
from undetected_chromedriver import Chrome, ChromeOptions
@ -21,7 +20,19 @@ def get_browser(
proxy: str = None,
options: ChromeOptions = None
) -> WebDriver:
if user_data_dir == None:
"""
Creates and returns a Chrome WebDriver with specified options.
Args:
user_data_dir (str, optional): Directory for user data. If None, uses default directory.
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
proxy (str, optional): Proxy settings for the browser. Defaults to None.
options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None.
Returns:
WebDriver: An instance of WebDriver configured with the specified options.
"""
if user_data_dir is None:
user_data_dir = user_config_dir("g4f")
if user_data_dir and debug.logging:
print("Open browser with config dir:", user_data_dir)
@ -39,36 +50,53 @@ def get_browser(
headless=headless
)
def get_driver_cookies(driver: WebDriver):
return dict([(cookie["name"], cookie["value"]) for cookie in driver.get_cookies()])
def get_driver_cookies(driver: WebDriver) -> dict:
"""
Retrieves cookies from the specified WebDriver.
Args:
driver (WebDriver): The WebDriver instance from which to retrieve cookies.
Returns:
dict: A dictionary containing cookies with their names as keys and values as cookie values.
"""
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
# Open website
"""
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
Args:
driver (WebDriver): The WebDriver to use for accessing the URL.
url (str): The URL to access.
timeout (int): Time in seconds to wait for the page to load.
Raises:
Exception: If there is an error while bypassing Cloudflare or loading the page.
"""
driver.get(url)
# Is cloudflare protection
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
if debug.logging:
print("Cloudflare protection detected:", url)
try:
# Click button in iframe
WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
)
driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
WebDriverWait(driver, 5).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input"))
)
driver.find_element(By.CSS_SELECTOR, "#challenge-stage input").click()
except:
pass
).click()
except Exception as e:
if debug.logging:
print(f"Error bypassing Cloudflare: {e}")
finally:
driver.switch_to.default_content()
# No cloudflare protection
WebDriverWait(driver, timeout).until(
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
)
class WebDriverSession():
class WebDriverSession:
"""
Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
"""
def __init__(
self,
webdriver: WebDriver = None,
@ -78,12 +106,21 @@ class WebDriverSession():
proxy: str = None,
options: ChromeOptions = None
):
"""
Initializes a new instance of the WebDriverSession.
Args:
webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None.
user_data_dir (str, optional): Directory for user data. Defaults to None.
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
virtual_display (bool, optional): Whether to use a virtual display. Defaults to False.
proxy (str, optional): Proxy settings for the browser. Defaults to None.
options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None.
"""
self.webdriver = webdriver
self.user_data_dir = user_data_dir
self.headless = headless
self.virtual_display = None
if has_pyvirtualdisplay and virtual_display:
self.virtual_display = Display(size=(1920, 1080))
self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None
self.proxy = proxy
self.options = options
self.default_driver = None
@ -94,8 +131,18 @@ class WebDriverSession():
headless: bool = False,
virtual_display: bool = False
) -> WebDriver:
if user_data_dir == None:
user_data_dir = self.user_data_dir
"""
Reopens the WebDriver session with new settings.
Args:
user_data_dir (str, optional): Directory for user data. Defaults to current value.
headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value.
virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value.
Returns:
WebDriver: The reopened WebDriver instance.
"""
user_data_dir = user_data_data_dir or self.user_data_dir
if self.default_driver:
self.default_driver.quit()
if not virtual_display and self.virtual_display:
@ -105,6 +152,12 @@ class WebDriverSession():
return self.default_driver
def __enter__(self) -> WebDriver:
"""
Context management method for entering a session. Initializes and returns a WebDriver instance.
Returns:
WebDriver: An instance of WebDriver for this session.
"""
if self.webdriver:
return self.webdriver
if self.virtual_display:
@ -113,11 +166,23 @@ class WebDriverSession():
return self.default_driver
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Context management method for exiting a session. Closes and quits the WebDriver.
Args:
exc_type: Exception type.
exc_val: Exception value.
exc_tb: Exception traceback.
Note:
Closes the WebDriver and stops the virtual display if used.
"""
if self.default_driver:
try:
self.default_driver.close()
except:
pass
except Exception as e:
if debug.logging:
print(f"Error closing WebDriver: {e}")
self.default_driver.quit()
if self.virtual_display:
self.virtual_display.stop()