2024-01-26 09:54:13 +03:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-01-01 19:48:57 +03:00
|
|
|
from abc import ABC, abstractmethod
|
2024-01-14 09:45:41 +03:00
|
|
|
from typing import Union, List, Dict, Type
|
|
|
|
from .typing import Messages, CreateResult
|
|
|
|
|
2024-01-01 19:48:57 +03:00
|
|
|
class BaseProvider(ABC):
|
2024-01-14 09:45:41 +03:00
|
|
|
"""
|
|
|
|
Abstract base class for a provider.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
url (str): URL of the provider.
|
|
|
|
working (bool): Indicates if the provider is currently working.
|
|
|
|
needs_auth (bool): Indicates if the provider needs authentication.
|
|
|
|
supports_stream (bool): Indicates if the provider supports streaming.
|
|
|
|
supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
|
|
|
|
supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
|
|
|
|
supports_message_history (bool): Indicates if the provider supports message history.
|
|
|
|
params (str): List parameters for the provider.
|
|
|
|
"""
|
|
|
|
|
2024-01-13 17:37:36 +03:00
|
|
|
url: str = None
|
2024-01-01 19:48:57 +03:00
|
|
|
working: bool = False
|
|
|
|
needs_auth: bool = False
|
|
|
|
supports_stream: bool = False
|
|
|
|
supports_gpt_35_turbo: bool = False
|
|
|
|
supports_gpt_4: bool = False
|
|
|
|
supports_message_history: bool = False
|
|
|
|
params: str
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@abstractmethod
|
|
|
|
def create_completion(
|
|
|
|
cls,
|
|
|
|
model: str,
|
|
|
|
messages: Messages,
|
|
|
|
stream: bool,
|
|
|
|
**kwargs
|
|
|
|
) -> CreateResult:
|
2024-01-14 09:45:41 +03:00
|
|
|
"""
|
|
|
|
Create a completion with the given parameters.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (str): The model to use.
|
|
|
|
messages (Messages): The messages to process.
|
|
|
|
stream (bool): Whether to use streaming.
|
|
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
CreateResult: The result of the creation process.
|
|
|
|
"""
|
2024-01-01 19:48:57 +03:00
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@abstractmethod
|
|
|
|
async def create_async(
|
|
|
|
cls,
|
|
|
|
model: str,
|
|
|
|
messages: Messages,
|
|
|
|
**kwargs
|
|
|
|
) -> str:
|
2024-01-14 09:45:41 +03:00
|
|
|
"""
|
|
|
|
Asynchronously create a completion with the given parameters.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (str): The model to use.
|
|
|
|
messages (Messages): The messages to process.
|
|
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: The result of the creation process.
|
|
|
|
"""
|
2024-01-01 19:48:57 +03:00
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@classmethod
|
2024-01-14 09:45:41 +03:00
|
|
|
def get_dict(cls) -> Dict[str, str]:
|
|
|
|
"""
|
|
|
|
Get a dictionary representation of the provider.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Dict[str, str]: A dictionary with provider's details.
|
|
|
|
"""
|
2024-01-01 19:48:57 +03:00
|
|
|
return {'name': cls.__name__, 'url': cls.url}
|
|
|
|
|
|
|
|
class BaseRetryProvider(BaseProvider):
|
2024-01-14 09:45:41 +03:00
|
|
|
"""
|
|
|
|
Base class for a provider that implements retry logic.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
providers (List[Type[BaseProvider]]): List of providers to use for retries.
|
|
|
|
shuffle (bool): Whether to shuffle the providers list.
|
|
|
|
exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
|
|
|
|
last_provider (Type[BaseProvider]): The last provider used.
|
|
|
|
"""
|
|
|
|
|
2024-01-01 19:48:57 +03:00
|
|
|
__name__: str = "RetryProvider"
|
|
|
|
supports_stream: bool = True
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2024-01-14 09:45:41 +03:00
|
|
|
providers: List[Type[BaseProvider]],
|
2024-01-01 19:48:57 +03:00
|
|
|
shuffle: bool = True
|
|
|
|
) -> None:
|
2024-01-14 09:45:41 +03:00
|
|
|
"""
|
|
|
|
Initialize the BaseRetryProvider.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
providers (List[Type[BaseProvider]]): List of providers to use.
|
|
|
|
shuffle (bool): Whether to shuffle the providers list.
|
|
|
|
"""
|
|
|
|
self.providers = providers
|
|
|
|
self.shuffle = shuffle
|
|
|
|
self.working = True
|
|
|
|
self.exceptions: Dict[str, Exception] = {}
|
|
|
|
self.last_provider: Type[BaseProvider] = None
|
2024-01-01 19:48:57 +03:00
|
|
|
|
2024-01-14 09:45:41 +03:00
|
|
|
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|