Add unitests for the client

Fix: async generator ignored GeneratorExit
Fix: ResourceWarning: unclosed event loop
This commit is contained in:
Heiner Lohaus 2024-02-14 09:21:57 +01:00
parent 151f8b8b0e
commit e1a0b3ffa2
8 changed files with 172 additions and 132 deletions

View File

@ -8,6 +8,7 @@ import unittest
import g4f import g4f
from g4f import ChatCompletion from g4f import ChatCompletion
from g4f.client import Client
from .mocks import ProviderMock, AsyncProviderMock, AsyncGeneratorProviderMock from .mocks import ProviderMock, AsyncProviderMock, AsyncGeneratorProviderMock
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
@ -24,11 +25,16 @@ class TestChatCompletion(unittest.TestCase):
def test_create(self): def test_create(self):
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)
self.assertEqual("Mock",result) self.assertEqual("Mock", result)
def test_create_generator(self): def test_create_generator(self):
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock) result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock)
self.assertEqual("Mock",result) self.assertEqual("Mock", result)
def test_await_callback(self):
client = Client(provider=AsyncGeneratorProviderMock)
response = client.chat.completions.create(DEFAULT_MESSAGES, "", max_tokens=0)
self.assertEqual("Mock", response.choices[0].message.content)
class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase): class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):

54
etc/unittest/client.py Normal file
View File

@ -0,0 +1,54 @@
import unittest
from g4f.client import Client, ChatCompletion, ChatCompletionChunk
from .mocks import AsyncGeneratorProviderMock, ModelProviderMock, YieldProviderMock
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
class TestPassModel(unittest.TestCase):
def test_response(self):
client = Client(provider=AsyncGeneratorProviderMock)
response = client.chat.completions.create(DEFAULT_MESSAGES, "")
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("Mock", response.choices[0].message.content)
def test_pass_model(self):
client = Client(provider=ModelProviderMock)
response = client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("Hello", response.choices[0].message.content)
def test_max_tokens(self):
client = Client(provider=YieldProviderMock)
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
response = client.chat.completions.create(messages, "Hello", max_tokens=1)
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("How ", response.choices[0].message.content)
response = client.chat.completions.create(messages, "Hello", max_tokens=2)
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("How are ", response.choices[0].message.content)
def test_max_stream(self):
client = Client(provider=YieldProviderMock)
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
response = client.chat.completions.create(messages, "Hello", stream=True)
for chunk in response:
self.assertIsInstance(chunk, ChatCompletionChunk)
self.assertIsInstance(chunk.choices[0].delta.content, str)
messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
response = list(response)
self.assertEqual(len(response), 2)
for chunk in response:
self.assertEqual(chunk.choices[0].delta.content, "You ")
def no_test_stop(self):
client = Client(provider=YieldProviderMock)
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
response = client.chat.completions.create(messages, "Hello", stop=["and"])
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("How are you?", response.choices[0].message.content)
if __name__ == '__main__':
unittest.main()

View File

@ -31,3 +31,12 @@ class ModelProviderMock(AbstractProvider):
model, messages, stream, **kwargs model, messages, stream, **kwargs
): ):
yield model yield model
class YieldProviderMock(AsyncGeneratorProvider):
working = True
async def create_async_generator(
model, messages, stream, **kwargs
):
for message in messages:
yield message["content"]

View File

@ -196,15 +196,20 @@ class AsyncGeneratorProvider(AsyncProvider):
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
gen = generator.__aiter__() gen = generator.__aiter__()
while True: # Fix for RuntimeError: async generator ignored GeneratorExit
try: async def await_callback(callback):
yield loop.run_until_complete(gen.__anext__()) return await callback()
except StopAsyncIteration:
break
if new_loop: try:
loop.close() while True:
asyncio.set_event_loop(None) yield loop.run_until_complete(await_callback(gen.__anext__))
except StopAsyncIteration:
...
# Fix for: ResourceWarning: unclosed event loop
finally:
if new_loop:
loop.close()
asyncio.set_event_loop(None)
@classmethod @classmethod
async def create_async( async def create_async(

View File

@ -385,7 +385,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
} }
) as response: ) as response:
if not response.ok: if not response.ok:
raise RuntimeError(f"Response {response.status_code}: {await response.text()}") raise RuntimeError(f"Response {response.status}: {await response.text()}")
last_message: int = 0 last_message: int = 0
async for line in response.iter_lines(): async for line in response.iter_lines():
if not line.startswith(b"data: "): if not line.startswith(b"data: "):

