mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-12-27 21:21:41 +03:00
commit
1ca80ed48b
19
.github/workflows/unittest.yml
vendored
Normal file
19
.github/workflows/unittest.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: Unittest
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build unittest
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.x"
|
||||
cache: 'pip'
|
||||
- name: Install requirements
|
||||
- run: pip install -r requirements.txt
|
||||
- name: Run tests
|
||||
run: python -m etc.unittest.main
|
73
etc/unittest/main.py
Normal file
73
etc/unittest/main.py
Normal file
@ -0,0 +1,73 @@
|
||||
import sys
|
||||
import pathlib
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
||||
|
||||
import g4f
|
||||
from g4f import ChatCompletion, get_last_provider
|
||||
from g4f.gui.server.backend import Backend_Api, get_error_message
|
||||
from g4f.base_provider import BaseProvider
|
||||
|
||||
g4f.debug.logging = False
|
||||
|
||||
class MockProvider(BaseProvider):
|
||||
working = True
|
||||
|
||||
def create_completion(
|
||||
model, messages, stream, **kwargs
|
||||
):
|
||||
yield "Mock"
|
||||
|
||||
async def create_async(
|
||||
model, messages, **kwargs
|
||||
):
|
||||
return "Mock"
|
||||
|
||||
class TestBackendApi(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.app = MagicMock()
|
||||
self.api = Backend_Api(self.app)
|
||||
|
||||
def test_version(self):
|
||||
response = self.api.get_version()
|
||||
self.assertIn("version", response)
|
||||
self.assertIn("latest_version", response)
|
||||
|
||||
class TestChatCompletion(unittest.TestCase):
|
||||
|
||||
def test_create(self):
|
||||
messages = [{'role': 'user', 'content': 'Hello'}]
|
||||
result = ChatCompletion.create(g4f.models.default, messages)
|
||||
self.assertTrue("Hello" in result or "Good" in result)
|
||||
|
||||
def test_get_last_provider(self):
|
||||
messages = [{'role': 'user', 'content': 'Hello'}]
|
||||
ChatCompletion.create(g4f.models.default, messages, MockProvider)
|
||||
self.assertEqual(get_last_provider(), MockProvider)
|
||||
|
||||
def test_bing_provider(self):
|
||||
messages = [{'role': 'user', 'content': 'Hello'}]
|
||||
provider = g4f.Provider.Bing
|
||||
result = ChatCompletion.create(g4f.models.default, messages, provider)
|
||||
self.assertTrue("Bing" in result)
|
||||
|
||||
class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_async(self):
|
||||
messages = [{'role': 'user', 'content': 'Hello'}]
|
||||
result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider)
|
||||
self.assertTrue("Mock" in result)
|
||||
|
||||
class TestUtilityFunctions(unittest.TestCase):
|
||||
|
||||
def test_get_error_message(self):
|
||||
g4f.debug.last_provider = g4f.Provider.Bing
|
||||
exception = Exception("Message")
|
||||
result = get_error_message(exception)
|
||||
self.assertEqual("Bing: Exception: Message", result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -15,12 +15,18 @@ from .bing.upload_image import upload_image
|
||||
from .bing.create_images import create_images
|
||||
from .bing.conversation import Conversation, create_conversation, delete_conversation
|
||||
|
||||
class Tones():
|
||||
class Tones:
|
||||
"""
|
||||
Defines the different tone options for the Bing provider.
|
||||
"""
|
||||
creative = "Creative"
|
||||
balanced = "Balanced"
|
||||
precise = "Precise"
|
||||
|
||||
class Bing(AsyncGeneratorProvider):
|
||||
"""
|
||||
Bing provider for generating responses using the Bing API.
|
||||
"""
|
||||
url = "https://bing.com/chat"
|
||||
working = True
|
||||
supports_message_history = True
|
||||
@ -38,6 +44,19 @@ class Bing(AsyncGeneratorProvider):
|
||||
web_search: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
"""
|
||||
Creates an asynchronous generator for producing responses from Bing.
|
||||
|
||||
:param model: The model to use.
|
||||
:param messages: Messages to process.
|
||||
:param proxy: Proxy to use for requests.
|
||||
:param timeout: Timeout for requests.
|
||||
:param cookies: Cookies for the session.
|
||||
:param tone: The tone of the response.
|
||||
:param image: The image type to be used.
|
||||
:param web_search: Flag to enable or disable web search.
|
||||
:return: An asynchronous result object.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
prompt = messages[0]["content"]
|
||||
context = None
|
||||
@ -56,65 +75,48 @@ class Bing(AsyncGeneratorProvider):
|
||||
|
||||
return stream_generate(prompt, tone, image, context, proxy, cookies, web_search, gpt4_turbo, timeout)
|
||||
|
||||
def create_context(messages: Messages):
|
||||
def create_context(messages: Messages) -> str:
|
||||
"""
|
||||
Creates a context string from a list of messages.
|
||||
|
||||
:param messages: A list of message dictionaries.
|
||||
:return: A string representing the context created from the messages.
|
||||
"""
|
||||
return "".join(
|
||||
f"[{message['role']}]" + ("(#message)" if message['role']!="system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
|
||||
f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
|
||||
for message in messages
|
||||
)
|
||||
|
||||
class Defaults:
|
||||
"""
|
||||
Default settings and configurations for the Bing provider.
|
||||
"""
|
||||
delimiter = "\x1e"
|
||||
ip_address = f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
||||
|
||||
# List of allowed message types for Bing responses
|
||||
allowedMessageTypes = [
|
||||
"ActionRequest",
|
||||
"Chat",
|
||||
"Context",
|
||||
# "Disengaged", unwanted
|
||||
"Progress",
|
||||
# "AdsQuery", unwanted
|
||||
"SemanticSerp",
|
||||
"GenerateContentQuery",
|
||||
"SearchQuery",
|
||||
# The following message types should not be added so that it does not flood with
|
||||
# useless messages (such as "Analyzing images" or "Searching the web") while it's retrieving the AI response
|
||||
# "InternalSearchQuery",
|
||||
# "InternalSearchResult",
|
||||
"RenderCardRequest",
|
||||
# "RenderContentRequest"
|
||||
"ActionRequest", "Chat", "Context", "Progress", "SemanticSerp",
|
||||
"GenerateContentQuery", "SearchQuery", "RenderCardRequest"
|
||||
]
|
||||
|
||||
sliceIds = [
|
||||
'abv2',
|
||||
'srdicton',
|
||||
'convcssclick',
|
||||
'stylewv2',
|
||||
'contctxp2tf',
|
||||
'802fluxv1pc_a',
|
||||
'806log2sphs0',
|
||||
'727savemem',
|
||||
'277teditgnds0',
|
||||
'207hlthgrds0',
|
||||
'abv2', 'srdicton', 'convcssclick', 'stylewv2', 'contctxp2tf',
|
||||
'802fluxv1pc_a', '806log2sphs0', '727savemem', '277teditgnds0', '207hlthgrds0'
|
||||
]
|
||||
|
||||
# Default location settings
|
||||
location = {
|
||||
"locale": "en-US",
|
||||
"market": "en-US",
|
||||
"region": "US",
|
||||
"locationHints": [
|
||||
{
|
||||
"country": "United States",
|
||||
"state": "California",
|
||||
"city": "Los Angeles",
|
||||
"timezoneoffset": 8,
|
||||
"countryConfidence": 8,
|
||||
"Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
|
||||
"RegionType": 2,
|
||||
"SourceType": 1,
|
||||
}
|
||||
],
|
||||
"locale": "en-US", "market": "en-US", "region": "US",
|
||||
"locationHints": [{
|
||||
"country": "United States", "state": "California", "city": "Los Angeles",
|
||||
"timezoneoffset": 8, "countryConfidence": 8,
|
||||
"Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
|
||||
"RegionType": 2, "SourceType": 1
|
||||
}],
|
||||
}
|
||||
|
||||
# Default headers for requests
|
||||
headers = {
|
||||
'accept': '*/*',
|
||||
'accept-language': 'en-US,en;q=0.9',
|
||||
@ -139,23 +141,13 @@ class Defaults:
|
||||
}
|
||||
|
||||
optionsSets = [
|
||||
'nlu_direct_response_filter',
|
||||
'deepleo',
|
||||
'disable_emoji_spoken_text',
|
||||
'responsible_ai_policy_235',
|
||||
'enablemm',
|
||||
'iyxapbing',
|
||||
'iycapbing',
|
||||
'gencontentv3',
|
||||
'fluxsrtrunc',
|
||||
'fluxtrunc',
|
||||
'fluxv1',
|
||||
'rai278',
|
||||
'replaceurl',
|
||||
'eredirecturl',
|
||||
'nojbfedge'
|
||||
'nlu_direct_response_filter', 'deepleo', 'disable_emoji_spoken_text',
|
||||
'responsible_ai_policy_235', 'enablemm', 'iyxapbing', 'iycapbing',
|
||||
'gencontentv3', 'fluxsrtrunc', 'fluxtrunc', 'fluxv1', 'rai278',
|
||||
'replaceurl', 'eredirecturl', 'nojbfedge'
|
||||
]
|
||||
|
||||
# Default cookies
|
||||
cookies = {
|
||||
'SRCHD' : 'AF=NOFORM',
|
||||
'PPLState' : '1',
|
||||
@ -166,6 +158,12 @@ class Defaults:
|
||||
}
|
||||
|
||||
def format_message(msg: dict) -> str:
|
||||
"""
|
||||
Formats a message dictionary into a JSON string with a delimiter.
|
||||
|
||||
:param msg: The message dictionary to format.
|
||||
:return: A formatted string representation of the message.
|
||||
"""
|
||||
return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter
|
||||
|
||||
def create_message(
|
||||
@ -177,7 +175,20 @@ def create_message(
|
||||
web_search: bool = False,
|
||||
gpt4_turbo: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Creates a message for the Bing API with specified parameters.
|
||||
|
||||
:param conversation: The current conversation object.
|
||||
:param prompt: The user's input prompt.
|
||||
:param tone: The desired tone for the response.
|
||||
:param context: Additional context for the prompt.
|
||||
:param image_response: The response if an image is involved.
|
||||
:param web_search: Flag to enable web search.
|
||||
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
|
||||
:return: A formatted string message for the Bing API.
|
||||
"""
|
||||
options_sets = Defaults.optionsSets
|
||||
# Append tone-specific options
|
||||
if tone == Tones.creative:
|
||||
options_sets.append("h3imaginative")
|
||||
elif tone == Tones.precise:
|
||||
@ -186,54 +197,49 @@ def create_message(
|
||||
options_sets.append("galileo")
|
||||
else:
|
||||
options_sets.append("harmonyv3")
|
||||
|
||||
|
||||
# Additional configurations based on parameters
|
||||
if not web_search:
|
||||
options_sets.append("nosearchall")
|
||||
|
||||
if gpt4_turbo:
|
||||
options_sets.append("dlgpt4t")
|
||||
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
struct = {
|
||||
'arguments': [
|
||||
{
|
||||
'source': 'cib',
|
||||
'optionsSets': options_sets,
|
||||
'allowedMessageTypes': Defaults.allowedMessageTypes,
|
||||
'sliceIds': Defaults.sliceIds,
|
||||
'traceId': os.urandom(16).hex(),
|
||||
'isStartOfSession': True,
|
||||
'arguments': [{
|
||||
'source': 'cib', 'optionsSets': options_sets,
|
||||
'allowedMessageTypes': Defaults.allowedMessageTypes,
|
||||
'sliceIds': Defaults.sliceIds,
|
||||
'traceId': os.urandom(16).hex(), 'isStartOfSession': True,
|
||||
'requestId': request_id,
|
||||
'message': {
|
||||
**Defaults.location,
|
||||
'author': 'user',
|
||||
'inputMethod': 'Keyboard',
|
||||
'text': prompt,
|
||||
'messageType': 'Chat',
|
||||
'requestId': request_id,
|
||||
'message': {**Defaults.location, **{
|
||||
'author': 'user',
|
||||
'inputMethod': 'Keyboard',
|
||||
'text': prompt,
|
||||
'messageType': 'Chat',
|
||||
'requestId': request_id,
|
||||
'messageId': request_id,
|
||||
}},
|
||||
"verbosity": "verbose",
|
||||
"scenario": "SERP",
|
||||
"plugins":[
|
||||
{"id":"c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}
|
||||
] if web_search else [],
|
||||
'tone': tone,
|
||||
'spokenTextMode': 'None',
|
||||
'conversationId': conversation.conversationId,
|
||||
'participant': {
|
||||
'id': conversation.clientId
|
||||
},
|
||||
}
|
||||
],
|
||||
'messageId': request_id
|
||||
},
|
||||
"verbosity": "verbose",
|
||||
"scenario": "SERP",
|
||||
"plugins": [{"id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}] if web_search else [],
|
||||
'tone': tone,
|
||||
'spokenTextMode': 'None',
|
||||
'conversationId': conversation.conversationId,
|
||||
'participant': {'id': conversation.clientId},
|
||||
}],
|
||||
'invocationId': '1',
|
||||
'target': 'chat',
|
||||
'type': 4
|
||||
}
|
||||
if image_response.get('imageUrl') and image_response.get('originalImageUrl'):
|
||||
|
||||
if image_response and image_response.get('imageUrl') and image_response.get('originalImageUrl'):
|
||||
struct['arguments'][0]['message']['originalImageUrl'] = image_response.get('originalImageUrl')
|
||||
struct['arguments'][0]['message']['imageUrl'] = image_response.get('imageUrl')
|
||||
struct['arguments'][0]['experienceType'] = None
|
||||
struct['arguments'][0]['attachedFileInfo'] = {"fileName": None, "fileType": None}
|
||||
|
||||
if context:
|
||||
struct['arguments'][0]['previousMessages'] = [{
|
||||
"author": "user",
|
||||
@ -242,30 +248,46 @@ def create_message(
|
||||
"messageType": "Context",
|
||||
"messageId": "discover-web--page-ping-mriduna-----"
|
||||
}]
|
||||
|
||||
return format_message(struct)
|
||||
|
||||
async def stream_generate(
|
||||
prompt: str,
|
||||
tone: str,
|
||||
image: ImageType = None,
|
||||
context: str = None,
|
||||
proxy: str = None,
|
||||
cookies: dict = None,
|
||||
web_search: bool = False,
|
||||
gpt4_turbo: bool = False,
|
||||
timeout: int = 900
|
||||
):
|
||||
prompt: str,
|
||||
tone: str,
|
||||
image: ImageType = None,
|
||||
context: str = None,
|
||||
proxy: str = None,
|
||||
cookies: dict = None,
|
||||
web_search: bool = False,
|
||||
gpt4_turbo: bool = False,
|
||||
timeout: int = 900
|
||||
):
|
||||
"""
|
||||
Asynchronously streams generated responses from the Bing API.
|
||||
|
||||
:param prompt: The user's input prompt.
|
||||
:param tone: The desired tone for the response.
|
||||
:param image: The image type involved in the response.
|
||||
:param context: Additional context for the prompt.
|
||||
:param proxy: Proxy settings for the request.
|
||||
:param cookies: Cookies for the session.
|
||||
:param web_search: Flag to enable web search.
|
||||
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
|
||||
:param timeout: Timeout for the request.
|
||||
:return: An asynchronous generator yielding responses.
|
||||
"""
|
||||
headers = Defaults.headers
|
||||
if cookies:
|
||||
headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
|
||||
|
||||
async with ClientSession(
|
||||
timeout=ClientTimeout(total=timeout),
|
||||
headers=headers
|
||||
timeout=ClientTimeout(total=timeout), headers=headers
|
||||
) as session:
|
||||
conversation = await create_conversation(session, proxy)
|
||||
image_response = await upload_image(session, image, tone, proxy) if image else None
|
||||
if image_response:
|
||||
yield image_response
|
||||
|
||||
try:
|
||||
async with session.ws_connect(
|
||||
'wss://sydney.bing.com/sydney/ChatHub',
|
||||
@ -289,7 +311,7 @@ async def stream_generate(
|
||||
if obj is None or not obj:
|
||||
continue
|
||||
response = json.loads(obj)
|
||||
if response.get('type') == 1 and response['arguments'][0].get('messages'):
|
||||
if response and response.get('type') == 1 and response['arguments'][0].get('messages'):
|
||||
message = response['arguments'][0]['messages'][0]
|
||||
image_response = None
|
||||
if (message['contentOrigin'] != 'Apology'):
|
||||
|
@ -1,16 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import json, random
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from .base_provider import AsyncGeneratorProvider
|
||||
|
||||
|
||||
models = {
|
||||
"claude-v2": "claude-2.0",
|
||||
"gemini-pro": "google-gemini-pro"
|
||||
"claude-v2": "claude-2.0",
|
||||
"claude-v2.1":"claude-2.1",
|
||||
"gemini-pro": "google-gemini-pro"
|
||||
}
|
||||
urls = [
|
||||
"https://free.chatgpt.org.uk",
|
||||
"https://ai.chatgpt.org.uk"
|
||||
]
|
||||
|
||||
class FreeChatgpt(AsyncGeneratorProvider):
|
||||
url = "https://free.chatgpt.org.uk"
|
||||
@ -31,6 +35,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
|
||||
model = models[model]
|
||||
elif not model:
|
||||
model = "gpt-3.5-turbo"
|
||||
url = random.choice(urls)
|
||||
headers = {
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type":"application/json",
|
||||
@ -55,7 +60,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
|
||||
"top_p":1,
|
||||
**kwargs
|
||||
}
|
||||
async with session.post(f'{cls.url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
|
||||
async with session.post(f'{url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
started = False
|
||||
async for line in response.content:
|
||||
|
@ -59,12 +59,16 @@ class Phind(AsyncGeneratorProvider):
|
||||
"rewrittenQuestion": prompt,
|
||||
"challenge": 0.21132115912208504
|
||||
}
|
||||
async with session.post(f"{cls.url}/api/infer/followup/answer", headers=headers, json=data) as response:
|
||||
async with session.post(f"https://https.api.phind.com/infer/", headers=headers, json=data) as response:
|
||||
new_line = False
|
||||
async for line in response.iter_lines():
|
||||
if line.startswith(b"data: "):
|
||||
chunk = line[6:]
|
||||
if chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
|
||||
if chunk.startswith(b'<PHIND_DONE/>'):
|
||||
break
|
||||
if chunk.startswith(b'<PHIND_WEBRESULTS>') or chunk.startswith(b'<PHIND_FOLLOWUP>'):
|
||||
pass
|
||||
elif chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
|
||||
pass
|
||||
elif chunk:
|
||||
yield chunk.decode()
|
||||
|
@ -1,28 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from abc import abstractmethod
|
||||
from inspect import signature, Parameter
|
||||
from .helper import get_event_loop, get_cookies, format_prompt
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from ..base_provider import BaseProvider
|
||||
from abc import abstractmethod
|
||||
from inspect import signature, Parameter
|
||||
from .helper import get_event_loop, get_cookies, format_prompt
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from ..base_provider import BaseProvider
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
NoneType = type(None)
|
||||
else:
|
||||
from types import NoneType
|
||||
|
||||
# Change event loop policy on windows for curl_cffi
|
||||
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
|
||||
if sys.platform == 'win32':
|
||||
if isinstance(
|
||||
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
|
||||
):
|
||||
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
"""
|
||||
Abstract class for providing asynchronous functionality to derived classes.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls,
|
||||
@ -33,62 +34,67 @@ class AbstractProvider(BaseProvider):
|
||||
executor: ThreadPoolExecutor = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
if not loop:
|
||||
loop = get_event_loop()
|
||||
"""
|
||||
Asynchronously creates a result based on the given model and messages.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
|
||||
executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
loop = loop or get_event_loop()
|
||||
|
||||
def create_func() -> str:
|
||||
return "".join(cls.create_completion(
|
||||
model,
|
||||
messages,
|
||||
False,
|
||||
**kwargs
|
||||
))
|
||||
return "".join(cls.create_completion(model, messages, False, **kwargs))
|
||||
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
create_func
|
||||
),
|
||||
loop.run_in_executor(executor, create_func),
|
||||
timeout=kwargs.get("timeout", 0)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def params(cls) -> str:
|
||||
if issubclass(cls, AsyncGeneratorProvider):
|
||||
sig = signature(cls.create_async_generator)
|
||||
elif issubclass(cls, AsyncProvider):
|
||||
sig = signature(cls.create_async)
|
||||
else:
|
||||
sig = signature(cls.create_completion)
|
||||
"""
|
||||
Returns the parameters supported by the provider.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this property is called.
|
||||
|
||||
Returns:
|
||||
str: A string listing the supported parameters.
|
||||
"""
|
||||
sig = signature(
|
||||
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
|
||||
cls.create_async if issubclass(cls, AsyncProvider) else
|
||||
cls.create_completion
|
||||
)
|
||||
|
||||
def get_type_name(annotation: type) -> str:
|
||||
if hasattr(annotation, "__name__"):
|
||||
annotation = annotation.__name__
|
||||
elif isinstance(annotation, NoneType):
|
||||
annotation = "None"
|
||||
return str(annotation)
|
||||
|
||||
return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
|
||||
|
||||
args = ""
|
||||
for name, param in sig.parameters.items():
|
||||
if name in ("self", "kwargs"):
|
||||
if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
|
||||
continue
|
||||
if name == "stream" and not cls.supports_stream:
|
||||
continue
|
||||
if args:
|
||||
args += ", "
|
||||
args += "\n " + name
|
||||
if name != "model" and param.annotation is not Parameter.empty:
|
||||
args += f": {get_type_name(param.annotation)}"
|
||||
if param.default == "":
|
||||
args += ' = ""'
|
||||
elif param.default is not Parameter.empty:
|
||||
args += f" = {param.default}"
|
||||
args += f"\n {name}"
|
||||
args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
|
||||
args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
|
||||
|
||||
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
|
||||
|
||||
|
||||
class AsyncProvider(AbstractProvider):
|
||||
"""
|
||||
Provides asynchronous functionality for creating completions.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
@ -99,8 +105,21 @@ class AsyncProvider(AbstractProvider):
|
||||
loop: AbstractEventLoop = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
if not loop:
|
||||
loop = get_event_loop()
|
||||
"""
|
||||
Creates a completion result synchronously.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Indicates whether to stream the results. Defaults to False.
|
||||
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
loop = loop or get_event_loop()
|
||||
coro = cls.create_async(model, messages, **kwargs)
|
||||
yield loop.run_until_complete(coro)
|
||||
|
||||
@ -111,10 +130,27 @@ class AsyncProvider(AbstractProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Abstract method for creating asynchronous results.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not overridden in derived classes.
|
||||
|
||||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AsyncGeneratorProvider(AsyncProvider):
|
||||
"""
|
||||
Provides asynchronous generator functionality for streaming results.
|
||||
"""
|
||||
supports_stream = True
|
||||
|
||||
@classmethod
|
||||
@ -127,15 +163,24 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
loop: AbstractEventLoop = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
if not loop:
|
||||
loop = get_event_loop()
|
||||
generator = cls.create_async_generator(
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
**kwargs
|
||||
)
|
||||
"""
|
||||
Creates a streaming completion result synchronously.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Indicates whether to stream the results. Defaults to True.
|
||||
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the streaming completion creation.
|
||||
"""
|
||||
loop = loop or get_event_loop()
|
||||
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
gen = generator.__aiter__()
|
||||
|
||||
while True:
|
||||
try:
|
||||
yield loop.run_until_complete(gen.__anext__())
|
||||
@ -149,21 +194,44 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously creates a result from a generator.
|
||||
|
||||
Args:
|
||||
cls (type): The class on which this method is called.
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The created result as a string.
|
||||
"""
|
||||
return "".join([
|
||||
chunk async for chunk in cls.create_async_generator(
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
**kwargs
|
||||
) if not isinstance(chunk, Exception)
|
||||
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
|
||||
if not isinstance(chunk, Exception)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_async_generator(
|
||||
async def create_async_generator(
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
raise NotImplementedError()
|
||||
"""
|
||||
Abstract method for creating an asynchronous generator.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Indicates whether to stream the results. Defaults to True.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not overridden in derived classes.
|
||||
|
||||
Returns:
|
||||
AsyncResult: An asynchronous generator yielding results.
|
||||
"""
|
||||
raise NotImplementedError()
|
@ -1,13 +1,33 @@
|
||||
from aiohttp import ClientSession
|
||||
|
||||
|
||||
class Conversation():
|
||||
class Conversation:
|
||||
"""
|
||||
Represents a conversation with specific attributes.
|
||||
"""
|
||||
def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None:
|
||||
"""
|
||||
Initialize a new conversation instance.
|
||||
|
||||
Args:
|
||||
conversationId (str): Unique identifier for the conversation.
|
||||
clientId (str): Client identifier.
|
||||
conversationSignature (str): Signature for the conversation.
|
||||
"""
|
||||
self.conversationId = conversationId
|
||||
self.clientId = clientId
|
||||
self.conversationSignature = conversationSignature
|
||||
|
||||
async def create_conversation(session: ClientSession, proxy: str = None) -> Conversation:
|
||||
"""
|
||||
Create a new conversation asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
proxy (str, optional): Proxy URL. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Conversation: An instance representing the created conversation.
|
||||
"""
|
||||
url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1199.4'
|
||||
async with session.get(url, proxy=proxy) as response:
|
||||
try:
|
||||
@ -24,12 +44,32 @@ async def create_conversation(session: ClientSession, proxy: str = None) -> Conv
|
||||
return Conversation(conversationId, clientId, conversationSignature)
|
||||
|
||||
async def list_conversations(session: ClientSession) -> list:
|
||||
"""
|
||||
List all conversations asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
|
||||
Returns:
|
||||
list: A list of conversations.
|
||||
"""
|
||||
url = "https://www.bing.com/turing/conversation/chats"
|
||||
async with session.get(url) as response:
|
||||
response = await response.json()
|
||||
return response["chats"]
|
||||
|
||||
async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> bool:
|
||||
"""
|
||||
Delete a conversation asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
conversation (Conversation): The conversation to delete.
|
||||
proxy (str, optional): Proxy URL. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: True if deletion was successful, False otherwise.
|
||||
"""
|
||||
url = "https://sydney.bing.com/sydney/DeleteSingleConversation"
|
||||
json = {
|
||||
"conversationId": conversation.conversationId,
|
||||
|
@ -1,9 +1,16 @@
|
||||
"""
|
||||
This module provides functionalities for creating and managing images using Bing's service.
|
||||
It includes functions for user login, session creation, image creation, and processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time, json, os
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from aiohttp import ClientSession
|
||||
from bs4 import BeautifulSoup
|
||||
from urllib.parse import quote
|
||||
from typing import Generator
|
||||
from typing import Generator, List, Dict
|
||||
|
||||
from ..create_images import CreateImagesProvider
|
||||
from ..helper import get_cookies, get_event_loop
|
||||
@ -12,23 +19,47 @@ from ...base_provider import ProviderType
|
||||
from ...image import format_images_markdown
|
||||
|
||||
BING_URL = "https://www.bing.com"
|
||||
TIMEOUT_LOGIN = 1200
|
||||
TIMEOUT_IMAGE_CREATION = 300
|
||||
ERRORS = [
|
||||
"this prompt is being reviewed",
|
||||
"this prompt has been blocked",
|
||||
"we're working hard to offer image creator in more languages",
|
||||
"we can't create your images right now"
|
||||
]
|
||||
BAD_IMAGES = [
|
||||
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||
]
|
||||
|
||||
def wait_for_login(driver: WebDriver, timeout: int = 1200) -> None:
|
||||
def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None:
|
||||
"""
|
||||
Waits for the user to log in within a given timeout period.
|
||||
|
||||
Args:
|
||||
driver (WebDriver): Webdriver for browser automation.
|
||||
timeout (int): Maximum waiting time in seconds.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the login process exceeds the timeout.
|
||||
"""
|
||||
driver.get(f"{BING_URL}/")
|
||||
value = driver.get_cookie("_U")
|
||||
if value:
|
||||
return
|
||||
start_time = time.time()
|
||||
while True:
|
||||
while not driver.get_cookie("_U"):
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError("Timeout error")
|
||||
value = driver.get_cookie("_U")
|
||||
if value:
|
||||
time.sleep(1)
|
||||
return
|
||||
time.sleep(0.5)
|
||||
|
||||
def create_session(cookies: dict) -> ClientSession:
|
||||
def create_session(cookies: Dict[str, str]) -> ClientSession:
|
||||
"""
|
||||
Creates a new client session with specified cookies and headers.
|
||||
|
||||
Args:
|
||||
cookies (Dict[str, str]): Cookies to be used for the session.
|
||||
|
||||
Returns:
|
||||
ClientSession: The created client session.
|
||||
"""
|
||||
headers = {
|
||||
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"accept-encoding": "gzip, deflate, br",
|
||||
@ -47,28 +78,32 @@ def create_session(cookies: dict) -> ClientSession:
|
||||
"upgrade-insecure-requests": "1",
|
||||
}
|
||||
if cookies:
|
||||
headers["cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
|
||||
headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
|
||||
return ClientSession(headers=headers)
|
||||
|
||||
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = 300) -> list:
|
||||
url_encoded_prompt = quote(prompt)
|
||||
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = TIMEOUT_IMAGE_CREATION) -> List[str]:
|
||||
"""
|
||||
Creates images based on a given prompt using Bing's service.
|
||||
|
||||
Args:
|
||||
session (ClientSession): Active client session.
|
||||
prompt (str): Prompt to generate images.
|
||||
proxy (str, optional): Proxy configuration.
|
||||
timeout (int): Timeout for the request.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of URLs to the created images.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If image creation fails or times out.
|
||||
"""
|
||||
url_encoded_prompt = quote(prompt)
|
||||
payload = f"q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||
async with session.post(
|
||||
url,
|
||||
allow_redirects=False,
|
||||
data=payload,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
async with session.post(url, allow_redirects=False, data=payload, timeout=timeout) as response:
|
||||
response.raise_for_status()
|
||||
errors = [
|
||||
"this prompt is being reviewed",
|
||||
"this prompt has been blocked",
|
||||
"we're working hard to offer image creator in more languages",
|
||||
"we can't create your images right now"
|
||||
]
|
||||
text = (await response.text()).lower()
|
||||
for error in errors:
|
||||
for error in ERRORS:
|
||||
if error in text:
|
||||
raise RuntimeError(f"Create images failed: {error}")
|
||||
if response.status != 302:
|
||||
@ -107,54 +142,109 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
|
||||
raise RuntimeError(error)
|
||||
return read_images(text)
|
||||
|
||||
def read_images(text: str) -> list:
|
||||
html_soup = BeautifulSoup(text, "html.parser")
|
||||
tags = html_soup.find_all("img")
|
||||
image_links = [img["src"] for img in tags if "mimg" in img["class"]]
|
||||
images = [link.split("?w=")[0] for link in image_links]
|
||||
bad_images = [
|
||||
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||
]
|
||||
if any(im in bad_images for im in images):
|
||||
def read_images(html_content: str) -> List[str]:
|
||||
"""
|
||||
Extracts image URLs from the HTML content.
|
||||
|
||||
Args:
|
||||
html_content (str): HTML content containing image URLs.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of image URLs.
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
tags = soup.find_all("img", class_="mimg")
|
||||
images = [img["src"].split("?w=")[0] for img in tags]
|
||||
if any(im in BAD_IMAGES for im in images):
|
||||
raise RuntimeError("Bad images found")
|
||||
if not images:
|
||||
raise RuntimeError("No images found")
|
||||
return images
|
||||
|
||||
async def create_images_markdown(cookies: dict, prompt: str, proxy: str = None) -> str:
|
||||
session = create_session(cookies)
|
||||
try:
|
||||
async def create_images_markdown(cookies: Dict[str, str], prompt: str, proxy: str = None) -> str:
|
||||
"""
|
||||
Creates markdown formatted string with images based on the prompt.
|
||||
|
||||
Args:
|
||||
cookies (Dict[str, str]): Cookies to be used for the session.
|
||||
prompt (str): Prompt to generate images.
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Returns:
|
||||
str: Markdown formatted string with images.
|
||||
"""
|
||||
async with create_session(cookies) as session:
|
||||
images = await create_images(session, prompt, proxy)
|
||||
return format_images_markdown(images, prompt)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
def get_cookies_from_browser(proxy: str = None) -> dict:
|
||||
driver = get_browser(proxy=proxy)
|
||||
try:
|
||||
def get_cookies_from_browser(proxy: str = None) -> Dict[str, str]:
|
||||
"""
|
||||
Retrieves cookies from the browser using webdriver.
|
||||
|
||||
Args:
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Retrieved cookies.
|
||||
"""
|
||||
with get_browser(proxy=proxy) as driver:
|
||||
wait_for_login(driver)
|
||||
time.sleep(1)
|
||||
return get_driver_cookies(driver)
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
def create_completion(prompt: str, cookies: dict = None, proxy: str = None) -> Generator:
|
||||
loop = get_event_loop()
|
||||
if not cookies:
|
||||
cookies = get_cookies(".bing.com")
|
||||
if "_U" not in cookies:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"Please login: [Bing]({login_url})\n\n"
|
||||
cookies = get_cookies_from_browser(proxy)
|
||||
yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
|
||||
class CreateImagesBing:
|
||||
"""A class for creating images using Bing."""
|
||||
|
||||
async def create_async(prompt: str, cookies: dict = None, proxy: str = None) -> str:
|
||||
if not cookies:
|
||||
cookies = get_cookies(".bing.com")
|
||||
if "_U" not in cookies:
|
||||
cookies = get_cookies_from_browser(proxy)
|
||||
return await create_images_markdown(cookies, prompt, proxy)
|
||||
_cookies: Dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator for creating imagecompletion based on a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): Prompt to generate images.
|
||||
cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Yields:
|
||||
Generator[str, None, None]: The final output as markdown formatted string with images.
|
||||
"""
|
||||
loop = get_event_loop()
|
||||
cookies = cookies or cls._cookies or get_cookies(".bing.com")
|
||||
if "_U" not in cookies:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"Please login: [Bing]({login_url})\n\n"
|
||||
cls._cookies = cookies = get_cookies_from_browser(proxy)
|
||||
yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
|
||||
|
||||
@classmethod
|
||||
async def create_async(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> str:
|
||||
"""
|
||||
Asynchronously creates a markdown formatted string with images based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): Prompt to generate images.
|
||||
cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Returns:
|
||||
str: Markdown formatted string with images.
|
||||
"""
|
||||
cookies = cookies or cls._cookies or get_cookies(".bing.com")
|
||||
if "_U" not in cookies:
|
||||
cls._cookies = cookies = get_cookies_from_browser(proxy)
|
||||
return await create_images_markdown(cookies, prompt, proxy)
|
||||
|
||||
def patch_provider(provider: ProviderType) -> CreateImagesProvider:
|
||||
return CreateImagesProvider(provider, create_completion, create_async)
|
||||
"""
|
||||
Patches a provider to include image creation capabilities.
|
||||
|
||||
Args:
|
||||
provider (ProviderType): The provider to be patched.
|
||||
|
||||
Returns:
|
||||
CreateImagesProvider: The patched provider with image creation capabilities.
|
||||
"""
|
||||
return CreateImagesProvider(provider, CreateImagesBing.create_completion, CreateImagesBing.create_async)
|
@ -1,64 +1,107 @@
|
||||
from __future__ import annotations
|
||||
"""
|
||||
Module to handle image uploading and processing for Bing AI integrations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import string
|
||||
import random
|
||||
import json
|
||||
import math
|
||||
from ...typing import ImageType
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image
|
||||
|
||||
from ...typing import ImageType, Tuple
|
||||
from ...image import to_image, process_image, to_base64, ImageResponse
|
||||
|
||||
image_config = {
|
||||
IMAGE_CONFIG = {
|
||||
"maxImagePixels": 360000,
|
||||
"imageCompressionRate": 0.7,
|
||||
"enableFaceBlurDebug": 0,
|
||||
"enableFaceBlurDebug": False,
|
||||
}
|
||||
|
||||
async def upload_image(
|
||||
session: ClientSession,
|
||||
image: ImageType,
|
||||
tone: str,
|
||||
session: ClientSession,
|
||||
image_data: ImageType,
|
||||
tone: str,
|
||||
proxy: str = None
|
||||
) -> ImageResponse:
|
||||
image = to_image(image)
|
||||
width, height = image.size
|
||||
max_image_pixels = image_config['maxImagePixels']
|
||||
if max_image_pixels / (width * height) < 1:
|
||||
new_width = int(width * math.sqrt(max_image_pixels / (width * height)))
|
||||
new_height = int(height * math.sqrt(max_image_pixels / (width * height)))
|
||||
else:
|
||||
new_width = width
|
||||
new_height = height
|
||||
new_img = process_image(image, new_width, new_height)
|
||||
new_img_binary_data = to_base64(new_img, image_config['imageCompressionRate'])
|
||||
data, boundary = build_image_upload_api_payload(new_img_binary_data, tone)
|
||||
headers = session.headers.copy()
|
||||
headers["content-type"] = f'multipart/form-data; boundary={boundary}'
|
||||
headers["referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
|
||||
headers["origin"] = 'https://www.bing.com'
|
||||
"""
|
||||
Uploads an image to Bing's AI service and returns the image response.
|
||||
|
||||
Args:
|
||||
session (ClientSession): The active session.
|
||||
image_data (bytes): The image data to be uploaded.
|
||||
tone (str): The tone of the conversation.
|
||||
proxy (str, optional): Proxy if any. Defaults to None.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the image upload fails.
|
||||
|
||||
Returns:
|
||||
ImageResponse: The response from the image upload.
|
||||
"""
|
||||
image = to_image(image_data)
|
||||
new_width, new_height = calculate_new_dimensions(image)
|
||||
processed_img = process_image(image, new_width, new_height)
|
||||
img_binary_data = to_base64(processed_img, IMAGE_CONFIG['imageCompressionRate'])
|
||||
|
||||
data, boundary = build_image_upload_payload(img_binary_data, tone)
|
||||
headers = prepare_headers(session, boundary)
|
||||
|
||||
async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response:
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Failed to upload image.")
|
||||
image_info = await response.json()
|
||||
if not image_info.get('blobId'):
|
||||
raise RuntimeError("Failed to parse image info.")
|
||||
result = {'bcid': image_info.get('blobId', "")}
|
||||
result['blurredBcid'] = image_info.get('processedBlobId', "")
|
||||
if result['blurredBcid'] != "":
|
||||
result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['blurredBcid']
|
||||
elif result['bcid'] != "":
|
||||
result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['bcid']
|
||||
result['originalImageUrl'] = (
|
||||
"https://www.bing.com/images/blob?bcid="
|
||||
+ result['blurredBcid']
|
||||
if image_config["enableFaceBlurDebug"]
|
||||
else "https://www.bing.com/images/blob?bcid="
|
||||
+ result['bcid']
|
||||
)
|
||||
return ImageResponse(result["imageUrl"], "", result)
|
||||
return parse_image_response(await response.json())
|
||||
|
||||
def build_image_upload_api_payload(image_bin: str, tone: str):
|
||||
payload = {
|
||||
def calculate_new_dimensions(image: Image.Image) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculates the new dimensions for the image based on the maximum allowed pixels.
|
||||
|
||||
Args:
|
||||
image (Image): The PIL Image object.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The new width and height for the image.
|
||||
"""
|
||||
width, height = image.size
|
||||
max_image_pixels = IMAGE_CONFIG['maxImagePixels']
|
||||
if max_image_pixels / (width * height) < 1:
|
||||
scale_factor = math.sqrt(max_image_pixels / (width * height))
|
||||
return int(width * scale_factor), int(height * scale_factor)
|
||||
return width, height
|
||||
|
||||
def build_image_upload_payload(image_bin: str, tone: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Builds the payload for image uploading.
|
||||
|
||||
Args:
|
||||
image_bin (str): Base64 encoded image binary data.
|
||||
tone (str): The tone of the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: The data and boundary for the payload.
|
||||
"""
|
||||
boundary = "----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
|
||||
data = f"--{boundary}\r\n" \
|
||||
f"Content-Disposition: form-data; name=\"knowledgeRequest\"\r\n\r\n" \
|
||||
f"{json.dumps(build_knowledge_request(tone), ensure_ascii=False)}\r\n" \
|
||||
f"--{boundary}\r\n" \
|
||||
f"Content-Disposition: form-data; name=\"imageBase64\"\r\n\r\n" \
|
||||
f"{image_bin}\r\n" \
|
||||
f"--{boundary}--\r\n"
|
||||
return data, boundary
|
||||
|
||||
def build_knowledge_request(tone: str) -> dict:
|
||||
"""
|
||||
Builds the knowledge request payload.
|
||||
|
||||
Args:
|
||||
tone (str): The tone of the conversation.
|
||||
|
||||
Returns:
|
||||
dict: The knowledge request payload.
|
||||
"""
|
||||
return {
|
||||
'invokedSkills': ["ImageById"],
|
||||
'subscriptionId': "Bing.Chat.Multimodal",
|
||||
'invokedSkillsRequestData': {
|
||||
@ -69,21 +112,46 @@ def build_image_upload_api_payload(image_bin: str, tone: str):
|
||||
'convotone': tone
|
||||
}
|
||||
}
|
||||
knowledge_request = {
|
||||
'imageInfo': {},
|
||||
'knowledgeRequest': payload
|
||||
}
|
||||
boundary="----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
|
||||
data = (
|
||||
f'--{boundary}'
|
||||
+ '\r\nContent-Disposition: form-data; name="knowledgeRequest"\r\n\r\n'
|
||||
+ json.dumps(knowledge_request, ensure_ascii=False)
|
||||
+ "\r\n--"
|
||||
+ boundary
|
||||
+ '\r\nContent-Disposition: form-data; name="imageBase64"\r\n\r\n'
|
||||
+ image_bin
|
||||
+ "\r\n--"
|
||||
+ boundary
|
||||
+ "--\r\n"
|
||||
|
||||
def prepare_headers(session: ClientSession, boundary: str) -> dict:
|
||||
"""
|
||||
Prepares the headers for the image upload request.
|
||||
|
||||
Args:
|
||||
session (ClientSession): The active session.
|
||||
boundary (str): The boundary string for the multipart/form-data.
|
||||
|
||||
Returns:
|
||||
dict: The headers for the request.
|
||||
"""
|
||||
headers = session.headers.copy()
|
||||
headers["Content-Type"] = f'multipart/form-data; boundary={boundary}'
|
||||
headers["Referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
|
||||
headers["Origin"] = 'https://www.bing.com'
|
||||
return headers
|
||||
|
||||
def parse_image_response(response: dict) -> ImageResponse:
|
||||
"""
|
||||
Parses the response from the image upload.
|
||||
|
||||
Args:
|
||||
response (dict): The response dictionary.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If parsing the image info fails.
|
||||
|
||||
Returns:
|
||||
ImageResponse: The parsed image response.
|
||||
"""
|
||||
if not response.get('blobId'):
|
||||
raise RuntimeError("Failed to parse image info.")
|
||||
|
||||
result = {'bcid': response.get('blobId', ""), 'blurredBcid': response.get('processedBlobId', "")}
|
||||
result["imageUrl"] = f"https://www.bing.com/images/blob?bcid={result['blurredBcid'] or result['bcid']}"
|
||||
|
||||
result['originalImageUrl'] = (
|
||||
f"https://www.bing.com/images/blob?bcid={result['blurredBcid']}"
|
||||
if IMAGE_CONFIG["enableFaceBlurDebug"] else
|
||||
f"https://www.bing.com/images/blob?bcid={result['bcid']}"
|
||||
)
|
||||
return data, boundary
|
||||
return ImageResponse(result["imageUrl"], "", result)
|
@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType
|
||||
|
||||
system_message = """
|
||||
You can generate custom images with the DALL-E 3 image generator.
|
||||
To generate a image with a prompt, do this:
|
||||
To generate an image with a prompt, do this:
|
||||
<img data-prompt=\"keywords for the image\">
|
||||
Don't use images with data uri. It is important to use a prompt instead.
|
||||
<img data-prompt=\"image caption\">
|
||||
"""
|
||||
|
||||
class CreateImagesProvider(BaseProvider):
|
||||
"""
|
||||
Provider class for creating images based on text prompts.
|
||||
|
||||
This provider handles image creation requests embedded within message content,
|
||||
using provided image creation functions.
|
||||
|
||||
Attributes:
|
||||
provider (ProviderType): The underlying provider to handle non-image related tasks.
|
||||
create_images (callable): A function to create images synchronously.
|
||||
create_images_async (callable): A function to create images asynchronously.
|
||||
system_message (str): A message that explains the image creation capability.
|
||||
include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
|
||||
__name__ (str): Name of the provider.
|
||||
url (str): URL of the provider.
|
||||
working (bool): Indicates if the provider is operational.
|
||||
supports_stream (bool): Indicates if the provider supports streaming.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: ProviderType,
|
||||
@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider):
|
||||
system_message: str = system_message,
|
||||
include_placeholder: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the CreateImagesProvider.
|
||||
|
||||
Args:
|
||||
provider (ProviderType): The underlying provider.
|
||||
create_images (callable): Function to create images synchronously.
|
||||
create_async (callable): Function to create images asynchronously.
|
||||
system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
|
||||
include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
|
||||
"""
|
||||
self.provider = provider
|
||||
self.create_images = create_images
|
||||
self.create_images_async = create_async
|
||||
@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider):
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Creates a completion result, processing any image creation prompts found within the messages.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process, which may contain image prompts.
|
||||
stream (bool, optional): Indicates whether to stream the results. Defaults to False.
|
||||
**kwargs: Additional keywordarguments for the provider.
|
||||
|
||||
Yields:
|
||||
CreateResult: Yields chunks of the processed messages, including image data if applicable.
|
||||
|
||||
Note:
|
||||
This method processes messages to detect image creation prompts. When such a prompt is found,
|
||||
it calls the synchronous image creation function and includes the resulting image in the output.
|
||||
"""
|
||||
messages.insert(0, {"role": "system", "content": self.system_message})
|
||||
buffer = ""
|
||||
for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
|
||||
@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously creates a response, processing any image creation prompts found within the messages.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for creation.
|
||||
messages (Messages): The messages to process, which may contain image prompts.
|
||||
**kwargs: Additional keyword arguments for the provider.
|
||||
|
||||
Returns:
|
||||
str: The processed response string, including asynchronously generated image data if applicable.
|
||||
|
||||
Note:
|
||||
This method processes messages to detect image creation prompts. When such a prompt is found,
|
||||
it calls the asynchronous image creation function and includes the resulting image in the output.
|
||||
"""
|
||||
messages.insert(0, {"role": "system", "content": self.system_message})
|
||||
response = await self.provider.create_async(model, messages, **kwargs)
|
||||
matches = re.findall(r'(<img data-prompt="(.*?)">)', response)
|
||||
|
@ -1,36 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import webbrowser
|
||||
import random
|
||||
import string
|
||||
import secrets
|
||||
import os
|
||||
from os import path
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
from asyncio import AbstractEventLoop, BaseEventLoop
|
||||
from platformdirs import user_config_dir
|
||||
from browser_cookie3 import (
|
||||
chrome,
|
||||
chromium,
|
||||
opera,
|
||||
opera_gx,
|
||||
brave,
|
||||
edge,
|
||||
vivaldi,
|
||||
firefox,
|
||||
_LinuxPasswordManager
|
||||
chrome, chromium, opera, opera_gx,
|
||||
brave, edge, vivaldi, firefox,
|
||||
_LinuxPasswordManager, BrowserCookieError
|
||||
)
|
||||
|
||||
from ..typing import Dict, Messages
|
||||
from .. import debug
|
||||
|
||||
# Local Cookie Storage
|
||||
# Global variable to store cookies
|
||||
_cookies: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
# If loop closed or not set, create new event loop.
|
||||
# If event loop is already running, handle nested event loops.
|
||||
# If "nest_asyncio" is installed, patch the event loop.
|
||||
def get_event_loop() -> AbstractEventLoop:
|
||||
"""
|
||||
Get the current asyncio event loop. If the loop is closed or not set, create a new event loop.
|
||||
If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed.
|
||||
|
||||
Returns:
|
||||
AbstractEventLoop: The current or new event loop.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if isinstance(loop, BaseEventLoop):
|
||||
@ -39,61 +34,50 @@ def get_event_loop() -> AbstractEventLoop:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# Is running event loop
|
||||
asyncio.get_running_loop()
|
||||
if not hasattr(loop.__class__, "_nest_patched"):
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply(loop)
|
||||
except RuntimeError:
|
||||
# No running event loop
|
||||
pass
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
|
||||
'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.'
|
||||
)
|
||||
return loop
|
||||
|
||||
def init_cookies():
|
||||
urls = [
|
||||
'https://chat-gpt.org',
|
||||
'https://www.aitianhu.com',
|
||||
'https://chatgptfree.ai',
|
||||
'https://gptchatly.com',
|
||||
'https://bard.google.com',
|
||||
'https://huggingface.co/chat',
|
||||
'https://open-assistant.io/chat'
|
||||
]
|
||||
|
||||
browsers = ['google-chrome', 'chrome', 'firefox', 'safari']
|
||||
|
||||
def open_urls_in_browser(browser):
|
||||
b = webbrowser.get(browser)
|
||||
for url in urls:
|
||||
b.open(url, new=0, autoraise=True)
|
||||
|
||||
for browser in browsers:
|
||||
try:
|
||||
open_urls_in_browser(browser)
|
||||
break
|
||||
except webbrowser.Error:
|
||||
continue
|
||||
|
||||
# Check for broken dbus address in docker image
|
||||
if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
|
||||
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
|
||||
|
||||
# Load cookies for a domain from all supported browsers.
|
||||
# Cache the results in the "_cookies" variable.
|
||||
def get_cookies(domain_name=''):
|
||||
|
||||
def get_cookies(domain_name: str = '') -> Dict[str, str]:
|
||||
"""
|
||||
Load cookies for a given domain from all supported browsers and cache the results.
|
||||
|
||||
Args:
|
||||
domain_name (str): The domain for which to load cookies.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary of cookie names and values.
|
||||
"""
|
||||
if domain_name in _cookies:
|
||||
return _cookies[domain_name]
|
||||
def g4f(domain_name):
|
||||
user_data_dir = user_config_dir("g4f")
|
||||
cookie_file = path.join(user_data_dir, "Default", "Cookies")
|
||||
return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name)
|
||||
|
||||
cookies = _load_cookies_from_browsers(domain_name)
|
||||
_cookies[domain_name] = cookies
|
||||
return cookies
|
||||
|
||||
def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]:
|
||||
"""
|
||||
Helper function to load cookies from various browsers.
|
||||
|
||||
Args:
|
||||
domain_name (str): The domain for which to load cookies.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary of cookie names and values.
|
||||
"""
|
||||
cookies = {}
|
||||
for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
|
||||
for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
|
||||
try:
|
||||
cookie_jar = cookie_fn(domain_name=domain_name)
|
||||
if len(cookie_jar) and debug.logging:
|
||||
@ -101,13 +85,38 @@ def get_cookies(domain_name=''):
|
||||
for cookie in cookie_jar:
|
||||
if cookie.name not in cookies:
|
||||
cookies[cookie.name] = cookie.value
|
||||
except:
|
||||
except BrowserCookieError:
|
||||
pass
|
||||
_cookies[domain_name] = cookies
|
||||
return _cookies[domain_name]
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
|
||||
return cookies
|
||||
|
||||
def _g4f(domain_name: str) -> list:
|
||||
"""
|
||||
Load cookies from the 'g4f' browser (if exists).
|
||||
|
||||
Args:
|
||||
domain_name (str): The domain for which to load cookies.
|
||||
|
||||
Returns:
|
||||
list: List of cookies.
|
||||
"""
|
||||
user_data_dir = user_config_dir("g4f")
|
||||
cookie_file = os.path.join(user_data_dir, "Default", "Cookies")
|
||||
return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name)
|
||||
|
||||
def format_prompt(messages: Messages, add_special_tokens=False) -> str:
|
||||
"""
|
||||
Format a series of messages into a single string, optionally adding special tokens.
|
||||
|
||||
Args:
|
||||
messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
|
||||
add_special_tokens (bool): Whether to add special formatting tokens.
|
||||
|
||||
Returns:
|
||||
str: A formatted string containing all messages.
|
||||
"""
|
||||
if not add_special_tokens and len(messages) <= 1:
|
||||
return messages[0]["content"]
|
||||
formatted = "\n".join([
|
||||
@ -116,12 +125,26 @@ def format_prompt(messages: Messages, add_special_tokens=False) -> str:
|
||||
])
|
||||
return f"{formatted}\nAssistant:"
|
||||
|
||||
|
||||
def get_random_string(length: int = 10) -> str:
|
||||
"""
|
||||
Generate a random string of specified length, containing lowercase letters and digits.
|
||||
|
||||
Args:
|
||||
length (int, optional): Length of the random string to generate. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
str: A random string of the specified length.
|
||||
"""
|
||||
return ''.join(
|
||||
random.choice(string.ascii_lowercase + string.digits)
|
||||
for _ in range(length)
|
||||
)
|
||||
|
||||
def get_random_hex() -> str:
|
||||
"""
|
||||
Generate a random hexadecimal string of a fixed length.
|
||||
|
||||
Returns:
|
||||
str: A random hexadecimal string of 32 characters (16 bytes).
|
||||
"""
|
||||
return secrets.token_hex(16).zfill(32)
|
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
|
||||
import uuid, json, asyncio, os
|
||||
from py_arkose_generator.arkose import get_values_for_request
|
||||
from async_property import async_cached_property
|
||||
from selenium.webdriver.common.by import By
|
||||
@ -14,7 +17,8 @@ from ...typing import AsyncResult, Messages
|
||||
from ...requests import StreamSession
|
||||
from ...image import to_image, to_bytes, ImageType, ImageResponse
|
||||
|
||||
models = {
|
||||
# Aliases for model names
|
||||
MODELS = {
|
||||
"gpt-3.5": "text-davinci-002-render-sha",
|
||||
"gpt-3.5-turbo": "text-davinci-002-render-sha",
|
||||
"gpt-4": "gpt-4",
|
||||
@ -22,13 +26,15 @@ models = {
|
||||
}
|
||||
|
||||
class OpenaiChat(AsyncGeneratorProvider):
|
||||
url = "https://chat.openai.com"
|
||||
working = True
|
||||
needs_auth = True
|
||||
"""A class for creating and managing conversations with OpenAI chat service"""
|
||||
|
||||
url = "https://chat.openai.com"
|
||||
working = True
|
||||
needs_auth = True
|
||||
supports_gpt_35_turbo = True
|
||||
supports_gpt_4 = True
|
||||
_cookies: dict = {}
|
||||
_default_model: str = None
|
||||
supports_gpt_4 = True
|
||||
_cookies: dict = {}
|
||||
_default_model: str = None
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
@ -43,6 +49,23 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
image: ImageType = None,
|
||||
**kwargs
|
||||
) -> Response:
|
||||
"""Create a new conversation or continue an existing one
|
||||
|
||||
Args:
|
||||
prompt: The user input to start or continue the conversation
|
||||
model: The name of the model to use for generating responses
|
||||
messages: The list of previous messages in the conversation
|
||||
history_disabled: A flag indicating if the history and training should be disabled
|
||||
action: The type of action to perform, either "next", "continue", or "variant"
|
||||
conversation_id: The ID of the existing conversation, if any
|
||||
parent_id: The ID of the parent message, if any
|
||||
image: The image to include in the user input, if any
|
||||
**kwargs: Additional keyword arguments to pass to the generator
|
||||
|
||||
Returns:
|
||||
A Response object that contains the generator, action, messages, and options
|
||||
"""
|
||||
# Add the user input to the messages list
|
||||
if prompt:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
@ -67,20 +90,33 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def upload_image(
|
||||
async def _upload_image(
|
||||
cls,
|
||||
session: StreamSession,
|
||||
headers: dict,
|
||||
image: ImageType
|
||||
) -> ImageResponse:
|
||||
"""Upload an image to the service and get the download URL
|
||||
|
||||
Args:
|
||||
session: The StreamSession object to use for requests
|
||||
headers: The headers to include in the requests
|
||||
image: The image to upload, either a PIL Image object or a bytes object
|
||||
|
||||
Returns:
|
||||
An ImageResponse object that contains the download URL, file name, and other data
|
||||
"""
|
||||
# Convert the image to a PIL Image object and get the extension
|
||||
image = to_image(image)
|
||||
extension = image.format.lower()
|
||||
# Convert the image to a bytes object and get the size
|
||||
data_bytes = to_bytes(image)
|
||||
data = {
|
||||
"file_name": f"{image.width}x{image.height}.{extension}",
|
||||
"file_size": len(data_bytes),
|
||||
"use_case": "multimodal"
|
||||
}
|
||||
# Post the image data to the service and get the image data
|
||||
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
image_data = {
|
||||
@ -91,6 +127,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
"height": image.height,
|
||||
"width": image.width
|
||||
}
|
||||
# Put the image bytes to the upload URL and check the status
|
||||
async with session.put(
|
||||
image_data["upload_url"],
|
||||
data=data_bytes,
|
||||
@ -100,6 +137,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
}
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
# Post the file ID to the service and get the download URL
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
|
||||
json={},
|
||||
@ -110,24 +148,45 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
return ImageResponse(download_url, image_data["file_name"], image_data)
|
||||
|
||||
@classmethod
|
||||
async def get_default_model(cls, session: StreamSession, headers: dict):
|
||||
async def _get_default_model(cls, session: StreamSession, headers: dict):
|
||||
"""Get the default model name from the service
|
||||
|
||||
Args:
|
||||
session: The StreamSession object to use for requests
|
||||
headers: The headers to include in the requests
|
||||
|
||||
Returns:
|
||||
The default model name as a string
|
||||
"""
|
||||
# Check the cache for the default model
|
||||
if cls._default_model:
|
||||
model = cls._default_model
|
||||
else:
|
||||
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
|
||||
data = await response.json()
|
||||
if "categories" in data:
|
||||
model = data["categories"][-1]["default_model"]
|
||||
else:
|
||||
RuntimeError(f"Response: {data}")
|
||||
cls._default_model = model
|
||||
return model
|
||||
return cls._default_model
|
||||
# Get the models data from the service
|
||||
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
|
||||
data = await response.json()
|
||||
if "categories" in data:
|
||||
cls._default_model = data["categories"][-1]["default_model"]
|
||||
else:
|
||||
raise RuntimeError(f"Response: {data}")
|
||||
return cls._default_model
|
||||
|
||||
@classmethod
|
||||
def create_messages(cls, prompt: str, image_response: ImageResponse = None):
|
||||
def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
|
||||
"""Create a list of messages for the user input
|
||||
|
||||
Args:
|
||||
prompt: The user input as a string
|
||||
image_response: The image response object, if any
|
||||
|
||||
Returns:
|
||||
A list of messages with the user input and the image, if any
|
||||
"""
|
||||
# Check if there is an image response
|
||||
if not image_response:
|
||||
# Create a content object with the text type and the prompt
|
||||
content = {"content_type": "text", "parts": [prompt]}
|
||||
else:
|
||||
# Create a content object with the multimodal text type and the image and the prompt
|
||||
content = {
|
||||
"content_type": "multimodal_text",
|
||||
"parts": [{
|
||||
@ -137,12 +196,15 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
"width": image_response.get("width"),
|
||||
}, prompt]
|
||||
}
|
||||
# Create a message object with the user role and the content
|
||||
messages = [{
|
||||
"id": str(uuid.uuid4()),
|
||||
"author": {"role": "user"},
|
||||
"content": content,
|
||||
}]
|
||||
# Check if there is an image response
|
||||
if image_response:
|
||||
# Add the metadata object with the attachments
|
||||
messages[0]["metadata"] = {
|
||||
"attachments": [{
|
||||
"height": image_response.get("height"),
|
||||
@ -156,19 +218,38 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
async def get_image_response(cls, session: StreamSession, headers: dict, line: dict):
|
||||
if "parts" in line["message"]["content"]:
|
||||
part = line["message"]["content"]["parts"][0]
|
||||
if "asset_pointer" in part and part["metadata"]:
|
||||
file_id = part["asset_pointer"].split("file-service://", 1)[1]
|
||||
prompt = part["metadata"]["dalle"]["prompt"]
|
||||
async with session.get(
|
||||
f"{cls.url}/backend-api/files/{file_id}/download",
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
download_url = (await response.json())["download_url"]
|
||||
return ImageResponse(download_url, prompt)
|
||||
async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
|
||||
"""
|
||||
Retrieves the image response based on the message content.
|
||||
|
||||
:param session: The StreamSession object.
|
||||
:param headers: HTTP headers for the request.
|
||||
:param line: The line of response containing image information.
|
||||
:return: An ImageResponse object with the image details.
|
||||
"""
|
||||
if "parts" not in line["message"]["content"]:
|
||||
return
|
||||
first_part = line["message"]["content"]["parts"][0]
|
||||
if "asset_pointer" not in first_part or "metadata" not in first_part:
|
||||
return
|
||||
file_id = first_part["asset_pointer"].split("file-service://", 1)[1]
|
||||
prompt = first_part["metadata"]["dalle"]["prompt"]
|
||||
try:
|
||||
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
download_url = (await response.json())["download_url"]
|
||||
return ImageResponse(download_url, prompt)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in downloading image: {e}")
|
||||
|
||||
@classmethod
|
||||
async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
|
||||
async with session.patch(
|
||||
f"{cls.url}/backend-api/conversation/{conversation_id}",
|
||||
json={"is_visible": False},
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
@ -188,26 +269,47 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
response_fields: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if model in models:
|
||||
model = models[model]
|
||||
"""
|
||||
Create an asynchronous generator for the conversation.
|
||||
|
||||
Args:
|
||||
model (str): The model name.
|
||||
messages (Messages): The list of previous messages.
|
||||
proxy (str): Proxy to use for requests.
|
||||
timeout (int): Timeout for requests.
|
||||
access_token (str): Access token for authentication.
|
||||
cookies (dict): Cookies to use for authentication.
|
||||
auto_continue (bool): Flag to automatically continue the conversation.
|
||||
history_disabled (bool): Flag to disable history and training.
|
||||
action (str): Type of action ('next', 'continue', 'variant').
|
||||
conversation_id (str): ID of the conversation.
|
||||
parent_id (str): ID of the parent message.
|
||||
image (ImageType): Image to include in the conversation.
|
||||
response_fields (bool): Flag to include response fields in the output.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Yields:
|
||||
AsyncResult: Asynchronous results from the generator.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an error occurs during processing.
|
||||
"""
|
||||
model = MODELS.get(model, model)
|
||||
if not parent_id:
|
||||
parent_id = str(uuid.uuid4())
|
||||
if not cookies:
|
||||
cookies = cls._cookies
|
||||
if not access_token:
|
||||
if not cookies:
|
||||
cls._cookies = cookies = get_cookies("chat.openai.com")
|
||||
if "access_token" in cookies:
|
||||
access_token = cookies["access_token"]
|
||||
cookies = cls._cookies or get_cookies("chat.openai.com")
|
||||
if not access_token and "access_token" in cookies:
|
||||
access_token = cookies["access_token"]
|
||||
if not access_token:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"Please login: [ChatGPT]({login_url})\n\n"
|
||||
access_token, cookies = cls.browse_access_token(proxy)
|
||||
access_token, cookies = cls._browse_access_token(proxy)
|
||||
cls._cookies = cookies
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
async with StreamSession(
|
||||
proxies={"https": proxy},
|
||||
impersonate="chrome110",
|
||||
@ -215,11 +317,11 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
|
||||
) as session:
|
||||
if not model:
|
||||
model = await cls.get_default_model(session, headers)
|
||||
model = await cls._get_default_model(session, headers)
|
||||
try:
|
||||
image_response = None
|
||||
if image:
|
||||
image_response = await cls.upload_image(session, headers, image)
|
||||
image_response = await cls._upload_image(session, headers, image)
|
||||
yield image_response
|
||||
except Exception as e:
|
||||
yield e
|
||||
@ -227,7 +329,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
while not end_turn.is_end:
|
||||
data = {
|
||||
"action": action,
|
||||
"arkose_token": await cls.get_arkose_token(session),
|
||||
"arkose_token": await cls._get_arkose_token(session),
|
||||
"conversation_id": conversation_id,
|
||||
"parent_message_id": parent_id,
|
||||
"model": model,
|
||||
@ -235,7 +337,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
}
|
||||
if action != "continue":
|
||||
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
|
||||
data["messages"] = cls.create_messages(prompt, image_response)
|
||||
data["messages"] = cls._create_messages(prompt, image_response)
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-api/conversation",
|
||||
json=data,
|
||||
@ -261,62 +363,80 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
if "message_type" not in line["message"]["metadata"]:
|
||||
continue
|
||||
try:
|
||||
image_response = await cls.get_image_response(session, headers, line)
|
||||
image_response = await cls._get_generated_image(session, headers, line)
|
||||
if image_response:
|
||||
yield image_response
|
||||
except Exception as e:
|
||||
yield e
|
||||
if line["message"]["author"]["role"] != "assistant":
|
||||
continue
|
||||
if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"):
|
||||
conversation_id = line["conversation_id"]
|
||||
parent_id = line["message"]["id"]
|
||||
if response_fields:
|
||||
response_fields = False
|
||||
yield ResponseFields(conversation_id, parent_id, end_turn)
|
||||
if "parts" in line["message"]["content"]:
|
||||
new_message = line["message"]["content"]["parts"][0]
|
||||
if len(new_message) > last_message:
|
||||
yield new_message[last_message:]
|
||||
last_message = len(new_message)
|
||||
if line["message"]["content"]["content_type"] != "text":
|
||||
continue
|
||||
if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
|
||||
continue
|
||||
conversation_id = line["conversation_id"]
|
||||
parent_id = line["message"]["id"]
|
||||
if response_fields:
|
||||
response_fields = False
|
||||
yield ResponseFields(conversation_id, parent_id, end_turn)
|
||||
if "parts" in line["message"]["content"]:
|
||||
new_message = line["message"]["content"]["parts"][0]
|
||||
if len(new_message) > last_message:
|
||||
yield new_message[last_message:]
|
||||
last_message = len(new_message)
|
||||
if "finish_details" in line["message"]["metadata"]:
|
||||
if line["message"]["metadata"]["finish_details"]["type"] == "stop":
|
||||
end_turn.end()
|
||||
break
|
||||
except Exception as e:
|
||||
yield e
|
||||
raise e
|
||||
if not auto_continue:
|
||||
break
|
||||
action = "continue"
|
||||
await asyncio.sleep(5)
|
||||
if history_disabled:
|
||||
async with session.patch(
|
||||
f"{cls.url}/backend-api/conversation/{conversation_id}",
|
||||
json={"is_visible": False},
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
if history_disabled and auto_continue:
|
||||
await cls._delete_conversation(session, headers, conversation_id)
|
||||
|
||||
@classmethod
|
||||
def browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
|
||||
def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
|
||||
"""
|
||||
Browse to obtain an access token.
|
||||
|
||||
Args:
|
||||
proxy (str): Proxy to use for browsing.
|
||||
|
||||
Returns:
|
||||
tuple[str, dict]: A tuple containing the access token and cookies.
|
||||
"""
|
||||
driver = get_browser(proxy=proxy)
|
||||
try:
|
||||
driver.get(f"{cls.url}/")
|
||||
WebDriverWait(driver, 1200).until(
|
||||
EC.presence_of_element_located((By.ID, "prompt-textarea"))
|
||||
WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
|
||||
access_token = driver.execute_script(
|
||||
"let session = await fetch('/api/auth/session');"
|
||||
"let data = await session.json();"
|
||||
"let accessToken = data['accessToken'];"
|
||||
"let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7);"
|
||||
"document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
|
||||
"return accessToken;"
|
||||
)
|
||||
javascript = """
|
||||
access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
|
||||
expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
|
||||
document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
|
||||
return access_token;
|
||||
"""
|
||||
return driver.execute_script(javascript), get_driver_cookies(driver)
|
||||
return access_token, get_driver_cookies(driver)
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
@classmethod
|
||||
async def get_arkose_token(cls, session: StreamSession) -> str:
|
||||
@classmethod
|
||||
async def _get_arkose_token(cls, session: StreamSession) -> str:
|
||||
"""
|
||||
Obtain an Arkose token for the session.
|
||||
|
||||
Args:
|
||||
session (StreamSession): The session object.
|
||||
|
||||
Returns:
|
||||
str: The Arkose token.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If unable to retrieve the token.
|
||||
"""
|
||||
config = {
|
||||
"pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
|
||||
"surl": "https://tcr9i.chat.openai.com",
|
||||
@ -332,26 +452,30 @@ return access_token;
|
||||
if "token" in decoded_json:
|
||||
return decoded_json["token"]
|
||||
raise RuntimeError(f"Response: {decoded_json}")
|
||||
|
||||
class EndTurn():
|
||||
|
||||
class EndTurn:
|
||||
"""
|
||||
Class to represent the end of a conversation turn.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.is_end = False
|
||||
|
||||
def end(self):
|
||||
self.is_end = True
|
||||
|
||||
class ResponseFields():
|
||||
def __init__(
|
||||
self,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
end_turn: EndTurn
|
||||
):
|
||||
class ResponseFields:
|
||||
"""
|
||||
Class to encapsulate response fields.
|
||||
"""
|
||||
def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
|
||||
self.conversation_id = conversation_id
|
||||
self.message_id = message_id
|
||||
self._end_turn = end_turn
|
||||
|
||||
class Response():
|
||||
"""
|
||||
Class to encapsulate a response from the chat service.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
generator: AsyncResult,
|
||||
@ -360,13 +484,13 @@ class Response():
|
||||
options: dict
|
||||
):
|
||||
self._generator = generator
|
||||
self.action: str = action
|
||||
self.is_end: bool = False
|
||||
self.action = action
|
||||
self.is_end = False
|
||||
self._message = None
|
||||
self._messages = messages
|
||||
self._options = options
|
||||
self._fields = None
|
||||
|
||||
|
||||
async def generator(self):
|
||||
if self._generator:
|
||||
self._generator = None
|
||||
@ -384,19 +508,16 @@ class Response():
|
||||
|
||||
def __aiter__(self):
|
||||
return self.generator()
|
||||
|
||||
|
||||
@async_cached_property
|
||||
async def message(self) -> str:
|
||||
[_ async for _ in self.generator()]
|
||||
await self.generator()
|
||||
return self._message
|
||||
|
||||
|
||||
async def get_fields(self):
|
||||
[_ async for _ in self.generator()]
|
||||
return {
|
||||
"conversation_id": self._fields.conversation_id,
|
||||
"parent_id": self._fields.message_id,
|
||||
}
|
||||
|
||||
await self.generator()
|
||||
return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id}
|
||||
|
||||
async def next(self, prompt: str, **kwargs) -> Response:
|
||||
return await OpenaiChat.create(
|
||||
**self._options,
|
||||
@ -406,7 +527,7 @@ class Response():
|
||||
**await self.get_fields(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
async def do_continue(self, **kwargs) -> Response:
|
||||
fields = await self.get_fields()
|
||||
if self.is_end:
|
||||
@ -418,7 +539,7 @@ class Response():
|
||||
**fields,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
async def variant(self, **kwargs) -> Response:
|
||||
if self.action != "next":
|
||||
raise RuntimeError("Can't create variant from continue or variant request.")
|
||||
@ -429,11 +550,9 @@ class Response():
|
||||
**await self.get_fields(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
@async_cached_property
|
||||
async def messages(self):
|
||||
messages = self._messages
|
||||
messages.append({
|
||||
"role": "assistant", "content": await self.message
|
||||
})
|
||||
messages.append({"role": "assistant", "content": await self.message})
|
||||
return messages
|
@ -7,8 +7,17 @@ from ..base_provider import BaseRetryProvider
|
||||
from .. import debug
|
||||
from ..errors import RetryProviderError, RetryNoProviderError
|
||||
|
||||
|
||||
class RetryProvider(BaseRetryProvider):
|
||||
"""
|
||||
A provider class to handle retries for creating completions with different providers.
|
||||
|
||||
Attributes:
|
||||
providers (list): A list of provider instances.
|
||||
shuffle (bool): A flag indicating whether to shuffle providers before use.
|
||||
exceptions (dict): A dictionary to store exceptions encountered during retries.
|
||||
last_provider (BaseProvider): The last provider that was used.
|
||||
"""
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
model: str,
|
||||
@ -16,10 +25,21 @@ class RetryProvider(BaseRetryProvider):
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
if stream:
|
||||
providers = [provider for provider in self.providers if provider.supports_stream]
|
||||
else:
|
||||
providers = self.providers
|
||||
"""
|
||||
Create a completion using available providers, with an option to stream the response.
|
||||
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
|
||||
|
||||
Yields:
|
||||
CreateResult: Tokens or results from the completion.
|
||||
|
||||
Raises:
|
||||
Exception: Any exception encountered during the completion process.
|
||||
"""
|
||||
providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
|
||||
if self.shuffle:
|
||||
random.shuffle(providers)
|
||||
|
||||
@ -50,10 +70,23 @@ class RetryProvider(BaseRetryProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously create a completion using available providers.
|
||||
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
|
||||
Returns:
|
||||
str: The result of the asynchronous completion.
|
||||
|
||||
Raises:
|
||||
Exception: Any exception encountered during the asynchronous completion process.
|
||||
"""
|
||||
providers = self.providers
|
||||
if self.shuffle:
|
||||
random.shuffle(providers)
|
||||
|
||||
|
||||
self.exceptions = {}
|
||||
for provider in providers:
|
||||
self.last_provider = provider
|
||||
@ -66,13 +99,20 @@ class RetryProvider(BaseRetryProvider):
|
||||
self.exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
|
||||
|
||||
self.raise_exceptions()
|
||||
|
||||
|
||||
def raise_exceptions(self) -> None:
|
||||
"""
|
||||
Raise a combined exception if any occurred during retries.
|
||||
|
||||
Raises:
|
||||
RetryProviderError: If any provider encountered an exception.
|
||||
RetryNoProviderError: If no provider is found.
|
||||
"""
|
||||
if self.exceptions:
|
||||
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
|
||||
f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
|
||||
]))
|
||||
|
||||
|
||||
raise RetryNoProviderError("No provider found")
|
@ -15,6 +15,26 @@ def get_model_and_provider(model : Union[Model, str],
|
||||
ignored : list[str] = None,
|
||||
ignore_working: bool = False,
|
||||
ignore_stream: bool = False) -> tuple[str, ProviderType]:
|
||||
"""
|
||||
Retrieves the model and provider based on input parameters.
|
||||
|
||||
Args:
|
||||
model (Union[Model, str]): The model to use, either as an object or a string identifier.
|
||||
provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
|
||||
stream (bool): Indicates if the operation should be performed as a stream.
|
||||
ignored (list[str], optional): List of provider names to be ignored.
|
||||
ignore_working (bool, optional): If True, ignores the working status of the provider.
|
||||
ignore_stream (bool, optional): If True, ignores the streaming capability of the provider.
|
||||
|
||||
Returns:
|
||||
tuple[str, ProviderType]: A tuple containing the model name and the provider type.
|
||||
|
||||
Raises:
|
||||
ProviderNotFoundError: If the provider is not found.
|
||||
ModelNotFoundError: If the model is not found.
|
||||
ProviderNotWorkingError: If the provider is not working.
|
||||
StreamNotSupportedError: If streaming is not supported by the provider.
|
||||
"""
|
||||
if debug.version_check:
|
||||
debug.version_check = False
|
||||
version.utils.check_version()
|
||||
@ -70,7 +90,30 @@ class ChatCompletion:
|
||||
ignore_stream_and_auth: bool = False,
|
||||
patch_provider: callable = None,
|
||||
**kwargs) -> Union[CreateResult, str]:
|
||||
"""
|
||||
Creates a chat completion using the specified model, provider, and messages.
|
||||
|
||||
Args:
|
||||
model (Union[Model, str]): The model to use, either as an object or a string identifier.
|
||||
messages (Messages): The messages for which the completion is to be created.
|
||||
provider (Union[ProviderType, str, None], optional): The provider to use, either as an object, a string identifier, or None.
|
||||
stream (bool, optional): Indicates if the operation should be performed as a stream.
|
||||
auth (Union[str, None], optional): Authentication token or credentials, if required.
|
||||
ignored (list[str], optional): List of provider names to be ignored.
|
||||
ignore_working (bool, optional): If True, ignores the working status of the provider.
|
||||
ignore_stream_and_auth (bool, optional): If True, ignores the stream and authentication requirement checks.
|
||||
patch_provider (callable, optional): Function to modify the provider.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Union[CreateResult, str]: The result of the chat completion operation.
|
||||
|
||||
Raises:
|
||||
AuthenticationRequiredError: If authentication is required but not provided.
|
||||
ProviderNotFoundError, ModelNotFoundError: If the specified provider or model is not found.
|
||||
ProviderNotWorkingError: If the provider is not operational.
|
||||
StreamNotSupportedError: If streaming is requested but not supported by the provider.
|
||||
"""
|
||||
model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth)
|
||||
|
||||
if not ignore_stream_and_auth and provider.needs_auth and not auth:
|
||||
@ -98,7 +141,24 @@ class ChatCompletion:
|
||||
ignored : list[str] = None,
|
||||
patch_provider: callable = None,
|
||||
**kwargs) -> Union[AsyncResult, str]:
|
||||
"""
|
||||
Asynchronously creates a completion using the specified model and provider.
|
||||
|
||||
Args:
|
||||
model (Union[Model, str]): The model to use, either as an object or a string identifier.
|
||||
messages (Messages): Messages to be processed.
|
||||
provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
|
||||
stream (bool): Indicates if the operation should be performed as a stream.
|
||||
ignored (list[str], optional): List of provider names to be ignored.
|
||||
patch_provider (callable, optional): Function to modify the provider.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Union[AsyncResult, str]: The result of the asynchronous chat completion operation.
|
||||
|
||||
Raises:
|
||||
StreamNotSupportedError: If streaming is requested but not supported by the provider.
|
||||
"""
|
||||
model, provider = get_model_and_provider(model, provider, False, ignored)
|
||||
|
||||
if stream:
|
||||
@ -118,7 +178,23 @@ class Completion:
|
||||
provider : Union[ProviderType, None] = None,
|
||||
stream : bool = False,
|
||||
ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]:
|
||||
"""
|
||||
Creates a completion based on the provided model, prompt, and provider.
|
||||
|
||||
Args:
|
||||
model (Union[Model, str]): The model to use, either as an object or a string identifier.
|
||||
prompt (str): The prompt text for which the completion is to be created.
|
||||
provider (Union[ProviderType, None], optional): The provider to use, either as an object or None.
|
||||
stream (bool, optional): Indicates if the operation should be performed as a stream.
|
||||
ignored (list[str], optional): List of provider names to be ignored.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Union[CreateResult, str]: The result of the completion operation.
|
||||
|
||||
Raises:
|
||||
ModelNotAllowedError: If the specified model is not allowed for use with this method.
|
||||
"""
|
||||
allowed_models = [
|
||||
'code-davinci-002',
|
||||
'text-ada-001',
|
||||
@ -137,6 +213,15 @@ class Completion:
|
||||
return result if stream else ''.join(result)
|
||||
|
||||
def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str]]:
|
||||
"""
|
||||
Retrieves the last used provider.
|
||||
|
||||
Args:
|
||||
as_dict (bool, optional): If True, returns the provider information as a dictionary.
|
||||
|
||||
Returns:
|
||||
Union[ProviderType, dict[str, str]]: The last used provider, either as an object or a dictionary.
|
||||
"""
|
||||
last = debug.last_provider
|
||||
if isinstance(last, BaseRetryProvider):
|
||||
last = last.last_provider
|
||||
|
@ -1,7 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from .typing import Messages, CreateResult, Union
|
||||
|
||||
from typing import Union, List, Dict, Type
|
||||
from .typing import Messages, CreateResult
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for a provider.
|
||||
|
||||
Attributes:
|
||||
url (str): URL of the provider.
|
||||
working (bool): Indicates if the provider is currently working.
|
||||
needs_auth (bool): Indicates if the provider needs authentication.
|
||||
supports_stream (bool): Indicates if the provider supports streaming.
|
||||
supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
|
||||
supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
|
||||
supports_message_history (bool): Indicates if the provider supports message history.
|
||||
params (str): List parameters for the provider.
|
||||
"""
|
||||
|
||||
url: str = None
|
||||
working: bool = False
|
||||
needs_auth: bool = False
|
||||
@ -20,6 +35,18 @@ class BaseProvider(ABC):
|
||||
stream: bool,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Whether to use streaming.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the creation process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
@ -30,25 +57,59 @@ class BaseProvider(ABC):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The result of the creation process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_dict(cls):
|
||||
def get_dict(cls) -> Dict[str, str]:
|
||||
"""
|
||||
Get a dictionary representation of the provider.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary with provider's details.
|
||||
"""
|
||||
return {'name': cls.__name__, 'url': cls.url}
|
||||
|
||||
class BaseRetryProvider(BaseProvider):
|
||||
"""
|
||||
Base class for a provider that implements retry logic.
|
||||
|
||||
Attributes:
|
||||
providers (List[Type[BaseProvider]]): List of providers to use for retries.
|
||||
shuffle (bool): Whether to shuffle the providers list.
|
||||
exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
|
||||
last_provider (Type[BaseProvider]): The last provider used.
|
||||
"""
|
||||
|
||||
__name__: str = "RetryProvider"
|
||||
supports_stream: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
providers: list[type[BaseProvider]],
|
||||
providers: List[Type[BaseProvider]],
|
||||
shuffle: bool = True
|
||||
) -> None:
|
||||
self.providers: list[type[BaseProvider]] = providers
|
||||
self.shuffle: bool = shuffle
|
||||
self.working: bool = True
|
||||
self.exceptions: dict[str, Exception] = {}
|
||||
self.last_provider: type[BaseProvider] = None
|
||||
"""
|
||||
Initialize the BaseRetryProvider.
|
||||
|
||||
Args:
|
||||
providers (List[Type[BaseProvider]]): List of providers to use.
|
||||
shuffle (bool): Whether to shuffle the providers list.
|
||||
"""
|
||||
self.providers = providers
|
||||
self.shuffle = shuffle
|
||||
self.working = True
|
||||
self.exceptions: Dict[str, Exception] = {}
|
||||
self.last_provider: Type[BaseProvider] = None
|
||||
|
||||
ProviderType = Union[type[BaseProvider], BaseRetryProvider]
|
||||
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|
@ -404,7 +404,7 @@ body {
|
||||
display: none;
|
||||
}
|
||||
|
||||
#image {
|
||||
#image, #file {
|
||||
display: none;
|
||||
}
|
||||
|
||||
@ -412,13 +412,22 @@ label[for="image"]:has(> input:valid){
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
label[for="image"] {
|
||||
label[for="file"]:has(> input:valid){
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
label[for="image"], label[for="file"] {
|
||||
cursor: pointer;
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
left: 10px;
|
||||
}
|
||||
|
||||
label[for="file"] {
|
||||
top: 32px;
|
||||
left: 10px;
|
||||
}
|
||||
|
||||
.buttons input[type="checkbox"] {
|
||||
height: 0;
|
||||
width: 0;
|
||||
|
@ -118,6 +118,10 @@
|
||||
<input type="file" id="image" name="image" accept="image/png, image/gif, image/jpeg" required/>
|
||||
<i class="fa-regular fa-image"></i>
|
||||
</label>
|
||||
<label for="file">
|
||||
<input type="file" id="file" name="file" accept="text/plain, text/html, text/xml, application/json, text/javascript, .sh, .py, .php, .css, .yaml, .sql, .svg, .log, .csv, .twig, .md" required/>
|
||||
<i class="fa-solid fa-paperclip"></i>
|
||||
</label>
|
||||
<div id="send-button">
|
||||
<i class="fa-solid fa-paper-plane-top"></i>
|
||||
</div>
|
||||
@ -125,7 +129,14 @@
|
||||
</div>
|
||||
<div class="buttons">
|
||||
<div class="field">
|
||||
<select name="model" id="model"></select>
|
||||
<select name="model" id="model">
|
||||
<option value="">Model: Default</option>
|
||||
<option value="gpt-4">gpt-4</option>
|
||||
<option value="gpt-3.5-turbo">gpt-3.5-turbo</option>
|
||||
<option value="llama2-70b">llama2-70b</option>
|
||||
<option value="gemini-pro">gemini-pro</option>
|
||||
<option value="">----</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="field">
|
||||
<select name="jailbreak" id="jailbreak" style="display: none;">
|
||||
@ -138,7 +149,16 @@
|
||||
<option value="gpt-evil-1.0">evil 1.0</option>
|
||||
</select>
|
||||
<div class="field">
|
||||
<select name="provider" id="provider"></select>
|
||||
<select name="provider" id="provider">
|
||||
<option value="">Provider: Auto</option>
|
||||
<option value="Bing">Bing</option>
|
||||
<option value="OpenaiChat">OpenaiChat</option>
|
||||
<option value="HuggingChat">HuggingChat</option>
|
||||
<option value="Bard">Bard</option>
|
||||
<option value="Liaobots">Liaobots</option>
|
||||
<option value="Phind">Phind</option>
|
||||
<option value="">----</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div class="field">
|
||||
|
@ -7,7 +7,9 @@ const spinner = box_conversations.querySelector(".spinner");
|
||||
const stop_generating = document.querySelector(`.stop_generating`);
|
||||
const regenerate = document.querySelector(`.regenerate`);
|
||||
const send_button = document.querySelector(`#send-button`);
|
||||
const imageInput = document.querySelector('#image') ;
|
||||
const imageInput = document.querySelector('#image');
|
||||
const fileInput = document.querySelector('#file');
|
||||
|
||||
let prompt_lock = false;
|
||||
|
||||
hljs.addPlugin(new CopyButtonPlugin());
|
||||
@ -42,6 +44,11 @@ const handle_ask = async () => {
|
||||
if (message.length > 0) {
|
||||
message_input.value = '';
|
||||
await add_conversation(window.conversation_id, message);
|
||||
if ("text" in fileInput.dataset) {
|
||||
message += '\n```' + fileInput.dataset.type + '\n';
|
||||
message += fileInput.dataset.text;
|
||||
message += '\n```'
|
||||
}
|
||||
await add_message(window.conversation_id, "user", message);
|
||||
window.token = message_id();
|
||||
message_box.innerHTML += `
|
||||
@ -55,6 +62,9 @@ const handle_ask = async () => {
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
document.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
hljs.highlightElement(el);
|
||||
});
|
||||
await ask_gpt();
|
||||
}
|
||||
};
|
||||
@ -171,17 +181,30 @@ const ask_gpt = async () => {
|
||||
content_inner.innerHTML += "<p>An error occured, please try again, if the problem persists, please use a other model or provider.</p>";
|
||||
} else {
|
||||
html = markdown_render(text);
|
||||
html = html.substring(0, html.lastIndexOf('</p>')) + '<span id="cursor"></span></p>';
|
||||
let lastElement, lastIndex = null;
|
||||
for (element of ['</p>', '</code></pre>', '</li>\n</ol>']) {
|
||||
const index = html.lastIndexOf(element)
|
||||
if (index > lastIndex) {
|
||||
lastElement = element;
|
||||
lastIndex = index;
|
||||
}
|
||||
}
|
||||
if (lastIndex) {
|
||||
html = html.substring(0, lastIndex) + '<span id="cursor"></span>' + lastElement;
|
||||
}
|
||||
content_inner.innerHTML = html;
|
||||
document.querySelectorAll('code').forEach((el) => {
|
||||
document.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
hljs.highlightElement(el);
|
||||
});
|
||||
}
|
||||
|
||||
window.scrollTo(0, 0);
|
||||
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
|
||||
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
|
||||
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
|
||||
}
|
||||
}
|
||||
if (!error && imageInput) imageInput.value = "";
|
||||
if (!error && fileInput) fileInput.value = "";
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
|
||||
@ -305,7 +328,7 @@ const load_conversation = async (conversation_id) => {
|
||||
`;
|
||||
}
|
||||
|
||||
document.querySelectorAll(`code`).forEach((el) => {
|
||||
document.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
hljs.highlightElement(el);
|
||||
});
|
||||
|
||||
@ -400,7 +423,7 @@ const load_conversations = async (limit, offset, loader) => {
|
||||
`;
|
||||
}
|
||||
|
||||
document.querySelectorAll(`code`).forEach((el) => {
|
||||
document.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
hljs.highlightElement(el);
|
||||
});
|
||||
};
|
||||
@ -602,14 +625,7 @@ observer.observe(message_input, { attributes: true });
|
||||
(async () => {
|
||||
response = await fetch('/backend-api/v2/models')
|
||||
models = await response.json()
|
||||
|
||||
let select = document.getElementById('model');
|
||||
select.textContent = '';
|
||||
|
||||
let auto = document.createElement('option');
|
||||
auto.value = '';
|
||||
auto.text = 'Model: Default';
|
||||
select.appendChild(auto);
|
||||
|
||||
for (model of models) {
|
||||
let option = document.createElement('option');
|
||||
@ -619,14 +635,7 @@ observer.observe(message_input, { attributes: true });
|
||||
|
||||
response = await fetch('/backend-api/v2/providers')
|
||||
providers = await response.json()
|
||||
|
||||
select = document.getElementById('provider');
|
||||
select.textContent = '';
|
||||
|
||||
auto = document.createElement('option');
|
||||
auto.value = '';
|
||||
auto.text = 'Provider: Auto';
|
||||
select.appendChild(auto);
|
||||
|
||||
for (provider of providers) {
|
||||
let option = document.createElement('option');
|
||||
@ -643,11 +652,34 @@ observer.observe(message_input, { attributes: true });
|
||||
|
||||
document.title = 'g4f - gui - ' + versions["version"];
|
||||
text = "version ~ "
|
||||
if (versions["version"] != versions["lastet_version"]) {
|
||||
release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["lastet_version"];
|
||||
text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["lastet_version"] +'">' + versions["version"] + ' 🆕</a>';
|
||||
if (versions["version"] != versions["latest_version"]) {
|
||||
release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"];
|
||||
text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["latest_version"] +'">' + versions["version"] + ' 🆕</a>';
|
||||
} else {
|
||||
text += versions["version"];
|
||||
}
|
||||
document.getElementById("version_text").innerHTML = text
|
||||
})()
|
||||
})()
|
||||
|
||||
fileInput.addEventListener('change', async (event) => {
|
||||
if (fileInput.files.length) {
|
||||
type = fileInput.files[0].type;
|
||||
if (type && type.indexOf('/')) {
|
||||
type = type.split('/').pop().replace('x-', '')
|
||||
type = type.replace('plain', 'plaintext')
|
||||
.replace('shellscript', 'sh')
|
||||
.replace('svg+xml', 'svg')
|
||||
.replace('vnd.trolltech.linguist', 'ts')
|
||||
} else {
|
||||
type = fileInput.files[0].name.split('.').pop()
|
||||
}
|
||||
fileInput.dataset.type = type
|
||||
const reader = new FileReader();
|
||||
reader.addEventListener('load', (event) => {
|
||||
fileInput.dataset.text = event.target.result;
|
||||
});
|
||||
reader.readAsText(fileInput.files[0]);
|
||||
} else {
|
||||
delete fileInput.dataset.text;
|
||||
}
|
||||
});
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import json
|
||||
from flask import request, Flask
|
||||
from typing import Generator
|
||||
from g4f import debug, version, models
|
||||
from g4f import _all_models, get_last_provider, ChatCompletion
|
||||
from g4f.image import is_allowed_extension, to_image
|
||||
@ -11,60 +12,123 @@ from .internet import get_search_message
|
||||
debug.logging = True
|
||||
|
||||
class Backend_Api:
|
||||
"""
|
||||
Handles various endpoints in a Flask application for backend operations.
|
||||
|
||||
This class provides methods to interact with models, providers, and to handle
|
||||
various functionalities like conversations, error handling, and version management.
|
||||
|
||||
Attributes:
|
||||
app (Flask): A Flask application instance.
|
||||
routes (dict): A dictionary mapping API endpoints to their respective handlers.
|
||||
"""
|
||||
def __init__(self, app: Flask) -> None:
|
||||
"""
|
||||
Initialize the backend API with the given Flask application.
|
||||
|
||||
Args:
|
||||
app (Flask): Flask application instance to attach routes to.
|
||||
"""
|
||||
self.app: Flask = app
|
||||
self.routes = {
|
||||
'/backend-api/v2/models': {
|
||||
'function': self.models,
|
||||
'methods' : ['GET']
|
||||
'function': self.get_models,
|
||||
'methods': ['GET']
|
||||
},
|
||||
'/backend-api/v2/providers': {
|
||||
'function': self.providers,
|
||||
'methods' : ['GET']
|
||||
'function': self.get_providers,
|
||||
'methods': ['GET']
|
||||
},
|
||||
'/backend-api/v2/version': {
|
||||
'function': self.version,
|
||||
'methods' : ['GET']
|
||||
'function': self.get_version,
|
||||
'methods': ['GET']
|
||||
},
|
||||
'/backend-api/v2/conversation': {
|
||||
'function': self._conversation,
|
||||
'function': self.handle_conversation,
|
||||
'methods': ['POST']
|
||||
},
|
||||
'/backend-api/v2/gen.set.summarize:title': {
|
||||
'function': self._gen_title,
|
||||
'function': self.generate_title,
|
||||
'methods': ['POST']
|
||||
},
|
||||
'/backend-api/v2/error': {
|
||||
'function': self.error,
|
||||
'function': self.handle_error,
|
||||
'methods': ['POST']
|
||||
}
|
||||
}
|
||||
|
||||
def error(self):
|
||||
def handle_error(self):
|
||||
"""
|
||||
Initialize the backend API with the given Flask application.
|
||||
|
||||
Args:
|
||||
app (Flask): Flask application instance to attach routes to.
|
||||
"""
|
||||
print(request.json)
|
||||
|
||||
return 'ok', 200
|
||||
|
||||
def models(self):
|
||||
def get_models(self):
|
||||
"""
|
||||
Return a list of all models.
|
||||
|
||||
Fetches and returns a list of all available models in the system.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of model names.
|
||||
"""
|
||||
return _all_models
|
||||
|
||||
def providers(self):
|
||||
return [
|
||||
provider.__name__ for provider in __providers__ if provider.working
|
||||
]
|
||||
def get_providers(self):
|
||||
"""
|
||||
Return a list of all working providers.
|
||||
"""
|
||||
return [provider.__name__ for provider in __providers__ if provider.working]
|
||||
|
||||
def version(self):
|
||||
def get_version(self):
|
||||
"""
|
||||
Returns the current and latest version of the application.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the current and latest version.
|
||||
"""
|
||||
return {
|
||||
"version": version.utils.current_version,
|
||||
"lastet_version": version.get_latest_version(),
|
||||
"latest_version": version.get_latest_version(),
|
||||
}
|
||||
|
||||
def _gen_title(self):
|
||||
return {
|
||||
'title': ''
|
||||
}
|
||||
def generate_title(self):
|
||||
"""
|
||||
Generates and returns a title based on the request data.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the generated title.
|
||||
"""
|
||||
return {'title': ''}
|
||||
|
||||
def _conversation(self):
|
||||
def handle_conversation(self):
|
||||
"""
|
||||
Handles conversation requests and streams responses back.
|
||||
|
||||
Returns:
|
||||
Response: A Flask response object for streaming.
|
||||
"""
|
||||
kwargs = self._prepare_conversation_kwargs()
|
||||
|
||||
return self.app.response_class(
|
||||
self._create_response_stream(kwargs),
|
||||
mimetype='text/event-stream'
|
||||
)
|
||||
|
||||
def _prepare_conversation_kwargs(self):
|
||||
"""
|
||||
Prepares arguments for chat completion based on the request data.
|
||||
|
||||
Reads the request and prepares the necessary arguments for handling
|
||||
a chat completion request.
|
||||
|
||||
Returns:
|
||||
dict: Arguments prepared for chat completion.
|
||||
"""
|
||||
kwargs = {}
|
||||
if 'image' in request.files:
|
||||
file = request.files['image']
|
||||
@ -87,47 +151,70 @@ class Backend_Api:
|
||||
messages[-1]["content"] = get_search_message(messages[-1]["content"])
|
||||
model = json_data.get('model')
|
||||
model = model if model else models.default
|
||||
provider = json_data.get('provider', '').replace('g4f.Provider.', '')
|
||||
provider = provider if provider and provider != "Auto" else None
|
||||
patch = patch_provider if json_data.get('patch_provider') else None
|
||||
|
||||
def try_response():
|
||||
try:
|
||||
first = True
|
||||
for chunk in ChatCompletion.create(
|
||||
model=model,
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
ignore_stream_and_auth=True,
|
||||
patch_provider=patch,
|
||||
**kwargs
|
||||
):
|
||||
if first:
|
||||
first = False
|
||||
yield json.dumps({
|
||||
'type' : 'provider',
|
||||
'provider': get_last_provider(True)
|
||||
}) + "\n"
|
||||
if isinstance(chunk, Exception):
|
||||
logging.exception(chunk)
|
||||
yield json.dumps({
|
||||
'type' : 'message',
|
||||
'message': get_error_message(chunk),
|
||||
}) + "\n"
|
||||
else:
|
||||
yield json.dumps({
|
||||
'type' : 'content',
|
||||
'content': str(chunk),
|
||||
}) + "\n"
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield json.dumps({
|
||||
'type' : 'error',
|
||||
'error': get_error_message(e)
|
||||
})
|
||||
return {
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"ignore_stream_and_auth": True,
|
||||
"patch_provider": patch,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
return self.app.response_class(try_response(), mimetype='text/event-stream')
|
||||
def _create_response_stream(self, kwargs) -> Generator[str, None, None]:
|
||||
"""
|
||||
Creates and returns a streaming response for the conversation.
|
||||
|
||||
Args:
|
||||
kwargs (dict): Arguments for creating the chat completion.
|
||||
|
||||
Yields:
|
||||
str: JSON formatted response chunks for the stream.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs during the streaming process.
|
||||
"""
|
||||
try:
|
||||
first = True
|
||||
for chunk in ChatCompletion.create(**kwargs):
|
||||
if first:
|
||||
first = False
|
||||
yield self._format_json('provider', get_last_provider(True))
|
||||
if isinstance(chunk, Exception):
|
||||
logging.exception(chunk)
|
||||
yield self._format_json('message', get_error_message(chunk))
|
||||
else:
|
||||
yield self._format_json('content', str(chunk))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield self._format_json('error', get_error_message(e))
|
||||
|
||||
def _format_json(self, response_type: str, content) -> str:
|
||||
"""
|
||||
Formats and returns a JSON response.
|
||||
|
||||
Args:
|
||||
response_type (str): The type of the response.
|
||||
content: The content to be included in the response.
|
||||
|
||||
Returns:
|
||||
str: A JSON formatted string.
|
||||
"""
|
||||
return json.dumps({
|
||||
'type': response_type,
|
||||
response_type: content
|
||||
}) + "\n"
|
||||
|
||||
def get_error_message(exception: Exception) -> str:
|
||||
"""
|
||||
Generates a formatted error message from an exception.
|
||||
|
||||
Args:
|
||||
exception (Exception): The exception to format.
|
||||
|
||||
Returns:
|
||||
str: A formatted error message string.
|
||||
"""
|
||||
return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}"
|
105
g4f/image.py
105
g4f/image.py
@ -4,9 +4,18 @@ import base64
|
||||
from .typing import ImageType, Union
|
||||
from PIL import Image
|
||||
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
|
||||
|
||||
def to_image(image: ImageType) -> Image.Image:
|
||||
"""
|
||||
Converts the input image to a PIL Image object.
|
||||
|
||||
Args:
|
||||
image (Union[str, bytes, Image.Image]): The input image.
|
||||
|
||||
Returns:
|
||||
Image.Image: The converted PIL Image object.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
is_data_uri_an_image(image)
|
||||
image = extract_data_uri(image)
|
||||
@ -20,21 +29,48 @@ def to_image(image: ImageType) -> Image.Image:
|
||||
image = copy
|
||||
return image
|
||||
|
||||
def is_allowed_extension(filename) -> bool:
|
||||
def is_allowed_extension(filename: str) -> bool:
|
||||
"""
|
||||
Checks if the given filename has an allowed extension.
|
||||
|
||||
Args:
|
||||
filename (str): The filename to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the extension is allowed, False otherwise.
|
||||
"""
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
def is_data_uri_an_image(data_uri: str) -> bool:
|
||||
"""
|
||||
Checks if the given data URI represents an image.
|
||||
|
||||
Args:
|
||||
data_uri (str): The data URI to check.
|
||||
|
||||
Raises:
|
||||
ValueError: If the data URI is invalid or the image format is not allowed.
|
||||
"""
|
||||
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
|
||||
if not re.match(r'data:image/(\w+);base64,', data_uri):
|
||||
raise ValueError("Invalid data URI image.")
|
||||
# Extract the image format from the data URI
|
||||
# Extract the image format from the data URI
|
||||
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
|
||||
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
|
||||
if image_format.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise ValueError("Invalid image format (from mime file type).")
|
||||
|
||||
def is_accepted_format(binary_data: bytes) -> bool:
|
||||
"""
|
||||
Checks if the given binary data represents an image with an accepted format.
|
||||
|
||||
Args:
|
||||
binary_data (bytes): The binary data to check.
|
||||
|
||||
Raises:
|
||||
ValueError: If the image format is not allowed.
|
||||
"""
|
||||
if binary_data.startswith(b'\xFF\xD8\xFF'):
|
||||
pass # It's a JPEG image
|
||||
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
|
||||
@ -49,13 +85,31 @@ def is_accepted_format(binary_data: bytes) -> bool:
|
||||
pass # It's a WebP image
|
||||
else:
|
||||
raise ValueError("Invalid image format (from magic code).")
|
||||
|
||||
|
||||
def extract_data_uri(data_uri: str) -> bytes:
|
||||
"""
|
||||
Extracts the binary data from the given data URI.
|
||||
|
||||
Args:
|
||||
data_uri (str): The data URI.
|
||||
|
||||
Returns:
|
||||
bytes: The extracted binary data.
|
||||
"""
|
||||
data = data_uri.split(",")[1]
|
||||
data = base64.b64decode(data)
|
||||
return data
|
||||
|
||||
def get_orientation(image: Image.Image) -> int:
|
||||
"""
|
||||
Gets the orientation of the given image.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image.
|
||||
|
||||
Returns:
|
||||
int: The orientation value.
|
||||
"""
|
||||
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
|
||||
if exif_data is not None:
|
||||
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
|
||||
@ -63,6 +117,17 @@ def get_orientation(image: Image.Image) -> int:
|
||||
return orientation
|
||||
|
||||
def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
|
||||
"""
|
||||
Processes the given image by adjusting its orientation and resizing it.
|
||||
|
||||
Args:
|
||||
img (Image.Image): The image to process.
|
||||
new_width (int): The new width of the image.
|
||||
new_height (int): The new height of the image.
|
||||
|
||||
Returns:
|
||||
Image.Image: The processed image.
|
||||
"""
|
||||
orientation = get_orientation(img)
|
||||
if orientation:
|
||||
if orientation > 4:
|
||||
@ -75,13 +140,34 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
|
||||
img = img.transpose(Image.ROTATE_90)
|
||||
img.thumbnail((new_width, new_height))
|
||||
return img
|
||||
|
||||
|
||||
def to_base64(image: Image.Image, compression_rate: float) -> str:
|
||||
"""
|
||||
Converts the given image to a base64-encoded string.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to convert.
|
||||
compression_rate (float): The compression rate (0.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
str: The base64-encoded image.
|
||||
"""
|
||||
output_buffer = BytesIO()
|
||||
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
|
||||
return base64.b64encode(output_buffer.getvalue()).decode()
|
||||
|
||||
def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
|
||||
"""
|
||||
Formats the given images as a markdown string.
|
||||
|
||||
Args:
|
||||
images: The images to format.
|
||||
prompt (str): The prompt for the images.
|
||||
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown string.
|
||||
"""
|
||||
if isinstance(images, list):
|
||||
images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
|
||||
images = "\n".join(images)
|
||||
@ -92,6 +178,15 @@ def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=20
|
||||
return f"\n{start_flag}{images}\n{end_flag}\n"
|
||||
|
||||
def to_bytes(image: Image.Image) -> bytes:
|
||||
"""
|
||||
Converts the given image to bytes.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to convert.
|
||||
|
||||
Returns:
|
||||
bytes: The image as bytes.
|
||||
"""
|
||||
bytes_io = BytesIO()
|
||||
image.save(bytes_io, image.format)
|
||||
image.seek(0)
|
||||
|
@ -31,12 +31,21 @@ from .Provider import (
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Model:
|
||||
"""
|
||||
Represents a machine learning model configuration.
|
||||
|
||||
Attributes:
|
||||
name (str): Name of the model.
|
||||
base_provider (str): Default provider for the model.
|
||||
best_provider (ProviderType): The preferred provider for the model, typically with retry logic.
|
||||
"""
|
||||
name: str
|
||||
base_provider: str
|
||||
best_provider: ProviderType = None
|
||||
|
||||
@staticmethod
|
||||
def __all__() -> list[str]:
|
||||
"""Returns a list of all model names."""
|
||||
return _all_models
|
||||
|
||||
default = Model(
|
||||
@ -298,6 +307,12 @@ pi = Model(
|
||||
)
|
||||
|
||||
class ModelUtils:
|
||||
"""
|
||||
Utility class for mapping string identifiers to Model instances.
|
||||
|
||||
Attributes:
|
||||
convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances.
|
||||
"""
|
||||
convert: dict[str, Model] = {
|
||||
# gpt-3.5
|
||||
'gpt-3.5-turbo' : gpt_35_turbo,
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partialmethod
|
||||
from typing import AsyncGenerator
|
||||
from urllib.parse import urlparse
|
||||
@ -9,27 +8,41 @@ from curl_cffi.requests import AsyncSession, Session, Response
|
||||
from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies
|
||||
|
||||
class StreamResponse:
|
||||
"""
|
||||
A wrapper class for handling asynchronous streaming responses.
|
||||
|
||||
Attributes:
|
||||
inner (Response): The original Response object.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: Response) -> None:
|
||||
"""Initialize the StreamResponse with the provided Response object."""
|
||||
self.inner: Response = inner
|
||||
|
||||
async def text(self) -> str:
|
||||
"""Asynchronously get the response text."""
|
||||
return await self.inner.atext()
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""Raise an HTTPError if one occurred."""
|
||||
self.inner.raise_for_status()
|
||||
|
||||
async def json(self, **kwargs) -> dict:
|
||||
"""Asynchronously parse the JSON response content."""
|
||||
return json.loads(await self.inner.acontent(), **kwargs)
|
||||
|
||||
async def iter_lines(self) -> AsyncGenerator[bytes, None]:
|
||||
"""Asynchronously iterate over the lines of the response."""
|
||||
async for line in self.inner.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def iter_content(self) -> AsyncGenerator[bytes, None]:
|
||||
"""Asynchronously iterate over the response content."""
|
||||
async for chunk in self.inner.aiter_content():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Asynchronously enter the runtime context for the response object."""
|
||||
inner: Response = await self.inner
|
||||
self.inner = inner
|
||||
self.request = inner.request
|
||||
@ -39,24 +52,47 @@ class StreamResponse:
|
||||
self.headers = inner.headers
|
||||
self.cookies = inner.cookies
|
||||
return self
|
||||
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
"""Asynchronously exit the runtime context for the response object."""
|
||||
await self.inner.aclose()
|
||||
|
||||
|
||||
class StreamSession(AsyncSession):
|
||||
"""
|
||||
An asynchronous session class for handling HTTP requests with streaming.
|
||||
|
||||
Inherits from AsyncSession.
|
||||
"""
|
||||
|
||||
def request(
|
||||
self, method: str, url: str, **kwargs
|
||||
) -> StreamResponse:
|
||||
"""Create and return a StreamResponse object for the given HTTP request."""
|
||||
return StreamResponse(super().request(method, url, stream=True, **kwargs))
|
||||
|
||||
# Defining HTTP methods as partial methods of the request method.
|
||||
head = partialmethod(request, "HEAD")
|
||||
get = partialmethod(request, "GET")
|
||||
post = partialmethod(request, "POST")
|
||||
put = partialmethod(request, "PUT")
|
||||
patch = partialmethod(request, "PATCH")
|
||||
delete = partialmethod(request, "DELETE")
|
||||
|
||||
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120):
|
||||
|
||||
|
||||
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session:
|
||||
"""
|
||||
Create a Session object using a WebDriver to handle cookies and headers.
|
||||
|
||||
Args:
|
||||
url (str): The URL to navigate to using the WebDriver.
|
||||
webdriver (WebDriver, optional): The WebDriver instance to use.
|
||||
proxy (str, optional): Proxy server to use for the Session.
|
||||
timeout (int, optional): Timeout in seconds for the WebDriver.
|
||||
|
||||
Returns:
|
||||
Session: A Session object configured with cookies and headers from the WebDriver.
|
||||
"""
|
||||
with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=True) as driver:
|
||||
bypass_cloudflare(driver, url, timeout)
|
||||
cookies = get_driver_cookies(driver)
|
||||
@ -78,4 +114,4 @@ def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str =
|
||||
proxies={"https": proxy, "http": proxy},
|
||||
timeout=timeout,
|
||||
impersonate="chrome110"
|
||||
)
|
||||
)
|
101
g4f/version.py
101
g4f/version.py
@ -5,45 +5,120 @@ from importlib.metadata import version as get_package_version, PackageNotFoundEr
|
||||
from subprocess import check_output, CalledProcessError, PIPE
|
||||
from .errors import VersionNotFoundError
|
||||
|
||||
def get_latest_version() -> str:
|
||||
try:
|
||||
get_package_version("g4f")
|
||||
response = requests.get("https://pypi.org/pypi/g4f/json").json()
|
||||
return response["info"]["version"]
|
||||
except PackageNotFoundError:
|
||||
url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest"
|
||||
response = requests.get(url).json()
|
||||
return response["tag_name"]
|
||||
def get_pypi_version(package_name: str) -> str:
|
||||
"""
|
||||
Retrieves the latest version of a package from PyPI.
|
||||
|
||||
class VersionUtils():
|
||||
Args:
|
||||
package_name (str): The name of the package for which to retrieve the version.
|
||||
|
||||
Returns:
|
||||
str: The latest version of the specified package from PyPI.
|
||||
|
||||
Raises:
|
||||
VersionNotFoundError: If there is an error in fetching the version from PyPI.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
|
||||
return response["info"]["version"]
|
||||
except requests.RequestException as e:
|
||||
raise VersionNotFoundError(f"Failed to get PyPI version: {e}")
|
||||
|
||||
def get_github_version(repo: str) -> str:
|
||||
"""
|
||||
Retrieves the latest release version from a GitHub repository.
|
||||
|
||||
Args:
|
||||
repo (str): The name of the GitHub repository.
|
||||
|
||||
Returns:
|
||||
str: The latest release version from the specified GitHub repository.
|
||||
|
||||
Raises:
|
||||
VersionNotFoundError: If there is an error in fetching the version from GitHub.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
|
||||
return response["tag_name"]
|
||||
except requests.RequestException as e:
|
||||
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
|
||||
|
||||
def get_latest_version() -> str:
|
||||
"""
|
||||
Retrieves the latest release version of the 'g4f' package from PyPI or GitHub.
|
||||
|
||||
Returns:
|
||||
str: The latest release version of 'g4f'.
|
||||
|
||||
Note:
|
||||
The function first tries to fetch the version from PyPI. If the package is not found,
|
||||
it retrieves the version from the GitHub repository.
|
||||
"""
|
||||
try:
|
||||
# Is installed via package manager?
|
||||
get_package_version("g4f")
|
||||
return get_pypi_version("g4f")
|
||||
except PackageNotFoundError:
|
||||
# Else use Github version:
|
||||
return get_github_version("xtekky/gpt4free")
|
||||
|
||||
class VersionUtils:
|
||||
"""
|
||||
Utility class for managing and comparing package versions of 'g4f'.
|
||||
"""
|
||||
@cached_property
|
||||
def current_version(self) -> str:
|
||||
"""
|
||||
Retrieves the current version of the 'g4f' package.
|
||||
|
||||
Returns:
|
||||
str: The current version of 'g4f'.
|
||||
|
||||
Raises:
|
||||
VersionNotFoundError: If the version cannot be determined from the package manager,
|
||||
Docker environment, or git repository.
|
||||
"""
|
||||
# Read from package manager
|
||||
try:
|
||||
return get_package_version("g4f")
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
# Read from docker environment
|
||||
version = environ.get("G4F_VERSION")
|
||||
if version:
|
||||
return version
|
||||
|
||||
# Read from git repository
|
||||
try:
|
||||
command = ["git", "describe", "--tags", "--abbrev=0"]
|
||||
return check_output(command, text=True, stderr=PIPE).strip()
|
||||
except CalledProcessError:
|
||||
pass
|
||||
|
||||
raise VersionNotFoundError("Version not found")
|
||||
|
||||
|
||||
@cached_property
|
||||
def latest_version(self) -> str:
|
||||
"""
|
||||
Retrieves the latest version of the 'g4f' package.
|
||||
|
||||
Returns:
|
||||
str: The latest version of 'g4f'.
|
||||
"""
|
||||
return get_latest_version()
|
||||
|
||||
|
||||
def check_version(self) -> None:
|
||||
"""
|
||||
Checks if the current version of 'g4f' is up to date with the latest version.
|
||||
|
||||
Note:
|
||||
If a newer version is available, it prints a message with the new version and update instructions.
|
||||
"""
|
||||
try:
|
||||
if self.current_version != self.latest_version:
|
||||
print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f')
|
||||
except Exception as e:
|
||||
print(f'Failed to check g4f version: {e}')
|
||||
|
||||
|
||||
utils = VersionUtils()
|
111
g4f/webdriver.py
111
g4f/webdriver.py
@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from platformdirs import user_config_dir
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from undetected_chromedriver import Chrome, ChromeOptions
|
||||
@ -21,7 +20,19 @@ def get_browser(
|
||||
proxy: str = None,
|
||||
options: ChromeOptions = None
|
||||
) -> WebDriver:
|
||||
if user_data_dir == None:
|
||||
"""
|
||||
Creates and returns a Chrome WebDriver with specified options.
|
||||
|
||||
Args:
|
||||
user_data_dir (str, optional): Directory for user data. If None, uses default directory.
|
||||
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
|
||||
proxy (str, optional): Proxy settings for the browser. Defaults to None.
|
||||
options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None.
|
||||
|
||||
Returns:
|
||||
WebDriver: An instance of WebDriver configured with the specified options.
|
||||
"""
|
||||
if user_data_dir is None:
|
||||
user_data_dir = user_config_dir("g4f")
|
||||
if user_data_dir and debug.logging:
|
||||
print("Open browser with config dir:", user_data_dir)
|
||||
@ -39,36 +50,53 @@ def get_browser(
|
||||
headless=headless
|
||||
)
|
||||
|
||||
def get_driver_cookies(driver: WebDriver):
|
||||
return dict([(cookie["name"], cookie["value"]) for cookie in driver.get_cookies()])
|
||||
def get_driver_cookies(driver: WebDriver) -> dict:
|
||||
"""
|
||||
Retrieves cookies from the specified WebDriver.
|
||||
|
||||
Args:
|
||||
driver (WebDriver): The WebDriver instance from which to retrieve cookies.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing cookies with their names as keys and values as cookie values.
|
||||
"""
|
||||
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
|
||||
|
||||
def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
|
||||
# Open website
|
||||
"""
|
||||
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
|
||||
|
||||
Args:
|
||||
driver (WebDriver): The WebDriver to use for accessing the URL.
|
||||
url (str): The URL to access.
|
||||
timeout (int): Time in seconds to wait for the page to load.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error while bypassing Cloudflare or loading the page.
|
||||
"""
|
||||
driver.get(url)
|
||||
# Is cloudflare protection
|
||||
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
|
||||
if debug.logging:
|
||||
print("Cloudflare protection detected:", url)
|
||||
try:
|
||||
# Click button in iframe
|
||||
WebDriverWait(driver, 5).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
|
||||
)
|
||||
driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
|
||||
WebDriverWait(driver, 5).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input"))
|
||||
)
|
||||
driver.find_element(By.CSS_SELECTOR, "#challenge-stage input").click()
|
||||
except:
|
||||
pass
|
||||
).click()
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error bypassing Cloudflare: {e}")
|
||||
finally:
|
||||
driver.switch_to.default_content()
|
||||
# No cloudflare protection
|
||||
WebDriverWait(driver, timeout).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
|
||||
)
|
||||
|
||||
class WebDriverSession():
|
||||
class WebDriverSession:
|
||||
"""
|
||||
Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
webdriver: WebDriver = None,
|
||||
@ -78,12 +106,21 @@ class WebDriverSession():
|
||||
proxy: str = None,
|
||||
options: ChromeOptions = None
|
||||
):
|
||||
"""
|
||||
Initializes a new instance of the WebDriverSession.
|
||||
|
||||
Args:
|
||||
webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None.
|
||||
user_data_dir (str, optional): Directory for user data. Defaults to None.
|
||||
headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
|
||||
virtual_display (bool, optional): Whether to use a virtual display. Defaults to False.
|
||||
proxy (str, optional): Proxy settings for the browser. Defaults to None.
|
||||
options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None.
|
||||
"""
|
||||
self.webdriver = webdriver
|
||||
self.user_data_dir = user_data_dir
|
||||
self.headless = headless
|
||||
self.virtual_display = None
|
||||
if has_pyvirtualdisplay and virtual_display:
|
||||
self.virtual_display = Display(size=(1920, 1080))
|
||||
self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None
|
||||
self.proxy = proxy
|
||||
self.options = options
|
||||
self.default_driver = None
|
||||
@ -94,8 +131,18 @@ class WebDriverSession():
|
||||
headless: bool = False,
|
||||
virtual_display: bool = False
|
||||
) -> WebDriver:
|
||||
if user_data_dir == None:
|
||||
user_data_dir = self.user_data_dir
|
||||
"""
|
||||
Reopens the WebDriver session with new settings.
|
||||
|
||||
Args:
|
||||
user_data_dir (str, optional): Directory for user data. Defaults to current value.
|
||||
headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value.
|
||||
virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value.
|
||||
|
||||
Returns:
|
||||
WebDriver: The reopened WebDriver instance.
|
||||
"""
|
||||
user_data_dir = user_data_data_dir or self.user_data_dir
|
||||
if self.default_driver:
|
||||
self.default_driver.quit()
|
||||
if not virtual_display and self.virtual_display:
|
||||
@ -105,6 +152,12 @@ class WebDriverSession():
|
||||
return self.default_driver
|
||||
|
||||
def __enter__(self) -> WebDriver:
|
||||
"""
|
||||
Context management method for entering a session. Initializes and returns a WebDriver instance.
|
||||
|
||||
Returns:
|
||||
WebDriver: An instance of WebDriver for this session.
|
||||
"""
|
||||
if self.webdriver:
|
||||
return self.webdriver
|
||||
if self.virtual_display:
|
||||
@ -113,11 +166,23 @@ class WebDriverSession():
|
||||
return self.default_driver
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""
|
||||
Context management method for exiting a session. Closes and quits the WebDriver.
|
||||
|
||||
Args:
|
||||
exc_type: Exception type.
|
||||
exc_val: Exception value.
|
||||
exc_tb: Exception traceback.
|
||||
|
||||
Note:
|
||||
Closes the WebDriver and stops the virtual display if used.
|
||||
"""
|
||||
if self.default_driver:
|
||||
try:
|
||||
self.default_driver.close()
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error closing WebDriver: {e}")
|
||||
self.default_driver.quit()
|
||||
if self.virtual_display:
|
||||
self.virtual_display.stop()
|
Loading…
Reference in New Issue
Block a user