Merge pull request #1950 from hlohaus/leech

Update chatgpt url, uvloop support
This commit is contained in:
H Lohaus 2024-05-15 02:28:45 +02:00 committed by GitHub
commit 008ed60d98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 222 additions and 86 deletions

View File

@ -17,12 +17,12 @@ except ImportError:
pass pass
from ... import debug from ... import debug
from ...typing import Messages, Cookies, ImageType, AsyncResult from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator
from ..base_provider import AsyncGeneratorProvider from ..base_provider import AsyncGeneratorProvider
from ..helper import format_prompt, get_cookies from ..helper import format_prompt, get_cookies
from ...requests.raise_for_status import raise_for_status from ...requests.raise_for_status import raise_for_status
from ...errors import MissingAuthError, MissingRequirementsError from ...errors import MissingAuthError, MissingRequirementsError
from ...image import to_bytes, ImageResponse from ...image import to_bytes, to_data_uri, ImageResponse
from ...webdriver import get_browser, get_driver_cookies from ...webdriver import get_browser, get_driver_cookies
REQUEST_HEADERS = { REQUEST_HEADERS = {
@ -59,7 +59,7 @@ class Gemini(AsyncGeneratorProvider):
_cookies: Cookies = None _cookies: Cookies = None
@classmethod @classmethod
async def nodriver_login(cls) -> Cookies: async def nodriver_login(cls) -> AsyncIterator[str]:
try: try:
import nodriver as uc import nodriver as uc
except ImportError: except ImportError:
@ -72,6 +72,9 @@ class Gemini(AsyncGeneratorProvider):
if debug.logging: if debug.logging:
print(f"Open nodriver with user_dir: {user_data_dir}") print(f"Open nodriver with user_dir: {user_data_dir}")
browser = await uc.start(user_data_dir=user_data_dir) browser = await uc.start(user_data_dir=user_data_dir)
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
yield f"Please login: [Google Gemini]({login_url})\n\n"
page = await browser.get(f"{cls.url}/app") page = await browser.get(f"{cls.url}/app")
await page.select("div.ql-editor.textarea", 240) await page.select("div.ql-editor.textarea", 240)
cookies = {} cookies = {}
@ -79,10 +82,10 @@ class Gemini(AsyncGeneratorProvider):
if c.domain.endswith(".google.com"): if c.domain.endswith(".google.com"):
cookies[c.name] = c.value cookies[c.name] = c.value
await page.close() await page.close()
return cookies cls._cookies = cookies
@classmethod @classmethod
async def webdriver_login(cls, proxy: str): async def webdriver_login(cls, proxy: str) -> AsyncIterator[str]:
driver = None driver = None
try: try:
driver = get_browser(proxy=proxy) driver = get_browser(proxy=proxy)
@ -131,13 +134,14 @@ class Gemini(AsyncGeneratorProvider):
) as session: ) as session:
snlm0e = await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None snlm0e = await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None
if not snlm0e: if not snlm0e:
cls._cookies = await cls.nodriver_login(); async for chunk in cls.nodriver_login():
yield chunk
if cls._cookies is None: if cls._cookies is None:
async for chunk in cls.webdriver_login(proxy): async for chunk in cls.webdriver_login(proxy):
yield chunk yield chunk
if not snlm0e: if not snlm0e:
if "__Secure-1PSID" not in cls._cookies: if cls._cookies is None or "__Secure-1PSID" not in cls._cookies:
raise MissingAuthError('Missing "__Secure-1PSID" cookie') raise MissingAuthError('Missing "__Secure-1PSID" cookie')
snlm0e = await cls.fetch_snlm0e(session, cls._cookies) snlm0e = await cls.fetch_snlm0e(session, cls._cookies)
if not snlm0e: if not snlm0e:
@ -193,6 +197,13 @@ class Gemini(AsyncGeneratorProvider):
image = fetch.headers["location"] image = fetch.headers["location"]
resolved_images.append(image) resolved_images.append(image)
preview.append(image.replace('=s512', '=s200')) preview.append(image.replace('=s512', '=s200'))
# preview_url = image.replace('=s512', '=s200')
# async with client.get(preview_url) as fetch:
# preview_data = to_data_uri(await fetch.content.read())
# async with client.get(image) as fetch:
# data = to_data_uri(await fetch.content.read())
# preview.append(preview_data)
# resolved_images.append(data)
yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
def build_request( def build_request(

View File

@ -38,7 +38,7 @@ DEFAULT_HEADERS = {
"accept": "*/*", "accept": "*/*",
"accept-encoding": "gzip, deflate, br, zstd", "accept-encoding": "gzip, deflate, br, zstd",
"accept-language": "en-US,en;q=0.5", "accept-language": "en-US,en;q=0.5",
"referer": "https://chat.openai.com/", "referer": "https://chatgpt.com/",
"sec-ch-ua": "\"Brave\";v=\"123\", \"Not:A-Brand\";v=\"8\", \"Chromium\";v=\"123\"", "sec-ch-ua": "\"Brave\";v=\"123\", \"Not:A-Brand\";v=\"8\", \"Chromium\";v=\"123\"",
"sec-ch-ua-mobile": "?0", "sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": "\"Windows\"", "sec-ch-ua-platform": "\"Windows\"",
@ -53,15 +53,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"""A class for creating and managing conversations with OpenAI chat service""" """A class for creating and managing conversations with OpenAI chat service"""
label = "OpenAI ChatGPT" label = "OpenAI ChatGPT"
url = "https://chat.openai.com" url = "https://chatgpt.com"
working = True working = True
supports_gpt_35_turbo = True supports_gpt_35_turbo = True
supports_gpt_4 = True supports_gpt_4 = True
supports_message_history = True supports_message_history = True
supports_system_message = True supports_system_message = True
default_model = None default_model = None
default_vision_model = "gpt-4-vision" default_vision_model = "gpt-4o"
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"] models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo", "gpt-4o"]
model_aliases = { model_aliases = {
"text-davinci-002-render-sha": "gpt-3.5-turbo", "text-davinci-002-render-sha": "gpt-3.5-turbo",
"": "gpt-3.5-turbo", "": "gpt-3.5-turbo",
@ -442,6 +442,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
try: try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e: except Exception as e:
image_request = None
if debug.logging: if debug.logging:
print("OpenaiChat: Upload image failed") print("OpenaiChat: Upload image failed")
print(f"{e.__class__.__name__}: {e}") print(f"{e.__class__.__name__}: {e}")
@ -601,7 +602,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
this._fetch = this.fetch; this._fetch = this.fetch;
this.fetch = async (url, options) => { this.fetch = async (url, options) => {
const response = await this._fetch(url, options); const response = await this._fetch(url, options);
if (url == "https://chat.openai.com/backend-api/conversation") { if (url == "https://chatgpt.com/backend-api/conversation") {
this._headers = options.headers; this._headers = options.headers;
return response; return response;
} }
@ -637,7 +638,7 @@ this.fetch = async (url, options) => {
if debug.logging: if debug.logging:
print(f"Open nodriver with user_dir: {user_data_dir}") print(f"Open nodriver with user_dir: {user_data_dir}")
browser = await uc.start(user_data_dir=user_data_dir) browser = await uc.start(user_data_dir=user_data_dir)
page = await browser.get("https://chat.openai.com/") page = await browser.get("https://chatgpt.com/")
await page.select("[id^=headlessui-menu-button-]", 240) await page.select("[id^=headlessui-menu-button-]", 240)
api_key = await page.evaluate( api_key = await page.evaluate(
"(async () => {" "(async () => {"
@ -652,7 +653,7 @@ this.fetch = async (url, options) => {
) )
cookies = {} cookies = {}
for c in await page.browser.cookies.get_all(): for c in await page.browser.cookies.get_all():
if c.domain.endswith("chat.openai.com"): if c.domain.endswith("chatgpt.com"):
cookies[c.name] = c.value cookies[c.name] = c.value
user_agent = await page.evaluate("window.navigator.userAgent") user_agent = await page.evaluate("window.navigator.userAgent")
await page.close() await page.close()

View File

@ -26,7 +26,7 @@ class arkReq:
self.userAgent = userAgent self.userAgent = userAgent
arkPreURL = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147" arkPreURL = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147"
sessionUrl = "https://chat.openai.com/api/auth/session" sessionUrl = "https://chatgpt.com/api/auth/session"
chatArk: arkReq = None chatArk: arkReq = None
accessToken: str = None accessToken: str = None
cookies: dict = None cookies: dict = None

View File

@ -16,12 +16,9 @@ def generate_proof_token(required: bool, seed: str, difficulty: str, user_agent:
# Get current UTC time # Get current UTC time
now_utc = datetime.now(timezone.utc) now_utc = datetime.now(timezone.utc)
# Convert UTC time to Eastern Time parse_time = now_utc.strftime('%a, %d %b %Y %H:%M:%S GMT')
now_et = now_utc.astimezone(timezone(timedelta(hours=-5)))
parse_time = now_et.strftime('%a, %d %b %Y %H:%M:%S GMT') config = [core + screen, parse_time, None, 0, user_agent, "https://tcr9i.chat.openai.com/v2/35536E1E-65B4-4D96-9D97-6ADB7EFF8147/api.js","dpl=53d243de46ff04dadd88d293f088c2dd728f126f","en","en-US",442,"plugins[object PluginArray]","","alert"]
config = [core + screen, parse_time, 4294705152, 0, user_agent]
diff_len = len(difficulty) // 2 diff_len = len(difficulty) // 2

View File

@ -11,10 +11,9 @@ from .types import AsyncIterResponse, ImageProvider
from .image_models import ImageModels from .image_models import ImageModels
from .helper import filter_json, find_stop, filter_none, cast_iter_async from .helper import filter_json, find_stop, filter_none, cast_iter_async
from .service import get_last_provider, get_model_and_provider from .service import get_last_provider, get_model_and_provider
from ..typing import Union, Iterator, Messages, AsyncIterator, ImageType from ..typing import Union, Messages, AsyncIterator, ImageType
from ..errors import NoImageResponseError from ..errors import NoImageResponseError
from ..image import ImageResponse as ImageProviderResponse from ..image import ImageResponse as ImageProviderResponse
from ..providers.base_provider import AsyncGeneratorProvider
try: try:
anext anext
@ -88,7 +87,7 @@ def create_response(
api_key: str = None, api_key: str = None,
**kwargs **kwargs
): ):
has_asnyc = isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider) has_asnyc = hasattr(provider, "create_async_generator")
if has_asnyc: if has_asnyc:
create = provider.create_async_generator create = provider.create_async_generator
else: else:
@ -157,7 +156,7 @@ class Chat():
def __init__(self, client: AsyncClient, provider: ProviderType = None): def __init__(self, client: AsyncClient, provider: ProviderType = None):
self.completions = Completions(client, provider) self.completions = Completions(client, provider)
async def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]: async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
async for chunk in response: async for chunk in response:
if isinstance(chunk, ImageProviderResponse): if isinstance(chunk, ImageProviderResponse):
return ImagesResponse([Image(image) for image in chunk.get_list()]) return ImagesResponse([Image(image) for image in chunk.get_list()])
@ -182,7 +181,7 @@ class Images():
async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse: async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse:
provider = self.models.get(model, self.provider) provider = self.models.get(model, self.provider)
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): if hasattr(provider, "create_async_generator"):
response = create_image(self.client, provider, prompt, **kwargs) response = create_image(self.client, provider, prompt, **kwargs)
else: else:
response = await provider.create_async(prompt) response = await provider.create_async(prompt)
@ -195,7 +194,7 @@ class Images():
async def create_variation(self, image: ImageType, model: str = None, **kwargs): async def create_variation(self, image: ImageType, model: str = None, **kwargs):
provider = self.models.get(model, self.provider) provider = self.models.get(model, self.provider)
result = None result = None
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): if hasattr(provider, "create_async_generator"):
response = provider.create_async_generator( response = provider.create_async_generator(
"", "",
[{"role": "user", "content": "create a image like this"}], [{"role": "user", "content": "create a image like this"}],

View File

@ -19,8 +19,7 @@
<script src="/static/js/highlightjs-copy.min.js"></script> <script src="/static/js/highlightjs-copy.min.js"></script>
<script src="/static/js/chat.v1.js" defer></script> <script src="/static/js/chat.v1.js" defer></script>
<script src="https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js"></script>
<link rel="stylesheet" <link rel="stylesheet" href="/static/css/dracula.min.css">
href="//cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.7.0/build/styles/base16/dracula.min.css">
<script> <script>
MathJax = { MathJax = {
chtml: { chtml: {
@ -244,8 +243,5 @@
<div class="mobile-sidebar"> <div class="mobile-sidebar">
<i class="fa-solid fa-bars"></i> <i class="fa-solid fa-bars"></i>
</div> </div>
<script>
</script>
</body> </body>
</html> </html>

View File

@ -0,0 +1,7 @@
/*!
Theme: Dracula
Author: Mike Barkmin (http://github.com/mikebarkmin) based on Dracula Theme (http://github.com/dracula)
License: ~ MIT (or more permissive) [via base16-schemes-source]
Maintainer: @highlightjs/core-team
Version: 2021.09.0
*/pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{color:#e9e9f4;background:#282936}.hljs ::selection,.hljs::selection{background-color:#4d4f68;color:#e9e9f4}.hljs-comment{color:#626483}.hljs-tag{color:#62d6e8}.hljs-operator,.hljs-punctuation,.hljs-subst{color:#e9e9f4}.hljs-operator{opacity:.7}.hljs-bullet,.hljs-deletion,.hljs-name,.hljs-selector-tag,.hljs-template-variable,.hljs-variable{color:#ea51b2}.hljs-attr,.hljs-link,.hljs-literal,.hljs-number,.hljs-symbol,.hljs-variable.constant_{color:#b45bcf}.hljs-class .hljs-title,.hljs-title,.hljs-title.class_{color:#00f769}.hljs-strong{font-weight:700;color:#00f769}.hljs-addition,.hljs-code,.hljs-string,.hljs-title.class_.inherited__{color:#ebff87}.hljs-built_in,.hljs-doctag,.hljs-keyword.hljs-atrule,.hljs-quote,.hljs-regexp{color:#a1efe4}.hljs-attribute,.hljs-function .hljs-title,.hljs-section,.hljs-title.function_,.ruby .hljs-property{color:#62d6e8}.diff .hljs-meta,.hljs-keyword,.hljs-template-tag,.hljs-type{color:#b45bcf}.hljs-emphasis{color:#b45bcf;font-style:italic}.hljs-meta,.hljs-meta .hljs-keyword,.hljs-meta .hljs-string{color:#00f769}.hljs-meta .hljs-keyword,.hljs-meta-keyword{font-weight:700}

View File

@ -381,7 +381,8 @@ body {
} }
.message .count .fa-clipboard, .message .count .fa-clipboard,
.message .count .fa-volume-high { .message .count .fa-volume-high,
.message .count .fa-rotate {
z-index: 1000; z-index: 1000;
cursor: pointer; cursor: pointer;
} }

View File

@ -109,8 +109,9 @@ const register_message_buttons = async () => {
let playlist = []; let playlist = [];
function play_next() { function play_next() {
const next = playlist.shift(); const next = playlist.shift();
if (next) if (next && el.dataset.do_play) {
next.play(); next.play();
}
} }
if (el.dataset.stopped) { if (el.dataset.stopped) {
el.classList.remove("blink") el.classList.remove("blink")
@ -179,6 +180,20 @@ const register_message_buttons = async () => {
}); });
} }
}); });
document.querySelectorAll(".message .fa-rotate").forEach(async (el) => {
if (!("click" in el.dataset)) {
el.dataset.click = "true";
el.addEventListener("click", async () => {
const message_el = el.parentElement.parentElement.parentElement;
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
prompt_lock = true;
await hide_message(window.conversation_id, message_el.dataset.index);
window.token = message_id();
await ask_gpt(message_el.dataset.index);
})
}
});
} }
const delete_conversations = async () => { const delete_conversations = async () => {
@ -257,9 +272,9 @@ const remove_cancel_button = async () => {
}, 300); }, 300);
}; };
const prepare_messages = (messages, filter_last_message=true) => { const prepare_messages = (messages, message_index = -1) => {
// Removes none user messages at end // Removes none user messages at end
if (filter_last_message) { if (message_index == -1) {
let last_message; let last_message;
while (last_message = messages.pop()) { while (last_message = messages.pop()) {
if (last_message["role"] == "user") { if (last_message["role"] == "user") {
@ -267,14 +282,16 @@ const prepare_messages = (messages, filter_last_message=true) => {
break; break;
} }
} }
} else if (message_index >= 0) {
messages = messages.filter((_, index) => message_index >= index);
} }
// Remove history, if it's selected // Remove history, if it's selected
if (document.getElementById('history')?.checked) { if (document.getElementById('history')?.checked) {
if (filter_last_message) { if (message_index == null) {
messages = [messages.pop()];
} else {
messages = [messages.pop(), messages.pop()]; messages = [messages.pop(), messages.pop()];
} else {
messages = [messages.pop()];
} }
} }
@ -361,11 +378,11 @@ imageInput?.addEventListener("click", (e) => {
} }
}); });
const ask_gpt = async () => { const ask_gpt = async (message_index = -1) => {
regenerate.classList.add(`regenerate-hidden`); regenerate.classList.add(`regenerate-hidden`);
messages = await get_messages(window.conversation_id); messages = await get_messages(window.conversation_id);
total_messages = messages.length; total_messages = messages.length;
messages = prepare_messages(messages); messages = prepare_messages(messages, message_index);
stop_generating.classList.remove(`stop_generating-hidden`); stop_generating.classList.remove(`stop_generating-hidden`);
@ -528,6 +545,7 @@ const hide_option = async (conversation_id) => {
const span_el = document.createElement("span"); const span_el = document.createElement("span");
span_el.innerText = input_el.value; span_el.innerText = input_el.value;
span_el.classList.add("convo-title"); span_el.classList.add("convo-title");
span_el.onclick = () => set_conversation(conversation_id);
left_el.removeChild(input_el); left_el.removeChild(input_el);
left_el.appendChild(span_el); left_el.appendChild(span_el);
} }
@ -616,7 +634,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
} }
if (window.GPTTokenizer_cl100k_base) { if (window.GPTTokenizer_cl100k_base) {
const filtered = prepare_messages(messages, false); const filtered = prepare_messages(messages, null);
if (filtered.length > 0) { if (filtered.length > 0) {
last_model = last_model?.startsWith("gpt-4") ? "gpt-4" : "gpt-3.5-turbo" last_model = last_model?.startsWith("gpt-4") ? "gpt-4" : "gpt-3.5-turbo"
let count_total = GPTTokenizer_cl100k_base?.encodeChat(filtered, last_model).length let count_total = GPTTokenizer_cl100k_base?.encodeChat(filtered, last_model).length
@ -683,15 +701,15 @@ async function save_system_message() {
await save_conversation(window.conversation_id, conversation); await save_conversation(window.conversation_id, conversation);
} }
} }
const hide_message = async (conversation_id, message_index =- 1) => {
const hide_last_message = async (conversation_id) => {
const conversation = await get_conversation(conversation_id) const conversation = await get_conversation(conversation_id)
const last_message = conversation.items.pop(); message_index = message_index == -1 ? conversation.items.length - 1 : message_index
const last_message = message_index in conversation.items ? conversation.items[message_index] : null;
if (last_message !== null) { if (last_message !== null) {
if (last_message["role"] == "assistant") { if (last_message["role"] == "assistant") {
last_message["regenerate"] = true; last_message["regenerate"] = true;
} }
conversation.items.push(last_message); conversation.items[message_index] = last_message;
} }
await save_conversation(conversation_id, conversation); await save_conversation(conversation_id, conversation);
}; };
@ -790,7 +808,7 @@ document.getElementById("cancelButton").addEventListener("click", async () => {
document.getElementById("regenerateButton").addEventListener("click", async () => { document.getElementById("regenerateButton").addEventListener("click", async () => {
prompt_lock = true; prompt_lock = true;
await hide_last_message(window.conversation_id); await hide_message(window.conversation_id);
window.token = message_id(); window.token = message_id();
await ask_gpt(); await ask_gpt();
}); });

View File

@ -3,18 +3,16 @@ from __future__ import annotations
import asyncio import asyncio
import random import random
from ..typing import Type, List, CreateResult, Messages, Iterator from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult
from .types import BaseProvider, BaseRetryProvider from .types import BaseProvider, BaseRetryProvider
from .. import debug from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError from ..errors import RetryProviderError, RetryNoProviderError
class RetryProvider(BaseRetryProvider): class NewBaseRetryProvider(BaseRetryProvider):
def __init__( def __init__(
self, self,
providers: List[Type[BaseProvider]], providers: List[Type[BaseProvider]],
shuffle: bool = True, shuffle: bool = True
single_provider_retry: bool = False,
max_retries: int = 3,
) -> None: ) -> None:
""" """
Initialize the BaseRetryProvider. Initialize the BaseRetryProvider.
@ -26,8 +24,6 @@ class RetryProvider(BaseRetryProvider):
""" """
self.providers = providers self.providers = providers
self.shuffle = shuffle self.shuffle = shuffle
self.single_provider_retry = single_provider_retry
self.max_retries = max_retries
self.working = True self.working = True
self.last_provider: Type[BaseProvider] = None self.last_provider: Type[BaseProvider] = None
@ -56,7 +52,146 @@ class RetryProvider(BaseRetryProvider):
exceptions = {} exceptions = {}
started: bool = False started: bool = False
for provider in providers:
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
raise_exceptions(exceptions)
async def create_async(
self,
model: str,
messages: Messages,
**kwargs,
) -> str:
"""
Asynchronously create a completion using available providers.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
Returns:
str: The result of the asynchronous completion.
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
providers = self.providers
if self.shuffle:
random.shuffle(providers)
exceptions = {}
for provider in providers:
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60),
)
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)
def get_providers(self, stream: bool):
providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
if self.shuffle:
random.shuffle(providers)
return providers
async def create_async_generator(
self,
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> AsyncResult:
exceptions = {}
started: bool = False
for provider in self.get_providers(stream):
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
if not stream:
yield await provider.create_async(model, messages, **kwargs)
elif hasattr(provider, "create_async_generator"):
async for token in provider.create_async_generator(model, messages, stream, **kwargs):
yield token
else:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
raise_exceptions(exceptions)
class RetryProvider(NewBaseRetryProvider):
def __init__(
self,
providers: List[Type[BaseProvider]],
shuffle: bool = True,
single_provider_retry: bool = False,
max_retries: int = 3,
) -> None:
"""
Initialize the BaseRetryProvider.
Args:
providers (List[Type[BaseProvider]]): List of providers to use.
shuffle (bool): Whether to shuffle the providers list.
single_provider_retry (bool): Whether to retry a single provider if it fails.
max_retries (int): Maximum number of retries for a single provider.
"""
super().__init__(providers, shuffle)
self.single_provider_retry = single_provider_retry
self.max_retries = max_retries
def create_completion(
self,
model: str,
messages: Messages,
stream: bool = False,
**kwargs,
) -> CreateResult:
"""
Create a completion using available providers, with an option to stream the response.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
Yields:
CreateResult: Tokens or results from the completion.
Raises:
Exception: Any exception encountered during the completion process.
"""
providers = self.get_providers(stream)
if self.single_provider_retry and len(providers) == 1: if self.single_provider_retry and len(providers) == 1:
exceptions = {}
started: bool = False
provider = providers[0] provider = providers[0]
self.last_provider = provider self.last_provider = provider
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
@ -74,25 +209,9 @@ class RetryProvider(BaseRetryProvider):
print(f"{provider.__name__}: {e.__class__.__name__}: {e}") print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started: if started:
raise e raise e
raise_exceptions(exceptions)
else: else:
for provider in providers: yield from super().create_completion(model, messages, stream, **kwargs)
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
raise_exceptions(exceptions)
async def create_async( async def create_async(
self, self,
@ -131,22 +250,9 @@ class RetryProvider(BaseRetryProvider):
exceptions[provider.__name__] = e exceptions[provider.__name__] = e
if debug.logging: if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}") print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)
else: else:
for provider in providers: return await super().create_async(model, messages, **kwargs)
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60),
)
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)
class IterProvider(BaseRetryProvider): class IterProvider(BaseRetryProvider):
__name__ = "IterProvider" __name__ = "IterProvider"