Fix unit tests

This commit is contained in:
Heiner Lohaus 2024-11-26 23:38:48 +01:00
parent 4ae3d98df8
commit 16a11f991f
3 changed files with 17 additions and 14 deletions

View File

@ -8,7 +8,6 @@ import asyncio
import base64
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
from ..providers.base_provider import AsyncGeneratorProvider
from ..image import ImageResponse, copy_images, images_dir
from ..typing import Messages, Image, ImageType
from ..providers.types import ProviderType
@ -292,6 +291,7 @@ class Images:
**kwargs
) -> ImagesResponse:
provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__
if proxy is None:
proxy = self.client.proxy
@ -317,17 +317,17 @@ class Images:
response = item
break
else:
raise ValueError(f"Provider {getattr(provider_handler, '__name__')} does not support image generation")
raise ValueError(f"Provider {provider_name} does not support image generation")
if isinstance(response, ImageResponse):
return await self._process_image_response(
response,
response_format,
proxy,
model,
getattr(provider_handler, "__name__", None)
provider_name
)
if response is None:
raise NoImageResponseError(f"No image response from {getattr(provider_handler, '__name__')}")
raise NoImageResponseError(f"No image response from {provider_name}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
def create_variation(
@ -352,6 +352,7 @@ class Images:
**kwargs
) -> ImagesResponse:
provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount)
provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__
if proxy is None:
proxy = self.client.proxy
@ -372,14 +373,14 @@ class Images:
else:
response = provider_handler.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else:
raise NoImageResponseError(f"Provider {provider} does not support image variation")
raise NoImageResponseError(f"Provider {provider_name} does not support image variation")
if isinstance(response, str):
response = ImageResponse([response])
if isinstance(response, ImageResponse):
return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None))
return self._process_image_response(response, response_format, proxy, model, provider_name)
if response is None:
raise NoImageResponseError(f"No image response from {getattr(provider, '__name__')}")
raise NoImageResponseError(f"No image response from {provider_name}")
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
async def _process_image_response(

View File

@ -74,11 +74,13 @@ def get_model_and_provider(model : Union[Model, str],
if not provider:
raise ProviderNotFoundError(f'No provider found for model: {model}')
provider_name = provider.__name__ if hasattr(provider, "__name__") else type(provider).__name__
if isinstance(model, Model):
model = model.name
if not ignore_working and not provider.working:
raise ProviderNotWorkingError(f'{provider.__name__} is not working')
raise ProviderNotWorkingError(f"{provider_name} is not working")
if isinstance(provider, BaseRetryProvider):
if not ignore_working:
@ -87,12 +89,12 @@ def get_model_and_provider(model : Union[Model, str],
provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
if not ignore_stream and not provider.supports_stream and stream:
raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument')
raise StreamNotSupportedError(f'{provider_name} does not support "stream" argument')
if model:
debug.log(f'Using {type(provider).__name__} provider and {model} model')
debug.log(f'Using {provider_name} provider and {model} model')
else:
debug.log(f'Using {type(provider).__name__} provider')
debug.log(f'Using {provider_name} provider')
debug.last_provider = provider
debug.last_model = model
@ -115,7 +117,7 @@ def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, st
if as_dict:
if last:
return {
"name": type(last).__name__,
"name": last.__name__ if hasattr(last, "__name__") else type(last).__name__,
"url": last.url,
"model": debug.last_model,
"label": getattr(last, "label", None) if hasattr(last, "label") else None