gpt4free/g4f/Provider/base_provider.py

180 lines
4.6 KiB
Python
Raw Normal View History

from __future__ import annotations
2023-07-28 13:07:17 +03:00
import asyncio
import functools
from asyncio import SelectorEventLoop, AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import ABC, abstractmethod
2023-07-28 13:07:17 +03:00
import browser_cookie3
from ..typing import AsyncGenerator, CreateResult
2023-07-28 13:07:17 +03:00
class BaseProvider(ABC):
url: str
2023-08-27 18:37:44 +03:00
working = False
needs_auth = False
supports_stream = False
2023-07-28 13:07:17 +03:00
supports_gpt_35_turbo = False
2023-08-27 18:37:44 +03:00
supports_gpt_4 = False
2023-07-28 13:07:17 +03:00
@staticmethod
@abstractmethod
def create_completion(
model: str,
messages: list[dict[str, str]],
stream: bool,
**kwargs
) -> CreateResult:
2023-07-28 13:07:17 +03:00
raise NotImplementedError()
@classmethod
async def create_async(
cls,
model: str,
messages: list[dict[str, str]],
*,
loop: AbstractEventLoop = None,
executor: ThreadPoolExecutor = None,
**kwargs
) -> str:
if not loop:
loop = asyncio.get_event_loop()
partial_func = functools.partial(
cls.create_completion,
model,
messages,
False,
**kwargs
)
response = await loop.run_in_executor(
executor,
partial_func
)
return "".join(response)
2023-07-28 13:07:17 +03:00
@classmethod
@property
def params(cls):
params = [
("model", "str"),
("messages", "list[dict[str, str]]"),
("stream", "bool"),
]
param = ", ".join([": ".join(p) for p in params])
2023-09-18 00:23:54 +03:00
return f"g4f.provider.{cls.__name__} supports: ({param})"
class AsyncProvider(BaseProvider):
@classmethod
def create_completion(
cls,
model: str,
messages: list[dict[str, str]],
2023-09-18 08:15:43 +03:00
stream: bool = False,
**kwargs
) -> CreateResult:
2023-09-20 16:01:33 +03:00
loop = create_event_loop()
try:
yield loop.run_until_complete(cls.create_async(model, messages, **kwargs))
finally:
loop.close()
@staticmethod
@abstractmethod
async def create_async(
model: str,
2023-09-20 18:31:25 +03:00
messages: list[dict[str, str]],
**kwargs
) -> str:
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
supports_stream = True
@classmethod
def create_completion(
cls,
model: str,
messages: list[dict[str, str]],
stream: bool = True,
**kwargs
) -> CreateResult:
2023-09-20 18:31:25 +03:00
loop = create_event_loop()
2023-09-18 08:15:43 +03:00
try:
generator = cls.create_async_generator(
model,
messages,
stream=stream,
**kwargs
)
2023-09-18 08:15:43 +03:00
gen = generator.__aiter__()
while True:
try:
yield loop.run_until_complete(gen.__anext__())
except StopAsyncIteration:
break
finally:
loop.close()
@classmethod
async def create_async(
cls,
model: str,
messages: list[dict[str, str]],
**kwargs
) -> str:
2023-09-20 18:31:25 +03:00
return "".join([
chunk async for chunk in cls.create_async_generator(
model,
messages,
stream=False,
**kwargs
)
])
@staticmethod
@abstractmethod
def create_async_generator(
2023-09-18 08:15:43 +03:00
model: str,
messages: list[dict[str, str]],
**kwargs
) -> AsyncGenerator:
raise NotImplementedError()
2023-09-20 16:01:33 +03:00
# Don't create a new event loop in a running async loop.
# Force use selector event loop on windows and linux use it anyway.
2023-09-20 18:31:25 +03:00
def create_event_loop() -> SelectorEventLoop:
try:
asyncio.get_running_loop()
except RuntimeError:
return SelectorEventLoop()
raise RuntimeError(
'Use "create_async" instead of "create" function in a running event loop.')
2023-09-20 16:01:33 +03:00
2023-09-18 08:15:43 +03:00
_cookies = {}
2023-09-18 08:15:43 +03:00
def get_cookies(cookie_domain: str) -> dict:
if cookie_domain not in _cookies:
_cookies[cookie_domain] = {}
try:
2023-09-18 08:15:43 +03:00
for cookie in browser_cookie3.load(cookie_domain):
_cookies[cookie_domain][cookie.name] = cookie.value
except:
pass
return _cookies[cookie_domain]
2023-09-18 08:15:43 +03:00
def format_prompt(messages: list[dict[str, str]], add_special_tokens=False):
if add_special_tokens or len(messages) > 1:
formatted = "\n".join(
["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
)
return f"{formatted}\nAssistant:"
else:
return messages[0]["content"]