diff --git a/etc/examples/api.py b/etc/examples/api.py index f8f5d5ec..2485bade 100644 --- a/etc/examples/api.py +++ b/etc/examples/api.py @@ -1,13 +1,17 @@ import requests import json +import uuid + url = "http://localhost:1337/v1/chat/completions" +conversation_id = str(uuid.uuid4()) body = { "model": "", - "provider": "", + "provider": "Copilot", "stream": True, "messages": [ - {"role": "user", "content": "What can you do? Who are you?"} - ] + {"role": "user", "content": "Hello, i am Heiner. How are you?"} + ], + "conversation_id": conversation_id } response = requests.post(url, json=body, stream=True) response.raise_for_status() @@ -21,4 +25,27 @@ for line in response.iter_lines(): print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="") except json.JSONDecodeError: pass -print() \ No newline at end of file +print() +print() +print() +body = { + "model": "", + "provider": "Copilot", + "stream": True, + "messages": [ + {"role": "user", "content": "Tell me somethings about my name"} + ], + "conversation_id": conversation_id +} +response = requests.post(url, json=body, stream=True) +response.raise_for_status() +for line in response.iter_lines(): + if line.startswith(b"data: "): + try: + json_data = json.loads(line[6:]) + if json_data.get("error"): + print(json_data) + break + print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="") + except json.JSONDecodeError: + pass \ No newline at end of file diff --git a/g4f/Provider/Airforce.py b/g4f/Provider/Airforce.py index 54bb543b..f5bcfefa 100644 --- a/g4f/Provider/Airforce.py +++ b/g4f/Provider/Airforce.py @@ -20,7 +20,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): working = True supports_system_message = True supports_message_history = True - + @classmethod def fetch_completions_models(cls): response = requests.get('https://api.airforce/models', verify=False) @@ -34,19 +34,20 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): response.raise_for_status() return response.json() - completions_models = fetch_completions_models.__func__(None) - imagine_models = fetch_imagine_models.__func__(None) - default_model = "gpt-4o-mini" default_image_model = "flux" additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "Flux-1.1-Pro"] - text_models = completions_models - image_models = [*imagine_models, *additional_models_imagine] - models = [ - *text_models, - *image_models, - ] - + + @classmethod + def get_models(cls): + if not cls.models: + cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine] + cls.models = [ + *cls.fetch_completions_models(), + *cls.image_models + ] + return cls.models + model_aliases = { ### completions ### # openchat @@ -100,7 +101,6 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin): **kwargs ) -> AsyncResult: model = cls.get_model(model) - if model in cls.image_models: return cls._generate_image(model, messages, proxy, seed, size) else: diff --git a/g4f/Provider/Blackbox.py b/g4f/Provider/Blackbox.py index b259b4aa..41905537 100644 --- a/g4f/Provider/Blackbox.py +++ b/g4f/Provider/Blackbox.py @@ -20,15 +20,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): supports_system_message = True supports_message_history = True _last_validated_value = None - + default_model = 'blackboxai' default_vision_model = default_model default_image_model = 'Image Generation' image_models = ['Image Generation', 'repomap'] vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'gemini-1.5-flash', 'llama-3.1-8b', 'llama-3.1-70b', 'llama-3.1-405b'] - + userSelectedModel = ['gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro'] - + agentMode = { 'Image Generation': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"}, } @@ -77,22 +77,21 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): } additional_prefixes = { - 'gpt-4o': '@gpt-4o', - 'gemini-pro': '@gemini-pro', - 'claude-sonnet-3.5': '@claude-sonnet' - } + 'gpt-4o': '@gpt-4o', + 'gemini-pro': '@gemini-pro', + 'claude-sonnet-3.5': '@claude-sonnet' + } model_prefixes = { - **{mode: f"@{value['id']}" for mode, value in trendingAgentMode.items() - if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]}, - **additional_prefixes - } + **{ + mode: f"@{value['id']}" for mode, value in trendingAgentMode.items() + if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"] + }, + **additional_prefixes + } - models = list(dict.fromkeys([default_model, *userSelectedModel, *list(agentMode.keys()), *list(trendingAgentMode.keys())])) - - model_aliases = { "gemini-flash": "gemini-1.5-flash", "claude-3.5-sonnet": "claude-sonnet-3.5", @@ -131,12 +130,11 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): return cls._last_validated_value - @staticmethod def generate_id(length=7): characters = string.ascii_letters + string.digits return ''.join(random.choice(characters) for _ in range(length)) - + @classmethod def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages: prefix = cls.model_prefixes.get(model, "") @@ -157,6 +155,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): cls, model: str, messages: Messages, + prompt: str = None, proxy: str = None, web_search: bool = False, image: ImageType = None, @@ -191,7 +190,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): 'sec-fetch-site': 'same-origin', 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36' } - + data = { "messages": messages, "id": message_id, @@ -221,26 +220,25 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: response.raise_for_status() response_text = await response.text() - + if model in cls.image_models: image_matches = re.findall(r'!\[.*?\]\((https?://[^\)]+)\)', response_text) if image_matches: image_url = image_matches[0] - image_response = ImageResponse(images=[image_url], alt="Generated Image") - yield image_response + yield ImageResponse(image_url, prompt) return response_text = re.sub(r'Generated by BLACKBOX.AI, try unlimited chat https://www.blackbox.ai', '', response_text, flags=re.DOTALL) - + json_match = re.search(r'\$~~~\$(.*?)\$~~~\$', response_text, re.DOTALL) if json_match: search_results = json.loads(json_match.group(1)) answer = response_text.split('$~~~$')[-1].strip() - + formatted_response = f"{answer}\n\n**Source:**" for i, result in enumerate(search_results, 1): formatted_response += f"\n{i}. {result['title']}: {result['link']}" - + yield formatted_response else: yield response_text.strip() diff --git a/g4f/Provider/Copilot.py b/g4f/Provider/Copilot.py index e8eea0a5..e10a55e8 100644 --- a/g4f/Provider/Copilot.py +++ b/g4f/Provider/Copilot.py @@ -57,6 +57,7 @@ class Copilot(AbstractProvider): image: ImageType = None, conversation: Conversation = None, return_conversation: bool = False, + web_search: bool = True, **kwargs ) -> CreateResult: if not has_curl_cffi: @@ -124,12 +125,14 @@ class Copilot(AbstractProvider): is_started = False msg = None image_prompt: str = None + last_msg = None while True: try: msg = wss.recv()[0] msg = json.loads(msg) except: break + last_msg = msg if msg.get("event") == "appendText": is_started = True yield msg.get("text") @@ -139,8 +142,12 @@ class Copilot(AbstractProvider): yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")}) elif msg.get("event") == "done": break + elif msg.get("event") == "error": + raise RuntimeError(f"Error: {msg}") + elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]: + debug.log(f"Copilot Message: {msg}") if not is_started: - raise RuntimeError(f"Last message: {msg}") + raise RuntimeError(f"Invalid response: {last_msg}") @classmethod async def get_access_token_and_cookies(cls, proxy: str = None): diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py index 57597bf1..a30f896d 100644 --- a/g4f/Provider/PollinationsAI.py +++ b/g4f/Provider/PollinationsAI.py @@ -3,7 +3,6 @@ from __future__ import annotations from urllib.parse import quote import random import requests -from sys import maxsize from aiohttp import ClientSession from ..typing import AsyncResult, Messages @@ -40,6 +39,7 @@ class PollinationsAI(OpenaiAPI): cls, model: str, messages: Messages, + prompt: str = None, api_base: str = "https://text.pollinations.ai/openai", api_key: str = None, proxy: str = None, @@ -49,9 +49,10 @@ class PollinationsAI(OpenaiAPI): if model: model = cls.get_model(model) if model in cls.image_models: - prompt = messages[-1]["content"] + if prompt is None: + prompt = messages[-1]["content"] if seed is None: - seed = random.randint(0, maxsize) + seed = random.randint(0, 100000) image = f"https://image.pollinations.ai/prompt/{quote(prompt)}?width=1024&height=1024&seed={int(seed)}&nofeed=true&nologo=true&model={quote(model)}" yield ImageResponse(image, prompt) return diff --git a/g4f/Provider/ReplicateHome.py b/g4f/Provider/ReplicateHome.py index a7fc9b54..00de09e0 100644 --- a/g4f/Provider/ReplicateHome.py +++ b/g4f/Provider/ReplicateHome.py @@ -6,6 +6,8 @@ from aiohttp import ClientSession, ContentTypeError from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..requests.aiohttp import get_connector +from ..requests.raise_for_status import raise_for_status from .helper import format_prompt from ..image import ImageResponse @@ -32,10 +34,8 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): 'yorickvp/llava-13b', ] - - models = text_models + image_models - + model_aliases = { # image_models "sd-3": "stability-ai/stable-diffusion-3", @@ -56,23 +56,14 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): # text_models "google-deepmind/gemma-2b-it": "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626", "yorickvp/llava-13b": "80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb", - } - @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model - @classmethod async def create_async_generator( cls, model: str, messages: Messages, + prompt: str = None, proxy: str = None, **kwargs ) -> AsyncResult: @@ -96,29 +87,30 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36" } - async with ClientSession(headers=headers) as session: - if model in cls.image_models: - prompt = messages[-1]['content'] if messages else "" - else: - prompt = format_prompt(messages) - + async with ClientSession(headers=headers, connector=get_connector(proxy=proxy)) as session: + if prompt is None: + if model in cls.image_models: + prompt = messages[-1]['content'] + else: + prompt = format_prompt(messages) + data = { "model": model, "version": cls.model_versions[model], "input": {"prompt": prompt}, } - - async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() + + async with session.post(cls.api_endpoint, json=data) as response: + await raise_for_status(response) result = await response.json() prediction_id = result['id'] - + poll_url = f"https://homepage.replicate.com/api/poll?id={prediction_id}" max_attempts = 30 delay = 5 for _ in range(max_attempts): - async with session.get(poll_url, proxy=proxy) as response: - response.raise_for_status() + async with session.get(poll_url) as response: + await raise_for_status(response) try: result = await response.json() except ContentTypeError: @@ -131,7 +123,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): if result['status'] == 'succeeded': if model in cls.image_models: image_url = result['output'][0] - yield ImageResponse(image_url, "Generated image") + yield ImageResponse(image_url, prompt) return else: for chunk in result['output']: @@ -140,6 +132,6 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): elif result['status'] == 'failed': raise Exception(f"Prediction failed: {result.get('error')}") await asyncio.sleep(delay) - + if result['status'] != 'succeeded': raise Exception("Prediction timed out") diff --git a/g4f/Provider/RubiksAI.py b/g4f/Provider/RubiksAI.py index c06e6c3d..816ea60c 100644 --- a/g4f/Provider/RubiksAI.py +++ b/g4f/Provider/RubiksAI.py @@ -9,7 +9,7 @@ from urllib.parse import urlencode from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, Sources from ..requests.raise_for_status import raise_for_status class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): @@ -23,7 +23,6 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): default_model = 'gpt-4o-mini' models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro'] - model_aliases = { "llama-3.1-70b": "llama-3.1-70b-versatile", } @@ -118,7 +117,7 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): if 'url' in json_data and 'title' in json_data: if web_search: - sources.append({'title': json_data['title'], 'url': json_data['url']}) + sources.append(json_data) elif 'choices' in json_data: for choice in json_data['choices']: @@ -128,5 +127,4 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): yield content if web_search and sources: - sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)]) - yield f"\n\n**Source:**\n{sources_text}" \ No newline at end of file + yield Sources(sources) \ No newline at end of file diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 667f6964..c0d8edf0 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,4 +1,4 @@ from ..providers.base_provider import * -from ..providers.types import FinishReason, Streaming -from ..providers.conversation import BaseConversation +from ..providers.types import Streaming +from ..providers.response import BaseConversation, Sources, FinishReason from .helper import get_cookies, format_prompt \ No newline at end of file diff --git a/g4f/Provider/bing/conversation.py b/g4f/Provider/bing/conversation.py index b5c237f9..43bcbb4d 100644 --- a/g4f/Provider/bing/conversation.py +++ b/g4f/Provider/bing/conversation.py @@ -2,7 +2,7 @@ from __future__ import annotations from ...requests import StreamSession, raise_for_status from ...errors import RateLimitError -from ...providers.conversation import BaseConversation +from ...providers.response import BaseConversation class Conversation(BaseConversation): """ diff --git a/g4f/Provider/needs_auth/BingCreateImages.py b/g4f/Provider/needs_auth/BingCreateImages.py index 80984d40..b95a78c3 100644 --- a/g4f/Provider/needs_auth/BingCreateImages.py +++ b/g4f/Provider/needs_auth/BingCreateImages.py @@ -28,13 +28,14 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin): cls, model: str, messages: Messages, + prompt: str = None, api_key: str = None, cookies: Cookies = None, proxy: str = None, **kwargs ) -> AsyncResult: session = BingCreateImages(cookies, proxy, api_key) - yield await session.generate(messages[-1]["content"]) + yield await session.generate(messages[-1]["content"] if prompt is None else prompt) async def generate(self, prompt: str) -> ImageResponse: """ diff --git a/g4f/Provider/needs_auth/DeepInfraImage.py b/g4f/Provider/needs_auth/DeepInfraImage.py index 24df04e3..44790561 100644 --- a/g4f/Provider/needs_auth/DeepInfraImage.py +++ b/g4f/Provider/needs_auth/DeepInfraImage.py @@ -29,9 +29,10 @@ class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin): cls, model: str, messages: Messages, + prompt: str = None, **kwargs ) -> AsyncResult: - yield await cls.create_async(messages[-1]["content"], model, **kwargs) + yield await cls.create_async(messages[-1]["content"] if prompt is None else prompt, model, **kwargs) @classmethod async def create_async( diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 97515ec4..587c0a23 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -28,7 +28,7 @@ from ...requests.raise_for_status import raise_for_status from ...requests.aiohttp import StreamSession from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format from ...errors import MissingAuthError, ResponseError -from ...providers.conversation import BaseConversation +from ...providers.response import BaseConversation from ..helper import format_cookies from ..openai.har_file import get_request_config, NoValidHarFileError from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 02ba5260..21e69388 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -4,6 +4,7 @@ import logging import json import uvicorn import secrets +import os from fastapi import FastAPI, Response, Request from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse @@ -13,13 +14,16 @@ from starlette.exceptions import HTTPException from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from fastapi.encoders import jsonable_encoder from fastapi.middleware.cors import CORSMiddleware +from starlette.responses import FileResponse from pydantic import BaseModel from typing import Union, Optional import g4f import g4f.debug from g4f.client import AsyncClient, ChatCompletion +from g4f.providers.response import BaseConversation from g4f.client.helper import filter_none +from g4f.image import is_accepted_format, images_dir from g4f.typing import Messages from g4f.cookies import read_cookie_files @@ -63,6 +67,7 @@ class ChatCompletionsConfig(BaseModel): api_key: Optional[str] = None web_search: Optional[bool] = None proxy: Optional[str] = None + conversation_id: str = None class ImageGenerationConfig(BaseModel): prompt: str @@ -98,6 +103,7 @@ class Api: self.client = AsyncClient() self.g4f_api_key = g4f_api_key self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key") + self.conversations: dict[str, dict[str, BaseConversation]] = {} def register_authorization(self): @self.app.middleware("http") @@ -179,12 +185,21 @@ class Api: async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None): try: config.provider = provider if config.provider is None else config.provider + if config.provider is None: + config.provider = AppConfig.provider if config.api_key is None and request is not None: auth_header = request.headers.get("Authorization") if auth_header is not None: - auth_header = auth_header.split(None, 1)[-1] - if auth_header and auth_header != "Bearer": - config.api_key = auth_header + api_key = auth_header.split(None, 1)[-1] + if api_key and api_key != "Bearer": + config.api_key = api_key + + conversation = return_conversation = None + if config.conversation_id is not None and config.provider is not None: + return_conversation = True + if config.conversation_id in self.conversations: + if config.provider in self.conversations[config.conversation_id]: + conversation = self.conversations[config.conversation_id][config.provider] # Create the completion response response = self.client.chat.completions.create( @@ -194,6 +209,11 @@ class Api: "provider": AppConfig.provider, "proxy": AppConfig.proxy, **config.dict(exclude_none=True), + **{ + "conversation_id": None, + "return_conversation": return_conversation, + "conversation": conversation + } }, ignored=AppConfig.ignored_providers ), @@ -206,7 +226,13 @@ class Api: async def streaming(): try: async for chunk in response: - yield f"data: {json.dumps(chunk.to_json())}\n\n" + if isinstance(chunk, BaseConversation): + if config.conversation_id is not None and config.provider is not None: + if config.conversation_id not in self.conversations: + self.conversations[config.conversation_id] = {} + self.conversations[config.conversation_id][config.provider] = chunk + else: + yield f"data: {json.dumps(chunk.to_json())}\n\n" except GeneratorExit: pass except Exception as e: @@ -222,7 +248,13 @@ class Api: @self.app.post("/v1/images/generate") @self.app.post("/v1/images/generations") - async def generate_image(config: ImageGenerationConfig): + async def generate_image(config: ImageGenerationConfig, request: Request): + if config.api_key is None: + auth_header = request.headers.get("Authorization") + if auth_header is not None: + api_key = auth_header.split(None, 1)[-1] + if api_key and api_key != "Bearer": + config.api_key = api_key try: response = await self.client.images.generate( prompt=config.prompt, @@ -234,6 +266,9 @@ class Api: proxy = config.proxy ) ) + for image in response.data: + if hasattr(image, "url") and image.url.startswith("/"): + image.url = f"{request.base_url}{image.url.lstrip('/')}" return JSONResponse(response.to_json()) except Exception as e: logger.exception(e) @@ -243,6 +278,18 @@ class Api: async def completions(): return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") + @self.app.get("/images/{filename}") + async def get_image(filename): + target = os.path.join(images_dir, filename) + + if not os.path.isfile(target): + return Response(status_code=404) + + with open(target, "rb") as f: + content_type = is_accepted_format(f.read(12)) + + return FileResponse(target, media_type=content_type) + def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str: last_provider = {} if not image else g4f.get_last_provider(True) provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 549a244b..dea19a60 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -6,24 +6,27 @@ import random import string import asyncio import base64 -import aiohttp import logging -from typing import Union, AsyncIterator, Iterator, Coroutine +from typing import Union, AsyncIterator, Iterator, Coroutine, Optional from ..providers.base_provider import AsyncGeneratorProvider -from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP -from ..typing import Messages, Image -from ..providers.types import ProviderType, FinishReason, BaseConversation -from ..errors import NoImageResponseError +from ..image import ImageResponse, copy_images, images_dir +from ..typing import Messages, Image, ImageType +from ..providers.types import ProviderType +from ..providers.response import ResponseType, FinishReason, BaseConversation +from ..errors import NoImageResponseError, ModelNotFoundError from ..providers.retry_provider import IterListProvider +from ..providers.base_provider import get_running_loop from ..Provider.needs_auth.BingCreateImages import BingCreateImages -from ..requests.aiohttp import get_connector from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse from .image_models import ImageModels from .types import IterResponse, ImageProvider, Client as BaseClient from .service import get_model_and_provider, get_last_provider, convert_to_provider from .helper import find_stop, filter_json, filter_none, safe_aclose, to_sync_iter, to_async_iterator +ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]] +AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]] + try: anext # Python 3.8+ except NameError: @@ -35,12 +38,12 @@ except NameError: # Synchronous iter_response function def iter_response( - response: Union[Iterator[str], AsyncIterator[str]], + response: Union[Iterator[Union[str, ResponseType]]], stream: bool, - response_format: dict = None, - max_tokens: int = None, - stop: list = None -) -> Iterator[Union[ChatCompletion, ChatCompletionChunk]]: + response_format: Optional[dict] = None, + max_tokens: Optional[int] = None, + stop: Optional[list[str]] = None +) -> ChatCompletionResponseType: content = "" finish_reason = None completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) @@ -88,22 +91,23 @@ def iter_response( yield ChatCompletion(content, finish_reason, completion_id, int(time.time())) # Synchronous iter_append_model_and_provider function -def iter_append_model_and_provider(response: Iterator[ChatCompletionChunk]) -> Iterator[ChatCompletionChunk]: +def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType: last_provider = None for chunk in response: - last_provider = get_last_provider(True) if last_provider is None else last_provider - chunk.model = last_provider.get("model") - chunk.provider = last_provider.get("name") - yield chunk + if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)): + last_provider = get_last_provider(True) if last_provider is None else last_provider + chunk.model = last_provider.get("model") + chunk.provider = last_provider.get("name") + yield chunk async def async_iter_response( - response: AsyncIterator[str], + response: AsyncIterator[Union[str, ResponseType]], stream: bool, - response_format: dict = None, - max_tokens: int = None, - stop: list = None -) -> AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]: + response_format: Optional[dict] = None, + max_tokens: Optional[int] = None, + stop: Optional[list[str]] = None +) -> AsyncChatCompletionResponseType: content = "" finish_reason = None completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) @@ -149,13 +153,16 @@ async def async_iter_response( if hasattr(response, 'aclose'): await safe_aclose(response) -async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompletionChunk]) -> AsyncIterator: +async def async_iter_append_model_and_provider( + response: AsyncChatCompletionResponseType + ) -> AsyncChatCompletionResponseType: last_provider = None try: async for chunk in response: - last_provider = get_last_provider(True) if last_provider is None else last_provider - chunk.model = last_provider.get("model") - chunk.provider = last_provider.get("name") + if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)): + last_provider = get_last_provider(True) if last_provider is None else last_provider + chunk.model = last_provider.get("model") + chunk.provider = last_provider.get("name") yield chunk finally: if hasattr(response, 'aclose'): @@ -164,8 +171,8 @@ async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompl class Client(BaseClient): def __init__( self, - provider: ProviderType = None, - image_provider: ImageProvider = None, + provider: Optional[ProviderType] = None, + image_provider: Optional[ImageProvider] = None, **kwargs ) -> None: super().__init__(**kwargs) @@ -173,7 +180,7 @@ class Client(BaseClient): self.images: Images = Images(self, image_provider) class Completions: - def __init__(self, client: Client, provider: ProviderType = None): + def __init__(self, client: Client, provider: Optional[ProviderType] = None): self.client: Client = client self.provider: ProviderType = provider @@ -181,16 +188,16 @@ class Completions: self, messages: Messages, model: str, - provider: ProviderType = None, - stream: bool = False, - proxy: str = None, - response_format: dict = None, - max_tokens: int = None, - stop: Union[list[str], str] = None, - api_key: str = None, - ignored: list[str] = None, - ignore_working: bool = False, - ignore_stream: bool = False, + provider: Optional[ProviderType] = None, + stream: Optional[bool] = False, + proxy: Optional[str] = None, + response_format: Optional[dict] = None, + max_tokens: Optional[int] = None, + stop: Optional[Union[list[str], str]] = None, + api_key: Optional[str] = None, + ignored: Optional[list[str]] = None, + ignore_working: Optional[bool] = False, + ignore_stream: Optional[bool] = False, **kwargs ) -> IterResponse: model, provider = get_model_and_provider( @@ -234,22 +241,38 @@ class Completions: class Chat: completions: Completions - def __init__(self, client: Client, provider: ProviderType = None): + def __init__(self, client: Client, provider: Optional[ProviderType] = None): self.completions = Completions(client, provider) class Images: - def __init__(self, client: Client, provider: ProviderType = None): + def __init__(self, client: Client, provider: Optional[ProviderType] = None): self.client: Client = client - self.provider: ProviderType = provider + self.provider: Optional[ProviderType] = provider self.models: ImageModels = ImageModels(client) - def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse: + def generate( + self, + prompt: str, + model: str = None, + provider: Optional[ProviderType] = None, + response_format: str = "url", + proxy: Optional[str] = None, + **kwargs + ) -> ImagesResponse: """ Synchronous generate method that runs the async_generate method in an event loop. """ - return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy, **kwargs)) + return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs)) - async def async_generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse: + async def async_generate( + self, + prompt: str, + model: Optional[str] = None, + provider: Optional[ProviderType] = None, + response_format: Optional[str] = "url", + proxy: Optional[str] = None, + **kwargs + ) -> ImagesResponse: if provider is None: provider_handler = self.models.get(model, provider or self.provider or BingCreateImages) elif isinstance(provider, str): @@ -257,97 +280,73 @@ class Images: else: provider_handler = provider if provider_handler is None: - raise ValueError(f"Unknown model: {model}") - if proxy is None: - proxy = self.client.proxy - + raise ModelNotFoundError(f"Unknown model: {model}") if isinstance(provider_handler, IterListProvider): if provider_handler.providers: provider_handler = provider_handler.providers[0] else: - raise ValueError(f"IterListProvider for model {model} has no providers") + raise ModelNotFoundError(f"IterListProvider for model {model} has no providers") + if proxy is None: + proxy = self.client.proxy response = None - if hasattr(provider_handler, "create_async_generator"): - messages = [{"role": "user", "content": prompt}] - async for item in provider_handler.create_async_generator(model, messages, **kwargs): + if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): + messages = [{"role": "user", "content": f"Generate a image: {prompt}"}] + async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs): if isinstance(item, ImageResponse): response = item break - elif hasattr(provider, 'create'): + elif hasattr(provider_handler, 'create'): if asyncio.iscoroutinefunction(provider_handler.create): response = await provider_handler.create(prompt) else: response = provider_handler.create(prompt) if isinstance(response, str): response = ImageResponse([response], prompt) + elif hasattr(provider_handler, "create_completion"): + get_running_loop(check_nested=True) + messages = [{"role": "user", "content": f"Generate a image: {prompt}"}] + for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs): + if isinstance(item, ImageResponse): + response = item + break else: raise ValueError(f"Provider {provider} does not support image generation") if isinstance(response, ImageResponse): - return await self._process_image_response(response, response_format, proxy, model=model, provider=provider) - + return await self._process_image_response( + response, + response_format, + proxy, + model, + getattr(provider_handler, "__name__", None) + ) raise NoImageResponseError(f"Unexpected response type: {type(response)}") - async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse: - async def process_image_item(session: aiohttp.ClientSession, image_data: str): - image_data_bytes = None - if image_data.startswith("http://") or image_data.startswith("https://"): - if response_format == "url": - return Image(url=image_data, revised_prompt=response.alt) - elif response_format == "b64_json": - # Fetch the image data and convert it to base64 - image_data_bytes = await self._fetch_image(session, image_data) - b64_json = base64.b64encode(image_data_bytes).decode("utf-8") - return Image(b64_json=b64_json, url=image_data, revised_prompt=response.alt) - else: - # Assume image_data is base64 data or binary - if response_format == "url": - if image_data.startswith("data:image"): - # Remove the data URL scheme and get the base64 data - base64_data = image_data.split(",", 1)[-1] - else: - base64_data = image_data - # Decode the base64 data - image_data_bytes = base64.b64decode(base64_data) - if image_data_bytes: - file_name = self._save_image(image_data_bytes) - return Image(url=file_name, revised_prompt=response.alt) - else: - raise ValueError("Unable to process image data") - - last_provider = get_last_provider(True) - async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session: - return ImagesResponse( - await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]), - model=last_provider.get("model") if model is None else model, - provider=last_provider.get("name") if provider is None else provider - ) - - async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes: - # Asynchronously fetch image data from the URL - async with session.get(url) as resp: - if resp.status == 200: - return await resp.read() - else: - raise RuntimeError(f"Failed to fetch image from {url}, status code {resp.status}") - - def _save_image(self, image_data_bytes: bytes) -> str: - os.makedirs('generated_images', exist_ok=True) - image = to_image(image_data_bytes) - file_name = f"generated_images/image_{int(time.time())}_{random.randint(0, 10000)}.{EXTENSIONS_MAP[is_accepted_format(image_data_bytes)]}" - image.save(file_name) - return file_name - - def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse: + def create_variation( + self, + image: Union[str, bytes], + model: str = None, + provider: Optional[ProviderType] = None, + response_format: str = "url", + **kwargs + ) -> ImagesResponse: return asyncio.run(self.async_create_variation( image, model, provider, response_format, **kwargs )) - async def async_create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse: + async def async_create_variation( + self, + image: ImageType, + model: Optional[str] = None, + provider: Optional[ProviderType] = None, + response_format: str = "url", + proxy: Optional[str] = None, + **kwargs + ) -> ImagesResponse: if provider is None: provider = self.models.get(model, provider or self.provider or BingCreateImages) if provider is None: - raise ValueError(f"Unknown model: {model}") + raise ModelNotFoundError(f"Unknown model: {model}") if isinstance(provider, str): provider = convert_to_provider(provider) if proxy is None: @@ -355,38 +354,61 @@ class Images: if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): messages = [{"role": "user", "content": "create a variation of this image"}] - image_data = to_data_uri(image) generator = None try: - generator = provider.create_async_generator(model, messages, image=image_data, response_format=response_format, proxy=proxy, **kwargs) - async for response in generator: - if isinstance(response, ImageResponse): - return self._process_image_response(response) - except RuntimeError as e: - if "async generator ignored GeneratorExit" in str(e): - logging.warning("Generator ignored GeneratorExit in create_variation, handling gracefully") - else: - raise + generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs) + async for chunk in generator: + if isinstance(chunk, ImageResponse): + response = chunk + break finally: if generator and hasattr(generator, 'aclose'): await safe_aclose(generator) - logging.info("AsyncGeneratorProvider processing completed in create_variation") elif hasattr(provider, 'create_variation'): if asyncio.iscoroutinefunction(provider.create_variation): response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) else: response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs) - if isinstance(response, str): - response = ImageResponse([response]) - return self._process_image_response(response) else: - raise ValueError(f"Provider {provider} does not support image variation") + raise NoImageResponseError(f"Provider {provider} does not support image variation") + + if isinstance(response, str): + response = ImageResponse([response]) + if isinstance(response, ImageResponse): + return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None)) + raise NoImageResponseError(f"Unexpected response type: {type(response)}") + + async def _process_image_response( + self, + response: ImageResponse, + response_format: str, + proxy: str = None, + model: Optional[str] = None, + provider: Optional[str] = None + ) -> list[Image]: + if response_format in ("url", "b64_json"): + images = await copy_images(response.get_list(), response.options.get("cookies"), proxy) + async def process_image_item(image_file: str) -> Image: + if response_format == "b64_json": + with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file: + image_data = base64.b64encode(file.read()).decode() + return Image(url=image_file, b64_json=image_data, revised_prompt=response.alt) + return Image(url=image_file, revised_prompt=response.alt) + images = await asyncio.gather(*[process_image_item(image) for image in images]) + else: + images = [Image(url=image, revised_prompt=response.alt) for image in response.get_list()] + last_provider = get_last_provider(True) + return ImagesResponse( + images, + model=last_provider.get("model") if model is None else model, + provider=last_provider.get("name") if provider is None else provider + ) class AsyncClient(BaseClient): def __init__( self, - provider: ProviderType = None, - image_provider: ImageProvider = None, + provider: Optional[ProviderType] = None, + image_provider: Optional[ImageProvider] = None, **kwargs ) -> None: super().__init__(**kwargs) @@ -396,11 +418,11 @@ class AsyncClient(BaseClient): class AsyncChat: completions: AsyncCompletions - def __init__(self, client: AsyncClient, provider: ProviderType = None): + def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None): self.completions = AsyncCompletions(client, provider) class AsyncCompletions: - def __init__(self, client: AsyncClient, provider: ProviderType = None): + def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None): self.client: AsyncClient = client self.provider: ProviderType = provider @@ -408,18 +430,18 @@ class AsyncCompletions: self, messages: Messages, model: str, - provider: ProviderType = None, - stream: bool = False, - proxy: str = None, - response_format: dict = None, - max_tokens: int = None, - stop: Union[list[str], str] = None, - api_key: str = None, - ignored: list[str] = None, - ignore_working: bool = False, - ignore_stream: bool = False, + provider: Optional[ProviderType] = None, + stream: Optional[bool] = False, + proxy: Optional[str] = None, + response_format: Optional[dict] = None, + max_tokens: Optional[int] = None, + stop: Optional[Union[list[str], str]] = None, + api_key: Optional[str] = None, + ignored: Optional[list[str]] = None, + ignore_working: Optional[bool] = False, + ignore_stream: Optional[bool] = False, **kwargs - ) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk]]: + ) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]: model, provider = get_model_and_provider( model, self.provider if provider is None else provider, @@ -450,15 +472,29 @@ class AsyncCompletions: return response if stream else anext(response) class AsyncImages(Images): - def __init__(self, client: AsyncClient, provider: ImageProvider = None): + def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None): self.client: AsyncClient = client - self.provider: ImageProvider = provider + self.provider: Optional[ProviderType] = provider self.models: ImageModels = ImageModels(client) - async def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse: + async def generate( + self, + prompt: str, + model: Optional[str] = None, + provider: Optional[ProviderType] = None, + response_format: str = "url", + **kwargs + ) -> ImagesResponse: return await self.async_generate(prompt, model, provider, response_format, **kwargs) - async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse: + async def create_variation( + self, + image: ImageType, + model: str = None, + provider: ProviderType = None, + response_format: str = "url", + **kwargs + ) -> ImagesResponse: return await self.async_create_variation( image, model, provider, response_format, **kwargs - ) + ) \ No newline at end of file diff --git a/g4f/client/helper.py b/g4f/client/helper.py index 71bfd38a..93588c07 100644 --- a/g4f/client/helper.py +++ b/g4f/client/helper.py @@ -6,7 +6,7 @@ import threading import logging import asyncio -from typing import AsyncIterator, Iterator, AsyncGenerator +from typing import AsyncIterator, Iterator, AsyncGenerator, Optional def filter_json(text: str) -> str: """ @@ -23,7 +23,7 @@ def filter_json(text: str) -> str: return match.group("code") return text -def find_stop(stop, content: str, chunk: str = None): +def find_stop(stop: Optional[list[str]], content: str, chunk: str = None): first = -1 word = None if stop is not None: diff --git a/g4f/client/image_models.py b/g4f/client/image_models.py index edaa4592..0b97a56b 100644 --- a/g4f/client/image_models.py +++ b/g4f/client/image_models.py @@ -1,7 +1,5 @@ from __future__ import annotations -from .types import Client, ImageProvider - from ..models import ModelUtils class ImageModels(): diff --git a/g4f/client/types.py b/g4f/client/types.py index 4f252ba9..5010e098 100644 --- a/g4f/client/types.py +++ b/g4f/client/types.py @@ -3,7 +3,7 @@ from __future__ import annotations import os from .stubs import ChatCompletion, ChatCompletionChunk -from ..providers.types import BaseProvider, ProviderType, FinishReason +from ..providers.types import BaseProvider from typing import Union, Iterator, AsyncIterator ImageProvider = Union[BaseProvider, object] diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index 48214093..6c2ad8b6 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -111,6 +111,11 @@ +