Add Groq and Openai interfaces, Add integration tests

This commit is contained in:
Heiner Lohaus 2024-04-06 01:05:00 +02:00
parent 1e2cf48cba
commit d44b39b31c
10 changed files with 167 additions and 25 deletions

View File

@ -5,5 +5,6 @@ from .main import *
from .model import *
from .client import *
from .include import *
from .integration import *
unittest.main()

View File

@ -0,0 +1,25 @@
import unittest
import json
from g4f.client import Client, ChatCompletion
from g4f.Provider import Bing, OpenaiChat
DEFAULT_MESSAGES = [{"role": "system", "content": 'Response in json, Example: {"success: True"}'},
{"role": "user", "content": "Say success true in json"}]
class TestProviderIntegration(unittest.TestCase):
def test_bing(self):
client = Client(provider=Bing)
response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"})
self.assertIsInstance(response, ChatCompletion)
self.assertIn("success", json.loads(response.choices[0].message.content))
def test_openai(self):
client = Client(provider=OpenaiChat)
response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"})
self.assertIsInstance(response, ChatCompletion)
self.assertIn("success", json.loads(response.choices[0].message.content))
if __name__ == '__main__':
unittest.main()

View File

@ -1,2 +1,3 @@
from ..providers.base_provider import *
from ..providers.types import FinishReason
from .helper import get_cookies, format_prompt

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from .Openai import Openai
from ...typing import AsyncResult, Messages
class Groq(Openai):
url = "https://console.groq.com/playground"
working = True
default_model = "mixtral-8x7b-32768"
models = ["mixtral-8x7b-32768", "llama2-70b-4096", "gemma-7b-it"]
model_aliases = {"mixtral-8x7b": "mixtral-8x7b-32768", "llama2-70b": "llama2-70b-4096"}
@classmethod
def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "https://api.groq.com/openai/v1",
**kwargs
) -> AsyncResult:
return super().create_async_generator(
model, messages, api_base=api_base, **kwargs
)

View File

@ -0,0 +1,74 @@
from __future__ import annotations
import json
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
from ...typing import AsyncResult, Messages
from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
from ...errors import MissingAuthError
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://openai.com"
working = True
needs_auth = True
supports_message_history = True
supports_system_message = True
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
proxy: str = None,
timeout: int = 120,
api_key: str = None,
api_base: str = "https://api.openai.com/v1",
temperature: float = None,
max_tokens: int = None,
top_p: float = None,
stop: str = None,
stream: bool = False,
**kwargs
) -> AsyncResult:
if api_key is None:
raise MissingAuthError('Add a "api_key"')
async with StreamSession(
proxies={"all": proxy},
headers=cls.get_headers(api_key),
timeout=timeout
) as session:
data = {
"messages": messages,
"model": cls.get_model(model),
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stop": stop,
"stream": stream,
}
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
await raise_for_status(response)
async for line in response.iter_lines():
if line.startswith(b"data: ") or not stream:
async for chunk in cls.read_line(line[6:] if stream else line, stream):
yield chunk
@staticmethod
async def read_line(line: str, stream: bool):
if line == b"[DONE]":
return
choice = json.loads(line)["choices"][0]
if stream and "content" in choice["delta"] and choice["delta"]["content"]:
yield choice["delta"]["content"]
elif not stream and "content" in choice["message"]:
yield choice["message"]["content"]
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
@staticmethod
def get_headers(api_key: str) -> dict:
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}

View File

@ -4,4 +4,6 @@ from .Theb import Theb
from .ThebApi import ThebApi
from .OpenaiChat import OpenaiChat
from .OpenAssistant import OpenAssistant
from .Poe import Poe
from .Poe import Poe
from .Openai import Openai
from .Groq import Groq

View File

