Fix api with default providers, add unittests for RetryProvider

This commit is contained in:
Heiner Lohaus 2024-11-28 17:46:46 +01:00
parent 971a01eb5c
commit c31f5435c4
7 changed files with 161 additions and 91 deletions

View File

@ -6,5 +6,6 @@ from .main import *
from .model import *
from .client import *
from .include import *
from .retry_provider import *
unittest.main()

View File

@ -40,3 +40,31 @@ class YieldProviderMock(AsyncGeneratorProvider):
):
for message in messages:
yield message["content"]
class RaiseExceptionProviderMock(AbstractProvider):
working = True
@classmethod
def create_completion(
cls, model, messages, stream, **kwargs
):
raise RuntimeError(cls.__name__)
yield cls.__name__
class AsyncRaiseExceptionProviderMock(AsyncGeneratorProvider):
working = True
@classmethod
async def create_async_generator(
cls, model, messages, stream, **kwargs
):
raise RuntimeError(cls.__name__)
yield cls.__name__
class YieldNoneProviderMock(AsyncGeneratorProvider):
working = True
async def create_async_generator(
model, messages, stream, **kwargs
):
yield None

View File

@ -0,0 +1,60 @@
from __future__ import annotations
import unittest
from g4f.client import AsyncClient, ChatCompletion, ChatCompletionChunk
from g4f.providers.retry_provider import IterListProvider
from .mocks import YieldProviderMock, RaiseExceptionProviderMock, AsyncRaiseExceptionProviderMock, YieldNoneProviderMock
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
class TestIterListProvider(unittest.IsolatedAsyncioTestCase):
async def test_skip_provider(self):
client = AsyncClient(provider=IterListProvider([RaiseExceptionProviderMock, YieldProviderMock], False))
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("Hello", response.choices[0].message.content)
async def test_only_one_result(self):
client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock]))
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("Hello", response.choices[0].message.content)
async def test_stream_skip_provider(self):
client = AsyncClient(provider=IterListProvider([AsyncRaiseExceptionProviderMock, YieldProviderMock], False))
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
response = client.chat.completions.create(messages, "Hello", stream=True)
async for chunk in response:
chunk: ChatCompletionChunk = chunk
self.assertIsInstance(chunk, ChatCompletionChunk)
if chunk.choices[0].delta.content is not None:
self.assertIsInstance(chunk.choices[0].delta.content, str)
async def test_stream_only_one_result(self):
client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock], False))
messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You "]]
response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
response_list = []
async for chunk in response:
response_list.append(chunk)
self.assertEqual(len(response_list), 3)
for chunk in response_list:
if chunk.choices[0].delta.content is not None:
self.assertEqual(chunk.choices[0].delta.content, "You ")
async def test_skip_none(self):
client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False))
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("Hello", response.choices[0].message.content)
async def test_stream_skip_none(self):
client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False))
response = client.chat.completions.create(DEFAULT_MESSAGES, "", stream=True)
response_list = [chunk async for chunk in response]
self.assertEqual(len(response_list), 2)
for chunk in response_list:
if chunk.choices[0].delta.content is not None:
self.assertEqual(chunk.choices[0].delta.content, "Hello")

View File

@ -474,7 +474,7 @@ class AsyncCompletions:
**kwargs
)
if not isinstance(response, AsyncIterator):
if not hasattr(response, "__aiter__"):
response = to_async_iterator(response)
response = async_iter_response(response, stream, response_format, max_tokens, stop)
response = async_iter_append_model_and_provider(response)

View File

@ -7,14 +7,14 @@ from ..errors import ProviderNotFoundError, ModelNotFoundError, ProviderNotWorki
from ..models import Model, ModelUtils, default
from ..Provider import ProviderUtils
from ..providers.types import BaseRetryProvider, ProviderType
from ..providers.retry_provider import IterProvider
from ..providers.retry_provider import IterListProvider
def convert_to_provider(provider: str) -> ProviderType:
if " " in provider:
provider_list = [ProviderUtils.convert[p] for p in provider.split() if p in ProviderUtils.convert]
if not provider_list:
raise ProviderNotFoundError(f'Providers not found: {provider}')
provider = IterProvider(provider_list)
provider = IterListProvider(provider_list, False)
elif provider in ProviderUtils.convert:
provider = ProviderUtils.convert[provider]
elif provider:

