mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-12-23 11:02:40 +03:00
Fix api with default providers, add unittests for RetryProvider
This commit is contained in:
parent
971a01eb5c
commit
c31f5435c4
@ -6,5 +6,6 @@ from .main import *
|
||||
from .model import *
|
||||
from .client import *
|
||||
from .include import *
|
||||
from .retry_provider import *
|
||||
|
||||
unittest.main()
|
@ -40,3 +40,31 @@ class YieldProviderMock(AsyncGeneratorProvider):
|
||||
):
|
||||
for message in messages:
|
||||
yield message["content"]
|
||||
|
||||
class RaiseExceptionProviderMock(AbstractProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
def create_completion(
|
||||
cls, model, messages, stream, **kwargs
|
||||
):
|
||||
raise RuntimeError(cls.__name__)
|
||||
yield cls.__name__
|
||||
|
||||
class AsyncRaiseExceptionProviderMock(AsyncGeneratorProvider):
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls, model, messages, stream, **kwargs
|
||||
):
|
||||
raise RuntimeError(cls.__name__)
|
||||
yield cls.__name__
|
||||
|
||||
class YieldNoneProviderMock(AsyncGeneratorProvider):
|
||||
working = True
|
||||
|
||||
async def create_async_generator(
|
||||
model, messages, stream, **kwargs
|
||||
):
|
||||
yield None
|
60
etc/unittest/retry_provider.py
Normal file
60
etc/unittest/retry_provider.py
Normal file
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from g4f.client import AsyncClient, ChatCompletion, ChatCompletionChunk
|
||||
from g4f.providers.retry_provider import IterListProvider
|
||||
from .mocks import YieldProviderMock, RaiseExceptionProviderMock, AsyncRaiseExceptionProviderMock, YieldNoneProviderMock
|
||||
|
||||
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
||||
|
||||
class TestIterListProvider(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_skip_provider(self):
|
||||
client = AsyncClient(provider=IterListProvider([RaiseExceptionProviderMock, YieldProviderMock], False))
|
||||
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
|
||||
self.assertIsInstance(response, ChatCompletion)
|
||||
self.assertEqual("Hello", response.choices[0].message.content)
|
||||
|
||||
async def test_only_one_result(self):
|
||||
client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock]))
|
||||
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
|
||||
self.assertIsInstance(response, ChatCompletion)
|
||||
self.assertEqual("Hello", response.choices[0].message.content)
|
||||
|
||||
async def test_stream_skip_provider(self):
|
||||
client = AsyncClient(provider=IterListProvider([AsyncRaiseExceptionProviderMock, YieldProviderMock], False))
|
||||
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
|
||||
response = client.chat.completions.create(messages, "Hello", stream=True)
|
||||
async for chunk in response:
|
||||
chunk: ChatCompletionChunk = chunk
|
||||
self.assertIsInstance(chunk, ChatCompletionChunk)
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
self.assertIsInstance(chunk.choices[0].delta.content, str)
|
||||
|
||||
async def test_stream_only_one_result(self):
|
||||
client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock], False))
|
||||
messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You "]]
|
||||
response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
|
||||
response_list = []
|
||||
async for chunk in response:
|
||||
response_list.append(chunk)
|
||||
self.assertEqual(len(response_list), 3)
|
||||
for chunk in response_list:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
self.assertEqual(chunk.choices[0].delta.content, "You ")
|
||||
|
||||
async def test_skip_none(self):
|
||||
client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False))
|
||||
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
|
||||
self.assertIsInstance(response, ChatCompletion)
|
||||
self.assertEqual("Hello", response.choices[0].message.content)
|
||||
|
||||
async def test_stream_skip_none(self):
|
||||
client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False))
|
||||
response = client.chat.completions.create(DEFAULT_MESSAGES, "", stream=True)
|
||||
response_list = [chunk async for chunk in response]
|
||||
self.assertEqual(len(response_list), 2)
|
||||
for chunk in response_list:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
self.assertEqual(chunk.choices[0].delta.content, "Hello")
|
@ -474,7 +474,7 @@ class AsyncCompletions:
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not isinstance(response, AsyncIterator):
|
||||
if not hasattr(response, "__aiter__"):
|
||||
response = to_async_iterator(response)
|
||||
response = async_iter_response(response, stream, response_format, max_tokens, stop)
|
||||
response = async_iter_append_model_and_provider(response)
|
||||
|
@ -7,14 +7,14 @@ from ..errors import ProviderNotFoundError, ModelNotFoundError, ProviderNotWorki
|
||||
from ..models import Model, ModelUtils, default
|
||||
from ..Provider import ProviderUtils
|
||||
from ..providers.types import BaseRetryProvider, ProviderType
|
||||
from ..providers.retry_provider import IterProvider
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
|
||||
def convert_to_provider(provider: str) -> ProviderType:
|
||||
if " " in provider:
|
||||
provider_list = [ProviderUtils.convert[p] for p in provider.split() if p in ProviderUtils.convert]
|
||||
if not provider_list:
|
||||
raise ProviderNotFoundError(f'Providers not found: {provider}')
|
||||
provider = IterProvider(provider_list)
|
||||
provider = IterListProvider(provider_list, False)
|
||||
elif provider in ProviderUtils.convert:
|
||||
provider = ProviderUtils.convert[provider]
|
||||
elif provider:
|
||||
|
@ -57,7 +57,9 @@ class AbstractProvider(BaseProvider):
|
||||
loop = loop or asyncio.get_running_loop()
|
||||
|
||||
def create_func() -> str:
|
||||
return "".join(cls.create_completion(model, messages, False, **kwargs))
|
||||
chunks = [str(chunk) for chunk in cls.create_completion(model, messages, False, **kwargs) if chunk]
|
||||
if chunks:
|
||||
return "".join(chunks)
|
||||
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, create_func),
|
||||
@ -205,7 +207,7 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
"""
|
||||
return "".join([
|
||||
str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
|
||||
if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
|
||||
if chunk and not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
|
@ -8,6 +8,8 @@ from .types import BaseProvider, BaseRetryProvider, ProviderType
|
||||
from .. import debug
|
||||
from ..errors import RetryProviderError, RetryNoProviderError
|
||||
|
||||
DEFAULT_TIMEOUT = 60
|
||||
|
||||
class IterListProvider(BaseRetryProvider):
|
||||
def __init__(
|
||||
self,
|
||||
@ -50,11 +52,11 @@ class IterListProvider(BaseRetryProvider):
|
||||
|
||||
for provider in self.get_providers(stream):
|
||||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
for chunk in provider.create_completion(model, messages, stream, **kwargs):
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
if started:
|
||||
return
|
||||
@ -87,13 +89,14 @@ class IterListProvider(BaseRetryProvider):
|
||||
|
||||
for provider in self.get_providers(False):
|
||||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
return await asyncio.wait_for(
|
||||
chunk = await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", 60),
|
||||
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
|
||||
)
|
||||
if chunk:
|
||||
return chunk
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
@ -119,15 +122,20 @@ class IterListProvider(BaseRetryProvider):
|
||||
|
||||
for provider in self.get_providers(stream):
|
||||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
if not stream:
|
||||
yield await provider.create_async(model, messages, **kwargs)
|
||||
chunk = await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
|
||||
)
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
elif hasattr(provider, "create_async_generator"):
|
||||
async for token in provider.create_async_generator(model, messages, stream=stream, **kwargs):
|
||||
yield token
|
||||
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
else:
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
@ -137,8 +145,7 @@ class IterListProvider(BaseRetryProvider):
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
|
||||
@ -243,31 +250,35 @@ class RetryProvider(IterListProvider):
|
||||
else:
|
||||
return await super().create_async(model, messages, **kwargs)
|
||||
|
||||
class IterProvider(BaseRetryProvider):
|
||||
__name__ = "IterProvider"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
providers: List[BaseProvider],
|
||||
) -> None:
|
||||
providers.reverse()
|
||||
self.providers: List[BaseProvider] = providers
|
||||
self.working: bool = True
|
||||
self.last_provider: BaseProvider = None
|
||||
|
||||
def create_completion(
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
stream: bool = True,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
exceptions: dict = {}
|
||||
started: bool = False
|
||||
for provider in self.iter_providers():
|
||||
if stream and not provider.supports_stream:
|
||||
continue
|
||||
) -> AsyncResult:
|
||||
exceptions = {}
|
||||
started = False
|
||||
|
||||
if self.single_provider_retry:
|
||||
provider = self.providers[0]
|
||||
self.last_provider = provider
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
|
||||
if not stream:
|
||||
chunk = await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
|
||||
)
|
||||
if chunk:
|
||||
started = True
|
||||
elif hasattr(provider, "create_async_generator"):
|
||||
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
|
||||
if chunk:
|
||||
yield chunk
|
||||
started = True
|
||||
else:
|
||||
for token in provider.create_completion(model, messages, stream, **kwargs):
|
||||
yield token
|
||||
started = True
|
||||
@ -277,42 +288,10 @@ class IterProvider(BaseRetryProvider):
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
async def create_async(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
**kwargs
|
||||
) -> str:
|
||||
exceptions: dict = {}
|
||||
for provider in self.iter_providers():
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
provider.create_async(model, messages, **kwargs),
|
||||
timeout=kwargs.get("timeout", 60)
|
||||
)
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
def iter_providers(self) -> Iterator[BaseProvider]:
|
||||
used_provider = []
|
||||
try:
|
||||
while self.providers:
|
||||
provider = self.providers.pop()
|
||||
used_provider.append(provider)
|
||||
self.last_provider = provider
|
||||
if debug.logging:
|
||||
print(f"Using {provider.__name__} provider")
|
||||
yield provider
|
||||
finally:
|
||||
used_provider.reverse()
|
||||
self.providers = [*used_provider, *self.providers]
|
||||
else:
|
||||
async for chunk in super().create_async_generator(model, messages, stream, **kwargs):
|
||||
yield chunk
|
||||
|
||||
def raise_exceptions(exceptions: dict) -> None:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user