@ -8,7 +8,7 @@ import string
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .typing import Union, Iterator, Messages, ImageType
from .providers.types import BaseProvider, ProviderType
from .providers.types import BaseProvider, ProviderType, FinishReason
from .image import ImageResponse as ImageProviderResponse
from .errors import NoImageResponseError, RateLimitError, MissingAuthError
from . import get_model_and_provider, get_last_provider
@ -47,6 +47,9 @@ def iter_response(
finish_reason = None
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
for idx, chunk in enumerate(response):
if isinstance(chunk, FinishReason):
finish_reason = chunk.reason
break
content += str(chunk)
if max_tokens is not None and idx + 1 >= max_tokens:
finish_reason = "length"

View File

@ -6,9 +6,10 @@ from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
from inspect import signature, Parameter
from ..typing import CreateResult, AsyncResult, Messages, Union
from .types import BaseProvider
from ..errors import NestAsyncioError, ModelNotSupportedError
from typing import Callable, Union
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider, FinishReason
from ..errors import NestAsyncioError, ModelNotSupportedError, MissingRequirementsError
from .. import debug
if sys.version_info < (3, 10):
@ -21,17 +22,23 @@ if sys.platform == 'win32':
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
def get_running_loop() -> Union[AbstractEventLoop, None]:
def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
try:
loop = asyncio.get_running_loop()
if not hasattr(loop.__class__, "_nest_patched"):
raise NestAsyncioError(
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
)
if check_nested and not hasattr(loop.__class__, "_nest_patched"):
try:
import nest_asyncio
nest_asyncio.apply(loop)
except ImportError:
raise MissingRequirementsError('Install "nest_asyncio" package')
return loop
except RuntimeError:
pass
# Fix for RuntimeError: async generator ignored GeneratorExit
async def await_callback(callback: Callable):
return await callback()
class AbstractProvider(BaseProvider):
"""
Abstract class for providing asynchronous functionality to derived classes.
@ -132,7 +139,7 @@ class AsyncProvider(AbstractProvider):
Returns:
CreateResult: The result of the completion creation.
"""
get_running_loop()
get_running_loop(check_nested=True)
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@ -158,7 +165,6 @@ class AsyncProvider(AbstractProvider):
"""
raise NotImplementedError()
class AsyncGeneratorProvider(AsyncProvider):
"""
Provides asynchronous generator functionality for streaming results.
@ -187,9 +193,9 @@ class AsyncGeneratorProvider(AsyncProvider):
Returns:
CreateResult: The result of the streaming completion creation.
"""
loop = get_running_loop()
loop = get_running_loop(check_nested=True)
new_loop = False
if not loop:
if loop is None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
new_loop = True
@ -197,16 +203,11 @@ class AsyncGeneratorProvider(AsyncProvider):
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__()
# Fix for RuntimeError: async generator ignored GeneratorExit
async def await_callback(callback):
return await callback()
try:
while True:
yield loop.run_until_complete(await_callback(gen.__anext__))
except StopAsyncIteration:
...
# Fix for: ResourceWarning: unclosed event loop
finally:
if new_loop:
loop.close()
@ -233,7 +234,7 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
return "".join([
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, Exception)
if not isinstance(chunk, (Exception, FinishReason))
])
@staticmethod

View File

@ -97,4 +97,8 @@ class BaseRetryProvider(BaseProvider):
__name__: str = "RetryProvider"
supports_stream: bool = True
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
class FinishReason():
def __init__(self, reason: str):
self.reason = reason

View File

@ -15,11 +15,19 @@ class StreamResponse(ClientResponse):
async for chunk in self.content.iter_any():
yield chunk
async def json(self) -> Any:
return await super().json(content_type=None)
async def json(self, content_type: str = None) -> Any:
return await super().json(content_type=content_type)
class StreamSession(ClientSession):
def __init__(self, headers: dict = {}, timeout: int = None, proxies: dict = {}, impersonate = None, **kwargs):
def __init__(
self,
headers: dict = {},
timeout: int = None,
connector: BaseConnector = None,
proxies: dict = {},
impersonate = None,
**kwargs
):
if impersonate:
headers = {
**DEFAULT_HEADERS,
@ -29,7 +37,7 @@ class StreamSession(ClientSession):
**kwargs,
timeout=ClientTimeout(timeout) if timeout else None,
response_class=StreamResponse,
connector=get_connector(kwargs.get("connector"), proxies.get("https")),
connector=get_connector(connector, proxies.get("all", proxies.get("https"))),
headers=headers
)