diff --git a/etc/unittest/__main__.py b/etc/unittest/__main__.py index 0acc5865..e49dec30 100644 --- a/etc/unittest/__main__.py +++ b/etc/unittest/__main__.py @@ -6,5 +6,6 @@ from .main import * from .model import * from .client import * from .include import * +from .retry_provider import * -unittest.main() +unittest.main() \ No newline at end of file diff --git a/etc/unittest/mocks.py b/etc/unittest/mocks.py index 102730fa..c2058e34 100644 --- a/etc/unittest/mocks.py +++ b/etc/unittest/mocks.py @@ -34,9 +34,37 @@ class ModelProviderMock(AbstractProvider): class YieldProviderMock(AsyncGeneratorProvider): working = True - + async def create_async_generator( model, messages, stream, **kwargs ): for message in messages: - yield message["content"] \ No newline at end of file + yield message["content"] + +class RaiseExceptionProviderMock(AbstractProvider): + working = True + + @classmethod + def create_completion( + cls, model, messages, stream, **kwargs + ): + raise RuntimeError(cls.__name__) + yield cls.__name__ + +class AsyncRaiseExceptionProviderMock(AsyncGeneratorProvider): + working = True + + @classmethod + async def create_async_generator( + cls, model, messages, stream, **kwargs + ): + raise RuntimeError(cls.__name__) + yield cls.__name__ + +class YieldNoneProviderMock(AsyncGeneratorProvider): + working = True + + async def create_async_generator( + model, messages, stream, **kwargs + ): + yield None \ No newline at end of file diff --git a/etc/unittest/retry_provider.py b/etc/unittest/retry_provider.py new file mode 100644 index 00000000..6d41ef94 --- /dev/null +++ b/etc/unittest/retry_provider.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import unittest + +from g4f.client import AsyncClient, ChatCompletion, ChatCompletionChunk +from g4f.providers.retry_provider import IterListProvider +from .mocks import YieldProviderMock, RaiseExceptionProviderMock, AsyncRaiseExceptionProviderMock, YieldNoneProviderMock + +DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] + +class TestIterListProvider(unittest.IsolatedAsyncioTestCase): + + async def test_skip_provider(self): + client = AsyncClient(provider=IterListProvider([RaiseExceptionProviderMock, YieldProviderMock], False)) + response = await client.chat.completions.create(DEFAULT_MESSAGES, "") + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("Hello", response.choices[0].message.content) + + async def test_only_one_result(self): + client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock])) + response = await client.chat.completions.create(DEFAULT_MESSAGES, "") + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("Hello", response.choices[0].message.content) + + async def test_stream_skip_provider(self): + client = AsyncClient(provider=IterListProvider([AsyncRaiseExceptionProviderMock, YieldProviderMock], False)) + messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]] + response = client.chat.completions.create(messages, "Hello", stream=True) + async for chunk in response: + chunk: ChatCompletionChunk = chunk + self.assertIsInstance(chunk, ChatCompletionChunk) + if chunk.choices[0].delta.content is not None: + self.assertIsInstance(chunk.choices[0].delta.content, str) + + async def test_stream_only_one_result(self): + client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock], False)) + messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You "]] + response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2) + response_list = [] + async for chunk in response: + response_list.append(chunk) + self.assertEqual(len(response_list), 3) + for chunk in response_list: + if chunk.choices[0].delta.content is not None: + self.assertEqual(chunk.choices[0].delta.content, "You ") + + async def test_skip_none(self): + client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False)) + response = await client.chat.completions.create(DEFAULT_MESSAGES, "") + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("Hello", response.choices[0].message.content) + + async def test_stream_skip_none(self): + client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False)) + response = client.chat.completions.create(DEFAULT_MESSAGES, "", stream=True) + response_list = [chunk async for chunk in response] + self.assertEqual(len(response_list), 2) + for chunk in response_list: + if chunk.choices[0].delta.content is not None: + self.assertEqual(chunk.choices[0].delta.content, "Hello") \ No newline at end of file diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index b00c5a65..ea47ec73 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -474,7 +474,7 @@ class AsyncCompletions: **kwargs ) - if not isinstance(response, AsyncIterator): + if not hasattr(response, "__aiter__"): response = to_async_iterator(response) response = async_iter_response(response, stream, response_format, max_tokens, stop) response = async_iter_append_model_and_provider(response) diff --git a/g4f/client/service.py b/g4f/client/service.py index 45230c79..80dc70df 100644 --- a/g4f/client/service.py +++ b/g4f/client/service.py @@ -7,14 +7,14 @@ from ..errors import ProviderNotFoundError, ModelNotFoundError, ProviderNotWorki from ..models import Model, ModelUtils, default from ..Provider import ProviderUtils from ..providers.types import BaseRetryProvider, ProviderType -from ..providers.retry_provider import IterProvider +from ..providers.retry_provider import IterListProvider def convert_to_provider(provider: str) -> ProviderType: if " " in provider: provider_list = [ProviderUtils.convert[p] for p in provider.split() if p in ProviderUtils.convert] if not provider_list: raise ProviderNotFoundError(f'Providers not found: {provider}') - provider = IterProvider(provider_list) + provider = IterListProvider(provider_list, False) elif provider in ProviderUtils.convert: provider = ProviderUtils.convert[provider] elif provider: diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index 80a9e09d..e8a47154 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -57,7 +57,9 @@ class AbstractProvider(BaseProvider): loop = loop or asyncio.get_running_loop() def create_func() -> str: - return "".join(cls.create_completion(model, messages, False, **kwargs)) + chunks = [str(chunk) for chunk in cls.create_completion(model, messages, False, **kwargs) if chunk] + if chunks: + return "".join(chunks) return await asyncio.wait_for( loop.run_in_executor(executor, create_func), @@ -205,7 +207,7 @@ class AsyncGeneratorProvider(AsyncProvider): """ return "".join([ str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) - if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData)) + if chunk and not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData)) ]) @staticmethod diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py index efcae375..92386955 100644 --- a/g4f/providers/retry_provider.py +++ b/g4f/providers/retry_provider.py @@ -8,6 +8,8 @@ from .types import BaseProvider, BaseRetryProvider, ProviderType from .. import debug from ..errors import RetryProviderError, RetryNoProviderError +DEFAULT_TIMEOUT = 60 + class IterListProvider(BaseRetryProvider): def __init__( self, @@ -50,12 +52,12 @@ class IterListProvider(BaseRetryProvider): for provider in self.get_providers(stream): self.last_provider = provider + debug.log(f"Using {provider.__name__} provider") try: - if debug.logging: - print(f"Using {provider.__name__} provider") - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token - started = True + for chunk in provider.create_completion(model, messages, stream, **kwargs): + if chunk: + yield chunk + started = True if started: return except Exception as e: @@ -87,13 +89,14 @@ class IterListProvider(BaseRetryProvider): for provider in self.get_providers(False): self.last_provider = provider + debug.log(f"Using {provider.__name__} provider") try: - if debug.logging: - print(f"Using {provider.__name__} provider") - return await asyncio.wait_for( + chunk = await asyncio.wait_for( provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60), + timeout=kwargs.get("timeout", DEFAULT_TIMEOUT), ) + if chunk: + return chunk except Exception as e: exceptions[provider.__name__] = e if debug.logging: @@ -119,16 +122,21 @@ class IterListProvider(BaseRetryProvider): for provider in self.get_providers(stream): self.last_provider = provider + debug.log(f"Using {provider.__name__} provider") try: - if debug.logging: - print(f"Using {provider.__name__} provider") if not stream: - yield await provider.create_async(model, messages, **kwargs) - started = True - elif hasattr(provider, "create_async_generator"): - async for token in provider.create_async_generator(model, messages, stream=stream, **kwargs): - yield token + chunk = await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", DEFAULT_TIMEOUT), + ) + if chunk: + yield chunk started = True + elif hasattr(provider, "create_async_generator"): + async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs): + if chunk: + yield chunk + started = True else: for token in provider.create_completion(model, messages, stream, **kwargs): yield token @@ -137,8 +145,7 @@ class IterListProvider(BaseRetryProvider): return except Exception as e: exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}") if started: raise e @@ -243,76 +250,48 @@ class RetryProvider(IterListProvider): else: return await super().create_async(model, messages, **kwargs) -class IterProvider(BaseRetryProvider): - __name__ = "IterProvider" - - def __init__( - self, - providers: List[BaseProvider], - ) -> None: - providers.reverse() - self.providers: List[BaseProvider] = providers - self.working: bool = True - self.last_provider: BaseProvider = None - - def create_completion( + async def create_async_generator( self, model: str, messages: Messages, - stream: bool = False, + stream: bool = True, **kwargs - ) -> CreateResult: - exceptions: dict = {} - started: bool = False - for provider in self.iter_providers(): - if stream and not provider.supports_stream: - continue - try: - for token in provider.create_completion(model, messages, stream, **kwargs): - yield token - started = True - if started: - return - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - if started: - raise e - raise_exceptions(exceptions) + ) -> AsyncResult: + exceptions = {} + started = False - async def create_async( - self, - model: str, - messages: Messages, - **kwargs - ) -> str: - exceptions: dict = {} - for provider in self.iter_providers(): - try: - return await asyncio.wait_for( - provider.create_async(model, messages, **kwargs), - timeout=kwargs.get("timeout", 60) - ) - except Exception as e: - exceptions[provider.__name__] = e - if debug.logging: - print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - raise_exceptions(exceptions) - - def iter_providers(self) -> Iterator[BaseProvider]: - used_provider = [] - try: - while self.providers: - provider = self.providers.pop() - used_provider.append(provider) - self.last_provider = provider - if debug.logging: - print(f"Using {provider.__name__} provider") - yield provider - finally: - used_provider.reverse() - self.providers = [*used_provider, *self.providers] + if self.single_provider_retry: + provider = self.providers[0] + self.last_provider = provider + for attempt in range(self.max_retries): + try: + debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})") + if not stream: + chunk = await asyncio.wait_for( + provider.create_async(model, messages, **kwargs), + timeout=kwargs.get("timeout", DEFAULT_TIMEOUT), + ) + if chunk: + started = True + elif hasattr(provider, "create_async_generator"): + async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs): + if chunk: + yield chunk + started = True + else: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + exceptions[provider.__name__] = e + if debug.logging: + print(f"{provider.__name__}: {e.__class__.__name__}: {e}") + raise_exceptions(exceptions) + else: + async for chunk in super().create_async_generator(model, messages, stream, **kwargs): + yield chunk def raise_exceptions(exceptions: dict) -> None: """