Fix: debug.logging not work in retry provider

This commit is contained in:
hs_junxiang 2023-10-19 10:15:38 +08:00
parent cb3677a3e2
commit 042ee7633b
2 changed files with 6 additions and 6 deletions

View File

@ -5,13 +5,13 @@ import random
from typing import List, Type, Dict from typing import List, Type, Dict
from ..typing import CreateResult, Messages from ..typing import CreateResult, Messages
from .base_provider import BaseProvider, AsyncProvider from .base_provider import BaseProvider, AsyncProvider
from ..debug import logging
class RetryProvider(AsyncProvider): class RetryProvider(AsyncProvider):
__name__: str = "RetryProvider" __name__: str = "RetryProvider"
working: bool = True working: bool = True
supports_stream: bool = True supports_stream: bool = True
logging: bool = False
def __init__( def __init__(
self, self,
@ -21,7 +21,6 @@ class RetryProvider(AsyncProvider):
self.providers: List[Type[BaseProvider]] = providers self.providers: List[Type[BaseProvider]] = providers
self.shuffle: bool = shuffle self.shuffle: bool = shuffle
def create_completion( def create_completion(
self, self,
model: str, model: str,
@ -40,7 +39,7 @@ class RetryProvider(AsyncProvider):
started: bool = False started: bool = False
for provider in providers: for provider in providers:
try: try:
if logging: if self.logging:
print(f"Using {provider.__name__} provider") print(f"Using {provider.__name__} provider")
for token in provider.create_completion(model, messages, stream, **kwargs): for token in provider.create_completion(model, messages, stream, **kwargs):
yield token yield token
@ -49,7 +48,7 @@ class RetryProvider(AsyncProvider):
return return
except Exception as e: except Exception as e:
self.exceptions[provider.__name__] = e self.exceptions[provider.__name__] = e
if logging: if self.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}") print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started: if started:
raise e raise e
@ -72,11 +71,11 @@ class RetryProvider(AsyncProvider):
return await asyncio.wait_for(provider.create_async(model, messages, **kwargs), timeout=60) return await asyncio.wait_for(provider.create_async(model, messages, **kwargs), timeout=60)
except asyncio.TimeoutError as e: except asyncio.TimeoutError as e:
self.exceptions[provider.__name__] = e self.exceptions[provider.__name__] = e
if logging: if self.logging:
print(f"{provider.__name__}: TimeoutError: {e}") print(f"{provider.__name__}: TimeoutError: {e}")
except Exception as e: except Exception as e:
self.exceptions[provider.__name__] = e self.exceptions[provider.__name__] = e
if logging: if self.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}") print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
self.raise_exceptions() self.raise_exceptions()

View File

@ -48,6 +48,7 @@ def get_model_and_provider(model : Union[Model, str],
raise ValueError(f'{provider.__name__} does not support "stream" argument') raise ValueError(f'{provider.__name__} does not support "stream" argument')
if logging: if logging:
RetryProvider.logging = True
print(f'Using {provider.__name__} provider') print(f'Using {provider.__name__} provider')
return model, provider return model, provider