gpt4free/g4f/Provider/base_provider.py

249 lines
8.2 KiB
Python
Raw Normal View History

from __future__ import annotations
2023-11-20 22:00:56 +03:00
import sys
import asyncio
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
from inspect import signature, Parameter
from .helper import get_event_loop, get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages
from ..base_provider import BaseProvider
2024-01-20 20:23:54 +03:00
from ..errors import NestAsyncioError
2023-07-28 13:07:17 +03:00
2023-11-20 22:00:56 +03:00
if sys.version_info < (3, 10):
NoneType = type(None)
else:
from types import NoneType
2023-07-28 13:07:17 +03:00
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32':
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
"""
@classmethod
async def create_async(
cls,
model: str,
2023-10-10 10:49:29 +03:00
messages: Messages,
*,
loop: AbstractEventLoop = None,
executor: ThreadPoolExecutor = None,
**kwargs
) -> str:
"""
Asynchronously creates a result based on the given model and messages.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
"""
2024-01-20 20:23:54 +03:00
loop = loop or asyncio.get_running_loop()
2023-10-07 11:17:43 +03:00
def create_func() -> str:
return "".join(cls.create_completion(model, messages, False, **kwargs))
2023-10-07 11:17:43 +03:00
return await asyncio.wait_for(
loop.run_in_executor(executor, create_func),
timeout=kwargs.get("timeout", 0)
)
2023-11-20 15:59:14 +03:00
2023-07-28 13:07:17 +03:00
@classmethod
@property
2023-10-07 11:17:43 +03:00
def params(cls) -> str:
"""
Returns the parameters supported by the provider.
Args:
cls (type): The class on which this property is called.
Returns:
str: A string listing the supported parameters.
"""
sig = signature(
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
cls.create_async if issubclass(cls, AsyncProvider) else
cls.create_completion
)
2023-11-20 15:59:14 +03:00
def get_type_name(annotation: type) -> str:
return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
2023-11-24 17:16:00 +03:00
args = ""
2023-11-20 15:59:14 +03:00
for name, param in sig.parameters.items():
if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream):
2023-11-20 15:59:14 +03:00
continue
args += f"\n {name}"
args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else ""
2023-11-20 15:59:14 +03:00
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
class AsyncProvider(AbstractProvider):
"""
Provides asynchronous functionality for creating completions.
"""
@classmethod
def create_completion(
cls,
model: str,
2023-10-10 10:49:29 +03:00
messages: Messages,
2023-09-18 08:15:43 +03:00
stream: bool = False,
**kwargs
) -> CreateResult:
"""
Creates a completion result synchronously.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to False.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the completion creation.
"""
2024-01-20 20:23:54 +03:00
try:
loop = asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
raise NestAsyncioError(
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
)
except RuntimeError:
pass
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@abstractmethod
async def create_async(
model: str,
2023-10-10 10:49:29 +03:00
messages: Messages,
2023-09-20 18:31:25 +03:00
**kwargs
) -> str:
"""
Abstract method for creating asynchronous results.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Raises:
NotImplementedError: If this method is not overridden in derived classes.
Returns:
str: The created result as a string.
"""
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
"""
Provides asynchronous generator functionality for streaming results.
"""
supports_stream = True
@classmethod
def create_completion(
cls,
model: str,
2023-10-10 10:49:29 +03:00
messages: Messages,
stream: bool = True,
**kwargs
) -> CreateResult:
"""
Creates a streaming completion result synchronously.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to True.
loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
CreateResult: The result of the streaming completion creation.
"""
2024-01-20 20:23:54 +03:00
try:
loop = asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
raise NestAsyncioError(
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
2023-10-07 11:17:43 +03:00
gen = generator.__aiter__()
while True:
try:
yield loop.run_until_complete(gen.__anext__())
except StopAsyncIteration:
break
2023-09-18 08:15:43 +03:00
@classmethod
async def create_async(
cls,
model: str,
2023-10-10 10:49:29 +03:00
messages: Messages,
**kwargs
) -> str:
"""
Asynchronously creates a result from a generator.
Args:
cls (type): The class on which this method is called.
model (str): The model to use for creation.
messages (Messages): The messages to process.
**kwargs: Additional keyword arguments.
Returns:
str: The created result as a string.
"""
2023-09-20 18:31:25 +03:00
return "".join([
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, Exception)
2023-09-20 18:31:25 +03:00
])
2023-10-07 11:17:43 +03:00
@staticmethod
@abstractmethod
async def create_async_generator(
2023-09-18 08:15:43 +03:00
model: str,
2023-10-10 10:49:29 +03:00
messages: Messages,
stream: bool = True,
2023-09-18 08:15:43 +03:00
**kwargs
2023-10-10 10:49:29 +03:00
) -> AsyncResult:
"""
Abstract method for creating an asynchronous generator.
Args:
model (str): The model to use for creation.
messages (Messages): The messages to process.
stream (bool): Indicates whether to stream the results. Defaults to True.
**kwargs: Additional keyword arguments.
Raises:
NotImplementedError: If this method is not overridden in derived classes.
Returns:
AsyncResult: An asynchronous generator yielding results.
"""
raise NotImplementedError()