mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-23 00:22:09 +03:00
Refactor code with AI
Add doctypes to many functions Add file upload for text files Add alternative url to FreeChatgpt Add webp to allowed image types
This commit is contained in:
parent
ceed364cb1
commit
5756586cde
@ -15,12 +15,18 @@ from .bing.upload_image import upload_image
|
||||
from .bing.create_images import create_images
|
||||
from .bing.conversation import Conversation, create_conversation, delete_conversation
|
||||
|
||||
class Tones():
|
||||
class Tones:
|
||||
"""
|
||||
Defines the different tone options for the Bing provider.
|
||||
"""
|
||||
creative = "Creative"
|
||||
balanced = "Balanced"
|
||||
precise = "Precise"
|
||||
|
||||
class Bing(AsyncGeneratorProvider):
|
||||
"""
|
||||
Bing provider for generating responses using the Bing API.
|
||||
"""
|
||||
url = "https://bing.com/chat"
|
||||
working = True
|
||||
supports_message_history = True
|
||||
@ -38,6 +44,19 @@ class Bing(AsyncGeneratorProvider):
|
||||
web_search: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
"""
|
||||
Creates an asynchronous generator for producing responses from Bing.
|
||||
|
||||
:param model: The model to use.
|
||||
:param messages: Messages to process.
|
||||
:param proxy: Proxy to use for requests.
|
||||
:param timeout: Timeout for requests.
|
||||
:param cookies: Cookies for the session.
|
||||
:param tone: The tone of the response.
|
||||
:param image: The image type to be used.
|
||||
:param web_search: Flag to enable or disable web search.
|
||||
:return: An asynchronous result object.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
prompt = messages[0]["content"]
|
||||
context = None
|
||||
@ -56,65 +75,48 @@ class Bing(AsyncGeneratorProvider):
|
||||
|
||||
return stream_generate(prompt, tone, image, context, proxy, cookies, web_search, gpt4_turbo, timeout)
|
||||
|
||||
def create_context(messages: Messages):
|
||||
def create_context(messages: Messages) -> str:
|
||||
"""
|
||||
Creates a context string from a list of messages.
|
||||
|
||||
:param messages: A list of message dictionaries.
|
||||
:return: A string representing the context created from the messages.
|
||||
"""
|
||||
return "".join(
|
||||
f"[{message['role']}]" + ("(#message)" if message['role']!="system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
|
||||
f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}\n\n"
|
||||
for message in messages
|
||||
)
|
||||
|
||||
class Defaults:
|
||||
"""
|
||||
Default settings and configurations for the Bing provider.
|
||||
"""
|
||||
delimiter = "\x1e"
|
||||
ip_address = f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}"
|
||||
|
||||
# List of allowed message types for Bing responses
|
||||
allowedMessageTypes = [
|
||||
"ActionRequest",
|
||||
"Chat",
|
||||
"Context",
|
||||
# "Disengaged", unwanted
|
||||
"Progress",
|
||||
# "AdsQuery", unwanted
|
||||
"SemanticSerp",
|
||||
"GenerateContentQuery",
|
||||
"SearchQuery",
|
||||
# The following message types should not be added so that it does not flood with
|
||||
# useless messages (such as "Analyzing images" or "Searching the web") while it's retrieving the AI response
|
||||
# "InternalSearchQuery",
|
||||
# "InternalSearchResult",
|
||||
"RenderCardRequest",
|
||||
# "RenderContentRequest"
|
||||
"ActionRequest", "Chat", "Context", "Progress", "SemanticSerp",
|
||||
"GenerateContentQuery", "SearchQuery", "RenderCardRequest"
|
||||
]
|
||||
|
||||
sliceIds = [
|
||||
'abv2',
|
||||
'srdicton',
|
||||
'convcssclick',
|
||||
'stylewv2',
|
||||
'contctxp2tf',
|
||||
'802fluxv1pc_a',
|
||||
'806log2sphs0',
|
||||
'727savemem',
|
||||
'277teditgnds0',
|
||||
'207hlthgrds0',
|
||||
'abv2', 'srdicton', 'convcssclick', 'stylewv2', 'contctxp2tf',
|
||||
'802fluxv1pc_a', '806log2sphs0', '727savemem', '277teditgnds0', '207hlthgrds0'
|
||||
]
|
||||
|
||||
# Default location settings
|
||||
location = {
|
||||
"locale": "en-US",
|
||||
"market": "en-US",
|
||||
"region": "US",
|
||||
"locationHints": [
|
||||
{
|
||||
"country": "United States",
|
||||
"state": "California",
|
||||
"city": "Los Angeles",
|
||||
"timezoneoffset": 8,
|
||||
"countryConfidence": 8,
|
||||
"Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
|
||||
"RegionType": 2,
|
||||
"SourceType": 1,
|
||||
}
|
||||
],
|
||||
"locale": "en-US", "market": "en-US", "region": "US",
|
||||
"locationHints": [{
|
||||
"country": "United States", "state": "California", "city": "Los Angeles",
|
||||
"timezoneoffset": 8, "countryConfidence": 8,
|
||||
"Center": {"Latitude": 34.0536909, "Longitude": -118.242766},
|
||||
"RegionType": 2, "SourceType": 1
|
||||
}],
|
||||
}
|
||||
|
||||
# Default headers for requests
|
||||
headers = {
|
||||
'accept': '*/*',
|
||||
'accept-language': 'en-US,en;q=0.9',
|
||||
@ -139,23 +141,13 @@ class Defaults:
|
||||
}
|
||||
|
||||
optionsSets = [
|
||||
'nlu_direct_response_filter',
|
||||
'deepleo',
|
||||
'disable_emoji_spoken_text',
|
||||
'responsible_ai_policy_235',
|
||||
'enablemm',
|
||||
'iyxapbing',
|
||||
'iycapbing',
|
||||
'gencontentv3',
|
||||
'fluxsrtrunc',
|
||||
'fluxtrunc',
|
||||
'fluxv1',
|
||||
'rai278',
|
||||
'replaceurl',
|
||||
'eredirecturl',
|
||||
'nojbfedge'
|
||||
'nlu_direct_response_filter', 'deepleo', 'disable_emoji_spoken_text',
|
||||
'responsible_ai_policy_235', 'enablemm', 'iyxapbing', 'iycapbing',
|
||||
'gencontentv3', 'fluxsrtrunc', 'fluxtrunc', 'fluxv1', 'rai278',
|
||||
'replaceurl', 'eredirecturl', 'nojbfedge'
|
||||
]
|
||||
|
||||
# Default cookies
|
||||
cookies = {
|
||||
'SRCHD' : 'AF=NOFORM',
|
||||
'PPLState' : '1',
|
||||
@ -166,6 +158,12 @@ class Defaults:
|
||||
}
|
||||
|
||||
def format_message(msg: dict) -> str:
|
||||
"""
|
||||
Formats a message dictionary into a JSON string with a delimiter.
|
||||
|
||||
:param msg: The message dictionary to format.
|
||||
:return: A formatted string representation of the message.
|
||||
"""
|
||||
return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter
|
||||
|
||||
def create_message(
|
||||
@ -177,7 +175,20 @@ def create_message(
|
||||
web_search: bool = False,
|
||||
gpt4_turbo: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Creates a message for the Bing API with specified parameters.
|
||||
|
||||
:param conversation: The current conversation object.
|
||||
:param prompt: The user's input prompt.
|
||||
:param tone: The desired tone for the response.
|
||||
:param context: Additional context for the prompt.
|
||||
:param image_response: The response if an image is involved.
|
||||
:param web_search: Flag to enable web search.
|
||||
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
|
||||
:return: A formatted string message for the Bing API.
|
||||
"""
|
||||
options_sets = Defaults.optionsSets
|
||||
# Append tone-specific options
|
||||
if tone == Tones.creative:
|
||||
options_sets.append("h3imaginative")
|
||||
elif tone == Tones.precise:
|
||||
@ -186,54 +197,49 @@ def create_message(
|
||||
options_sets.append("galileo")
|
||||
else:
|
||||
options_sets.append("harmonyv3")
|
||||
|
||||
|
||||
# Additional configurations based on parameters
|
||||
if not web_search:
|
||||
options_sets.append("nosearchall")
|
||||
|
||||
if gpt4_turbo:
|
||||
options_sets.append("dlgpt4t")
|
||||
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
struct = {
|
||||
'arguments': [
|
||||
{
|
||||
'source': 'cib',
|
||||
'optionsSets': options_sets,
|
||||
'allowedMessageTypes': Defaults.allowedMessageTypes,
|
||||
'sliceIds': Defaults.sliceIds,
|
||||
'traceId': os.urandom(16).hex(),
|
||||
'isStartOfSession': True,
|
||||
'arguments': [{
|
||||
'source': 'cib', 'optionsSets': options_sets,
|
||||
'allowedMessageTypes': Defaults.allowedMessageTypes,
|
||||
'sliceIds': Defaults.sliceIds,
|
||||
'traceId': os.urandom(16).hex(), 'isStartOfSession': True,
|
||||
'requestId': request_id,
|
||||
'message': {
|
||||
**Defaults.location,
|
||||
'author': 'user',
|
||||
'inputMethod': 'Keyboard',
|
||||
'text': prompt,
|
||||
'messageType': 'Chat',
|
||||
'requestId': request_id,
|
||||
'message': {**Defaults.location, **{
|
||||
'author': 'user',
|
||||
'inputMethod': 'Keyboard',
|
||||
'text': prompt,
|
||||
'messageType': 'Chat',
|
||||
'requestId': request_id,
|
||||
'messageId': request_id,
|
||||
}},
|
||||
"verbosity": "verbose",
|
||||
"scenario": "SERP",
|
||||
"plugins":[
|
||||
{"id":"c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}
|
||||
] if web_search else [],
|
||||
'tone': tone,
|
||||
'spokenTextMode': 'None',
|
||||
'conversationId': conversation.conversationId,
|
||||
'participant': {
|
||||
'id': conversation.clientId
|
||||
},
|
||||
}
|
||||
],
|
||||
'messageId': request_id
|
||||
},
|
||||
"verbosity": "verbose",
|
||||
"scenario": "SERP",
|
||||
"plugins": [{"id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}] if web_search else [],
|
||||
'tone': tone,
|
||||
'spokenTextMode': 'None',
|
||||
'conversationId': conversation.conversationId,
|
||||
'participant': {'id': conversation.clientId},
|
||||
}],
|
||||
'invocationId': '1',
|
||||
'target': 'chat',
|
||||
'type': 4
|
||||
}
|
||||
if image_response.get('imageUrl') and image_response.get('originalImageUrl'):
|
||||
|
||||
if image_response and image_response.get('imageUrl') and image_response.get('originalImageUrl'):
|
||||
struct['arguments'][0]['message']['originalImageUrl'] = image_response.get('originalImageUrl')
|
||||
struct['arguments'][0]['message']['imageUrl'] = image_response.get('imageUrl')
|
||||
struct['arguments'][0]['experienceType'] = None
|
||||
struct['arguments'][0]['attachedFileInfo'] = {"fileName": None, "fileType": None}
|
||||
|
||||
if context:
|
||||
struct['arguments'][0]['previousMessages'] = [{
|
||||
"author": "user",
|
||||
@ -242,30 +248,46 @@ def create_message(
|
||||
"messageType": "Context",
|
||||
"messageId": "discover-web--page-ping-mriduna-----"
|
||||
}]
|
||||
|
||||
return format_message(struct)
|
||||
|
||||
async def stream_generate(
|
||||
prompt: str,
|
||||
tone: str,
|
||||
image: ImageType = None,
|
||||
context: str = None,
|
||||
proxy: str = None,
|
||||
cookies: dict = None,
|
||||
web_search: bool = False,
|
||||
gpt4_turbo: bool = False,
|
||||
timeout: int = 900
|
||||
):
|
||||
prompt: str,
|
||||
tone: str,
|
||||
image: ImageType = None,
|
||||
context: str = None,
|
||||
proxy: str = None,
|
||||
cookies: dict = None,
|
||||
web_search: bool = False,
|
||||
gpt4_turbo: bool = False,
|
||||
timeout: int = 900
|
||||
):
|
||||
"""
|
||||
Asynchronously streams generated responses from the Bing API.
|
||||
|
||||
:param prompt: The user's input prompt.
|
||||
:param tone: The desired tone for the response.
|
||||
:param image: The image type involved in the response.
|
||||
:param context: Additional context for the prompt.
|
||||
:param proxy: Proxy settings for the request.
|
||||
:param cookies: Cookies for the session.
|
||||
:param web_search: Flag to enable web search.
|
||||
:param gpt4_turbo: Flag to enable GPT-4 Turbo.
|
||||
:param timeout: Timeout for the request.
|
||||
:return: An asynchronous generator yielding responses.
|
||||
"""
|
||||
headers = Defaults.headers
|
||||
if cookies:
|
||||
headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
|
||||
|
||||
async with ClientSession(
|
||||
timeout=ClientTimeout(total=timeout),
|
||||
headers=headers
|
||||
timeout=ClientTimeout(total=timeout), headers=headers
|
||||
) as session:
|
||||
conversation = await create_conversation(session, proxy)
|
||||
image_response = await upload_image(session, image, tone, proxy) if image else None
|
||||
if image_response:
|
||||
yield image_response
|
||||
|
||||
try:
|
||||
async with session.ws_connect(
|
||||
'wss://sydney.bing.com/sydney/ChatHub',
|
||||
@ -289,7 +311,7 @@ async def stream_generate(
|
||||
if obj is None or not obj:
|
||||
continue
|
||||
response = json.loads(obj)
|
||||
if response.get('type') == 1 and response['arguments'][0].get('messages'):
|
||||
if response and response.get('type') == 1 and response['arguments'][0].get('messages'):
|
||||
message = response['arguments'][0]['messages'][0]
|
||||
image_response = None
|
||||
if (message['contentOrigin'] != 'Apology'):
|
||||
|
@ -1,16 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import json, random
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from .base_provider import AsyncGeneratorProvider
|
||||
|
||||
|
||||
models = {
|
||||
"claude-v2": "claude-2.0",
|
||||
"gemini-pro": "google-gemini-pro"
|
||||
"claude-v2": "claude-2.0",
|
||||
"claude-v2.1":"claude-2.1",
|
||||
"gemini-pro": "google-gemini-pro"
|
||||
}
|
||||
urls = [
|
||||
"https://free.chatgpt.org.uk",
|
||||
"https://ai.chatgpt.org.uk"
|
||||
]
|
||||
|
||||
class FreeChatgpt(AsyncGeneratorProvider):
|
||||
url = "https://free.chatgpt.org.uk"
|
||||
@ -31,6 +35,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
|
||||
model = models[model]
|
||||
elif not model:
|
||||
model = "gpt-3.5-turbo"
|
||||
url = random.choice(urls)
|
||||
headers = {
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"Content-Type":"application/json",
|
||||
@ -55,7 +60,7 @@ class FreeChatgpt(AsyncGeneratorProvider):
|
||||
"top_p":1,
|
||||
**kwargs
|
||||
}
|
||||
async with session.post(f'{cls.url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
|
||||
async with session.post(f'{url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
started = False
|
||||
async for line in response.content:
|
||||
|
@ -1,28 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from abc import abstractmethod
|
||||
from inspect import signature, Parameter
|
||||
from .helper import get_event_loop, get_cookies, format_prompt
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from ..base_provider import BaseProvider
|
||||
from abc import abstractmethod
|
||||
from inspect import signature, Parameter
|
||||
from .helper import get_event_loop, get_cookies, format_prompt
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from ..base_provider import BaseProvider
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
NoneType = type(None)
|
||||
else:
|
||||
from types import NoneType
|
||||
|
||||
# Change event loop policy on windows for curl_cffi
|
||||
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
|
||||
if sys.platform == 'win32':
|
||||
if isinstance(
|
||||
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
|
||||
):
|
||||
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
"""
|
||||
Abstract class for providing asynchronous functionality to derived classes.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls,
|
||||
@ -33,62 +34,50 @@ class AbstractProvider(BaseProvider):
|
||||
executor: ThreadPoolExecutor = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
if not loop:
|
||||
loop = get_event_loop()
|
||||
"""
|
||||
Asynchronously creates a result based on the given model and messages.
|
||||
"""
|
||||
loop = loop or get_event_loop()
|
||||
|
||||
def create_func() -> str:
|
||||
return "".join(cls.create_completion(
|
||||
model,
|
||||
messages,
|
||||
False,
|
||||
**kwargs
|
||||
))
|
||||
return "".join(cls.create_completion(model, messages, False, **kwargs))
|
||||
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
create_func
|
||||
),
|
||||
loop.run_in_executor(executor, create_func),
|
||||
timeout=kwargs.get("timeout", 0)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@property
|
||||
def params(cls) -> str:
|
||||
if issubclass(cls, AsyncGeneratorProvider):
|
||||
sig = signature(cls.create_async_generator)
|
||||
elif issubclass(cls, AsyncProvider):
|
||||
sig = signature(cls.create_async)
|
||||
else:
|
||||
sig = signature(cls.create_completion)
|
||||
"""
|
||||
Returns the parameters supported by the provider.
|
||||
"""
|
||||
sig = signature(
|
||||
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
|
||||
cls.create_async if issubclass(cls, AsyncProvider) else
|
||||
cls.create_completion
|
||||
)
|
||||
|
||||
def get_type_name(annotation: type) -> str:
|
||||
if hasattr(annotation, "__name__"):
|
||||
annotation = annotation.__name__
|
||||
elif isinstance(annotation, NoneType):
|
||||
annotation = "None"
|
||||
return str(annotation)
|
||||
|
||||
return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
|
||||
|
||||
args = ""
|
||||
for name, param in sig.parameters.items():
|
||||
if name in ("self", "kwargs"):
|
||||
if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
|
||||
continue
|
||||
if name == "stream" and not cls.supports_stream:
|
||||
continue
|
||||
if args:
|
||||
args += ", "
|
||||
args += "\n " + name
|
||||
if name != "model" and param.annotation is not Parameter.empty:
|
||||
args += f": {get_type_name(param.annotation)}"
|
||||
if param.default == "":
|
||||
args += ' = ""'
|
||||
elif param.default is not Parameter.empty:
|
||||
args += f" = {param.default}"
|
||||
args += f"\n {name}"
|
||||
args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
|
||||
args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
|
||||
|
||||
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
|
||||
|
||||
|
||||
class AsyncProvider(AbstractProvider):
|
||||
"""
|
||||
Provides asynchronous functionality for creating completions.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls,
|
||||
@ -99,8 +88,10 @@ class AsyncProvider(AbstractProvider):
|
||||
loop: AbstractEventLoop = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
if not loop:
|
||||
loop = get_event_loop()
|
||||
"""
|
||||
Creates a completion result synchronously.
|
||||
"""
|
||||
loop = loop or get_event_loop()
|
||||
coro = cls.create_async(model, messages, **kwargs)
|
||||
yield loop.run_until_complete(coro)
|
||||
|
||||
@ -111,10 +102,16 @@ class AsyncProvider(AbstractProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Abstract method for creating asynchronous results.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AsyncGeneratorProvider(AsyncProvider):
|
||||
"""
|
||||
Provides asynchronous generator functionality for streaming results.
|
||||
"""
|
||||
supports_stream = True
|
||||
|
||||
@classmethod
|
||||
@ -127,15 +124,13 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
loop: AbstractEventLoop = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
if not loop:
|
||||
loop = get_event_loop()
|
||||
generator = cls.create_async_generator(
|
||||
model,
|
||||
messages,
|
||||
stream=stream,
|
||||
**kwargs
|
||||
)
|
||||
"""
|
||||
Creates a streaming completion result synchronously.
|
||||
"""
|
||||
loop = loop or get_event_loop()
|
||||
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
gen = generator.__aiter__()
|
||||
|
||||
while True:
|
||||
try:
|
||||
yield loop.run_until_complete(gen.__anext__())
|
||||
@ -149,21 +144,23 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously creates a result from a generator.
|
||||
"""
|
||||
return "".join([
|
||||
chunk async for chunk in cls.create_async_generator(
|
||||
model,
|
||||
messages,
|
||||
stream=False,
|
||||
**kwargs
|
||||
) if not isinstance(chunk, Exception)
|
||||
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
|
||||
if not isinstance(chunk, Exception)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def create_async_generator(
|
||||
async def create_async_generator(
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
raise NotImplementedError()
|
||||
"""
|
||||
Abstract method for creating an asynchronous generator.
|
||||
"""
|
||||
raise NotImplementedError()
|
@ -1,13 +1,33 @@
|
||||
from aiohttp import ClientSession
|
||||
|
||||
|
||||
class Conversation():
|
||||
class Conversation:
|
||||
"""
|
||||
Represents a conversation with specific attributes.
|
||||
"""
|
||||
def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None:
|
||||
"""
|
||||
Initialize a new conversation instance.
|
||||
|
||||
Args:
|
||||
conversationId (str): Unique identifier for the conversation.
|
||||
clientId (str): Client identifier.
|
||||
conversationSignature (str): Signature for the conversation.
|
||||
"""
|
||||
self.conversationId = conversationId
|
||||
self.clientId = clientId
|
||||
self.conversationSignature = conversationSignature
|
||||
|
||||
async def create_conversation(session: ClientSession, proxy: str = None) -> Conversation:
|
||||
"""
|
||||
Create a new conversation asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
proxy (str, optional): Proxy URL. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Conversation: An instance representing the created conversation.
|
||||
"""
|
||||
url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1199.4'
|
||||
async with session.get(url, proxy=proxy) as response:
|
||||
try:
|
||||
@ -24,12 +44,32 @@ async def create_conversation(session: ClientSession, proxy: str = None) -> Conv
|
||||
return Conversation(conversationId, clientId, conversationSignature)
|
||||
|
||||
async def list_conversations(session: ClientSession) -> list:
|
||||
"""
|
||||
List all conversations asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
|
||||
Returns:
|
||||
list: A list of conversations.
|
||||
"""
|
||||
url = "https://www.bing.com/turing/conversation/chats"
|
||||
async with session.get(url) as response:
|
||||
response = await response.json()
|
||||
return response["chats"]
|
||||
|
||||
async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> bool:
|
||||
"""
|
||||
Delete a conversation asynchronously.
|
||||
|
||||
Args:
|
||||
session (ClientSession): An instance of aiohttp's ClientSession.
|
||||
conversation (Conversation): The conversation to delete.
|
||||
proxy (str, optional): Proxy URL. Defaults to None.
|
||||
|
||||
Returns:
|
||||
bool: True if deletion was successful, False otherwise.
|
||||
"""
|
||||
url = "https://sydney.bing.com/sydney/DeleteSingleConversation"
|
||||
json = {
|
||||
"conversationId": conversation.conversationId,
|
||||
|
@ -1,9 +1,16 @@
|
||||
"""
|
||||
This module provides functionalities for creating and managing images using Bing's service.
|
||||
It includes functions for user login, session creation, image creation, and processing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time, json, os
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from aiohttp import ClientSession
|
||||
from bs4 import BeautifulSoup
|
||||
from urllib.parse import quote
|
||||
from typing import Generator
|
||||
from typing import Generator, List, Dict
|
||||
|
||||
from ..create_images import CreateImagesProvider
|
||||
from ..helper import get_cookies, get_event_loop
|
||||
@ -12,23 +19,47 @@ from ...base_provider import ProviderType
|
||||
from ...image import format_images_markdown
|
||||
|
||||
BING_URL = "https://www.bing.com"
|
||||
TIMEOUT_LOGIN = 1200
|
||||
TIMEOUT_IMAGE_CREATION = 300
|
||||
ERRORS = [
|
||||
"this prompt is being reviewed",
|
||||
"this prompt has been blocked",
|
||||
"we're working hard to offer image creator in more languages",
|
||||
"we can't create your images right now"
|
||||
]
|
||||
BAD_IMAGES = [
|
||||
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||
]
|
||||
|
||||
def wait_for_login(driver: WebDriver, timeout: int = 1200) -> None:
|
||||
def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None:
|
||||
"""
|
||||
Waits for the user to log in within a given timeout period.
|
||||
|
||||
Args:
|
||||
driver (WebDriver): Webdriver for browser automation.
|
||||
timeout (int): Maximum waiting time in seconds.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the login process exceeds the timeout.
|
||||
"""
|
||||
driver.get(f"{BING_URL}/")
|
||||
value = driver.get_cookie("_U")
|
||||
if value:
|
||||
return
|
||||
start_time = time.time()
|
||||
while True:
|
||||
while not driver.get_cookie("_U"):
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError("Timeout error")
|
||||
value = driver.get_cookie("_U")
|
||||
if value:
|
||||
time.sleep(1)
|
||||
return
|
||||
time.sleep(0.5)
|
||||
|
||||
def create_session(cookies: dict) -> ClientSession:
|
||||
def create_session(cookies: Dict[str, str]) -> ClientSession:
|
||||
"""
|
||||
Creates a new client session with specified cookies and headers.
|
||||
|
||||
Args:
|
||||
cookies (Dict[str, str]): Cookies to be used for the session.
|
||||
|
||||
Returns:
|
||||
ClientSession: The created client session.
|
||||
"""
|
||||
headers = {
|
||||
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"accept-encoding": "gzip, deflate, br",
|
||||
@ -47,28 +78,32 @@ def create_session(cookies: dict) -> ClientSession:
|
||||
"upgrade-insecure-requests": "1",
|
||||
}
|
||||
if cookies:
|
||||
headers["cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
|
||||
headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items())
|
||||
return ClientSession(headers=headers)
|
||||
|
||||
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = 300) -> list:
|
||||
url_encoded_prompt = quote(prompt)
|
||||
async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = TIMEOUT_IMAGE_CREATION) -> List[str]:
|
||||
"""
|
||||
Creates images based on a given prompt using Bing's service.
|
||||
|
||||
Args:
|
||||
session (ClientSession): Active client session.
|
||||
prompt (str): Prompt to generate images.
|
||||
proxy (str, optional): Proxy configuration.
|
||||
timeout (int): Timeout for the request.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of URLs to the created images.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If image creation fails or times out.
|
||||
"""
|
||||
url_encoded_prompt = quote(prompt)
|
||||
payload = f"q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
|
||||
async with session.post(
|
||||
url,
|
||||
allow_redirects=False,
|
||||
data=payload,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
async with session.post(url, allow_redirects=False, data=payload, timeout=timeout) as response:
|
||||
response.raise_for_status()
|
||||
errors = [
|
||||
"this prompt is being reviewed",
|
||||
"this prompt has been blocked",
|
||||
"we're working hard to offer image creator in more languages",
|
||||
"we can't create your images right now"
|
||||
]
|
||||
text = (await response.text()).lower()
|
||||
for error in errors:
|
||||
for error in ERRORS:
|
||||
if error in text:
|
||||
raise RuntimeError(f"Create images failed: {error}")
|
||||
if response.status != 302:
|
||||
@ -107,54 +142,109 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
|
||||
raise RuntimeError(error)
|
||||
return read_images(text)
|
||||
|
||||
def read_images(text: str) -> list:
|
||||
html_soup = BeautifulSoup(text, "html.parser")
|
||||
tags = html_soup.find_all("img")
|
||||
image_links = [img["src"] for img in tags if "mimg" in img["class"]]
|
||||
images = [link.split("?w=")[0] for link in image_links]
|
||||
bad_images = [
|
||||
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
|
||||
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
|
||||
]
|
||||
if any(im in bad_images for im in images):
|
||||
def read_images(html_content: str) -> List[str]:
|
||||
"""
|
||||
Extracts image URLs from the HTML content.
|
||||
|
||||
Args:
|
||||
html_content (str): HTML content containing image URLs.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of image URLs.
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
tags = soup.find_all("img", class_="mimg")
|
||||
images = [img["src"].split("?w=")[0] for img in tags]
|
||||
if any(im in BAD_IMAGES for im in images):
|
||||
raise RuntimeError("Bad images found")
|
||||
if not images:
|
||||
raise RuntimeError("No images found")
|
||||
return images
|
||||
|
||||
async def create_images_markdown(cookies: dict, prompt: str, proxy: str = None) -> str:
|
||||
session = create_session(cookies)
|
||||
try:
|
||||
async def create_images_markdown(cookies: Dict[str, str], prompt: str, proxy: str = None) -> str:
|
||||
"""
|
||||
Creates markdown formatted string with images based on the prompt.
|
||||
|
||||
Args:
|
||||
cookies (Dict[str, str]): Cookies to be used for the session.
|
||||
prompt (str): Prompt to generate images.
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Returns:
|
||||
str: Markdown formatted string with images.
|
||||
"""
|
||||
async with create_session(cookies) as session:
|
||||
images = await create_images(session, prompt, proxy)
|
||||
return format_images_markdown(images, prompt)
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
def get_cookies_from_browser(proxy: str = None) -> dict:
|
||||
driver = get_browser(proxy=proxy)
|
||||
try:
|
||||
def get_cookies_from_browser(proxy: str = None) -> Dict[str, str]:
|
||||
"""
|
||||
Retrieves cookies from the browser using webdriver.
|
||||
|
||||
Args:
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Retrieved cookies.
|
||||
"""
|
||||
with get_browser(proxy=proxy) as driver:
|
||||
wait_for_login(driver)
|
||||
time.sleep(1)
|
||||
return get_driver_cookies(driver)
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
def create_completion(prompt: str, cookies: dict = None, proxy: str = None) -> Generator:
|
||||
loop = get_event_loop()
|
||||
if not cookies:
|
||||
cookies = get_cookies(".bing.com")
|
||||
if "_U" not in cookies:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"Please login: [Bing]({login_url})\n\n"
|
||||
cookies = get_cookies_from_browser(proxy)
|
||||
yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
|
||||
class CreateImagesBing:
|
||||
"""A class for creating images using Bing."""
|
||||
|
||||
async def create_async(prompt: str, cookies: dict = None, proxy: str = None) -> str:
|
||||
if not cookies:
|
||||
cookies = get_cookies(".bing.com")
|
||||
if "_U" not in cookies:
|
||||
cookies = get_cookies_from_browser(proxy)
|
||||
return await create_images_markdown(cookies, prompt, proxy)
|
||||
_cookies: Dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str]:
|
||||
"""
|
||||
Generator for creating imagecompletion based on a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): Prompt to generate images.
|
||||
cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Yields:
|
||||
Generator[str, None, None]: The final output as markdown formatted string with images.
|
||||
"""
|
||||
loop = get_event_loop()
|
||||
cookies = cookies or cls._cookies or get_cookies(".bing.com")
|
||||
if "_U" not in cookies:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"Please login: [Bing]({login_url})\n\n"
|
||||
cls._cookies = cookies = get_cookies_from_browser(proxy)
|
||||
yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy))
|
||||
|
||||
@classmethod
|
||||
async def create_async(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> str:
|
||||
"""
|
||||
Asynchronously creates a markdown formatted string with images based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): Prompt to generate images.
|
||||
cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically.
|
||||
proxy (str, optional): Proxy configuration.
|
||||
|
||||
Returns:
|
||||
str: Markdown formatted string with images.
|
||||
"""
|
||||
cookies = cookies or cls._cookies or get_cookies(".bing.com")
|
||||
if "_U" not in cookies:
|
||||
cls._cookies = cookies = get_cookies_from_browser(proxy)
|
||||
return await create_images_markdown(cookies, prompt, proxy)
|
||||
|
||||
def patch_provider(provider: ProviderType) -> CreateImagesProvider:
|
||||
return CreateImagesProvider(provider, create_completion, create_async)
|
||||
"""
|
||||
Patches a provider to include image creation capabilities.
|
||||
|
||||
Args:
|
||||
provider (ProviderType): The provider to be patched.
|
||||
|
||||
Returns:
|
||||
CreateImagesProvider: The patched provider with image creation capabilities.
|
||||
"""
|
||||
return CreateImagesProvider(provider, CreateImagesBing.create_completion, CreateImagesBing.create_async)
|
@ -1,64 +1,107 @@
|
||||
from __future__ import annotations
|
||||
"""
|
||||
Module to handle image uploading and processing for Bing AI integrations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import string
|
||||
import random
|
||||
import json
|
||||
import math
|
||||
from ...typing import ImageType
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image
|
||||
|
||||
from ...typing import ImageType, Tuple
|
||||
from ...image import to_image, process_image, to_base64, ImageResponse
|
||||
|
||||
image_config = {
|
||||
IMAGE_CONFIG = {
|
||||
"maxImagePixels": 360000,
|
||||
"imageCompressionRate": 0.7,
|
||||
"enableFaceBlurDebug": 0,
|
||||
"enableFaceBlurDebug": False,
|
||||
}
|
||||
|
||||
async def upload_image(
|
||||
session: ClientSession,
|
||||
image: ImageType,
|
||||
tone: str,
|
||||
session: ClientSession,
|
||||
image_data: ImageType,
|
||||
tone: str,
|
||||
proxy: str = None
|
||||
) -> ImageResponse:
|
||||
image = to_image(image)
|
||||
width, height = image.size
|
||||
max_image_pixels = image_config['maxImagePixels']
|
||||
if max_image_pixels / (width * height) < 1:
|
||||
new_width = int(width * math.sqrt(max_image_pixels / (width * height)))
|
||||
new_height = int(height * math.sqrt(max_image_pixels / (width * height)))
|
||||
else:
|
||||
new_width = width
|
||||
new_height = height
|
||||
new_img = process_image(image, new_width, new_height)
|
||||
new_img_binary_data = to_base64(new_img, image_config['imageCompressionRate'])
|
||||
data, boundary = build_image_upload_api_payload(new_img_binary_data, tone)
|
||||
headers = session.headers.copy()
|
||||
headers["content-type"] = f'multipart/form-data; boundary={boundary}'
|
||||
headers["referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
|
||||
headers["origin"] = 'https://www.bing.com'
|
||||
"""
|
||||
Uploads an image to Bing's AI service and returns the image response.
|
||||
|
||||
Args:
|
||||
session (ClientSession): The active session.
|
||||
image_data (bytes): The image data to be uploaded.
|
||||
tone (str): The tone of the conversation.
|
||||
proxy (str, optional): Proxy if any. Defaults to None.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the image upload fails.
|
||||
|
||||
Returns:
|
||||
ImageResponse: The response from the image upload.
|
||||
"""
|
||||
image = to_image(image_data)
|
||||
new_width, new_height = calculate_new_dimensions(image)
|
||||
processed_img = process_image(image, new_width, new_height)
|
||||
img_binary_data = to_base64(processed_img, IMAGE_CONFIG['imageCompressionRate'])
|
||||
|
||||
data, boundary = build_image_upload_payload(img_binary_data, tone)
|
||||
headers = prepare_headers(session, boundary)
|
||||
|
||||
async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response:
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Failed to upload image.")
|
||||
image_info = await response.json()
|
||||
if not image_info.get('blobId'):
|
||||
raise RuntimeError("Failed to parse image info.")
|
||||
result = {'bcid': image_info.get('blobId', "")}
|
||||
result['blurredBcid'] = image_info.get('processedBlobId', "")
|
||||
if result['blurredBcid'] != "":
|
||||
result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['blurredBcid']
|
||||
elif result['bcid'] != "":
|
||||
result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['bcid']
|
||||
result['originalImageUrl'] = (
|
||||
"https://www.bing.com/images/blob?bcid="
|
||||
+ result['blurredBcid']
|
||||
if image_config["enableFaceBlurDebug"]
|
||||
else "https://www.bing.com/images/blob?bcid="
|
||||
+ result['bcid']
|
||||
)
|
||||
return ImageResponse(result["imageUrl"], "", result)
|
||||
return parse_image_response(await response.json())
|
||||
|
||||
def build_image_upload_api_payload(image_bin: str, tone: str):
|
||||
payload = {
|
||||
def calculate_new_dimensions(image: Image.Image) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculates the new dimensions for the image based on the maximum allowed pixels.
|
||||
|
||||
Args:
|
||||
image (Image): The PIL Image object.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The new width and height for the image.
|
||||
"""
|
||||
width, height = image.size
|
||||
max_image_pixels = IMAGE_CONFIG['maxImagePixels']
|
||||
if max_image_pixels / (width * height) < 1:
|
||||
scale_factor = math.sqrt(max_image_pixels / (width * height))
|
||||
return int(width * scale_factor), int(height * scale_factor)
|
||||
return width, height
|
||||
|
||||
def build_image_upload_payload(image_bin: str, tone: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Builds the payload for image uploading.
|
||||
|
||||
Args:
|
||||
image_bin (str): Base64 encoded image binary data.
|
||||
tone (str): The tone of the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: The data and boundary for the payload.
|
||||
"""
|
||||
boundary = "----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
|
||||
data = f"--{boundary}\r\n" \
|
||||
f"Content-Disposition: form-data; name=\"knowledgeRequest\"\r\n\r\n" \
|
||||
f"{json.dumps(build_knowledge_request(tone), ensure_ascii=False)}\r\n" \
|
||||
f"--{boundary}\r\n" \
|
||||
f"Content-Disposition: form-data; name=\"imageBase64\"\r\n\r\n" \
|
||||
f"{image_bin}\r\n" \
|
||||
f"--{boundary}--\r\n"
|
||||
return data, boundary
|
||||
|
||||
def build_knowledge_request(tone: str) -> dict:
|
||||
"""
|
||||
Builds the knowledge request payload.
|
||||
|
||||
Args:
|
||||
tone (str): The tone of the conversation.
|
||||
|
||||
Returns:
|
||||
dict: The knowledge request payload.
|
||||
"""
|
||||
return {
|
||||
'invokedSkills': ["ImageById"],
|
||||
'subscriptionId': "Bing.Chat.Multimodal",
|
||||
'invokedSkillsRequestData': {
|
||||
@ -69,21 +112,46 @@ def build_image_upload_api_payload(image_bin: str, tone: str):
|
||||
'convotone': tone
|
||||
}
|
||||
}
|
||||
knowledge_request = {
|
||||
'imageInfo': {},
|
||||
'knowledgeRequest': payload
|
||||
}
|
||||
boundary="----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16))
|
||||
data = (
|
||||
f'--{boundary}'
|
||||
+ '\r\nContent-Disposition: form-data; name="knowledgeRequest"\r\n\r\n'
|
||||
+ json.dumps(knowledge_request, ensure_ascii=False)
|
||||
+ "\r\n--"
|
||||
+ boundary
|
||||
+ '\r\nContent-Disposition: form-data; name="imageBase64"\r\n\r\n'
|
||||
+ image_bin
|
||||
+ "\r\n--"
|
||||
+ boundary
|
||||
+ "--\r\n"
|
||||
|
||||
def prepare_headers(session: ClientSession, boundary: str) -> dict:
|
||||
"""
|
||||
Prepares the headers for the image upload request.
|
||||
|
||||
Args:
|
||||
session (ClientSession): The active session.
|
||||
boundary (str): The boundary string for the multipart/form-data.
|
||||
|
||||
Returns:
|
||||
dict: The headers for the request.
|
||||
"""
|
||||
headers = session.headers.copy()
|
||||
headers["Content-Type"] = f'multipart/form-data; boundary={boundary}'
|
||||
headers["Referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx'
|
||||
headers["Origin"] = 'https://www.bing.com'
|
||||
return headers
|
||||
|
||||
def parse_image_response(response: dict) -> ImageResponse:
|
||||
"""
|
||||
Parses the response from the image upload.
|
||||
|
||||
Args:
|
||||
response (dict): The response dictionary.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If parsing the image info fails.
|
||||
|
||||
Returns:
|
||||
ImageResponse: The parsed image response.
|
||||
"""
|
||||
if not response.get('blobId'):
|
||||
raise RuntimeError("Failed to parse image info.")
|
||||
|
||||
result = {'bcid': response.get('blobId', ""), 'blurredBcid': response.get('processedBlobId', "")}
|
||||
result["imageUrl"] = f"https://www.bing.com/images/blob?bcid={result['blurredBcid'] or result['bcid']}"
|
||||
|
||||
result['originalImageUrl'] = (
|
||||
f"https://www.bing.com/images/blob?bcid={result['blurredBcid']}"
|
||||
if IMAGE_CONFIG["enableFaceBlurDebug"] else
|
||||
f"https://www.bing.com/images/blob?bcid={result['bcid']}"
|
||||
)
|
||||
return data, boundary
|
||||
return ImageResponse(result["imageUrl"], "", result)
|
@ -1,36 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import webbrowser
|
||||
import random
|
||||
import string
|
||||
import secrets
|
||||
import os
|
||||
from os import path
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
from asyncio import AbstractEventLoop, BaseEventLoop
|
||||
from platformdirs import user_config_dir
|
||||
from browser_cookie3 import (
|
||||
chrome,
|
||||
chromium,
|
||||
opera,
|
||||
opera_gx,
|
||||
brave,
|
||||
edge,
|
||||
vivaldi,
|
||||
firefox,
|
||||
_LinuxPasswordManager
|
||||
chrome, chromium, opera, opera_gx,
|
||||
brave, edge, vivaldi, firefox,
|
||||
_LinuxPasswordManager, BrowserCookieError
|
||||
)
|
||||
|
||||
from ..typing import Dict, Messages
|
||||
from .. import debug
|
||||
|
||||
# Local Cookie Storage
|
||||
# Global variable to store cookies
|
||||
_cookies: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
# If loop closed or not set, create new event loop.
|
||||
# If event loop is already running, handle nested event loops.
|
||||
# If "nest_asyncio" is installed, patch the event loop.
|
||||
def get_event_loop() -> AbstractEventLoop:
|
||||
"""
|
||||
Get the current asyncio event loop. If the loop is closed or not set, create a new event loop.
|
||||
If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed.
|
||||
|
||||
Returns:
|
||||
AbstractEventLoop: The current or new event loop.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if isinstance(loop, BaseEventLoop):
|
||||
@ -39,61 +34,50 @@ def get_event_loop() -> AbstractEventLoop:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
# Is running event loop
|
||||
asyncio.get_running_loop()
|
||||
if not hasattr(loop.__class__, "_nest_patched"):
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply(loop)
|
||||
except RuntimeError:
|
||||
# No running event loop
|
||||
pass
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.'
|
||||
'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.'
|
||||
)
|
||||
return loop
|
||||
|
||||
def init_cookies():
|
||||
urls = [
|
||||
'https://chat-gpt.org',
|
||||
'https://www.aitianhu.com',
|
||||
'https://chatgptfree.ai',
|
||||
'https://gptchatly.com',
|
||||
'https://bard.google.com',
|
||||
'https://huggingface.co/chat',
|
||||
'https://open-assistant.io/chat'
|
||||
]
|
||||
|
||||
browsers = ['google-chrome', 'chrome', 'firefox', 'safari']
|
||||
|
||||
def open_urls_in_browser(browser):
|
||||
b = webbrowser.get(browser)
|
||||
for url in urls:
|
||||
b.open(url, new=0, autoraise=True)
|
||||
|
||||
for browser in browsers:
|
||||
try:
|
||||
open_urls_in_browser(browser)
|
||||
break
|
||||
except webbrowser.Error:
|
||||
continue
|
||||
|
||||
# Check for broken dbus address in docker image
|
||||
if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
|
||||
_LinuxPasswordManager.get_password = lambda a, b: b"secret"
|
||||
|
||||
# Load cookies for a domain from all supported browsers.
|
||||
# Cache the results in the "_cookies" variable.
|
||||
def get_cookies(domain_name=''):
|
||||
|
||||
def get_cookies(domain_name: str = '') -> Dict[str, str]:
|
||||
"""
|
||||
Load cookies for a given domain from all supported browsers and cache the results.
|
||||
|
||||
Args:
|
||||
domain_name (str): The domain for which to load cookies.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary of cookie names and values.
|
||||
"""
|
||||
if domain_name in _cookies:
|
||||
return _cookies[domain_name]
|
||||
def g4f(domain_name):
|
||||
user_data_dir = user_config_dir("g4f")
|
||||
cookie_file = path.join(user_data_dir, "Default", "Cookies")
|
||||
return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name)
|
||||
|
||||
cookies = _load_cookies_from_browsers(domain_name)
|
||||
_cookies[domain_name] = cookies
|
||||
return cookies
|
||||
|
||||
def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]:
|
||||
"""
|
||||
Helper function to load cookies from various browsers.
|
||||
|
||||
Args:
|
||||
domain_name (str): The domain for which to load cookies.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary of cookie names and values.
|
||||
"""
|
||||
cookies = {}
|
||||
for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
|
||||
for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
|
||||
try:
|
||||
cookie_jar = cookie_fn(domain_name=domain_name)
|
||||
if len(cookie_jar) and debug.logging:
|
||||
@ -101,13 +85,38 @@ def get_cookies(domain_name=''):
|
||||
for cookie in cookie_jar:
|
||||
if cookie.name not in cookies:
|
||||
cookies[cookie.name] = cookie.value
|
||||
except:
|
||||
except BrowserCookieError:
|
||||
pass
|
||||
_cookies[domain_name] = cookies
|
||||
return _cookies[domain_name]
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
|
||||
return cookies
|
||||
|
||||
def _g4f(domain_name: str) -> list:
|
||||
"""
|
||||
Load cookies from the 'g4f' browser (if exists).
|
||||
|
||||
Args:
|
||||
domain_name (str): The domain for which to load cookies.
|
||||
|
||||
Returns:
|
||||
list: List of cookies.
|
||||
"""
|
||||
user_data_dir = user_config_dir("g4f")
|
||||
cookie_file = os.path.join(user_data_dir, "Default", "Cookies")
|
||||
return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name)
|
||||
|
||||
def format_prompt(messages: Messages, add_special_tokens=False) -> str:
|
||||
"""
|
||||
Format a series of messages into a single string, optionally adding special tokens.
|
||||
|
||||
Args:
|
||||
messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
|
||||
add_special_tokens (bool): Whether to add special formatting tokens.
|
||||
|
||||
Returns:
|
||||
str: A formatted string containing all messages.
|
||||
"""
|
||||
if not add_special_tokens and len(messages) <= 1:
|
||||
return messages[0]["content"]
|
||||
formatted = "\n".join([
|
||||
@ -116,12 +125,26 @@ def format_prompt(messages: Messages, add_special_tokens=False) -> str:
|
||||
])
|
||||
return f"{formatted}\nAssistant:"
|
||||
|
||||
|
||||
def get_random_string(length: int = 10) -> str:
|
||||
"""
|
||||
Generate a random string of specified length, containing lowercase letters and digits.
|
||||
|
||||
Args:
|
||||
length (int, optional): Length of the random string to generate. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
str: A random string of the specified length.
|
||||
"""
|
||||
return ''.join(
|
||||
random.choice(string.ascii_lowercase + string.digits)
|
||||
for _ in range(length)
|
||||
)
|
||||
|
||||
def get_random_hex() -> str:
|
||||
"""
|
||||
Generate a random hexadecimal string of a fixed length.
|
||||
|
||||
Returns:
|
||||
str: A random hexadecimal string of 32 characters (16 bytes).
|
||||
"""
|
||||
return secrets.token_hex(16).zfill(32)
|
@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
|
||||
import uuid, json, asyncio, os
|
||||
from py_arkose_generator.arkose import get_values_for_request
|
||||
from async_property import async_cached_property
|
||||
from selenium.webdriver.common.by import By
|
||||
@ -14,7 +17,8 @@ from ...typing import AsyncResult, Messages
|
||||
from ...requests import StreamSession
|
||||
from ...image import to_image, to_bytes, ImageType, ImageResponse
|
||||
|
||||
models = {
|
||||
# Aliases for model names
|
||||
MODELS = {
|
||||
"gpt-3.5": "text-davinci-002-render-sha",
|
||||
"gpt-3.5-turbo": "text-davinci-002-render-sha",
|
||||
"gpt-4": "gpt-4",
|
||||
@ -22,13 +26,15 @@ models = {
|
||||
}
|
||||
|
||||
class OpenaiChat(AsyncGeneratorProvider):
|
||||
url = "https://chat.openai.com"
|
||||
working = True
|
||||
needs_auth = True
|
||||
"""A class for creating and managing conversations with OpenAI chat service"""
|
||||
|
||||
url = "https://chat.openai.com"
|
||||
working = True
|
||||
needs_auth = True
|
||||
supports_gpt_35_turbo = True
|
||||
supports_gpt_4 = True
|
||||
_cookies: dict = {}
|
||||
_default_model: str = None
|
||||
supports_gpt_4 = True
|
||||
_cookies: dict = {}
|
||||
_default_model: str = None
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
@ -43,6 +49,23 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
image: ImageType = None,
|
||||
**kwargs
|
||||
) -> Response:
|
||||
"""Create a new conversation or continue an existing one
|
||||
|
||||
Args:
|
||||
prompt: The user input to start or continue the conversation
|
||||
model: The name of the model to use for generating responses
|
||||
messages: The list of previous messages in the conversation
|
||||
history_disabled: A flag indicating if the history and training should be disabled
|
||||
action: The type of action to perform, either "next", "continue", or "variant"
|
||||
conversation_id: The ID of the existing conversation, if any
|
||||
parent_id: The ID of the parent message, if any
|
||||
image: The image to include in the user input, if any
|
||||
**kwargs: Additional keyword arguments to pass to the generator
|
||||
|
||||
Returns:
|
||||
A Response object that contains the generator, action, messages, and options
|
||||
"""
|
||||
# Add the user input to the messages list
|
||||
if prompt:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
@ -67,20 +90,33 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def upload_image(
|
||||
async def _upload_image(
|
||||
cls,
|
||||
session: StreamSession,
|
||||
headers: dict,
|
||||
image: ImageType
|
||||
) -> ImageResponse:
|
||||
"""Upload an image to the service and get the download URL
|
||||
|
||||
Args:
|
||||
session: The StreamSession object to use for requests
|
||||
headers: The headers to include in the requests
|
||||
image: The image to upload, either a PIL Image object or a bytes object
|
||||
|
||||
Returns:
|
||||
An ImageResponse object that contains the download URL, file name, and other data
|
||||
"""
|
||||
# Convert the image to a PIL Image object and get the extension
|
||||
image = to_image(image)
|
||||
extension = image.format.lower()
|
||||
# Convert the image to a bytes object and get the size
|
||||
data_bytes = to_bytes(image)
|
||||
data = {
|
||||
"file_name": f"{image.width}x{image.height}.{extension}",
|
||||
"file_size": len(data_bytes),
|
||||
"use_case": "multimodal"
|
||||
}
|
||||
# Post the image data to the service and get the image data
|
||||
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
image_data = {
|
||||
@ -91,6 +127,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
"height": image.height,
|
||||
"width": image.width
|
||||
}
|
||||
# Put the image bytes to the upload URL and check the status
|
||||
async with session.put(
|
||||
image_data["upload_url"],
|
||||
data=data_bytes,
|
||||
@ -100,6 +137,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
}
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
# Post the file ID to the service and get the download URL
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
|
||||
json={},
|
||||
@ -110,24 +148,45 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
return ImageResponse(download_url, image_data["file_name"], image_data)
|
||||
|
||||
@classmethod
|
||||
async def get_default_model(cls, session: StreamSession, headers: dict):
|
||||
async def _get_default_model(cls, session: StreamSession, headers: dict):
|
||||
"""Get the default model name from the service
|
||||
|
||||
Args:
|
||||
session: The StreamSession object to use for requests
|
||||
headers: The headers to include in the requests
|
||||
|
||||
Returns:
|
||||
The default model name as a string
|
||||
"""
|
||||
# Check the cache for the default model
|
||||
if cls._default_model:
|
||||
model = cls._default_model
|
||||
else:
|
||||
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
|
||||
data = await response.json()
|
||||
if "categories" in data:
|
||||
model = data["categories"][-1]["default_model"]
|
||||
else:
|
||||
RuntimeError(f"Response: {data}")
|
||||
cls._default_model = model
|
||||
return model
|
||||
return cls._default_model
|
||||
# Get the models data from the service
|
||||
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
|
||||
data = await response.json()
|
||||
if "categories" in data:
|
||||
cls._default_model = data["categories"][-1]["default_model"]
|
||||
else:
|
||||
raise RuntimeError(f"Response: {data}")
|
||||
return cls._default_model
|
||||
|
||||
@classmethod
|
||||
def create_messages(cls, prompt: str, image_response: ImageResponse = None):
|
||||
def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
|
||||
"""Create a list of messages for the user input
|
||||
|
||||
Args:
|
||||
prompt: The user input as a string
|
||||
image_response: The image response object, if any
|
||||
|
||||
Returns:
|
||||
A list of messages with the user input and the image, if any
|
||||
"""
|
||||
# Check if there is an image response
|
||||
if not image_response:
|
||||
# Create a content object with the text type and the prompt
|
||||
content = {"content_type": "text", "parts": [prompt]}
|
||||
else:
|
||||
# Create a content object with the multimodal text type and the image and the prompt
|
||||
content = {
|
||||
"content_type": "multimodal_text",
|
||||
"parts": [{
|
||||
@ -137,12 +196,15 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
"width": image_response.get("width"),
|
||||
}, prompt]
|
||||
}
|
||||
# Create a message object with the user role and the content
|
||||
messages = [{
|
||||
"id": str(uuid.uuid4()),
|
||||
"author": {"role": "user"},
|
||||
"content": content,
|
||||
}]
|
||||
# Check if there is an image response
|
||||
if image_response:
|
||||
# Add the metadata object with the attachments
|
||||
messages[0]["metadata"] = {
|
||||
"attachments": [{
|
||||
"height": image_response.get("height"),
|
||||
@ -156,19 +218,38 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
async def get_image_response(cls, session: StreamSession, headers: dict, line: dict):
|
||||
if "parts" in line["message"]["content"]:
|
||||
part = line["message"]["content"]["parts"][0]
|
||||
if "asset_pointer" in part and part["metadata"]:
|
||||
file_id = part["asset_pointer"].split("file-service://", 1)[1]
|
||||
prompt = part["metadata"]["dalle"]["prompt"]
|
||||
async with session.get(
|
||||
f"{cls.url}/backend-api/files/{file_id}/download",
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
download_url = (await response.json())["download_url"]
|
||||
return ImageResponse(download_url, prompt)
|
||||
async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
|
||||
"""
|
||||
Retrieves the image response based on the message content.
|
||||
|
||||
:param session: The StreamSession object.
|
||||
:param headers: HTTP headers for the request.
|
||||
:param line: The line of response containing image information.
|
||||
:return: An ImageResponse object with the image details.
|
||||
"""
|
||||
if "parts" not in line["message"]["content"]:
|
||||
return
|
||||
first_part = line["message"]["content"]["parts"][0]
|
||||
if "asset_pointer" not in first_part or "metadata" not in first_part:
|
||||
return
|
||||
file_id = first_part["asset_pointer"].split("file-service://", 1)[1]
|
||||
prompt = first_part["metadata"]["dalle"]["prompt"]
|
||||
try:
|
||||
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
download_url = (await response.json())["download_url"]
|
||||
return ImageResponse(download_url, prompt)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in downloading image: {e}")
|
||||
|
||||
@classmethod
|
||||
async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
|
||||
async with session.patch(
|
||||
f"{cls.url}/backend-api/conversation/{conversation_id}",
|
||||
json={"is_visible": False},
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
@ -188,26 +269,47 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
response_fields: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if model in models:
|
||||
model = models[model]
|
||||
"""
|
||||
Create an asynchronous generator for the conversation.
|
||||
|
||||
Args:
|
||||
model (str): The model name.
|
||||
messages (Messages): The list of previous messages.
|
||||
proxy (str): Proxy to use for requests.
|
||||
timeout (int): Timeout for requests.
|
||||
access_token (str): Access token for authentication.
|
||||
cookies (dict): Cookies to use for authentication.
|
||||
auto_continue (bool): Flag to automatically continue the conversation.
|
||||
history_disabled (bool): Flag to disable history and training.
|
||||
action (str): Type of action ('next', 'continue', 'variant').
|
||||
conversation_id (str): ID of the conversation.
|
||||
parent_id (str): ID of the parent message.
|
||||
image (ImageType): Image to include in the conversation.
|
||||
response_fields (bool): Flag to include response fields in the output.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Yields:
|
||||
AsyncResult: Asynchronous results from the generator.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an error occurs during processing.
|
||||
"""
|
||||
model = MODELS.get(model, model)
|
||||
if not parent_id:
|
||||
parent_id = str(uuid.uuid4())
|
||||
if not cookies:
|
||||
cookies = cls._cookies
|
||||
if not access_token:
|
||||
if not cookies:
|
||||
cls._cookies = cookies = get_cookies("chat.openai.com")
|
||||
if "access_token" in cookies:
|
||||
access_token = cookies["access_token"]
|
||||
cookies = cls._cookies or get_cookies("chat.openai.com")
|
||||
if not access_token and "access_token" in cookies:
|
||||
access_token = cookies["access_token"]
|
||||
if not access_token:
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"Please login: [ChatGPT]({login_url})\n\n"
|
||||
access_token, cookies = cls.browse_access_token(proxy)
|
||||
access_token, cookies = cls._browse_access_token(proxy)
|
||||
cls._cookies = cookies
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
async with StreamSession(
|
||||
proxies={"https": proxy},
|
||||
impersonate="chrome110",
|
||||
@ -215,11 +317,11 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
|
||||
) as session:
|
||||
if not model:
|
||||
model = await cls.get_default_model(session, headers)
|
||||
model = await cls._get_default_model(session, headers)
|
||||
try:
|
||||
image_response = None
|
||||
if image:
|
||||
image_response = await cls.upload_image(session, headers, image)
|
||||
image_response = await cls._upload_image(session, headers, image)
|
||||
yield image_response
|
||||
except Exception as e:
|
||||
yield e
|
||||
@ -227,7 +329,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
while not end_turn.is_end:
|
||||
data = {
|
||||
"action": action,
|
||||
"arkose_token": await cls.get_arkose_token(session),
|
||||
"arkose_token": await cls._get_arkose_token(session),
|
||||
"conversation_id": conversation_id,
|
||||
"parent_message_id": parent_id,
|
||||
"model": model,
|
||||
@ -235,7 +337,7 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
}
|
||||
if action != "continue":
|
||||
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
|
||||
data["messages"] = cls.create_messages(prompt, image_response)
|
||||
data["messages"] = cls._create_messages(prompt, image_response)
|
||||
async with session.post(
|
||||
f"{cls.url}/backend-api/conversation",
|
||||
json=data,
|
||||
@ -261,62 +363,80 @@ class OpenaiChat(AsyncGeneratorProvider):
|
||||
if "message_type" not in line["message"]["metadata"]:
|
||||
continue
|
||||
try:
|
||||
image_response = await cls.get_image_response(session, headers, line)
|
||||
image_response = await cls._get_generated_image(session, headers, line)
|
||||
if image_response:
|
||||
yield image_response
|
||||
except Exception as e:
|
||||
yield e
|
||||
if line["message"]["author"]["role"] != "assistant":
|
||||
continue
|
||||
if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"):
|
||||
conversation_id = line["conversation_id"]
|
||||
parent_id = line["message"]["id"]
|
||||
if response_fields:
|
||||
response_fields = False
|
||||
yield ResponseFields(conversation_id, parent_id, end_turn)
|
||||
if "parts" in line["message"]["content"]:
|
||||
new_message = line["message"]["content"]["parts"][0]
|
||||
if len(new_message) > last_message:
|
||||
yield new_message[last_message:]
|
||||
last_message = len(new_message)
|
||||
if line["message"]["content"]["content_type"] != "text":
|
||||
continue
|
||||
if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"):
|
||||
continue
|
||||
conversation_id = line["conversation_id"]
|
||||
parent_id = line["message"]["id"]
|
||||
if response_fields:
|
||||
response_fields = False
|
||||
yield ResponseFields(conversation_id, parent_id, end_turn)
|
||||
if "parts" in line["message"]["content"]:
|
||||
new_message = line["message"]["content"]["parts"][0]
|
||||
if len(new_message) > last_message:
|
||||
yield new_message[last_message:]
|
||||
last_message = len(new_message)
|
||||
if "finish_details" in line["message"]["metadata"]:
|
||||
if line["message"]["metadata"]["finish_details"]["type"] == "stop":
|
||||
end_turn.end()
|
||||
break
|
||||
except Exception as e:
|
||||
yield e
|
||||
raise e
|
||||
if not auto_continue:
|
||||
break
|
||||
action = "continue"
|
||||
await asyncio.sleep(5)
|
||||
if history_disabled:
|
||||
async with session.patch(
|
||||
f"{cls.url}/backend-api/conversation/{conversation_id}",
|
||||
json={"is_visible": False},
|
||||
headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
if history_disabled and auto_continue:
|
||||
await cls._delete_conversation(session, headers, conversation_id)
|
||||
|
||||
@classmethod
|
||||
def browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
|
||||
def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
|
||||
"""
|
||||
Browse to obtain an access token.
|
||||
|
||||
Args:
|
||||
proxy (str): Proxy to use for browsing.
|
||||
|
||||
Returns:
|
||||
tuple[str, dict]: A tuple containing the access token and cookies.
|
||||
"""
|
||||
driver = get_browser(proxy=proxy)
|
||||
try:
|
||||
driver.get(f"{cls.url}/")
|
||||
WebDriverWait(driver, 1200).until(
|
||||
EC.presence_of_element_located((By.ID, "prompt-textarea"))
|
||||
WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
|
||||
access_token = driver.execute_script(
|
||||
"let session = await fetch('/api/auth/session');"
|
||||
"let data = await session.json();"
|
||||
"let accessToken = data['accessToken'];"
|
||||
"let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7);"
|
||||
"document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
|
||||
"return accessToken;"
|
||||
)
|
||||
javascript = """
|
||||
access_token = (await (await fetch('/api/auth/session')).json())['accessToken'];
|
||||
expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week
|
||||
document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/';
|
||||
return access_token;
|
||||
"""
|
||||
return driver.execute_script(javascript), get_driver_cookies(driver)
|
||||
return access_token, get_driver_cookies(driver)
|
||||
finally:
|
||||
driver.quit()
|
||||
|
||||
@classmethod
|
||||
async def get_arkose_token(cls, session: StreamSession) -> str:
|
||||
@classmethod
|
||||
async def _get_arkose_token(cls, session: StreamSession) -> str:
|
||||
"""
|
||||
Obtain an Arkose token for the session.
|
||||
|
||||
Args:
|
||||
session (StreamSession): The session object.
|
||||
|
||||
Returns:
|
||||
str: The Arkose token.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If unable to retrieve the token.
|
||||
"""
|
||||
config = {
|
||||
"pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC",
|
||||
"surl": "https://tcr9i.chat.openai.com",
|
||||
@ -332,26 +452,30 @@ return access_token;
|
||||
if "token" in decoded_json:
|
||||
return decoded_json["token"]
|
||||
raise RuntimeError(f"Response: {decoded_json}")
|
||||
|
||||
class EndTurn():
|
||||
|
||||
class EndTurn:
|
||||
"""
|
||||
Class to represent the end of a conversation turn.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.is_end = False
|
||||
|
||||
def end(self):
|
||||
self.is_end = True
|
||||
|
||||
class ResponseFields():
|
||||
def __init__(
|
||||
self,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
end_turn: EndTurn
|
||||
):
|
||||
class ResponseFields:
|
||||
"""
|
||||
Class to encapsulate response fields.
|
||||
"""
|
||||
def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
|
||||
self.conversation_id = conversation_id
|
||||
self.message_id = message_id
|
||||
self._end_turn = end_turn
|
||||
|
||||
class Response():
|
||||
"""
|
||||
Class to encapsulate a response from the chat service.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
generator: AsyncResult,
|
||||
@ -360,13 +484,13 @@ class Response():
|
||||
options: dict
|
||||
):
|
||||
self._generator = generator
|
||||
self.action: str = action
|
||||
self.is_end: bool = False
|
||||
self.action = action
|
||||
self.is_end = False
|
||||
self._message = None
|
||||
self._messages = messages
|
||||
self._options = options
|
||||
self._fields = None
|
||||
|
||||
|
||||
async def generator(self):
|
||||
if self._generator:
|
||||
self._generator = None
|
||||
@ -384,19 +508,16 @@ class Response():
|
||||
|
||||
def __aiter__(self):
|
||||
return self.generator()
|
||||
|
||||
|
||||
@async_cached_property
|
||||
async def message(self) -> str:
|
||||
[_ async for _ in self.generator()]
|
||||
await self.generator()
|
||||
return self._message
|
||||
|
||||
|
||||
async def get_fields(self):
|
||||
[_ async for _ in self.generator()]
|
||||
return {
|
||||
"conversation_id": self._fields.conversation_id,
|
||||
"parent_id": self._fields.message_id,
|
||||
}
|
||||
|
||||
await self.generator()
|
||||
return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id}
|
||||
|
||||
async def next(self, prompt: str, **kwargs) -> Response:
|
||||
return await OpenaiChat.create(
|
||||
**self._options,
|
||||
@ -406,7 +527,7 @@ class Response():
|
||||
**await self.get_fields(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
async def do_continue(self, **kwargs) -> Response:
|
||||
fields = await self.get_fields()
|
||||
if self.is_end:
|
||||
@ -418,7 +539,7 @@ class Response():
|
||||
**fields,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
async def variant(self, **kwargs) -> Response:
|
||||
if self.action != "next":
|
||||
raise RuntimeError("Can't create variant from continue or variant request.")
|
||||
@ -429,11 +550,9 @@ class Response():
|
||||
**await self.get_fields(),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
@async_cached_property
|
||||
async def messages(self):
|
||||
messages = self._messages
|
||||
messages.append({
|
||||
"role": "assistant", "content": await self.message
|
||||
})
|
||||
messages.append({"role": "assistant", "content": await self.message})
|
||||
return messages
|
@ -7,8 +7,17 @@ from ..base_provider import BaseRetryProvider
|
||||
from .. import debug
|
||||
from ..errors import RetryProviderError, RetryNoProviderError
|
||||
|
||||
|
||||
class RetryProvider(BaseRetryProvider):
|
||||
"""
|
||||
A provider class to handle retries for creating completions with different providers.
|
||||
|
||||
Attributes:
|
||||
providers (list): A list of provider instances.
|
||||
shuffle (bool): A flag indicating whether to shuffle providers before use.
|
||||
exceptions (dict): A dictionary to store exceptions encountered during retries.
|
||||
last_provider (BaseProvider): The last provider that was used.
|
||||
"""
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
model: str,
|
||||
@ -16,10 +25,21 @@ class RetryProvider(BaseRetryProvider):
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
if stream:
|
||||
providers = [provider for provider in self.providers if provider.supports_stream]
|
||||
else:
|
||||
providers = self.providers
|
||||
"""
|
||||
Create a completion using available providers, with an option to stream the response.
|
||||
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
|
||||
|
||||
Yields:
|
||||
CreateResult: Tokens or results from the completion.
|
||||
|
||||
Raises:
|
||||
Exception: Any exception encountered during the completion process.
|
||||
"""
|
||||
providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
|
||||
if self.shuffle:
|
||||
random.shuffle(providers)
|
||||
|
||||
@ -50,10 +70,23 @@ class RetryProvider(BaseRetryProvider):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously create a completion using available providers.
|
||||
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
|
||||
Returns:
|
||||
str: The result of the asynchronous completion.
|
||||
|
||||
Raises:
|
||||
Exception: Any exception encountered during the asynchronous completion process.
|
||||
"""
|
||||
providers = self.providers
|
||||
if self.shuffle:
|
||||
random.shuffle(providers)
|
||||
|
||||
|
||||
self.exceptions = {}
|
||||
for provider in providers:
|
||||
self.last_provider = provider
|
||||
@ -66,13 +99,20 @@ class RetryProvider(BaseRetryProvider):
|
||||
self.exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
|
||||
|
||||
self.raise_exceptions()
|
||||
|
||||
|
||||
def raise_exceptions(self) -> None:
|
||||
"""
|
||||
Raise a combined exception if any occurred during retries.
|
||||
|
||||
Raises:
|
||||
RetryProviderError: If any provider encountered an exception.
|
||||
RetryNoProviderError: If no provider is found.
|
||||
"""
|
||||
if self.exceptions:
|
||||
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
|
||||
f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
|
||||
]))
|
||||
|
||||
|
||||
raise RetryNoProviderError("No provider found")
|
@ -15,6 +15,26 @@ def get_model_and_provider(model : Union[Model, str],
|
||||
ignored : list[str] = None,
|
||||
ignore_working: bool = False,
|
||||
ignore_stream: bool = False) -> tuple[str, ProviderType]:
|
||||
"""
|
||||
Retrieves the model and provider based on input parameters.
|
||||
|
||||
Args:
|
||||
model (Union[Model, str]): The model to use, either as an object or a string identifier.
|
||||
provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
|
||||
stream (bool): Indicates if the operation should be performed as a stream.
|
||||
ignored (list[str], optional): List of provider names to be ignored.
|
||||
ignore_working (bool, optional): If True, ignores the working status of the provider.
|
||||
ignore_stream (bool, optional): If True, ignores the streaming capability of the provider.
|
||||
|
||||
Returns:
|
||||
tuple[str, ProviderType]: A tuple containing the model name and the provider type.
|
||||
|
||||
Raises:
|
||||
ProviderNotFoundError: If the provider is not found.
|
||||
ModelNotFoundError: If the model is not found.
|
||||
ProviderNotWorkingError: If the provider is not working.
|
||||
StreamNotSupportedError: If streaming is not supported by the provider.
|
||||
"""
|
||||
if debug.version_check:
|
||||
debug.version_check = False
|
||||
version.utils.check_version()
|
||||
@ -70,7 +90,30 @@ class ChatCompletion:
|
||||
ignore_stream_and_auth: bool = False,
|
||||
patch_provider: callable = None,
|
||||
**kwargs) -> Union[CreateResult, str]:
|
||||
"""
|
||||
Creates a chat completion using the specified model, provider, and messages.
|
||||
|
||||
Args:
|
||||
model (Union[Model, str]): The model to use, either as an object or a string identifier.
|
||||
messages (Messages): The messages for which the completion is to be created.
|
||||
provider (Union[ProviderType, str, None], optional): The provider to use, either as an object, a string identifier, or None.
|
||||
stream (bool, optional): Indicates if the operation should be performed as a stream.
|
||||
auth (Union[str, None], optional): Authentication token or credentials, if required.
|
||||
ignored (list[str], optional): List of provider names to be ignored.
|
||||
ignore_working (bool, optional): If True, ignores the working status of the provider.
|
||||
ignore_stream_and_auth (bool, optional): If True, ignores the stream and authentication requirement checks.
|
||||
patch_provider (callable, optional): Function to modify the provider.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Union[CreateResult, str]: The result of the chat completion operation.
|
||||
|
||||
Raises:
|
||||
AuthenticationRequiredError: If authentication is required but not provided.
|
||||
ProviderNotFoundError, ModelNotFoundError: If the specified provider or model is not found.
|
||||
ProviderNotWorkingError: If the provider is not operational.
|
||||
StreamNotSupportedError: If streaming is requested but not supported by the provider.
|
||||
"""
|
||||
model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth)
|
||||
|
||||
if not ignore_stream_and_auth and provider.needs_auth and not auth:
|
||||
@ -98,7 +141,24 @@ class ChatCompletion:
|
||||
ignored : list[str] = None,
|
||||
patch_provider: callable = None,
|
||||
**kwargs) -> Union[AsyncResult, str]:
|
||||
"""
|
||||
Asynchronously creates a completion using the specified model and provider.
|
||||
|
||||
Args:
|
||||
model (Union[Model, str]): The model to use, either as an object or a string identifier.
|
||||
messages (Messages): Messages to be processed.
|
||||
provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None.
|
||||
stream (bool): Indicates if the operation should be performed as a stream.
|
||||
ignored (list[str], optional): List of provider names to be ignored.
|
||||
patch_provider (callable, optional): Function to modify the provider.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Union[AsyncResult, str]: The result of the asynchronous chat completion operation.
|
||||
|
||||
Raises:
|
||||
StreamNotSupportedError: If streaming is requested but not supported by the provider.
|
||||
"""
|
||||
model, provider = get_model_and_provider(model, provider, False, ignored)
|
||||
|
||||
if stream:
|
||||
@ -118,7 +178,23 @@ class Completion:
|
||||
provider : Union[ProviderType, None] = None,
|
||||
stream : bool = False,
|
||||
ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]:
|
||||
"""
|
||||
Creates a completion based on the provided model, prompt, and provider.
|
||||
|
||||
Args:
|
||||
model (Union[Model, str]): The model to use, either as an object or a string identifier.
|
||||
prompt (str): The prompt text for which the completion is to be created.
|
||||
provider (Union[ProviderType, None], optional): The provider to use, either as an object or None.
|
||||
stream (bool, optional): Indicates if the operation should be performed as a stream.
|
||||
ignored (list[str], optional): List of provider names to be ignored.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Union[CreateResult, str]: The result of the completion operation.
|
||||
|
||||
Raises:
|
||||
ModelNotAllowedError: If the specified model is not allowed for use with this method.
|
||||
"""
|
||||
allowed_models = [
|
||||
'code-davinci-002',
|
||||
'text-ada-001',
|
||||
@ -137,6 +213,15 @@ class Completion:
|
||||
return result if stream else ''.join(result)
|
||||
|
||||
def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str]]:
|
||||
"""
|
||||
Retrieves the last used provider.
|
||||
|
||||
Args:
|
||||
as_dict (bool, optional): If True, returns the provider information as a dictionary.
|
||||
|
||||
Returns:
|
||||
Union[ProviderType, dict[str, str]]: The last used provider, either as an object or a dictionary.
|
||||
"""
|
||||
last = debug.last_provider
|
||||
if isinstance(last, BaseRetryProvider):
|
||||
last = last.last_provider
|
||||
|
@ -1,7 +1,22 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from .typing import Messages, CreateResult, Union
|
||||
|
||||
from typing import Union, List, Dict, Type
|
||||
from .typing import Messages, CreateResult
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for a provider.
|
||||
|
||||
Attributes:
|
||||
url (str): URL of the provider.
|
||||
working (bool): Indicates if the provider is currently working.
|
||||
needs_auth (bool): Indicates if the provider needs authentication.
|
||||
supports_stream (bool): Indicates if the provider supports streaming.
|
||||
supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
|
||||
supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
|
||||
supports_message_history (bool): Indicates if the provider supports message history.
|
||||
params (str): List parameters for the provider.
|
||||
"""
|
||||
|
||||
url: str = None
|
||||
working: bool = False
|
||||
needs_auth: bool = False
|
||||
@ -20,6 +35,18 @@ class BaseProvider(ABC):
|
||||
stream: bool,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
stream (bool): Whether to use streaming.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
CreateResult: The result of the creation process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
@ -30,25 +57,59 @@ class BaseProvider(ABC):
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously create a completion with the given parameters.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
messages (Messages): The messages to process.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The result of the creation process.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_dict(cls):
|
||||
def get_dict(cls) -> Dict[str, str]:
|
||||
"""
|
||||
Get a dictionary representation of the provider.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary with provider's details.
|
||||
"""
|
||||
return {'name': cls.__name__, 'url': cls.url}
|
||||
|
||||
class BaseRetryProvider(BaseProvider):
|
||||
"""
|
||||
Base class for a provider that implements retry logic.
|
||||
|
||||
Attributes:
|
||||
providers (List[Type[BaseProvider]]): List of providers to use for retries.
|
||||
shuffle (bool): Whether to shuffle the providers list.
|
||||
exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
|
||||
last_provider (Type[BaseProvider]): The last provider used.
|
||||
"""
|
||||
|
||||
__name__: str = "RetryProvider"
|
||||
supports_stream: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
providers: list[type[BaseProvider]],
|
||||
providers: List[Type[BaseProvider]],
|
||||
shuffle: bool = True
|
||||
) -> None:
|
||||
self.providers: list[type[BaseProvider]] = providers
|
||||
self.shuffle: bool = shuffle
|
||||
self.working: bool = True
|
||||
self.exceptions: dict[str, Exception] = {}
|
||||
self.last_provider: type[BaseProvider] = None
|
||||
"""
|
||||
Initialize the BaseRetryProvider.
|
||||
|
||||
Args:
|
||||
providers (List[Type[BaseProvider]]): List of providers to use.
|
||||
shuffle (bool): Whether to shuffle the providers list.
|
||||
"""
|
||||
self.providers = providers
|
||||
self.shuffle = shuffle
|
||||
self.working = True
|
||||
self.exceptions: Dict[str, Exception] = {}
|
||||
self.last_provider: Type[BaseProvider] = None
|
||||
|
||||
ProviderType = Union[type[BaseProvider], BaseRetryProvider]
|
||||
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|
@ -404,7 +404,7 @@ body {
|
||||
display: none;
|
||||
}
|
||||
|
||||
#image {
|
||||
#image, #file {
|
||||
display: none;
|
||||
}
|
||||
|
||||
@ -412,13 +412,22 @@ label[for="image"]:has(> input:valid){
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
label[for="image"] {
|
||||
label[for="file"]:has(> input:valid){
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
label[for="image"], label[for="file"] {
|
||||
cursor: pointer;
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
left: 10px;
|
||||
}
|
||||
|
||||
label[for="file"] {
|
||||
top: 32px;
|
||||
left: 10px;
|
||||
}
|
||||
|
||||
.buttons input[type="checkbox"] {
|
||||
height: 0;
|
||||
width: 0;
|
||||
|
@ -118,6 +118,10 @@
|
||||
<input type="file" id="image" name="image" accept="image/png, image/gif, image/jpeg" required/>
|
||||
<i class="fa-regular fa-image"></i>
|
||||
</label>
|
||||
<label for="file">
|
||||
<input type="file" id="file" name="file" accept="text/plain, text/html, text/xml, application/json, text/javascript, .sh, .py, .php, .css, .yaml, .sql, .svg, .log, .csv, .twig, .md" required/>
|
||||
<i class="fa-solid fa-paperclip"></i>
|
||||
</label>
|
||||
<div id="send-button">
|
||||
<i class="fa-solid fa-paper-plane-top"></i>
|
||||
</div>
|
||||
@ -125,7 +129,14 @@
|
||||
</div>
|
||||
<div class="buttons">
|
||||
<div class="field">
|
||||
<select name="model" id="model"></select>
|
||||
<select name="model" id="model">
|
||||
<option value="">Model: Default</option>
|
||||
<option value="gpt-4">gpt-4</option>
|
||||
<option value="gpt-3.5-turbo">gpt-3.5-turbo</option>
|
||||
<option value="llama2-70b">llama2-70b</option>
|
||||
<option value="gemini-pro">gemini-pro</option>
|
||||
<option value="">----</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="field">
|
||||
<select name="jailbreak" id="jailbreak" style="display: none;">
|
||||
@ -138,7 +149,16 @@
|
||||
<option value="gpt-evil-1.0">evil 1.0</option>
|
||||
</select>
|
||||
<div class="field">
|
||||
<select name="provider" id="provider"></select>
|
||||
<select name="provider" id="provider">
|
||||
<option value="">Provider: Auto</option>
|
||||
<option value="Bing">Bing</option>
|
||||
<option value="OpenaiChat">OpenaiChat</option>
|
||||
<option value="HuggingChat">HuggingChat</option>
|
||||
<option value="Bard">Bard</option>
|
||||
<option value="Liaobots">Liaobots</option>
|
||||
<option value="Phind">Phind</option>
|
||||
<option value="">----</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
<div class="field">
|
||||
|
@ -7,7 +7,9 @@ const spinner = box_conversations.querySelector(".spinner");
|
||||
const stop_generating = document.querySelector(`.stop_generating`);
|
||||
const regenerate = document.querySelector(`.regenerate`);
|
||||
const send_button = document.querySelector(`#send-button`);
|
||||
const imageInput = document.querySelector('#image') ;
|
||||
const imageInput = document.querySelector('#image');
|
||||
const fileInput = document.querySelector('#file');
|
||||
|
||||
let prompt_lock = false;
|
||||
|
||||
hljs.addPlugin(new CopyButtonPlugin());
|
||||
@ -42,6 +44,11 @@ const handle_ask = async () => {
|
||||
if (message.length > 0) {
|
||||
message_input.value = '';
|
||||
await add_conversation(window.conversation_id, message);
|
||||
if ("text" in fileInput.dataset) {
|
||||
message += '\n```' + fileInput.dataset.type + '\n';
|
||||
message += fileInput.dataset.text;
|
||||
message += '\n```'
|
||||
}
|
||||
await add_message(window.conversation_id, "user", message);
|
||||
window.token = message_id();
|
||||
message_box.innerHTML += `
|
||||
@ -55,6 +62,9 @@ const handle_ask = async () => {
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
document.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
hljs.highlightElement(el);
|
||||
});
|
||||
await ask_gpt();
|
||||
}
|
||||
};
|
||||
@ -171,17 +181,30 @@ const ask_gpt = async () => {
|
||||
content_inner.innerHTML += "<p>An error occured, please try again, if the problem persists, please use a other model or provider.</p>";
|
||||
} else {
|
||||
html = markdown_render(text);
|
||||
html = html.substring(0, html.lastIndexOf('</p>')) + '<span id="cursor"></span></p>';
|
||||
let lastElement, lastIndex = null;
|
||||
for (element of ['</p>', '</code></pre>', '</li>\n</ol>']) {
|
||||
const index = html.lastIndexOf(element)
|
||||
if (index > lastIndex) {
|
||||
lastElement = element;
|
||||
lastIndex = index;
|
||||
}
|
||||
}
|
||||
if (lastIndex) {
|
||||
html = html.substring(0, lastIndex) + '<span id="cursor"></span>' + lastElement;
|
||||
}
|
||||
content_inner.innerHTML = html;
|
||||
document.querySelectorAll('code').forEach((el) => {
|
||||
document.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
hljs.highlightElement(el);
|
||||
});
|
||||
}
|
||||
|
||||
window.scrollTo(0, 0);
|
||||
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
|
||||
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
|
||||
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
|
||||
}
|
||||
}
|
||||
if (!error && imageInput) imageInput.value = "";
|
||||
if (!error && fileInput) fileInput.value = "";
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
|
||||
@ -305,7 +328,7 @@ const load_conversation = async (conversation_id) => {
|
||||
`;
|
||||
}
|
||||
|
||||
document.querySelectorAll(`code`).forEach((el) => {
|
||||
document.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
hljs.highlightElement(el);
|
||||
});
|
||||
|
||||
@ -400,7 +423,7 @@ const load_conversations = async (limit, offset, loader) => {
|
||||
`;
|
||||
}
|
||||
|
||||
document.querySelectorAll(`code`).forEach((el) => {
|
||||
document.querySelectorAll('code:not(.hljs').forEach((el) => {
|
||||
hljs.highlightElement(el);
|
||||
});
|
||||
};
|
||||
@ -602,14 +625,7 @@ observer.observe(message_input, { attributes: true });
|
||||
(async () => {
|
||||
response = await fetch('/backend-api/v2/models')
|
||||
models = await response.json()
|
||||
|
||||
let select = document.getElementById('model');
|
||||
select.textContent = '';
|
||||
|
||||
let auto = document.createElement('option');
|
||||
auto.value = '';
|
||||
auto.text = 'Model: Default';
|
||||
select.appendChild(auto);
|
||||
|
||||
for (model of models) {
|
||||
let option = document.createElement('option');
|
||||
@ -619,14 +635,7 @@ observer.observe(message_input, { attributes: true });
|
||||
|
||||
response = await fetch('/backend-api/v2/providers')
|
||||
providers = await response.json()
|
||||
|
||||
select = document.getElementById('provider');
|
||||
select.textContent = '';
|
||||
|
||||
auto = document.createElement('option');
|
||||
auto.value = '';
|
||||
auto.text = 'Provider: Auto';
|
||||
select.appendChild(auto);
|
||||
|
||||
for (provider of providers) {
|
||||
let option = document.createElement('option');
|
||||
@ -650,4 +659,27 @@ observer.observe(message_input, { attributes: true });
|
||||
text += versions["version"];
|
||||
}
|
||||
document.getElementById("version_text").innerHTML = text
|
||||
})()
|
||||
})()
|
||||
|
||||
fileInput.addEventListener('change', async (event) => {
|
||||
if (fileInput.files.length) {
|
||||
type = fileInput.files[0].type;
|
||||
if (type && type.indexOf('/')) {
|
||||
type = type.split('/').pop().replace('x-', '')
|
||||
type = type.replace('plain', 'plaintext')
|
||||
.replace('shellscript', 'sh')
|
||||
.replace('svg+xml', 'svg')
|
||||
.replace('vnd.trolltech.linguist', 'ts')
|
||||
} else {
|
||||
type = fileInput.files[0].name.split('.').pop()
|
||||
}
|
||||
fileInput.dataset.type = type
|
||||
const reader = new FileReader();
|
||||
reader.addEventListener('load', (event) => {
|
||||
fileInput.dataset.text = event.target.result;
|
||||
});
|
||||
reader.readAsText(fileInput.files[0]);
|
||||
} else {
|
||||
delete fileInput.dataset.text;
|
||||
}
|
||||
});
|
105
g4f/image.py
105
g4f/image.py
@ -4,9 +4,18 @@ import base64
|
||||
from .typing import ImageType, Union
|
||||
from PIL import Image
|
||||
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
|
||||
|
||||
def to_image(image: ImageType) -> Image.Image:
|
||||
"""
|
||||
Converts the input image to a PIL Image object.
|
||||
|
||||
Args:
|
||||
image (Union[str, bytes, Image.Image]): The input image.
|
||||
|
||||
Returns:
|
||||
Image.Image: The converted PIL Image object.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
is_data_uri_an_image(image)
|
||||
image = extract_data_uri(image)
|
||||
@ -20,21 +29,48 @@ def to_image(image: ImageType) -> Image.Image:
|
||||
image = copy
|
||||
return image
|
||||
|
||||
def is_allowed_extension(filename) -> bool:
|
||||
def is_allowed_extension(filename: str) -> bool:
|
||||
"""
|
||||
Checks if the given filename has an allowed extension.
|
||||
|
||||
Args:
|
||||
filename (str): The filename to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the extension is allowed, False otherwise.
|
||||
"""
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
def is_data_uri_an_image(data_uri: str) -> bool:
|
||||
"""
|
||||
Checks if the given data URI represents an image.
|
||||
|
||||
Args:
|
||||
data_uri (str): The data URI to check.
|
||||
|
||||
Raises:
|
||||
ValueError: If the data URI is invalid or the image format is not allowed.
|
||||
"""
|
||||
# Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
|
||||
if not re.match(r'data:image/(\w+);base64,', data_uri):
|
||||
raise ValueError("Invalid data URI image.")
|
||||
# Extract the image format from the data URI
|
||||
# Extract the image format from the data URI
|
||||
image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1)
|
||||
# Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
|
||||
if image_format.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise ValueError("Invalid image format (from mime file type).")
|
||||
|
||||
def is_accepted_format(binary_data: bytes) -> bool:
|
||||
"""
|
||||
Checks if the given binary data represents an image with an accepted format.
|
||||
|
||||
Args:
|
||||
binary_data (bytes): The binary data to check.
|
||||
|
||||
Raises:
|
||||
ValueError: If the image format is not allowed.
|
||||
"""
|
||||
if binary_data.startswith(b'\xFF\xD8\xFF'):
|
||||
pass # It's a JPEG image
|
||||
elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
|
||||
@ -49,13 +85,31 @@ def is_accepted_format(binary_data: bytes) -> bool:
|
||||
pass # It's a WebP image
|
||||
else:
|
||||
raise ValueError("Invalid image format (from magic code).")
|
||||
|
||||
|
||||
def extract_data_uri(data_uri: str) -> bytes:
|
||||
"""
|
||||
Extracts the binary data from the given data URI.
|
||||
|
||||
Args:
|
||||
data_uri (str): The data URI.
|
||||
|
||||
Returns:
|
||||
bytes: The extracted binary data.
|
||||
"""
|
||||
data = data_uri.split(",")[1]
|
||||
data = base64.b64decode(data)
|
||||
return data
|
||||
|
||||
def get_orientation(image: Image.Image) -> int:
|
||||
"""
|
||||
Gets the orientation of the given image.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image.
|
||||
|
||||
Returns:
|
||||
int: The orientation value.
|
||||
"""
|
||||
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
|
||||
if exif_data is not None:
|
||||
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
|
||||
@ -63,6 +117,17 @@ def get_orientation(image: Image.Image) -> int:
|
||||
return orientation
|
||||
|
||||
def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image:
|
||||
"""
|
||||
Processes the given image by adjusting its orientation and resizing it.
|
||||
|
||||
Args:
|
||||
img (Image.Image): The image to process.
|
||||
new_width (int): The new width of the image.
|
||||
new_height (int): The new height of the image.
|
||||
|
||||
Returns:
|
||||
Image.Image: The processed image.
|
||||
"""
|
||||
orientation = get_orientation(img)
|
||||
if orientation:
|
||||
if orientation > 4:
|
||||
@ -75,13 +140,34 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
|
||||
img = img.transpose(Image.ROTATE_90)
|
||||
img.thumbnail((new_width, new_height))
|
||||
return img
|
||||
|
||||
|
||||
def to_base64(image: Image.Image, compression_rate: float) -> str:
|
||||
"""
|
||||
Converts the given image to a base64-encoded string.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to convert.
|
||||
compression_rate (float): The compression rate (0.0 to 1.0).
|
||||
|
||||
Returns:
|
||||
str: The base64-encoded image.
|
||||
"""
|
||||
output_buffer = BytesIO()
|
||||
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
|
||||
return base64.b64encode(output_buffer.getvalue()).decode()
|
||||
|
||||
def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
|
||||
"""
|
||||
Formats the given images as a markdown string.
|
||||
|
||||
Args:
|
||||
images: The images to format.
|
||||
prompt (str): The prompt for the images.
|
||||
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown string.
|
||||
"""
|
||||
if isinstance(images, list):
|
||||
images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
|
||||
images = "\n".join(images)
|
||||
@ -92,6 +178,15 @@ def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=20
|
||||
return f"\n{start_flag}{images}\n{end_flag}\n"
|
||||
|
||||
def to_bytes(image: Image.Image) -> bytes:
|
||||
"""
|
||||
Converts the given image to bytes.
|
||||
|
||||
Args:
|
||||
image (Image.Image): The image to convert.
|
||||
|
||||
Returns:
|
||||
bytes: The image as bytes.
|
||||
"""
|
||||
bytes_io = BytesIO()
|
||||
image.save(bytes_io, image.format)
|
||||
image.seek(0)
|
||||
|
@ -31,12 +31,21 @@ from .Provider import (
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Model:
|
||||
"""
|
||||
Represents a machine learning model configuration.
|
||||
|
||||
Attributes:
|
||||
name (str): Name of the model.
|
||||
base_provider (str): Default provider for the model.
|
||||
best_provider (ProviderType): The preferred provider for the model, typically with retry logic.
|
||||
"""
|
||||
name: str
|
||||
base_provider: str
|
||||
best_provider: ProviderType = None
|
||||
|
||||
@staticmethod
|
||||
def __all__() -> list[str]:
|
||||
"""Returns a list of all model names."""
|
||||
return _all_models
|
||||
|
||||
default = Model(
|
||||
@ -298,6 +307,12 @@ pi = Model(
|
||||
)
|
||||
|
||||
class ModelUtils:
|
||||
"""
|
||||
Utility class for mapping string identifiers to Model instances.
|
||||
|
||||
Attributes:
|
||||
convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances.
|
||||
"""
|
||||
convert: dict[str, Model] = {
|
||||
# gpt-3.5
|
||||
'gpt-3.5-turbo' : gpt_35_turbo,
|
||||
|
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partialmethod
|
||||
from typing import AsyncGenerator
|
||||
from urllib.parse import urlparse
|
||||
@ -9,27 +8,41 @@ from curl_cffi.requests import AsyncSession, Session, Response
|
||||
from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies
|
||||
|
||||
class StreamResponse:
|
||||
"""
|
||||
A wrapper class for handling asynchronous streaming responses.
|
||||
|
||||
Attributes:
|
||||
inner (Response): The original Response object.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: Response) -> None:
|
||||
"""Initialize the StreamResponse with the provided Response object."""
|
||||
self.inner: Response = inner
|
||||
|
||||
async def text(self) -> str:
|
||||
"""Asynchronously get the response text."""
|
||||
return await self.inner.atext()
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""Raise an HTTPError if one occurred."""
|
||||
self.inner.raise_for_status()
|
||||
|
||||
async def json(self, **kwargs) -> dict:
|
||||
"""Asynchronously parse the JSON response content."""
|
||||
return json.loads(await self.inner.acontent(), **kwargs)
|
||||
|
||||
async def iter_lines(self) -> AsyncGenerator[bytes, None]:
|
||||
"""Asynchronously iterate over the lines of the response."""
|
||||
async for line in self.inner.aiter_lines():
|
||||
yield line
|
||||
|
||||
async def iter_content(self) -> AsyncGenerator[bytes, None]:
|
||||
"""Asynchronously iterate over the response content."""
|
||||
async for chunk in self.inner.aiter_content():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Asynchronously enter the runtime context for the response object."""
|
||||
inner: Response = await self.inner
|
||||
self.inner = inner
|
||||
self.request = inner.request
|
||||
@ -39,24 +52,47 @@ class StreamResponse:
|
||||
self.headers = inner.headers
|
||||
self.cookies = inner.cookies
|
||||
return self
|
||||
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
"""Asynchronously exit the runtime context for the response object."""
|
||||
await self.inner.aclose()
|
||||
|
||||
|
||||
class StreamSession(AsyncSession):
|
||||
"""
|
||||
An asynchronous session class for handling HTTP requests with streaming.
|
||||
|
||||
Inherits from AsyncSession.
|
||||
"""
|
||||
|
||||
def request(
|
||||
self, method: str, url: str, **kwargs
|
||||
) -> StreamResponse:
|
||||
"""Create and return a StreamResponse object for the given HTTP request."""
|
||||
return StreamResponse(super().request(method, url, stream=True, **kwargs))
|
||||
|
||||
# Defining HTTP methods as partial methods of the request method.
|
||||
head = partialmethod(request, "HEAD")
|
||||
get = partialmethod(request, "GET")
|
||||
post = partialmethod(request, "POST")
|
||||
put = partialmethod(request, "PUT")
|
||||
patch = partialmethod(request, "PATCH")
|
||||
delete = partialmethod(request, "DELETE")
|
||||
|
||||
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120):
|
||||
|
||||
|
||||
def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session:
|
||||
"""
|
||||
Create a Session object using a WebDriver to handle cookies and headers.
|
||||
|
||||
Args:
|
||||
url (str): The URL to navigate to using the WebDriver.
|
||||
webdriver (WebDriver, optional): The WebDriver instance to use.
|
||||
proxy (str, optional): Proxy server to use for the Session.
|
||||
timeout (int, optional): Timeout in seconds for the WebDriver.
|
||||
|
||||
Returns:
|
||||
Session: A Session object configured with cookies and headers from the WebDriver.
|
||||
"""
|
||||
with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=True) as driver:
|
||||
bypass_cloudflare(driver, url, timeout)
|
||||
cookies = get_driver_cookies(driver)
|
||||
@ -78,4 +114,4 @@ def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str =
|
||||
proxies={"https": proxy, "http": proxy},
|
||||
timeout=timeout,
|
||||
impersonate="chrome110"
|
||||
)
|
||||
)
|
@ -5,45 +5,94 @@ from importlib.metadata import version as get_package_version, PackageNotFoundEr
|
||||
from subprocess import check_output, CalledProcessError, PIPE
|
||||
from .errors import VersionNotFoundError
|
||||
|
||||
def get_latest_version() -> str:
|
||||
try:
|
||||
get_package_version("g4f")
|
||||
response = requests.get("https://pypi.org/pypi/g4f/json").json()
|
||||
return response["info"]["version"]
|
||||
except PackageNotFoundError:
|
||||
url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest"
|
||||
response = requests.get(url).json()
|
||||
return response["tag_name"]
|
||||
def get_pypi_version(package_name: str) -> str:
|
||||
"""
|
||||
Get the latest version of a package from PyPI.
|
||||
|
||||
class VersionUtils():
|
||||
:param package_name: The name of the package.
|
||||
:return: The latest version of the package as a string.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
|
||||
return response["info"]["version"]
|
||||
except requests.RequestException as e:
|
||||
raise VersionNotFoundError(f"Failed to get PyPI version: {e}")
|
||||
|
||||
def get_github_version(repo: str) -> str:
|
||||
"""
|
||||
Get the latest release version from a GitHub repository.
|
||||
|
||||
:param repo: The name of the GitHub repository.
|
||||
:return: The latest release version as a string.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
|
||||
return response["tag_name"]
|
||||
except requests.RequestException as e:
|
||||
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
|
||||
|
||||
def get_latest_version():
|
||||
"""
|
||||
Get the latest release version from PyPI or the GitHub repository.
|
||||
|
||||
:return: The latest release version as a string.
|
||||
"""
|
||||
try:
|
||||
# Is installed via package manager?
|
||||
get_package_version("g4f")
|
||||
return get_pypi_version("g4f")
|
||||
except PackageNotFoundError:
|
||||
# Else use Github version:
|
||||
return get_github_version("xtekky/gpt4free")
|
||||
|
||||
class VersionUtils:
|
||||
"""
|
||||
Utility class for managing and comparing package versions.
|
||||
"""
|
||||
@cached_property
|
||||
def current_version(self) -> str:
|
||||
"""
|
||||
Get the current version of the g4f package.
|
||||
|
||||
:return: The current version as a string.
|
||||
"""
|
||||
# Read from package manager
|
||||
try:
|
||||
return get_package_version("g4f")
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
# Read from docker environment
|
||||
version = environ.get("G4F_VERSION")
|
||||
if version:
|
||||
return version
|
||||
|
||||
# Read from git repository
|
||||
try:
|
||||
command = ["git", "describe", "--tags", "--abbrev=0"]
|
||||
return check_output(command, text=True, stderr=PIPE).strip()
|
||||
except CalledProcessError:
|
||||
pass
|
||||
|
||||
raise VersionNotFoundError("Version not found")
|
||||
|
||||
|
||||
@cached_property
|
||||
def latest_version(self) -> str:
|
||||
"""
|
||||
Get the latest version of the g4f package.
|
||||
|
||||
:return: The latest version as a string.
|
||||
"""
|
||||
return get_latest_version()
|
||||
|
||||
|
||||
def check_version(self) -> None:
|
||||
"""
|
||||
Check if the current version is up to date with the latest version.
|
||||
"""
|
||||
try:
|
||||
if self.current_version != self.latest_version:
|
||||
print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f')
|
||||
except Exception as e:
|
||||
print(f'Failed to check g4f version: {e}')
|
||||
|
||||
|
||||
utils = VersionUtils()
|
@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from platformdirs import user_config_dir
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from undetected_chromedriver import Chrome, ChromeOptions
|
||||
@ -21,7 +20,16 @@ def get_browser(
|
||||
proxy: str = None,
|
||||
options: ChromeOptions = None
|
||||
) -> WebDriver:
|
||||
if user_data_dir == None:
|
||||
"""
|
||||
Creates and returns a Chrome WebDriver with the specified options.
|
||||
|
||||
:param user_data_dir: Directory for user data. If None, uses default directory.
|
||||
:param headless: Boolean indicating whether to run the browser in headless mode.
|
||||
:param proxy: Proxy settings for the browser.
|
||||
:param options: ChromeOptions object with specific browser options.
|
||||
:return: An instance of WebDriver.
|
||||
"""
|
||||
if user_data_dir is None:
|
||||
user_data_dir = user_config_dir("g4f")
|
||||
if user_data_dir and debug.logging:
|
||||
print("Open browser with config dir:", user_data_dir)
|
||||
@ -39,36 +47,45 @@ def get_browser(
|
||||
headless=headless
|
||||
)
|
||||
|
||||
def get_driver_cookies(driver: WebDriver):
|
||||
return dict([(cookie["name"], cookie["value"]) for cookie in driver.get_cookies()])
|
||||
def get_driver_cookies(driver: WebDriver) -> dict:
|
||||
"""
|
||||
Retrieves cookies from the given WebDriver.
|
||||
|
||||
:param driver: WebDriver from which to retrieve cookies.
|
||||
:return: A dictionary of cookies.
|
||||
"""
|
||||
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
|
||||
|
||||
def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
|
||||
# Open website
|
||||
"""
|
||||
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
|
||||
|
||||
:param driver: The WebDriver to use.
|
||||
:param url: URL to access.
|
||||
:param timeout: Time in seconds to wait for the page to load.
|
||||
"""
|
||||
driver.get(url)
|
||||
# Is cloudflare protection
|
||||
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
|
||||
if debug.logging:
|
||||
print("Cloudflare protection detected:", url)
|
||||
try:
|
||||
# Click button in iframe
|
||||
WebDriverWait(driver, 5).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
|
||||
)
|
||||
driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe"))
|
||||
WebDriverWait(driver, 5).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input"))
|
||||
)
|
||||
driver.find_element(By.CSS_SELECTOR, "#challenge-stage input").click()
|
||||
except:
|
||||
pass
|
||||
).click()
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error bypassing Cloudflare: {e}")
|
||||
finally:
|
||||
driver.switch_to.default_content()
|
||||
# No cloudflare protection
|
||||
WebDriverWait(driver, timeout).until(
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)"))
|
||||
)
|
||||
|
||||
class WebDriverSession():
|
||||
class WebDriverSession:
|
||||
"""
|
||||
Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
webdriver: WebDriver = None,
|
||||
@ -81,9 +98,7 @@ class WebDriverSession():
|
||||
self.webdriver = webdriver
|
||||
self.user_data_dir = user_data_dir
|
||||
self.headless = headless
|
||||
self.virtual_display = None
|
||||
if has_pyvirtualdisplay and virtual_display:
|
||||
self.virtual_display = Display(size=(1920, 1080))
|
||||
self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None
|
||||
self.proxy = proxy
|
||||
self.options = options
|
||||
self.default_driver = None
|
||||
@ -94,8 +109,15 @@ class WebDriverSession():
|
||||
headless: bool = False,
|
||||
virtual_display: bool = False
|
||||
) -> WebDriver:
|
||||
if user_data_dir == None:
|
||||
user_data_dir = self.user_data_dir
|
||||
"""
|
||||
Reopens the WebDriver session with the specified parameters.
|
||||
|
||||
:param user_data_dir: Directory for user data.
|
||||
:param headless: Boolean indicating whether to run the browser in headless mode.
|
||||
:param virtual_display: Boolean indicating whether to use a virtual display.
|
||||
:return: An instance of WebDriver.
|
||||
"""
|
||||
user_data_dir = user_data_dir or self.user_data_dir
|
||||
if self.default_driver:
|
||||
self.default_driver.quit()
|
||||
if not virtual_display and self.virtual_display:
|
||||
@ -105,6 +127,10 @@ class WebDriverSession():
|
||||
return self.default_driver
|
||||
|
||||
def __enter__(self) -> WebDriver:
|
||||
"""
|
||||
Context management method for entering a session.
|
||||
:return: An instance of WebDriver.
|
||||
"""
|
||||
if self.webdriver:
|
||||
return self.webdriver
|
||||
if self.virtual_display:
|
||||
@ -113,11 +139,15 @@ class WebDriverSession():
|
||||
return self.default_driver
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""
|
||||
Context management method for exiting a session. Closes and quits the WebDriver.
|
||||
"""
|
||||
if self.default_driver:
|
||||
try:
|
||||
self.default_driver.close()
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
if debug.logging:
|
||||
print(f"Error closing WebDriver: {e}")
|
||||
self.default_driver.quit()
|
||||
if self.virtual_display:
|
||||
self.virtual_display.stop()
|
Loading…
Reference in New Issue
Block a user