mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-22 15:05:57 +03:00
Merge pull request #1950 from hlohaus/leech
Update chatgpt url, uvloop support
This commit is contained in:
commit
008ed60d98
@ -17,12 +17,12 @@ except ImportError:
|
||||
pass
|
||||
|
||||
from ... import debug
|
||||
from ...typing import Messages, Cookies, ImageType, AsyncResult
|
||||
from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator
|
||||
from ..base_provider import AsyncGeneratorProvider
|
||||
from ..helper import format_prompt, get_cookies
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...errors import MissingAuthError, MissingRequirementsError
|
||||
from ...image import to_bytes, ImageResponse
|
||||
from ...image import to_bytes, to_data_uri, ImageResponse
|
||||
from ...webdriver import get_browser, get_driver_cookies
|
||||
|
||||
REQUEST_HEADERS = {
|
||||
@ -59,7 +59,7 @@ class Gemini(AsyncGeneratorProvider):
|
||||
_cookies: Cookies = None
|
||||
|
||||
@classmethod
|
||||
async def nodriver_login(cls) -> Cookies:
|
||||
async def nodriver_login(cls) -> AsyncIterator[str]:
|
||||
try:
|
||||
import nodriver as uc
|
||||
except ImportError:
|
||||
@ -72,6 +72,9 @@ class Gemini(AsyncGeneratorProvider):
|
||||
if debug.logging:
|
||||
print(f"Open nodriver with user_dir: {user_data_dir}")
|
||||
browser = await uc.start(user_data_dir=user_data_dir)
|
||||
login_url = os.environ.get("G4F_LOGIN_URL")
|
||||
if login_url:
|
||||
yield f"Please login: [Google Gemini]({login_url})\n\n"
|
||||
page = await browser.get(f"{cls.url}/app")
|
||||
await page.select("div.ql-editor.textarea", 240)
|
||||
cookies = {}
|
||||
@ -79,10 +82,10 @@ class Gemini(AsyncGeneratorProvider):
|
||||
if c.domain.endswith(".google.com"):
|
||||
cookies[c.name] = c.value
|
||||
await page.close()
|
||||
return cookies
|
||||
cls._cookies = cookies
|
||||
|
||||
@classmethod
|
||||
async def webdriver_login(cls, proxy: str):
|
||||
async def webdriver_login(cls, proxy: str) -> AsyncIterator[str]:
|
||||
driver = None
|
||||
try:
|
||||
driver = get_browser(proxy=proxy)
|
||||
@ -131,13 +134,14 @@ class Gemini(AsyncGeneratorProvider):
|
||||
) as session:
|
||||
snlm0e = await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None
|
||||
if not snlm0e:
|
||||
cls._cookies = await cls.nodriver_login();
|
||||
async for chunk in cls.nodriver_login():
|
||||
yield chunk
|
||||
if cls._cookies is None:
|
||||
async for chunk in cls.webdriver_login(proxy):
|
||||
yield chunk
|
||||
|
||||
if not snlm0e:
|
||||
if "__Secure-1PSID" not in cls._cookies:
|
||||
if cls._cookies is None or "__Secure-1PSID" not in cls._cookies:
|
||||
raise MissingAuthError('Missing "__Secure-1PSID" cookie')
|
||||
snlm0e = await cls.fetch_snlm0e(session, cls._cookies)
|
||||
if not snlm0e:
|
||||
@ -193,6 +197,13 @@ class Gemini(AsyncGeneratorProvider):
|
||||
image = fetch.headers["location"]
|
||||
resolved_images.append(image)
|
||||
preview.append(image.replace('=s512', '=s200'))
|
||||
# preview_url = image.replace('=s512', '=s200')
|
||||
# async with client.get(preview_url) as fetch:
|
||||
# preview_data = to_data_uri(await fetch.content.read())
|
||||
# async with client.get(image) as fetch:
|
||||
# data = to_data_uri(await fetch.content.read())
|
||||
# preview.append(preview_data)
|
||||
# resolved_images.append(data)
|
||||
yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
|
||||
|
||||
def build_request(
|
||||
|
@ -38,7 +38,7 @@ DEFAULT_HEADERS = {
|
||||
"accept": "*/*",
|
||||
"accept-encoding": "gzip, deflate, br, zstd",
|
||||
"accept-language": "en-US,en;q=0.5",
|
||||
"referer": "https://chat.openai.com/",
|
||||
"referer": "https://chatgpt.com/",
|
||||
"sec-ch-ua": "\"Brave\";v=\"123\", \"Not:A-Brand\";v=\"8\", \"Chromium\";v=\"123\"",
|
||||
"sec-ch-ua-mobile": "?0",
|
||||
"sec-ch-ua-platform": "\"Windows\"",
|
||||
@ -53,15 +53,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"""A class for creating and managing conversations with OpenAI chat service"""
|
||||
|
||||
label = "OpenAI ChatGPT"
|
||||
url = "https://chat.openai.com"
|
||||
url = "https://chatgpt.com"
|
||||
working = True
|
||||
supports_gpt_35_turbo = True
|
||||
supports_gpt_4 = True
|
||||
supports_message_history = True
|
||||
supports_system_message = True
|
||||
default_model = None
|
||||
default_vision_model = "gpt-4-vision"
|
||||
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
|
||||
default_vision_model = "gpt-4o"
|
||||
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo", "gpt-4o"]
|
||||
model_aliases = {
|
||||
"text-davinci-002-render-sha": "gpt-3.5-turbo",
|
||||
"": "gpt-3.5-turbo",
|
||||
@ -442,6 +442,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
try:
|
||||
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
|
||||
except Exception as e:
|
||||
image_request = None
|
||||
if debug.logging:
|
||||
print("OpenaiChat: Upload image failed")
|
||||
print(f"{e.__class__.__name__}: {e}")
|
||||
@ -601,7 +602,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
this._fetch = this.fetch;
|
||||
this.fetch = async (url, options) => {
|
||||
const response = await this._fetch(url, options);
|
||||
if (url == "https://chat.openai.com/backend-api/conversation") {
|
||||
if (url == "https://chatgpt.com/backend-api/conversation") {
|
||||
this._headers = options.headers;
|
||||
return response;
|
||||
}
|
||||
@ -637,7 +638,7 @@ this.fetch = async (url, options) => {
|
||||
if debug.logging:
|
||||
print(f"Open nodriver with user_dir: {user_data_dir}")
|
||||
browser = await uc.start(user_data_dir=user_data_dir)
|
||||
page = await browser.get("https://chat.openai.com/")
|
||||
page = await browser.get("https://chatgpt.com/")
|
||||
await page.select("[id^=headlessui-menu-button-]", 240)
|
||||
api_key = await page.evaluate(
|
||||
"(async () => {"
|
||||
@ -652,7 +653,7 @@ this.fetch = async (url, options) => {
|
||||
)
|
||||
cookies = {}
|
||||
for c in await page.browser.cookies.get_all():
|
||||
if c.domain.endswith("chat.openai.com"):
|
||||
if c.domain.endswith("chatgpt.com"):
|
||||
cookies[c.name] = c.value
|
||||
user_agent = await page.evaluate("window.navigator.userAgent")
|
||||
await page.close()
|
||||
|
@ -26,7 +26,7 @@ class arkReq:
|
||||
self.userAgent = userAgent
|
||||
|
||||
arkPreURL = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147"
|
||||
sessionUrl = "https://chat.openai.com/api/auth/session"
|
||||
sessionUrl = "https://chatgpt.com/api/auth/session"
|
||||
chatArk: arkReq = None
|
||||
accessToken: str = None
|
||||
cookies: dict = None
|
||||
|
@ -16,12 +16,9 @@ def generate_proof_token(required: bool, seed: str, difficulty: str, user_agent:
|
||||
|
||||
# Get current UTC time
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
# Convert UTC time to Eastern Time
|
||||
now_et = now_utc.astimezone(timezone(timedelta(hours=-5)))
|
||||
parse_time = now_utc.strftime('%a, %d %b %Y %H:%M:%S GMT')
|
||||
|
||||
parse_time = now_et.strftime('%a, %d %b %Y %H:%M:%S GMT')
|
||||
|
||||
config = [core + screen, parse_time, 4294705152, 0, user_agent]
|
||||
config = [core + screen, parse_time, None, 0, user_agent, "https://tcr9i.chat.openai.com/v2/35536E1E-65B4-4D96-9D97-6ADB7EFF8147/api.js","dpl=53d243de46ff04dadd88d293f088c2dd728f126f","en","en-US",442,"plugins−[object PluginArray]","","alert"]
|
||||
|
||||
diff_len = len(difficulty) // 2
|
||||
|
||||
|
@ -11,10 +11,9 @@ from .types import AsyncIterResponse, ImageProvider
|
||||
from .image_models import ImageModels
|
||||
from .helper import filter_json, find_stop, filter_none, cast_iter_async
|
||||
from .service import get_last_provider, get_model_and_provider
|
||||
from ..typing import Union, Iterator, Messages, AsyncIterator, ImageType
|
||||
from ..typing import Union, Messages, AsyncIterator, ImageType
|
||||
from ..errors import NoImageResponseError
|
||||
from ..image import ImageResponse as ImageProviderResponse
|
||||
from ..providers.base_provider import AsyncGeneratorProvider
|
||||
|
||||
try:
|
||||
anext
|
||||
@ -88,7 +87,7 @@ def create_response(
|
||||
api_key: str = None,
|
||||
**kwargs
|
||||
):
|
||||
has_asnyc = isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider)
|
||||
has_asnyc = hasattr(provider, "create_async_generator")
|
||||
if has_asnyc:
|
||||
create = provider.create_async_generator
|
||||
else:
|
||||
@ -157,7 +156,7 @@ class Chat():
|
||||
def __init__(self, client: AsyncClient, provider: ProviderType = None):
|
||||
self.completions = Completions(client, provider)
|
||||
|
||||
async def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]:
|
||||
async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, ImageProviderResponse):
|
||||
return ImagesResponse([Image(image) for image in chunk.get_list()])
|
||||
@ -182,7 +181,7 @@ class Images():
|
||||
|
||||
async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse:
|
||||
provider = self.models.get(model, self.provider)
|
||||
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
|
||||
if hasattr(provider, "create_async_generator"):
|
||||
response = create_image(self.client, provider, prompt, **kwargs)
|
||||
else:
|
||||
response = await provider.create_async(prompt)
|
||||
@ -195,7 +194,7 @@ class Images():
|
||||
async def create_variation(self, image: ImageType, model: str = None, **kwargs):
|
||||
provider = self.models.get(model, self.provider)
|
||||
result = None
|
||||
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
|
||||
if hasattr(provider, "create_async_generator"):
|
||||
response = provider.create_async_generator(
|
||||
"",
|
||||
[{"role": "user", "content": "create a image like this"}],
|
||||
|
@ -19,8 +19,7 @@
|
||||
<script src="/static/js/highlightjs-copy.min.js"></script>
|
||||
<script src="/static/js/chat.v1.js" defer></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js"></script>
|
||||
<link rel="stylesheet"
|
||||
href="//cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.7.0/build/styles/base16/dracula.min.css">
|
||||
<link rel="stylesheet" href="/static/css/dracula.min.css">
|
||||
<script>
|
||||
MathJax = {
|
||||
chtml: {
|
||||
@ -244,8 +243,5 @@
|
||||
<div class="mobile-sidebar">
|
||||
<i class="fa-solid fa-bars"></i>
|
||||
</div>
|
||||
<script>
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
|
7
g4f/gui/client/static/css/dracula.min.css
vendored
Normal file
7
g4f/gui/client/static/css/dracula.min.css
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
/*!
|
||||
Theme: Dracula
|
||||
Author: Mike Barkmin (http://github.com/mikebarkmin) based on Dracula Theme (http://github.com/dracula)
|
||||
License: ~ MIT (or more permissive) [via base16-schemes-source]
|
||||
Maintainer: @highlightjs/core-team
|
||||
Version: 2021.09.0
|
||||
*/pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{color:#e9e9f4;background:#282936}.hljs ::selection,.hljs::selection{background-color:#4d4f68;color:#e9e9f4}.hljs-comment{color:#626483}.hljs-tag{color:#62d6e8}.hljs-operator,.hljs-punctuation,.hljs-subst{color:#e9e9f4}.hljs-operator{opacity:.7}.hljs-bullet,.hljs-deletion,.hljs-name,.hljs-selector-tag,.hljs-template-variable,.hljs-variable{color:#ea51b2}.hljs-attr,.hljs-link,.hljs-literal,.hljs-number,.hljs-symbol,.hljs-variable.constant_{color:#b45bcf}.hljs-class .hljs-title,.hljs-title,.hljs-title.class_{color:#00f769}.hljs-strong{font-weight:700;color:#00f769}.hljs-addition,.hljs-code,.hljs-string,.hljs-title.class_.inherited__{color:#ebff87}.hljs-built_in,.hljs-doctag,.hljs-keyword.hljs-atrule,.hljs-quote,.hljs-regexp{color:#a1efe4}.hljs-attribute,.hljs-function .hljs-title,.hljs-section,.hljs-title.function_,.ruby .hljs-property{color:#62d6e8}.diff .hljs-meta,.hljs-keyword,.hljs-template-tag,.hljs-type{color:#b45bcf}.hljs-emphasis{color:#b45bcf;font-style:italic}.hljs-meta,.hljs-meta .hljs-keyword,.hljs-meta .hljs-string{color:#00f769}.hljs-meta .hljs-keyword,.hljs-meta-keyword{font-weight:700}
|
@ -381,7 +381,8 @@ body {
|
||||
}
|
||||
|
||||
.message .count .fa-clipboard,
|
||||
.message .count .fa-volume-high {
|
||||
.message .count .fa-volume-high,
|
||||
.message .count .fa-rotate {
|
||||
z-index: 1000;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
@ -109,8 +109,9 @@ const register_message_buttons = async () => {
|
||||
let playlist = [];
|
||||
function play_next() {
|
||||
const next = playlist.shift();
|
||||
if (next)
|
||||
if (next && el.dataset.do_play) {
|
||||
next.play();
|
||||
}
|
||||
}
|
||||
if (el.dataset.stopped) {
|
||||
el.classList.remove("blink")
|
||||
@ -179,6 +180,20 @@ const register_message_buttons = async () => {
|
||||
});
|
||||
}
|
||||
});
|
||||
document.querySelectorAll(".message .fa-rotate").forEach(async (el) => {
|
||||
if (!("click" in el.dataset)) {
|
||||
el.dataset.click = "true";
|
||||
el.addEventListener("click", async () => {
|
||||
const message_el = el.parentElement.parentElement.parentElement;
|
||||
el.classList.add("clicked");
|
||||
setTimeout(() => el.classList.remove("clicked"), 1000);
|
||||
prompt_lock = true;
|
||||
await hide_message(window.conversation_id, message_el.dataset.index);
|
||||
window.token = message_id();
|
||||
await ask_gpt(message_el.dataset.index);
|
||||
})
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const delete_conversations = async () => {
|
||||
@ -257,9 +272,9 @@ const remove_cancel_button = async () => {
|
||||
}, 300);
|
||||
};
|
||||
|
||||
const prepare_messages = (messages, filter_last_message=true) => {
|
||||
const prepare_messages = (messages, message_index = -1) => {
|
||||
// Removes none user messages at end
|
||||
if (filter_last_message) {
|
||||
if (message_index == -1) {
|
||||
let last_message;
|
||||
while (last_message = messages.pop()) {
|
||||
if (last_message["role"] == "user") {
|
||||
@ -267,14 +282,16 @@ const prepare_messages = (messages, filter_last_message=true) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (message_index >= 0) {
|
||||
messages = messages.filter((_, index) => message_index >= index);
|
||||
}
|
||||
|
||||
// Remove history, if it's selected
|
||||
if (document.getElementById('history')?.checked) {
|
||||
if (filter_last_message) {
|
||||
messages = [messages.pop()];
|
||||
} else {
|
||||
if (message_index == null) {
|
||||
messages = [messages.pop(), messages.pop()];
|
||||
} else {
|
||||
messages = [messages.pop()];
|
||||
}
|
||||
}
|
||||
|
||||
@ -361,11 +378,11 @@ imageInput?.addEventListener("click", (e) => {
|
||||
}
|
||||
});
|
||||
|
||||
const ask_gpt = async () => {
|
||||
const ask_gpt = async (message_index = -1) => {
|
||||
regenerate.classList.add(`regenerate-hidden`);
|
||||
messages = await get_messages(window.conversation_id);
|
||||
total_messages = messages.length;
|
||||
messages = prepare_messages(messages);
|
||||
messages = prepare_messages(messages, message_index);
|
||||
|
||||
stop_generating.classList.remove(`stop_generating-hidden`);
|
||||
|
||||
@ -528,6 +545,7 @@ const hide_option = async (conversation_id) => {
|
||||
const span_el = document.createElement("span");
|
||||
span_el.innerText = input_el.value;
|
||||
span_el.classList.add("convo-title");
|
||||
span_el.onclick = () => set_conversation(conversation_id);
|
||||
left_el.removeChild(input_el);
|
||||
left_el.appendChild(span_el);
|
||||
}
|
||||
@ -616,7 +634,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
|
||||
}
|
||||
|
||||
if (window.GPTTokenizer_cl100k_base) {
|
||||
const filtered = prepare_messages(messages, false);
|
||||
const filtered = prepare_messages(messages, null);
|
||||
if (filtered.length > 0) {
|
||||
last_model = last_model?.startsWith("gpt-4") ? "gpt-4" : "gpt-3.5-turbo"
|
||||
let count_total = GPTTokenizer_cl100k_base?.encodeChat(filtered, last_model).length
|
||||
@ -683,15 +701,15 @@ async function save_system_message() {
|
||||
await save_conversation(window.conversation_id, conversation);
|
||||
}
|
||||
}
|
||||
|
||||
const hide_last_message = async (conversation_id) => {
|
||||
const hide_message = async (conversation_id, message_index =- 1) => {
|
||||
const conversation = await get_conversation(conversation_id)
|
||||
const last_message = conversation.items.pop();
|
||||
message_index = message_index == -1 ? conversation.items.length - 1 : message_index
|
||||
const last_message = message_index in conversation.items ? conversation.items[message_index] : null;
|
||||
if (last_message !== null) {
|
||||
if (last_message["role"] == "assistant") {
|
||||
last_message["regenerate"] = true;
|
||||
}
|
||||
conversation.items.push(last_message);
|
||||
conversation.items[message_index] = last_message;
|
||||
}
|
||||
await save_conversation(conversation_id, conversation);
|
||||
};
|
||||
@ -790,7 +808,7 @@ document.getElementById("cancelButton").addEventListener("click", async () => {
|
||||
|
||||
document.getElementById("regenerateButton").addEventListener("click", async () => {
|
||||
prompt_lock = true;
|
||||
await hide_last_message(window.conversation_id);
|
||||
await hide_message(window.conversation_id);
|
||||
window.token = message_id();
|
||||
await ask_gpt();
|
||||
});
|
||||
|
@ -3,18 +3,16 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from ..typing import Type, List, CreateResult, Messages, Iterator
|
||||
from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult
|
||||
from .types import BaseProvider, BaseRetryProvider
|
||||
from .. import debug
|
||||
from ..errors import RetryProviderError, RetryNoProviderError
|
||||
|
||||
class RetryProvider(BaseRetryProvider):
|
||||
class NewBaseRetryProvider(BaseRetryProvider):
|
||||
def __init__(
|
||||
self,
|
||||
providers: List[Type[BaseProvider]],
|
||||
shuffle: bool = True,
|
||||
single_provider_retry: bool = False,
|
||||
max_retries: int = 3,
|
||||
shuffle: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the BaseRetryProvider.
|
||||
@ -26,8 +24,6 @@ class RetryProvider(BaseRetryProvider):
|
||||
"""
|
||||
self.providers = providers
|
||||
self.shuffle = shuffle
|
||||
self.single_provider_retry = single_provider_retry
|
||||
self.max_retries = max_retries
|
||||
self.working = True
|
||||
self.last_provider: Type[BaseProvider] = None
|
||||
|
||||
@ -56,7 +52,146 @@ class RetryProvider(BaseRetryProvider):
|
||||
exceptions = {}
|
||||
started: bool = False
|
||||
|
||||
for provider in providers:
|
||||
self.last_provider = provider
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
async def create_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Asynchronously create a completion using available providers.
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
Returns:
|
||||
str: The result of the asynchronous completion.
|
||||
Raises:
|
||||
Exception: Any exception encountered during the asynchronous completion process.
|
||||
"""
|
||||
providers = self.providers
|
||||
if self.shuffle:
|
||||
random.shuffle(providers)
|
||||
|
||||
exceptions = {}
|
||||
|
||||
for provider in providers:
|
||||
self.last_provider = provider
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
return await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", 60),
|
||||
)
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
def get_providers(self, stream: bool):
|
||||
providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
|
||||
if self.shuffle:
|
||||
random.shuffle(providers)
|
||||
return providers
|
||||
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
exceptions = {}
|
||||
started: bool = False
|
||||
|
||||
for provider in self.get_providers(stream):
|
||||
self.last_provider = provider
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
if not stream:
|
||||
yield await provider.create_async(model, messages, **kwargs)
|
||||
elif hasattr(provider, "create_async_generator"):
|
||||
async for token in provider.create_async_generator(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
else:
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
class RetryProvider(NewBaseRetryProvider):
|
||||
def __init__(
|
||||
self,
|
||||
providers: List[Type[BaseProvider]],
|
||||
shuffle: bool = True,
|
||||
single_provider_retry: bool = False,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the BaseRetryProvider.
|
||||
Args:
|
||||
providers (List[Type[BaseProvider]]): List of providers to use.
|
||||
shuffle (bool): Whether to shuffle the providers list.
|
||||
single_provider_retry (bool): Whether to retry a single provider if it fails.
|
||||
max_retries (int): Maximum number of retries for a single provider.
|
||||
"""
|
||||
super().__init__(providers, shuffle)
|
||||
self.single_provider_retry = single_provider_retry
|
||||
self.max_retries = max_retries
|
||||
|
||||
def create_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> CreateResult:
|
||||
"""
|
||||
Create a completion using available providers, with an option to stream the response.
|
||||
Args:
|
||||
model (str): The model to be used for completion.
|
||||
messages (Messages): The messages to be used for generating completion.
|
||||
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
|
||||
Yields:
|
||||
CreateResult: Tokens or results from the completion.
|
||||
Raises:
|
||||
Exception: Any exception encountered during the completion process.
|
||||
"""
|
||||
providers = self.get_providers(stream)
|
||||
if self.single_provider_retry and len(providers) == 1:
|
||||
exceptions = {}
|
||||
started: bool = False
|
||||
provider = providers[0]
|
||||
self.last_provider = provider
|
||||
for attempt in range(self.max_retries):
|
||||
@ -74,25 +209,9 @@ class RetryProvider(BaseRetryProvider):
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
raise_exceptions(exceptions)
|
||||
else:
|
||||
for provider in providers:
|
||||
self.last_provider = provider
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
yield from super().create_completion(model, messages, stream, **kwargs)
|
||||
|
||||
async def create_async(
|
||||
self,
|
||||
@ -131,22 +250,9 @@ class RetryProvider(BaseRetryProvider):
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
raise_exceptions(exceptions)
|
||||
else:
|
||||
for provider in providers:
|
||||
self.last_provider = provider
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
return await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", 60),
|
||||
)
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
return await super().create_async(model, messages, **kwargs)
|
||||
|
||||
class IterProvider(BaseRetryProvider):
|
||||
__name__ = "IterProvider"
|
||||
|
Loading…
Reference in New Issue
Block a user