Add Groq and Openai interfaces, Add integration tests

This commit is contained in:
Heiner Lohaus 2024-04-06 01:05:00 +02:00
parent 1e2cf48cba
commit d44b39b31c
10 changed files with 167 additions and 25 deletions

View File

@ -5,5 +5,6 @@ from .main import *
from .model import * from .model import *
from .client import * from .client import *
from .include import * from .include import *
from .integration import *
unittest.main() unittest.main()

View File

@ -0,0 +1,25 @@
import unittest
import json
from g4f.client import Client, ChatCompletion
from g4f.Provider import Bing, OpenaiChat
DEFAULT_MESSAGES = [{"role": "system", "content": 'Response in json, Example: {"success: True"}'},
{"role": "user", "content": "Say success true in json"}]
class TestProviderIntegration(unittest.TestCase):
def test_bing(self):
client = Client(provider=Bing)
response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"})
self.assertIsInstance(response, ChatCompletion)
self.assertIn("success", json.loads(response.choices[0].message.content))
def test_openai(self):
client = Client(provider=OpenaiChat)
response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"})
self.assertIsInstance(response, ChatCompletion)
self.assertIn("success", json.loads(response.choices[0].message.content))
if __name__ == '__main__':
unittest.main()

View File

@ -1,2 +1,3 @@
from ..providers.base_provider import * from ..providers.base_provider import *
from ..providers.types import FinishReason
from .helper import get_cookies, format_prompt from .helper import get_cookies, format_prompt

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from .Openai import Openai
from ...typing import AsyncResult, Messages
class Groq(Openai):
url = "https://console.groq.com/playground"
working = True
default_model = "mixtral-8x7b-32768"
models = ["mixtral-8x7b-32768", "llama2-70b-4096", "gemma-7b-it"]
model_aliases = {"mixtral-8x7b": "mixtral-8x7b-32768", "llama2-70b": "llama2-70b-4096"}
@classmethod
def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "https://api.groq.com/openai/v1",
**kwargs
) -> AsyncResult:
return super().create_async_generator(
model, messages, api_base=api_base, **kwargs
)

View File

@ -0,0 +1,74 @@
from __future__ import annotations
import json
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
from ...typing import AsyncResult, Messages
from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
from ...errors import MissingAuthError
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://openai.com"
working = True
needs_auth = True
supports_message_history = True
supports_system_message = True
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
timeout: int = 120,
api_key: str = None,
api_base: str = "https://api.openai.com/v1",
temperature: float = None,
max_tokens: int = None,
top_p: float = None,
stop: str = None,
stream: bool = False,
**kwargs
) -> AsyncResult:
if api_key is None:
raise MissingAuthError('Add a "api_key"')
async with StreamSession(
proxies={"all": proxy},
headers=cls.get_headers(api_key),
timeout=timeout
) as session:
data = {
"messages": messages,
"model": cls.get_model(model),
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stop": stop,
"stream": stream,
}
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
await raise_for_status(response)
async for line in response.iter_lines():
if line.startswith(b"data: ") or not stream:
async for chunk in cls.read_line(line[6:] if stream else line, stream):
yield chunk
@staticmethod
async def read_line(line: str, stream: bool):
if line == b"[DONE]":
return
choice = json.loads(line)["choices"][0]
if stream and "content" in choice["delta"] and choice["delta"]["content"]:
yield choice["delta"]["content"]
elif not stream and "content" in choice["message"]:
yield choice["message"]["content"]
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
@staticmethod
def get_headers(api_key: str) -> dict:
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}

View File

@ -4,4 +4,6 @@ from .Theb import Theb
from .ThebApi import ThebApi from .ThebApi import ThebApi
from .OpenaiChat import OpenaiChat from .OpenaiChat import OpenaiChat
from .OpenAssistant import OpenAssistant from .OpenAssistant import OpenAssistant
from .Poe import Poe from .Poe import Poe
from .Openai import Openai
from .Groq import Groq

View File

