update retryprovider

now works with one provider.
This commit is contained in:
abc 2024-04-12 22:29:43 +01:00
parent a053c29020
commit f57af704ba
3 changed files with 77 additions and 42 deletions

2
.gitignore vendored
View File

@ -58,3 +58,5 @@ hardir
node_modules
models
projects/windows/g4f
doc.txt
dist.py

View File

@ -334,6 +334,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises:
RuntimeError: If an error occurs during processing.
"""
async with StreamSession(
proxies={"all": proxy},
impersonate="chrome",
@ -360,6 +361,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
print("OpenaiChat: Load default_model failed")
print(f"{e.__class__.__name__}: {e}")
arkose_token = None
if cls.default_model is None:
try:
@ -582,6 +584,7 @@ this.fetch = async (url, options) => {
user_data_dir = user_config_dir("g4f-nodriver")
except:
user_data_dir = None
browser = await uc.start(user_data_dir=user_data_dir)
page = await browser.get("https://chat.openai.com/")
while await page.query_selector("#prompt-textarea") is None:

View File

@ -12,46 +12,40 @@ class RetryProvider(BaseRetryProvider):
def __init__(
self,
providers: List[Type[BaseProvider]],
shuffle: bool = True
shuffle: bool = True,
single_provider_retry: bool = False,
max_retries: int = 3,
) -> None:
"""
Initialize the BaseRetryProvider.
Args:
providers (List[Type[BaseProvider]]): List of providers to use.
shuffle (bool): Whether to shuffle the providers list.
single_provider_retry (bool): Whether to retry a single provider if it fails.
max_retries (int): Maximum number of retries for a single provider.
"""
self.providers = providers
self.shuffle = shuffle
self.single_provider_retry = single_provider_retry
self.max_retries = max_retries
self.working = True
self.last_provider: Type[BaseProvider] = None
"""
A provider class to handle retries for creating completions with different providers.
Attributes:
providers (list): A list of provider instances.
shuffle (bool): A flag indicating whether to shuffle providers before use.
last_provider (BaseProvider): The last provider that was used.
"""
def create_completion(
self,
model: str,
messages: Messages,
stream: bool = False,
**kwargs
**kwargs,
) -> CreateResult:
"""
Create a completion using available providers, with an option to stream the response.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
Yields:
CreateResult: Tokens or results from the completion.
Raises:
Exception: Any exception encountered during the completion process.
"""
@ -61,6 +55,26 @@ class RetryProvider(BaseRetryProvider):
exceptions = {}
started: bool = False
if self.single_provider_retry and len(providers) == 1:
provider = providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
try:
if debug.logging:
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
else:
for provider in providers:
self.last_provider = provider
try:
@ -84,18 +98,15 @@ class RetryProvider(BaseRetryProvider):
self,
model: str,
messages: Messages,
**kwargs
**kwargs,
) -> str:
"""
Asynchronously create a completion using available providers.
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
Returns:
str: The result of the asynchronous completion.
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
@ -104,12 +115,31 @@ class RetryProvider(BaseRetryProvider):
random.shuffle(providers)
exceptions = {}
if self.single_provider_retry and len(providers) == 1:
provider = providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
try:
if debug.logging:
print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60),
)
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
else:
for provider in providers:
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60)
timeout=kwargs.get("timeout", 60),
)
except Exception as e:
exceptions[provider.__name__] = e