mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-26 09:57:24 +03:00
Add Groq and Openai interfaces, Add integration tests
This commit is contained in:
parent
1e2cf48cba
commit
d44b39b31c
@ -5,5 +5,6 @@ from .main import *
|
||||
from .model import *
|
||||
from .client import *
|
||||
from .include import *
|
||||
from .integration import *
|
||||
|
||||
unittest.main()
|
25
etc/unittest/integration.py
Normal file
25
etc/unittest/integration.py
Normal file
@ -0,0 +1,25 @@
|
||||
import unittest
|
||||
import json
|
||||
|
||||
from g4f.client import Client, ChatCompletion
|
||||
from g4f.Provider import Bing, OpenaiChat
|
||||
|
||||
DEFAULT_MESSAGES = [{"role": "system", "content": 'Response in json, Example: {"success: True"}'},
|
||||
{"role": "user", "content": "Say success true in json"}]
|
||||
|
||||
class TestProviderIntegration(unittest.TestCase):
|
||||
|
||||
def test_bing(self):
|
||||
client = Client(provider=Bing)
|
||||
response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"})
|
||||
self.assertIsInstance(response, ChatCompletion)
|
||||
self.assertIn("success", json.loads(response.choices[0].message.content))
|
||||
|
||||
def test_openai(self):
|
||||
client = Client(provider=OpenaiChat)
|
||||
response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"})
|
||||
self.assertIsInstance(response, ChatCompletion)
|
||||
self.assertIn("success", json.loads(response.choices[0].message.content))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1,2 +1,3 @@
|
||||
from ..providers.base_provider import *
|
||||
from ..providers.types import FinishReason
|
||||
from .helper import get_cookies, format_prompt
|
23
g4f/Provider/needs_auth/Groq.py
Normal file
23
g4f/Provider/needs_auth/Groq.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .Openai import Openai
|
||||
from ...typing import AsyncResult, Messages
|
||||
|
||||
class Groq(Openai):
|
||||
url = "https://console.groq.com/playground"
|
||||
working = True
|
||||
default_model = "mixtral-8x7b-32768"
|
||||
models = ["mixtral-8x7b-32768", "llama2-70b-4096", "gemma-7b-it"]
|
||||
model_aliases = {"mixtral-8x7b": "mixtral-8x7b-32768", "llama2-70b": "llama2-70b-4096"}
|
||||
|
||||
@classmethod
|
||||
def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_base: str = "https://api.groq.com/openai/v1",
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
return super().create_async_generator(
|
||||
model, messages, api_base=api_base, **kwargs
|
||||
)
|
74
g4f/Provider/needs_auth/Openai.py
Normal file
74
g4f/Provider/needs_auth/Openai.py
Normal file
@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests import StreamSession
|
||||
from ...errors import MissingAuthError
|
||||
|
||||
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
url = "https://openai.com"
|
||||
working = True
|
||||
needs_auth = True
|
||||
supports_message_history = True
|
||||
supports_system_message = True
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
timeout: int = 120,
|
||||
api_key: str = None,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
top_p: float = None,
|
||||
stop: str = None,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if api_key is None:
|
||||
raise MissingAuthError('Add a "api_key"')
|
||||
async with StreamSession(
|
||||
proxies={"all": proxy},
|
||||
headers=cls.get_headers(api_key),
|
||||
timeout=timeout
|
||||
) as session:
|
||||
data = {
|
||||
"messages": messages,
|
||||
"model": cls.get_model(model),
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
}
|
||||
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
|
||||
await raise_for_status(response)
|
||||
async for line in response.iter_lines():
|
||||
if line.startswith(b"data: ") or not stream:
|
||||
async for chunk in cls.read_line(line[6:] if stream else line, stream):
|
||||
yield chunk
|
||||
|
||||
@staticmethod
|
||||
async def read_line(line: str, stream: bool):
|
||||
if line == b"[DONE]":
|
||||
return
|
||||
choice = json.loads(line)["choices"][0]
|
||||
if stream and "content" in choice["delta"] and choice["delta"]["content"]:
|
||||
yield choice["delta"]["content"]
|
||||
elif not stream and "content" in choice["message"]:
|
||||
yield choice["message"]["content"]
|
||||
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||
yield FinishReason(choice["finish_reason"])
|
||||
|
||||
@staticmethod
|
||||
def get_headers(api_key: str) -> dict:
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
@ -4,4 +4,6 @@ from .Theb import Theb
|
||||
from .ThebApi import ThebApi
|
||||
from .OpenaiChat import OpenaiChat
|
||||
from .OpenAssistant import OpenAssistant
|
||||
from .Poe import Poe
|
||||
from .Poe import Poe
|
||||
from .Openai import Openai
|
||||
from .Groq import Groq
|
@ -8,7 +8,7 @@ import string
|
||||
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||
from .typing import Union, Iterator, Messages, ImageType
|
||||
from .providers.types import BaseProvider, ProviderType
|
||||
from .providers.types import BaseProvider, ProviderType, FinishReason
|
||||
from .image import ImageResponse as ImageProviderResponse
|
||||
from .errors import NoImageResponseError, RateLimitError, MissingAuthError
|
||||
from . import get_model_and_provider, get_last_provider
|
||||
@ -47,6 +47,9 @@ def iter_response(
|
||||
finish_reason = None
|
||||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
||||
for idx, chunk in enumerate(response):
|
||||
if isinstance(chunk, FinishReason):
|
||||
finish_reason = chunk.reason
|
||||
break
|
||||
content += str(chunk)
|
||||
if max_tokens is not None and idx + 1 >= max_tokens:
|
||||
finish_reason = "length"
|
||||
|
@ -6,9 +6,10 @@ from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from abc import abstractmethod
|
||||
from inspect import signature, Parameter
|
||||
from ..typing import CreateResult, AsyncResult, Messages, Union
|
||||
from .types import BaseProvider
|
||||
from ..errors import NestAsyncioError, ModelNotSupportedError
|
||||
from typing import Callable, Union
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from .types import BaseProvider, FinishReason
|
||||
from ..errors import NestAsyncioError, ModelNotSupportedError, MissingRequirementsError
|
||||
from .. import debug
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
@ -21,17 +22,23 @@ if sys.platform == 'win32':
|
||||
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
def get_running_loop() -> Union[AbstractEventLoop, None]:
|
||||
def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if not hasattr(loop.__class__, "_nest_patched"):
|
||||
raise NestAsyncioError(
|
||||
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
|
||||
)
|
||||
if check_nested and not hasattr(loop.__class__, "_nest_patched"):
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply(loop)
|
||||
except ImportError:
|
||||
raise MissingRequirementsError('Install "nest_asyncio" package')
|
||||
return loop
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Fix for RuntimeError: async generator ignored GeneratorExit
|
||||
async def await_callback(callback: Callable):
|
||||
return await callback()
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
"""
|
||||
Abstract class for providing asynchronous functionality to derived classes.
|
||||
@ -132,7 +139,7 @@ class AsyncProvider(AbstractProvider):
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
get_running_loop()
|
||||
get_running_loop(check_nested=True)
|
||||
yield asyncio.run(cls.create_async(model, messages, **kwargs))
|
||||
|
||||
@staticmethod
|
||||
@ -158,7 +165,6 @@ class AsyncProvider(AbstractProvider):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AsyncGeneratorProvider(AsyncProvider):
|
||||
"""
|
||||
Provides asynchronous generator functionality for streaming results.
|
||||
@ -187,9 +193,9 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
Returns:
|
||||
CreateResult: The result of the streaming completion creation.
|
||||
"""
|
||||
loop = get_running_loop()
|
||||
loop = get_running_loop(check_nested=True)
|
||||
new_loop = False
|
||||
if not loop:
|
||||
if loop is None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
new_loop = True
|
||||
@ -197,16 +203,11 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
gen = generator.__aiter__()
|
||||
|
||||
# Fix for RuntimeError: async generator ignored GeneratorExit
|
||||
async def await_callback(callback):
|
||||
return await callback()
|
||||
|
||||
try:
|
||||
while True:
|
||||
yield loop.run_until_complete(await_callback(gen.__anext__))
|
||||
except StopAsyncIteration:
|
||||
...
|
||||
# Fix for: ResourceWarning: unclosed event loop
|
||||
finally:
|
||||
if new_loop:
|
||||
loop.close()
|
||||
@ -233,7 +234,7 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
"""
|
||||
return "".join([
|
||||
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
|
||||
if not isinstance(chunk, Exception)
|
||||
if not isinstance(chunk, (Exception, FinishReason))
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
|
@ -97,4 +97,8 @@ class BaseRetryProvider(BaseProvider):
|
||||
__name__: str = "RetryProvider"
|
||||
supports_stream: bool = True
|
||||
|
||||
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|
||||
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|
||||
|
||||
class FinishReason():
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
@ -15,11 +15,19 @@ class StreamResponse(ClientResponse):
|
||||
async for chunk in self.content.iter_any():
|
||||
yield chunk
|
||||
|
||||
async def json(self) -> Any:
|
||||
return await super().json(content_type=None)
|
||||
async def json(self, content_type: str = None) -> Any:
|
||||
return await super().json(content_type=content_type)
|
||||
|
||||
class StreamSession(ClientSession):
|
||||
def __init__(self, headers: dict = {}, timeout: int = None, proxies: dict = {}, impersonate = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
headers: dict = {},
|
||||
timeout: int = None,
|
||||
connector: BaseConnector = None,
|
||||
proxies: dict = {},
|
||||
impersonate = None,
|
||||
**kwargs
|
||||
):
|
||||
if impersonate:
|
||||
headers = {
|
||||
**DEFAULT_HEADERS,
|
||||
@ -29,7 +37,7 @@ class StreamSession(ClientSession):
|
||||
**kwargs,
|
||||
timeout=ClientTimeout(timeout) if timeout else None,
|
||||
response_class=StreamResponse,
|
||||
connector=get_connector(kwargs.get("connector"), proxies.get("https")),
|
||||
connector=get_connector(connector, proxies.get("all", proxies.get("https"))),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user