View File

@ -57,7 +57,9 @@ class AbstractProvider(BaseProvider):
loop = loop or asyncio.get_running_loop()
def create_func() -> str:
return "".join(cls.create_completion(model, messages, False, **kwargs))
chunks = [str(chunk) for chunk in cls.create_completion(model, messages, False, **kwargs) if chunk]
if chunks:
return "".join(chunks)
return await asyncio.wait_for(
loop.run_in_executor(executor, create_func),
@ -205,7 +207,7 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
return "".join([
str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
if chunk and not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
])
@staticmethod

View File

@ -8,6 +8,8 @@ from .types import BaseProvider, BaseRetryProvider, ProviderType
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
DEFAULT_TIMEOUT = 60
class IterListProvider(BaseRetryProvider):
def __init__(
self,
@ -50,11 +52,11 @@ class IterListProvider(BaseRetryProvider):
for provider in self.get_providers(stream):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
for chunk in provider.create_completion(model, messages, stream, **kwargs):
if chunk:
yield chunk
started = True
if started:
return
@ -87,13 +89,14 @@ class IterListProvider(BaseRetryProvider):
for provider in self.get_providers(False):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
return await asyncio.wait_for(
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
return chunk
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
@ -119,15 +122,20 @@ class IterListProvider(BaseRetryProvider):
for provider in self.get_providers(stream):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
if not stream:
yield await provider.create_async(model, messages, **kwargs)
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
yield chunk
started = True
elif hasattr(provider, "create_async_generator"):
async for token in provider.create_async_generator(model, messages, stream=stream, **kwargs):
yield token
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
if chunk:
yield chunk
started = True
else:
for token in provider.create_completion(model, messages, stream, **kwargs):
@ -137,8 +145,7 @@ class IterListProvider(BaseRetryProvider):
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
@ -243,31 +250,35 @@ class RetryProvider(IterListProvider):
else:
return await super().create_async(model, messages, **kwargs)
class IterProvider(BaseRetryProvider):
__name__ = "IterProvider"
def __init__(
self,
providers: List[BaseProvider],
) -> None:
providers.reverse()
self.providers: List[BaseProvider] = providers
self.working: bool = True
self.last_provider: BaseProvider = None
def create_completion(
async def create_async_generator(
self,
model: str,
messages: Messages,
stream: bool = False,
stream: bool = True,
**kwargs
) -> CreateResult:
exceptions: dict = {}
started: bool = False
for provider in self.iter_providers():
if stream and not provider.supports_stream:
continue
) -> AsyncResult:
exceptions = {}
started = False
if self.single_provider_retry:
provider = self.providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
try:
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
if not stream:
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
started = True
elif hasattr(provider, "create_async_generator"):
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
if chunk:
yield chunk
started = True
else:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
@ -277,42 +288,10 @@ class IterProvider(BaseRetryProvider):
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
raise_exceptions(exceptions)
async def create_async(
self,
model: str,
messages: Messages,
**kwargs
) -> str:
exceptions: dict = {}
for provider in self.iter_providers():
try:
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60)
)
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)
def iter_providers(self) -> Iterator[BaseProvider]:
used_provider = []
try:
while self.providers:
provider = self.providers.pop()
used_provider.append(provider)
self.last_provider = provider
if debug.logging:
print(f"Using {provider.__name__} provider")
yield provider
finally:
used_provider.reverse()
self.providers = [*used_provider, *self.providers]
else:
async for chunk in super().create_async_generator(model, messages, stream, **kwargs):
yield chunk
def raise_exceptions(exceptions: dict) -> None:
"""