View File

@ -2,9 +2,9 @@ from __future__ import annotations
import re import re
from .typing import Union, Generator, AsyncGenerator, Messages, ImageType from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .typing import Union, Generator, Messages, ImageType
from .base_provider import BaseProvider, ProviderType from .base_provider import BaseProvider, ProviderType
from .Provider.base_provider import AsyncGeneratorProvider
from .image import ImageResponse as ImageProviderResponse from .image import ImageResponse as ImageProviderResponse
from .Provider import BingCreateImages, Gemini, OpenaiChat from .Provider import BingCreateImages, Gemini, OpenaiChat
from .errors import NoImageResponseError from .errors import NoImageResponseError
@ -36,14 +36,14 @@ def iter_response(
stop: list = None stop: list = None
) -> Generator: ) -> Generator:
content = "" content = ""
idx = 1 finish_reason = None
chunk = None last_chunk = None
finish_reason = "stop"
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
if last_chunk is not None:
yield ChatCompletionChunk(last_chunk, finish_reason)
content += str(chunk) content += str(chunk)
if max_tokens is not None and idx > max_tokens: if max_tokens is not None and idx + 1 >= max_tokens:
finish_reason = "max_tokens" finish_reason = "max_tokens"
break
first = -1 first = -1
word = None word = None
if stop is not None: if stop is not None:
@ -52,98 +52,30 @@ def iter_response(
if first != -1: if first != -1:
content = content[:first] content = content[:first]
break break
if stream: if stream and first != -1:
first = chunk.find(word)
if first != -1: if first != -1:
first = chunk.find(word) chunk = chunk[:first]
if first != -1: else:
chunk = chunk[:first] first = 0
else:
first = 0
yield ChatCompletionChunk([ChatCompletionDeltaChoice(ChatCompletionDelta(chunk))])
if first != -1: if first != -1:
finish_reason = "stop"
if stream:
last_chunk = chunk
if finish_reason is not None:
break break
if last_chunk is not None:
yield ChatCompletionChunk(last_chunk, finish_reason)
if not stream: if not stream:
if response_format is not None and "type" in response_format: if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object": if response_format["type"] == "json_object":
response = read_json(response) response = read_json(response)
yield ChatCompletion([ChatCompletionChoice(ChatCompletionMessage(response, finish_reason))]) yield ChatCompletion(content, finish_reason)
async def aiter_response(
response: aiter,
stream: bool,
response_format: dict = None,
max_tokens: int = None,
stop: list = None
) -> AsyncGenerator:
content = ""
try:
idx = 0
chunk = None
async for chunk in response:
content += str(chunk)
if max_tokens is not None and idx > max_tokens:
break
first = -1
word = None
if stop is not None:
for word in list(stop):
first = content.find(word)
if first != -1:
content = content[:first]
break
if stream:
if first != -1:
first = chunk.find(word)
if first != -1:
chunk = chunk[:first]
else:
first = 0
yield ChatCompletionChunk([ChatCompletionDeltaChoice(ChatCompletionDelta(chunk))])
if first != -1:
break
idx += 1
except:
...
if not stream:
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
response = read_json(response)
yield ChatCompletion([ChatCompletionChoice(ChatCompletionMessage(response))])
class Model():
def __getitem__(self, item):
return getattr(self, item)
class ChatCompletion(Model):
def __init__(self, choices: list):
self.choices = choices
class ChatCompletionChunk(Model):
def __init__(self, choices: list):
self.choices = choices
class ChatCompletionChoice(Model):
def __init__(self, message: ChatCompletionMessage):
self.message = message
class ChatCompletionMessage(Model):
def __init__(self, content: str, finish_reason: str):
self.content = content
self.finish_reason = finish_reason
self.index = 0
self.logprobs = None
class ChatCompletionDelta(Model):
def __init__(self, content: str):
self.content = content
class ChatCompletionDeltaChoice(Model):
def __init__(self, delta: ChatCompletionDelta):
self.delta = delta
class Client(): class Client():
proxies: Proxies = None proxies: Proxies = None
chat: Chat chat: Chat
images: Images
def __init__( def __init__(
self, self,
@ -152,9 +84,9 @@ class Client():
proxies: Proxies = None, proxies: Proxies = None,
**kwargs **kwargs
) -> None: ) -> None:
self.proxies: Proxies = proxies
self.images = Images(self, image_provider)
self.chat = Chat(self, provider) self.chat = Chat(self, provider)
self.images = Images(self, image_provider)
self.proxies: Proxies = proxies
def get_proxy(self) -> Union[str, None]: def get_proxy(self) -> Union[str, None]:
if isinstance(self.proxies, str) or self.proxies is None: if isinstance(self.proxies, str) or self.proxies is None:
@ -178,13 +110,13 @@ class Completions():
stream: bool = False, stream: bool = False,
response_format: dict = None, response_format: dict = None,
max_tokens: int = None, max_tokens: int = None,
stop: list = None, stop: Union[list. str] = None,
**kwargs **kwargs
) -> Union[dict, Generator]: ) -> Union[ChatCompletion, Generator[ChatCompletionChunk]]:
if max_tokens is not None: if max_tokens is not None:
kwargs["max_tokens"] = max_tokens kwargs["max_tokens"] = max_tokens
if stop: if stop:
kwargs["stop"] = list(stop) kwargs["stop"] = stop
model, provider = get_model_and_provider( model, provider = get_model_and_provider(
model, model,
self.provider if provider is None else provider, self.provider if provider is None else provider,
@ -192,10 +124,8 @@ class Completions():
**kwargs **kwargs
) )
response = provider.create_completion(model, messages, stream=stream, **kwargs) response = provider.create_completion(model, messages, stream=stream, **kwargs)
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): stop = [stop] if isinstance(stop, str) else stop
response = iter_response(response, stream, response_format) # max_tokens, stop response = iter_response(response, stream, response_format, max_tokens, stop)
else:
response = iter_response(response, stream, response_format, max_tokens, stop)
return response if stream else next(response) return response if stream else next(response)
class Chat(): class Chat():
@ -212,20 +142,8 @@ class ImageModels():
self.client = client self.client = client
self.default = BingCreateImages(proxy=self.client.get_proxy()) self.default = BingCreateImages(proxy=self.client.get_proxy())
def get(self, name: str) -> ImageProvider: def get(self, name: str, default: ImageProvider = None) -> ImageProvider:
return getattr(self, name) if hasattr(self, name) else self.default return getattr(self, name) if hasattr(self, name) else default or self.default
class ImagesResponse(Model):
data: list[Image]
def __init__(self, data: list) -> None:
self.data = data
class Image(Model):
url: str
def __init__(self, url: str) -> None:
self.url = url
class Images(): class Images():
def __init__(self, client: Client, provider: ImageProvider = None): def __init__(self, client: Client, provider: ImageProvider = None):
@ -234,7 +152,7 @@ class Images():
self.models: ImageModels = ImageModels(client) self.models: ImageModels = ImageModels(client)
def generate(self, prompt, model: str = None, **kwargs): def generate(self, prompt, model: str = None, **kwargs):
provider = self.models.get(model) if model else self.provider or self.models.get(model) provider = self.models.get(model, self.provider)
if isinstance(provider, BaseProvider) or isinstance(provider, type) and issubclass(provider, BaseProvider): if isinstance(provider, BaseProvider) or isinstance(provider, type) and issubclass(provider, BaseProvider):
prompt = f"create a image: {prompt}" prompt = f"create a image: {prompt}"
response = provider.create_completion( response = provider.create_completion(
@ -249,11 +167,12 @@ class Images():
for chunk in response: for chunk in response:
if isinstance(chunk, ImageProviderResponse): if isinstance(chunk, ImageProviderResponse):
return ImagesResponse([Image(image)for image in list(chunk.images)]) images = [chunk.images] if isinstance(chunk.images, str) else chunk.images
return ImagesResponse([Image(image) for image in images])
raise NoImageResponseError() raise NoImageResponseError()
def create_variation(self, image: ImageType, model: str = None, **kwargs): def create_variation(self, image: ImageType, model: str = None, **kwargs):
provider = self.models.get(model) if model else self.provider provider = self.models.get(model, self.provider)
result = None result = None
if isinstance(provider, type) and issubclass(provider, BaseProvider): if isinstance(provider, type) and issubclass(provider, BaseProvider):
response = provider.create_completion( response = provider.create_completion(

44
g4f/stubs.py Normal file
View File

@ -0,0 +1,44 @@
from __future__ import annotations
class Model():
def __getitem__(self, item):
return getattr(self, item)
class ChatCompletion(Model):
def __init__(self, content: str, finish_reason: str):
self.choices = [ChatCompletionChoice(ChatCompletionMessage(content, finish_reason))]
class ChatCompletionChunk(Model):
def __init__(self, content: str, finish_reason: str):
self.choices = [ChatCompletionDeltaChoice(ChatCompletionDelta(content, finish_reason))]
class ChatCompletionMessage(Model):
def __init__(self, content: str, finish_reason: str):
self.content = content
self.finish_reason = finish_reason
class ChatCompletionChoice(Model):
def __init__(self, message: ChatCompletionMessage):
self.message = message
class ChatCompletionDelta(Model):
def __init__(self, content: str, finish_reason: str):
self.content = content
self.finish_reason = finish_reason
class ChatCompletionDeltaChoice(Model):
def __init__(self, delta: ChatCompletionDelta):
self.delta = delta
class Image(Model):
url: str
def __init__(self, url: str) -> None:
self.url = url
class ImagesResponse(Model):
data: list[Image]
def __init__(self, data: list) -> None:
self.data = data

3
image.html Normal file
View File

@ -0,0 +1,3 @@
<style type="text/css">#designer_attribute_container{display:flex;justify-content:space-between;width:100%;margin-top:10px}#designer_attribute_container .des_attr_i{width:16px;margin-right:2px}#designer_attribute_container .des_attr_txt{height:18px;line-height:18px;font-size:14px;font-weight:400;font-family:"Roboto",Helvetica,sans-serif}#designer_attribute_container #dalle_attribute_container{margin-left:auto}#designer_attribute_container #dalle_attribute_container .des_attr_dal{display:flex;justify-content:center;height:14px;border-radius:4px;background-color:#5f5f5e;padding:2px 8px}#designer_attribute_container #dalle_attribute_container .des_attr_dal_txt{height:14px;line-height:14px;font-size:11px;font-weight:400;font-family:"Roboto",Helvetica,sans-serif;color:#fff}.des_attr_txt{color:#fff}</style><div id="gir_async"
class="giric gir_1" data-rewriteurl="/images/create/a-serene-garden-filled-with-colorful-flowers-in-fu/1-65c8a550c2d34e67a93b016cd1f3ade3?FORM=GENCRE" data-cis="512" data-vimgseturl="/images/create/async/viewimageset/1-65c8a550c2d34e67a93b016cd1f3ade3&amp;IG=062C548EC8DD4047A2AAE63FD928194A&amp;IID=images.vis"
fir-th="OIG2.EPxx_.JFG402kzMQYYhj" data-ctc="Image copied to clipboard" data-wide="" data-wide-mobile=""><a class="single-img-link" target="_blank" href="/images/create/a-serene-garden-filled-with-colorful-flowers-in-fu/1-65c8a550c2d34e67a93b016cd1f3ade3?id=LrsKPoRLQud1%2bT8YdxQDhA%3d%3d&amp;view=detailv2&amp;idpp=genimg&amp;FORM=GCRIDP" h="ID=images,5015.1"><img class="gir_mmimg" src="https://tse4.mm.bing.net/th/id/OIG2.EPxx_.JFG402kzMQYYhj?w=270&amp;h=270&amp;c=6&amp;r=0&amp;o=5&amp;pid=ImgGn" alt="a serene garden filled with colorful flowers in full bloom"/></a><div id="designer_attribute_container"><img class="des_attr_i rms_img" alt="Designer" src="https://r.bing.com/rp/gmZtdJVd-klWl3XWpa6-ni1FU3M.svg" /><span class="des_attr_txt des_attr_txt_clr">Designer</span><div id="dalle_attribute_container"><div class="des_attr_dal"><span class="des_attr_dal_txt">Powered by DALL&#183;E 3</span></div></div></div></div>