update retryprovider

now works with one provider.
This commit is contained in:
abc 2024-04-12 22:29:43 +01:00
parent a053c29020
commit f57af704ba
3 changed files with 77 additions and 42 deletions

2
.gitignore vendored
View File

@ -58,3 +58,5 @@ hardir
node_modules node_modules
models models
projects/windows/g4f projects/windows/g4f
doc.txt
dist.py

View File

@ -334,6 +334,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises: Raises:
RuntimeError: If an error occurs during processing. RuntimeError: If an error occurs during processing.
""" """
async with StreamSession( async with StreamSession(
proxies={"all": proxy}, proxies={"all": proxy},
impersonate="chrome", impersonate="chrome",
@ -359,6 +360,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if debug.logging: if debug.logging:
print("OpenaiChat: Load default_model failed") print("OpenaiChat: Load default_model failed")
print(f"{e.__class__.__name__}: {e}") print(f"{e.__class__.__name__}: {e}")
arkose_token = None arkose_token = None
if cls.default_model is None: if cls.default_model is None:
@ -582,6 +584,7 @@ this.fetch = async (url, options) => {
user_data_dir = user_config_dir("g4f-nodriver") user_data_dir = user_config_dir("g4f-nodriver")
except: except:
user_data_dir = None user_data_dir = None
browser = await uc.start(user_data_dir=user_data_dir) browser = await uc.start(user_data_dir=user_data_dir)
page = await browser.get("https://chat.openai.com/") page = await browser.get("https://chat.openai.com/")
while await page.query_selector("#prompt-textarea") is None: while await page.query_selector("#prompt-textarea") is None:

View File

@ -12,46 +12,40 @@ class RetryProvider(BaseRetryProvider):
def __init__( def __init__(
self, self,
providers: List[Type[BaseProvider]], providers: List[Type[BaseProvider]],
shuffle: bool = True shuffle: bool = True,
single_provider_retry: bool = False,
max_retries: int = 3,
) -> None: ) -> None:
""" """
Initialize the BaseRetryProvider. Initialize the BaseRetryProvider.
Args: Args:
providers (List[Type[BaseProvider]]): List of providers to use. providers (List[Type[BaseProvider]]): List of providers to use.
shuffle (bool): Whether to shuffle the providers list. shuffle (bool): Whether to shuffle the providers list.
single_provider_retry (bool): Whether to retry a single provider if it fails.
max_retries (int): Maximum number of retries for a single provider.
""" """
self.providers = providers self.providers = providers
self.shuffle = shuffle self.shuffle = shuffle
self.single_provider_retry = single_provider_retry
self.max_retries = max_retries
self.working = True self.working = True
self.last_provider: Type[BaseProvider] = None self.last_provider: Type[BaseProvider] = None
"""
A provider class to handle retries for creating completions with different providers.
Attributes:
providers (list): A list of provider instances.
shuffle (bool): A flag indicating whether to shuffle providers before use.
last_provider (BaseProvider): The last provider that was used.
"""
def create_completion( def create_completion(
self, self,
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool = False, stream: bool = False,
**kwargs **kwargs,
) -> CreateResult: ) -> CreateResult:
""" """
Create a completion using available providers, with an option to stream the response. Create a completion using available providers, with an option to stream the response.
Args: Args:
model (str): The model to be used for completion. model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion. messages (Messages): The messages to be used for generating completion.
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
Yields: Yields:
CreateResult: Tokens or results from the completion. CreateResult: Tokens or results from the completion.
Raises: Raises:
Exception: Any exception encountered during the completion process. Exception: Any exception encountered during the completion process.
""" """
@ -61,22 +55,42 @@ class RetryProvider(BaseRetryProvider):
exceptions = {} exceptions = {}
started: bool = False started: bool = False
for provider in providers:
if self.single_provider_retry and len(providers) == 1:
provider = providers[0]
self.last_provider = provider self.last_provider = provider
try: for attempt in range(self.max_retries):
if debug.logging: try:
print(f"Using {provider.__name__} provider") if debug.logging:
for token in provider.create_completion(model, messages, stream, **kwargs): print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
yield token for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True started = True
if started: if started:
return return
except Exception as e: except Exception as e:
exceptions[provider.__name__] = e exceptions[provider.__name__] = e
if debug.logging: if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}") print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started: if started:
raise e raise e
else:
for provider in providers:
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
raise_exceptions(exceptions) raise_exceptions(exceptions)
@ -84,18 +98,15 @@ class RetryProvider(BaseRetryProvider):
self, self,
model: str, model: str,
messages: Messages, messages: Messages,
**kwargs **kwargs,
) -> str: ) -> str:
""" """
Asynchronously create a completion using available providers. Asynchronously create a completion using available providers.
Args: Args:
model (str): The model to be used for completion. model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion. messages (Messages): The messages to be used for generating completion.
Returns: Returns:
str: The result of the asynchronous completion. str: The result of the asynchronous completion.
Raises: Raises:
Exception: Any exception encountered during the asynchronous completion process. Exception: Any exception encountered during the asynchronous completion process.
""" """
@ -104,17 +115,36 @@ class RetryProvider(BaseRetryProvider):
random.shuffle(providers) random.shuffle(providers)
exceptions = {} exceptions = {}
for provider in providers:
if self.single_provider_retry and len(providers) == 1:
provider = providers[0]
self.last_provider = provider self.last_provider = provider
try: for attempt in range(self.max_retries):
return await asyncio.wait_for( try:
provider.create_async(model, messages, **kwargs), if debug.logging:
timeout=kwargs.get("timeout", 60) print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
) return await asyncio.wait_for(
except Exception as e: provider.create_async(model, messages, **kwargs),
exceptions[provider.__name__] = e timeout=kwargs.get("timeout", 60),
if debug.logging: )
print(f"{provider.__name__}: {e.__class__.__name__}: {e}") except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
else:
for provider in providers:
self.last_provider = provider
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60),
)
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions) raise_exceptions(exceptions)