mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-26 09:57:24 +03:00
Add Groq and Openai interfaces, Add integration tests
This commit is contained in:
parent
1e2cf48cba
commit
d44b39b31c
@ -5,5 +5,6 @@ from .main import *
|
|||||||
from .model import *
|
from .model import *
|
||||||
from .client import *
|
from .client import *
|
||||||
from .include import *
|
from .include import *
|
||||||
|
from .integration import *
|
||||||
|
|
||||||
unittest.main()
|
unittest.main()
|
25
etc/unittest/integration.py
Normal file
25
etc/unittest/integration.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import unittest
|
||||||
|
import json
|
||||||
|
|
||||||
|
from g4f.client import Client, ChatCompletion
|
||||||
|
from g4f.Provider import Bing, OpenaiChat
|
||||||
|
|
||||||
|
DEFAULT_MESSAGES = [{"role": "system", "content": 'Response in json, Example: {"success: True"}'},
|
||||||
|
{"role": "user", "content": "Say success true in json"}]
|
||||||
|
|
||||||
|
class TestProviderIntegration(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_bing(self):
|
||||||
|
client = Client(provider=Bing)
|
||||||
|
response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"})
|
||||||
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
|
self.assertIn("success", json.loads(response.choices[0].message.content))
|
||||||
|
|
||||||
|
def test_openai(self):
|
||||||
|
client = Client(provider=OpenaiChat)
|
||||||
|
response = client.chat.completions.create(DEFAULT_MESSAGES, "", response_format={"type": "json_object"})
|
||||||
|
self.assertIsInstance(response, ChatCompletion)
|
||||||
|
self.assertIn("success", json.loads(response.choices[0].message.content))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -1,2 +1,3 @@
|
|||||||
from ..providers.base_provider import *
|
from ..providers.base_provider import *
|
||||||
|
from ..providers.types import FinishReason
|
||||||
from .helper import get_cookies, format_prompt
|
from .helper import get_cookies, format_prompt
|
23
g4f/Provider/needs_auth/Groq.py
Normal file
23
g4f/Provider/needs_auth/Groq.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .Openai import Openai
|
||||||
|
from ...typing import AsyncResult, Messages
|
||||||
|
|
||||||
|
class Groq(Openai):
|
||||||
|
url = "https://console.groq.com/playground"
|
||||||
|
working = True
|
||||||
|
default_model = "mixtral-8x7b-32768"
|
||||||
|
models = ["mixtral-8x7b-32768", "llama2-70b-4096", "gemma-7b-it"]
|
||||||
|
model_aliases = {"mixtral-8x7b": "mixtral-8x7b-32768", "llama2-70b": "llama2-70b-4096"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_async_generator(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
api_base: str = "https://api.groq.com/openai/v1",
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
return super().create_async_generator(
|
||||||
|
model, messages, api_base=api_base, **kwargs
|
||||||
|
)
|
74
g4f/Provider/needs_auth/Openai.py
Normal file
74
g4f/Provider/needs_auth/Openai.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
|
||||||
|
from ...typing import AsyncResult, Messages
|
||||||
|
from ...requests.raise_for_status import raise_for_status
|
||||||
|
from ...requests import StreamSession
|
||||||
|
from ...errors import MissingAuthError
|
||||||
|
|
||||||
|
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
|
url = "https://openai.com"
|
||||||
|
working = True
|
||||||
|
needs_auth = True
|
||||||
|
supports_message_history = True
|
||||||
|
supports_system_message = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_async_generator(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
proxy: str = None,
|
||||||
|
timeout: int = 120,
|
||||||
|
api_key: str = None,
|
||||||
|
api_base: str = "https://api.openai.com/v1",
|
||||||
|
temperature: float = None,
|
||||||
|
max_tokens: int = None,
|
||||||
|
top_p: float = None,
|
||||||
|
stop: str = None,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
if api_key is None:
|
||||||
|
raise MissingAuthError('Add a "api_key"')
|
||||||
|
async with StreamSession(
|
||||||
|
proxies={"all": proxy},
|
||||||
|
headers=cls.get_headers(api_key),
|
||||||
|
timeout=timeout
|
||||||
|
) as session:
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
"model": cls.get_model(model),
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"top_p": top_p,
|
||||||
|
"stop": stop,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
|
||||||
|
await raise_for_status(response)
|
||||||
|
async for line in response.iter_lines():
|
||||||
|
if line.startswith(b"data: ") or not stream:
|
||||||
|
async for chunk in cls.read_line(line[6:] if stream else line, stream):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def read_line(line: str, stream: bool):
|
||||||
|
if line == b"[DONE]":
|
||||||
|
return
|
||||||
|
choice = json.loads(line)["choices"][0]
|
||||||
|
if stream and "content" in choice["delta"] and choice["delta"]["content"]:
|
||||||
|
yield choice["delta"]["content"]
|
||||||
|
elif not stream and "content" in choice["message"]:
|
||||||
|
yield choice["message"]["content"]
|
||||||
|
if "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||||
|
yield FinishReason(choice["finish_reason"])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_headers(api_key: str) -> dict:
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
@ -5,3 +5,5 @@ from .ThebApi import ThebApi
|
|||||||
from .OpenaiChat import OpenaiChat
|
from .OpenaiChat import OpenaiChat
|
||||||
from .OpenAssistant import OpenAssistant
|
from .OpenAssistant import OpenAssistant
|
||||||
from .Poe import Poe
|
from .Poe import Poe
|
||||||
|
from .Openai import Openai
|
||||||
|
from .Groq import Groq
|
@ -8,7 +8,7 @@ import string
|
|||||||
|
|
||||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||||
from .typing import Union, Iterator, Messages, ImageType
|
from .typing import Union, Iterator, Messages, ImageType
|
||||||
from .providers.types import BaseProvider, ProviderType
|
from .providers.types import BaseProvider, ProviderType, FinishReason
|
||||||
from .image import ImageResponse as ImageProviderResponse
|
from .image import ImageResponse as ImageProviderResponse
|
||||||
from .errors import NoImageResponseError, RateLimitError, MissingAuthError
|
from .errors import NoImageResponseError, RateLimitError, MissingAuthError
|
||||||
from . import get_model_and_provider, get_last_provider
|
from . import get_model_and_provider, get_last_provider
|
||||||
@ -47,6 +47,9 @@ def iter_response(
|
|||||||
finish_reason = None
|
finish_reason = None
|
||||||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
||||||
for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
|
if isinstance(chunk, FinishReason):
|
||||||
|
finish_reason = chunk.reason
|
||||||
|
break
|
||||||
content += str(chunk)
|
content += str(chunk)
|
||||||
if max_tokens is not None and idx + 1 >= max_tokens:
|
if max_tokens is not None and idx + 1 >= max_tokens:
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
|
@ -6,9 +6,10 @@ from asyncio import AbstractEventLoop
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from inspect import signature, Parameter
|
from inspect import signature, Parameter
|
||||||
from ..typing import CreateResult, AsyncResult, Messages, Union
|
from typing import Callable, Union
|
||||||
from .types import BaseProvider
|
from ..typing import CreateResult, AsyncResult, Messages
|
||||||
from ..errors import NestAsyncioError, ModelNotSupportedError
|
from .types import BaseProvider, FinishReason
|
||||||
|
from ..errors import NestAsyncioError, ModelNotSupportedError, MissingRequirementsError
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|
||||||
if sys.version_info < (3, 10):
|
if sys.version_info < (3, 10):
|
||||||
@ -21,17 +22,23 @@ if sys.platform == 'win32':
|
|||||||
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
|
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
|
||||||
def get_running_loop() -> Union[AbstractEventLoop, None]:
|
def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
if not hasattr(loop.__class__, "_nest_patched"):
|
if check_nested and not hasattr(loop.__class__, "_nest_patched"):
|
||||||
raise NestAsyncioError(
|
try:
|
||||||
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
|
import nest_asyncio
|
||||||
)
|
nest_asyncio.apply(loop)
|
||||||
|
except ImportError:
|
||||||
|
raise MissingRequirementsError('Install "nest_asyncio" package')
|
||||||
return loop
|
return loop
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Fix for RuntimeError: async generator ignored GeneratorExit
|
||||||
|
async def await_callback(callback: Callable):
|
||||||
|
return await callback()
|
||||||
|
|
||||||
class AbstractProvider(BaseProvider):
|
class AbstractProvider(BaseProvider):
|
||||||
"""
|
"""
|
||||||
Abstract class for providing asynchronous functionality to derived classes.
|
Abstract class for providing asynchronous functionality to derived classes.
|
||||||
@ -132,7 +139,7 @@ class AsyncProvider(AbstractProvider):
|
|||||||
Returns:
|
Returns:
|
||||||
CreateResult: The result of the completion creation.
|
CreateResult: The result of the completion creation.
|
||||||
"""
|
"""
|
||||||
get_running_loop()
|
get_running_loop(check_nested=True)
|
||||||
yield asyncio.run(cls.create_async(model, messages, **kwargs))
|
yield asyncio.run(cls.create_async(model, messages, **kwargs))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -158,7 +165,6 @@ class AsyncProvider(AbstractProvider):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class AsyncGeneratorProvider(AsyncProvider):
|
class AsyncGeneratorProvider(AsyncProvider):
|
||||||
"""
|
"""
|
||||||
Provides asynchronous generator functionality for streaming results.
|
Provides asynchronous generator functionality for streaming results.
|
||||||
@ -187,9 +193,9 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||||||
Returns:
|
Returns:
|
||||||
CreateResult: The result of the streaming completion creation.
|
CreateResult: The result of the streaming completion creation.
|
||||||
"""
|
"""
|
||||||
loop = get_running_loop()
|
loop = get_running_loop(check_nested=True)
|
||||||
new_loop = False
|
new_loop = False
|
||||||
if not loop:
|
if loop is None:
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
new_loop = True
|
new_loop = True
|
||||||
@ -197,16 +203,11 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||||||
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||||
gen = generator.__aiter__()
|
gen = generator.__aiter__()
|
||||||
|
|
||||||
# Fix for RuntimeError: async generator ignored GeneratorExit
|
|
||||||
async def await_callback(callback):
|
|
||||||
return await callback()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield loop.run_until_complete(await_callback(gen.__anext__))
|
yield loop.run_until_complete(await_callback(gen.__anext__))
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
...
|
...
|
||||||
# Fix for: ResourceWarning: unclosed event loop
|
|
||||||
finally:
|
finally:
|
||||||
if new_loop:
|
if new_loop:
|
||||||
loop.close()
|
loop.close()
|
||||||
@ -233,7 +234,7 @@ class AsyncGeneratorProvider(AsyncProvider):
|
|||||||
"""
|
"""
|
||||||
return "".join([
|
return "".join([
|
||||||
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
|
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
|
||||||
if not isinstance(chunk, Exception)
|
if not isinstance(chunk, (Exception, FinishReason))
|
||||||
])
|
])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -98,3 +98,7 @@ class BaseRetryProvider(BaseProvider):
|
|||||||
supports_stream: bool = True
|
supports_stream: bool = True
|
||||||
|
|
||||||
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|
||||||
|
|
||||||
|
class FinishReason():
|
||||||
|
def __init__(self, reason: str):
|
||||||
|
self.reason = reason
|
@ -15,11 +15,19 @@ class StreamResponse(ClientResponse):
|
|||||||
async for chunk in self.content.iter_any():
|
async for chunk in self.content.iter_any():
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def json(self) -> Any:
|
async def json(self, content_type: str = None) -> Any:
|
||||||
return await super().json(content_type=None)
|
return await super().json(content_type=content_type)
|
||||||
|
|
||||||
class StreamSession(ClientSession):
|
class StreamSession(ClientSession):
|
||||||
def __init__(self, headers: dict = {}, timeout: int = None, proxies: dict = {}, impersonate = None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
headers: dict = {},
|
||||||
|
timeout: int = None,
|
||||||
|
connector: BaseConnector = None,
|
||||||
|
proxies: dict = {},
|
||||||
|
impersonate = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
if impersonate:
|
if impersonate:
|
||||||
headers = {
|
headers = {
|
||||||
**DEFAULT_HEADERS,
|
**DEFAULT_HEADERS,
|
||||||
@ -29,7 +37,7 @@ class StreamSession(ClientSession):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
timeout=ClientTimeout(timeout) if timeout else None,
|
timeout=ClientTimeout(timeout) if timeout else None,
|
||||||
response_class=StreamResponse,
|
response_class=StreamResponse,
|
||||||
connector=get_connector(kwargs.get("connector"), proxies.get("https")),
|
connector=get_connector(connector, proxies.get("all", proxies.get("https"))),
|
||||||
headers=headers
|
headers=headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user