@ -8,7 +8,7 @@ import string
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .typing import Union, Iterator, Messages, ImageType from .typing import Union, Iterator, Messages, ImageType
from .providers.types import BaseProvider, ProviderType from .providers.types import BaseProvider, ProviderType, FinishReason
from .image import ImageResponse as ImageProviderResponse from .image import ImageResponse as ImageProviderResponse
from .errors import NoImageResponseError, RateLimitError, MissingAuthError from .errors import NoImageResponseError, RateLimitError, MissingAuthError
from . import get_model_and_provider, get_last_provider from . import get_model_and_provider, get_last_provider
@ -47,6 +47,9 @@ def iter_response(
finish_reason = None finish_reason = None
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
if isinstance(chunk, FinishReason):
finish_reason = chunk.reason
break
content += str(chunk) content += str(chunk)
if max_tokens is not None and idx + 1 >= max_tokens: if max_tokens is not None and idx + 1 >= max_tokens:
finish_reason = "length" finish_reason = "length"

View File

@ -6,9 +6,10 @@ from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod from abc import abstractmethod
from inspect import signature, Parameter from inspect import signature, Parameter
from ..typing import CreateResult, AsyncResult, Messages, Union from typing import Callable, Union
from .types import BaseProvider from ..typing import CreateResult, AsyncResult, Messages
from ..errors import NestAsyncioError, ModelNotSupportedError from .types import BaseProvider, FinishReason
from ..errors import NestAsyncioError, ModelNotSupportedError, MissingRequirementsError
from .. import debug from .. import debug
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
@ -21,17 +22,23 @@ if sys.platform == 'win32':
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def get_running_loop() -> Union[AbstractEventLoop, None]: def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"): if check_nested and not hasattr(loop.__class__, "_nest_patched"):
raise NestAsyncioError( try:
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' import nest_asyncio
) nest_asyncio.apply(loop)
except ImportError:
raise MissingRequirementsError('Install "nest_asyncio" package')
return loop return loop
except RuntimeError: except RuntimeError:
pass pass
# Fix for RuntimeError: async generator ignored GeneratorExit
async def await_callback(callback: Callable):
return await callback()
class AbstractProvider(BaseProvider): class AbstractProvider(BaseProvider):
""" """
Abstract class for providing asynchronous functionality to derived classes. Abstract class for providing asynchronous functionality to derived classes.
@ -132,7 +139,7 @@ class AsyncProvider(AbstractProvider):
Returns: Returns:
CreateResult: The result of the completion creation. CreateResult: The result of the completion creation.
""" """
get_running_loop() get_running_loop(check_nested=True)
yield asyncio.run(cls.create_async(model, messages, **kwargs)) yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod @staticmethod
@ -158,7 +165,6 @@ class AsyncProvider(AbstractProvider):
""" """
raise NotImplementedError() raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider): class AsyncGeneratorProvider(AsyncProvider):
""" """
Provides asynchronous generator functionality for streaming results. Provides asynchronous generator functionality for streaming results.
@ -187,9 +193,9 @@ class AsyncGeneratorProvider(AsyncProvider):
Returns: Returns:
CreateResult: The result of the streaming completion creation. CreateResult: The result of the streaming completion creation.
""" """
loop = get_running_loop() loop = get_running_loop(check_nested=True)
new_loop = False new_loop = False
if not loop: if loop is None:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
new_loop = True new_loop = True
@ -197,16 +203,11 @@ class AsyncGeneratorProvider(AsyncProvider):
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__() gen = generator.__aiter__()
# Fix for RuntimeError: async generator ignored GeneratorExit
async def await_callback(callback):
return await callback()
try: try:
while True: while True:
yield loop.run_until_complete(await_callback(gen.__anext__)) yield loop.run_until_complete(await_callback(gen.__anext__))
except StopAsyncIteration: except StopAsyncIteration:
... ...
# Fix for: ResourceWarning: unclosed event loop
finally: finally:
if new_loop: if new_loop:
loop.close() loop.close()
@ -233,7 +234,7 @@ class AsyncGeneratorProvider(AsyncProvider):
""" """
return "".join([ return "".join([
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, Exception) if not isinstance(chunk, (Exception, FinishReason))
]) ])
@staticmethod @staticmethod

View File

@ -97,4 +97,8 @@ class BaseRetryProvider(BaseProvider):
__name__: str = "RetryProvider" __name__: str = "RetryProvider"
supports_stream: bool = True supports_stream: bool = True
ProviderType = Union[Type[BaseProvider], BaseRetryProvider] ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
class FinishReason():
def __init__(self, reason: str):
self.reason = reason

View File

@ -15,11 +15,19 @@ class StreamResponse(ClientResponse):
async for chunk in self.content.iter_any(): async for chunk in self.content.iter_any():
yield chunk yield chunk
async def json(self) -> Any: async def json(self, content_type: str = None) -> Any:
return await super().json(content_type=None) return await super().json(content_type=content_type)
class StreamSession(ClientSession): class StreamSession(ClientSession):
def __init__(self, headers: dict = {}, timeout: int = None, proxies: dict = {}, impersonate = None, **kwargs): def __init__(
self,
headers: dict = {},
timeout: int = None,
connector: BaseConnector = None,
proxies: dict = {},
impersonate = None,
**kwargs
):
if impersonate: if impersonate:
headers = { headers = {
**DEFAULT_HEADERS, **DEFAULT_HEADERS,
@ -29,7 +37,7 @@ class StreamSession(ClientSession):
**kwargs, **kwargs,
timeout=ClientTimeout(timeout) if timeout else None, timeout=ClientTimeout(timeout) if timeout else None,
response_class=StreamResponse, response_class=StreamResponse,
connector=get_connector(kwargs.get("connector"), proxies.get("https")), connector=get_connector(connector, proxies.get("all", proxies.get("https"))),
headers=headers headers=headers
) )