Improve download of generated images, serve images in the api (#2391)

* Improve download of generated images, serve images in the api
Add support for conversation handling in the api

* Add orginal prompt to image response

* Add download images option in gui, fix loading model list in Airforce

* Add download images option in gui, fix loading model list in Airforce
This commit is contained in:
H Lohaus 2024-11-20 19:58:16 +01:00 committed by GitHub
parent c959d9b469
commit ffb4b0d162
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 494 additions and 328 deletions

View File

@ -1,13 +1,17 @@
import requests
import json
import uuid
url = "http://localhost:1337/v1/chat/completions"
conversation_id = str(uuid.uuid4())
body = {
"model": "",
"provider": "",
"provider": "Copilot",
"stream": True,
"messages": [
{"role": "user", "content": "What can you do? Who are you?"}
]
{"role": "user", "content": "Hello, i am Heiner. How are you?"}
],
"conversation_id": conversation_id
}
response = requests.post(url, json=body, stream=True)
response.raise_for_status()
@ -21,4 +25,27 @@ for line in response.iter_lines():
print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
except json.JSONDecodeError:
pass
print()
print()
print()
print()
body = {
"model": "",
"provider": "Copilot",
"stream": True,
"messages": [
{"role": "user", "content": "Tell me somethings about my name"}
],
"conversation_id": conversation_id
}
response = requests.post(url, json=body, stream=True)
response.raise_for_status()
for line in response.iter_lines():
if line.startswith(b"data: "):
try:
json_data = json.loads(line[6:])
if json_data.get("error"):
print(json_data)
break
print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
except json.JSONDecodeError:
pass

View File

@ -20,7 +20,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
working = True
supports_system_message = True
supports_message_history = True
@classmethod
def fetch_completions_models(cls):
response = requests.get('https://api.airforce/models', verify=False)
@ -34,19 +34,20 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
response.raise_for_status()
return response.json()
completions_models = fetch_completions_models.__func__(None)
imagine_models = fetch_imagine_models.__func__(None)
default_model = "gpt-4o-mini"
default_image_model = "flux"
additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "Flux-1.1-Pro"]
text_models = completions_models
image_models = [*imagine_models, *additional_models_imagine]
models = [
*text_models,
*image_models,
]
@classmethod
def get_models(cls):
if not cls.models:
cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine]
cls.models = [
*cls.fetch_completions_models(),
*cls.image_models
]
return cls.models
model_aliases = {
### completions ###
# openchat
@ -100,7 +101,6 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
if model in cls.image_models:
return cls._generate_image(model, messages, proxy, seed, size)
else:

View File

@ -20,15 +20,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True
_last_validated_value = None
default_model = 'blackboxai'
default_vision_model = default_model
default_image_model = 'Image Generation'
image_models = ['Image Generation', 'repomap']
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'gemini-1.5-flash', 'llama-3.1-8b', 'llama-3.1-70b', 'llama-3.1-405b']
userSelectedModel = ['gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
agentMode = {
'Image Generation': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
}
@ -77,22 +77,21 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
}
additional_prefixes = {
'gpt-4o': '@gpt-4o',
'gemini-pro': '@gemini-pro',
'claude-sonnet-3.5': '@claude-sonnet'
}
'gpt-4o': '@gpt-4o',
'gemini-pro': '@gemini-pro',
'claude-sonnet-3.5': '@claude-sonnet'
}
model_prefixes = {
**{mode: f"@{value['id']}" for mode, value in trendingAgentMode.items()
if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]},
**additional_prefixes
}
**{
mode: f"@{value['id']}" for mode, value in trendingAgentMode.items()
if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]
},
**additional_prefixes
}
models = list(dict.fromkeys([default_model, *userSelectedModel, *list(agentMode.keys()), *list(trendingAgentMode.keys())]))
model_aliases = {
"gemini-flash": "gemini-1.5-flash",
"claude-3.5-sonnet": "claude-sonnet-3.5",
@ -131,12 +130,11 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
return cls._last_validated_value
@staticmethod
def generate_id(length=7):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
@classmethod
def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages:
prefix = cls.model_prefixes.get(model, "")
@ -157,6 +155,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
cls,
model: str,
messages: Messages,
prompt: str = None,
proxy: str = None,
web_search: bool = False,
image: ImageType = None,
@ -191,7 +190,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
'sec-fetch-site': 'same-origin',
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36'
}
data = {
"messages": messages,
"id": message_id,
@ -221,26 +220,25 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
response_text = await response.text()
if model in cls.image_models:
image_matches = re.findall(r'!\[.*?\]\((https?://[^\)]+)\)', response_text)
if image_matches:
image_url = image_matches[0]
image_response = ImageResponse(images=[image_url], alt="Generated Image")
yield image_response
yield ImageResponse(image_url, prompt)
return
response_text = re.sub(r'Generated by BLACKBOX.AI, try unlimited chat https://www.blackbox.ai', '', response_text, flags=re.DOTALL)
json_match = re.search(r'\$~~~\$(.*?)\$~~~\$', response_text, re.DOTALL)
if json_match:
search_results = json.loads(json_match.group(1))
answer = response_text.split('$~~~$')[-1].strip()
formatted_response = f"{answer}\n\n**Source:**"
for i, result in enumerate(search_results, 1):
formatted_response += f"\n{i}. {result['title']}: {result['link']}"
yield formatted_response
else:
yield response_text.strip()

View File

@ -57,6 +57,7 @@ class Copilot(AbstractProvider):
image: ImageType = None,
conversation: Conversation = None,
return_conversation: bool = False,
web_search: bool = True,
**kwargs
) -> CreateResult:
if not has_curl_cffi:
@ -124,12 +125,14 @@ class Copilot(AbstractProvider):
is_started = False
msg = None
image_prompt: str = None
last_msg = None
while True:
try:
msg = wss.recv()[0]
msg = json.loads(msg)
except:
break
last_msg = msg
if msg.get("event") == "appendText":
is_started = True
yield msg.get("text")
@ -139,8 +142,12 @@ class Copilot(AbstractProvider):
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
elif msg.get("event") == "done":
break
elif msg.get("event") == "error":
raise RuntimeError(f"Error: {msg}")
elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]:
debug.log(f"Copilot Message: {msg}")
if not is_started:
raise RuntimeError(f"Last message: {msg}")
raise RuntimeError(f"Invalid response: {last_msg}")
@classmethod
async def get_access_token_and_cookies(cls, proxy: str = None):

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from urllib.parse import quote
import random
import requests
from sys import maxsize
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
@ -40,6 +39,7 @@ class PollinationsAI(OpenaiAPI):
cls,
model: str,
messages: Messages,
prompt: str = None,
api_base: str = "https://text.pollinations.ai/openai",
api_key: str = None,
proxy: str = None,
@ -49,9 +49,10 @@ class PollinationsAI(OpenaiAPI):
if model:
model = cls.get_model(model)
if model in cls.image_models:
prompt = messages[-1]["content"]
if prompt is None:
prompt = messages[-1]["content"]
if seed is None:
seed = random.randint(0, maxsize)
seed = random.randint(0, 100000)
image = f"https://image.pollinations.ai/prompt/{quote(prompt)}?width=1024&height=1024&seed={int(seed)}&nofeed=true&nologo=true&model={quote(model)}"
yield ImageResponse(image, prompt)
return

View File

@ -6,6 +6,8 @@ from aiohttp import ClientSession, ContentTypeError
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..requests.aiohttp import get_connector
from ..requests.raise_for_status import raise_for_status
from .helper import format_prompt
from ..image import ImageResponse
@ -32,10 +34,8 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
'yorickvp/llava-13b',
]
models = text_models + image_models
model_aliases = {
# image_models
"sd-3": "stability-ai/stable-diffusion-3",
@ -56,23 +56,14 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
# text_models
"google-deepmind/gemma-2b-it": "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626",
"yorickvp/llava-13b": "80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
}
@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
return model
elif model in cls.model_aliases:
return cls.model_aliases[model]
else:
return cls.default_model
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
prompt: str = None,
proxy: str = None,
**kwargs
) -> AsyncResult:
@ -96,29 +87,30 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36"
}
async with ClientSession(headers=headers) as session:
if model in cls.image_models:
prompt = messages[-1]['content'] if messages else ""
else:
prompt = format_prompt(messages)
async with ClientSession(headers=headers, connector=get_connector(proxy=proxy)) as session:
if prompt is None:
if model in cls.image_models:
prompt = messages[-1]['content']
else:
prompt = format_prompt(messages)
data = {
"model": model,
"version": cls.model_versions[model],
"input": {"prompt": prompt},
}
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
async with session.post(cls.api_endpoint, json=data) as response:
await raise_for_status(response)
result = await response.json()
prediction_id = result['id']
poll_url = f"https://homepage.replicate.com/api/poll?id={prediction_id}"
max_attempts = 30
delay = 5
for _ in range(max_attempts):
async with session.get(poll_url, proxy=proxy) as response:
response.raise_for_status()
async with session.get(poll_url) as response:
await raise_for_status(response)
try:
result = await response.json()
except ContentTypeError:
@ -131,7 +123,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
if result['status'] == 'succeeded':
if model in cls.image_models:
image_url = result['output'][0]
yield ImageResponse(image_url, "Generated image")
yield ImageResponse(image_url, prompt)
return
else:
for chunk in result['output']:
@ -140,6 +132,6 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
elif result['status'] == 'failed':
raise Exception(f"Prediction failed: {result.get('error')}")
await asyncio.sleep(delay)
if result['status'] != 'succeeded':
raise Exception("Prediction timed out")

View File

@ -9,7 +9,7 @@ from urllib.parse import urlencode
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, Sources
from ..requests.raise_for_status import raise_for_status
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
@ -23,7 +23,6 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
default_model = 'gpt-4o-mini'
models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
model_aliases = {
"llama-3.1-70b": "llama-3.1-70b-versatile",
}
@ -118,7 +117,7 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
if 'url' in json_data and 'title' in json_data:
if web_search:
sources.append({'title': json_data['title'], 'url': json_data['url']})
sources.append(json_data)
elif 'choices' in json_data:
for choice in json_data['choices']:
@ -128,5 +127,4 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
yield content
if web_search and sources:
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
yield f"\n\n**Source:**\n{sources_text}"
yield Sources(sources)

View File

@ -1,4 +1,4 @@
from ..providers.base_provider import *
from ..providers.types import FinishReason, Streaming
from ..providers.conversation import BaseConversation
from ..providers.types import Streaming
from ..providers.response import BaseConversation, Sources, FinishReason
from .helper import get_cookies, format_prompt

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from ...requests import StreamSession, raise_for_status
from ...errors import RateLimitError
from ...providers.conversation import BaseConversation
from ...providers.response import BaseConversation
class Conversation(BaseConversation):
"""

View File

@ -28,13 +28,14 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
cls,
model: str,
messages: Messages,
prompt: str = None,
api_key: str = None,
cookies: Cookies = None,
proxy: str = None,
**kwargs
) -> AsyncResult:
session = BingCreateImages(cookies, proxy, api_key)
yield await session.generate(messages[-1]["content"])
yield await session.generate(messages[-1]["content"] if prompt is None else prompt)
async def generate(self, prompt: str) -> ImageResponse:
"""

View File

@ -29,9 +29,10 @@ class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin):
cls,
model: str,
messages: Messages,
prompt: str = None,
**kwargs
) -> AsyncResult:
yield await cls.create_async(messages[-1]["content"], model, **kwargs)
yield await cls.create_async(messages[-1]["content"] if prompt is None else prompt, model, **kwargs)
@classmethod
async def create_async(

View File

@ -28,7 +28,7 @@ from ...requests.raise_for_status import raise_for_status
from ...requests.aiohttp import StreamSession
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, ResponseError
from ...providers.conversation import BaseConversation
from ...providers.response import BaseConversation
from ..helper import format_cookies
from ..openai.har_file import get_request_config, NoValidHarFileError
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url

View File

@ -4,6 +4,7 @@ import logging
import json
import uvicorn
import secrets
import os
from fastapi import FastAPI, Response, Request
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
@ -13,13 +14,16 @@ from starlette.exceptions import HTTPException
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import FileResponse
from pydantic import BaseModel
from typing import Union, Optional
import g4f
import g4f.debug
from g4f.client import AsyncClient, ChatCompletion
from g4f.providers.response import BaseConversation
from g4f.client.helper import filter_none
from g4f.image import is_accepted_format, images_dir
from g4f.typing import Messages
from g4f.cookies import read_cookie_files
@ -63,6 +67,7 @@ class ChatCompletionsConfig(BaseModel):
api_key: Optional[str] = None
web_search: Optional[bool] = None
proxy: Optional[str] = None
conversation_id: str = None
class ImageGenerationConfig(BaseModel):
prompt: str
@ -98,6 +103,7 @@ class Api:
self.client = AsyncClient()
self.g4f_api_key = g4f_api_key
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
self.conversations: dict[str, dict[str, BaseConversation]] = {}
def register_authorization(self):
@self.app.middleware("http")
@ -179,12 +185,21 @@ class Api:
async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
try:
config.provider = provider if config.provider is None else config.provider
if config.provider is None:
config.provider = AppConfig.provider
if config.api_key is None and request is not None:
auth_header = request.headers.get("Authorization")
if auth_header is not None:
auth_header = auth_header.split(None, 1)[-1]
if auth_header and auth_header != "Bearer":
config.api_key = auth_header
api_key = auth_header.split(None, 1)[-1]
if api_key and api_key != "Bearer":
config.api_key = api_key
conversation = return_conversation = None
if config.conversation_id is not None and config.provider is not None:
return_conversation = True
if config.conversation_id in self.conversations:
if config.provider in self.conversations[config.conversation_id]:
conversation = self.conversations[config.conversation_id][config.provider]
# Create the completion response
response = self.client.chat.completions.create(
@ -194,6 +209,11 @@ class Api:
"provider": AppConfig.provider,
"proxy": AppConfig.proxy,
**config.dict(exclude_none=True),
**{
"conversation_id": None,
"return_conversation": return_conversation,
"conversation": conversation
}
},
ignored=AppConfig.ignored_providers
),
@ -206,7 +226,13 @@ class Api:
async def streaming():
try:
async for chunk in response:
yield f"data: {json.dumps(chunk.to_json())}\n\n"
if isinstance(chunk, BaseConversation):
if config.conversation_id is not None and config.provider is not None:
if config.conversation_id not in self.conversations:
self.conversations[config.conversation_id] = {}
self.conversations[config.conversation_id][config.provider] = chunk
else:
yield f"data: {json.dumps(chunk.to_json())}\n\n"
except GeneratorExit:
pass
except Exception as e:
@ -222,7 +248,13 @@ class Api:
@self.app.post("/v1/images/generate")
@self.app.post("/v1/images/generations")
async def generate_image(config: ImageGenerationConfig):
async def generate_image(config: ImageGenerationConfig, request: Request):
if config.api_key is None:
auth_header = request.headers.get("Authorization")
if auth_header is not None:
api_key = auth_header.split(None, 1)[-1]
if api_key and api_key != "Bearer":
config.api_key = api_key
try:
response = await self.client.images.generate(
prompt=config.prompt,
@ -234,6 +266,9 @@ class Api:
proxy = config.proxy
)
)
for image in response.data:
if hasattr(image, "url") and image.url.startswith("/"):
image.url = f"{request.base_url}{image.url.lstrip('/')}"
return JSONResponse(response.to_json())
except Exception as e:
logger.exception(e)
@ -243,6 +278,18 @@ class Api:
async def completions():
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
@self.app.get("/images/{filename}")
async def get_image(filename):
target = os.path.join(images_dir, filename)
if not os.path.isfile(target):
return Response(status_code=404)
with open(target, "rb") as f:
content_type = is_accepted_format(f.read(12))
return FileResponse(target, media_type=content_type)
def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str:
last_provider = {} if not image else g4f.get_last_provider(True)
provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider

View File

@ -6,24 +6,27 @@ import random
import string
import asyncio
import base64
import aiohttp
import logging
from typing import Union, AsyncIterator, Iterator, Coroutine
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
from ..providers.base_provider import AsyncGeneratorProvider
from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP
from ..typing import Messages, Image
from ..providers.types import ProviderType, FinishReason, BaseConversation
from ..errors import NoImageResponseError
from ..image import ImageResponse, copy_images, images_dir
from ..typing import Messages, Image, ImageType
from ..providers.types import ProviderType
from ..providers.response import ResponseType, FinishReason, BaseConversation
from ..errors import NoImageResponseError, ModelNotFoundError
from ..providers.retry_provider import IterListProvider
from ..providers.base_provider import get_running_loop
from ..Provider.needs_auth.BingCreateImages import BingCreateImages
from ..requests.aiohttp import get_connector
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, get_last_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_sync_iter, to_async_iterator
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
try:
anext # Python 3.8+
except NameError:
@ -35,12 +38,12 @@ except NameError:
# Synchronous iter_response function
def iter_response(
response: Union[Iterator[str], AsyncIterator[str]],
response: Union[Iterator[Union[str, ResponseType]]],
stream: bool,
response_format: dict = None,
max_tokens: int = None,
stop: list = None
) -> Iterator[Union[ChatCompletion, ChatCompletionChunk]]:
response_format: Optional[dict] = None,
max_tokens: Optional[int] = None,
stop: Optional[list[str]] = None
) -> ChatCompletionResponseType:
content = ""
finish_reason = None
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
@ -88,22 +91,23 @@ def iter_response(
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
# Synchronous iter_append_model_and_provider function
def iter_append_model_and_provider(response: Iterator[ChatCompletionChunk]) -> Iterator[ChatCompletionChunk]:
def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType:
last_provider = None
for chunk in response:
last_provider = get_last_provider(True) if last_provider is None else last_provider
chunk.model = last_provider.get("model")
chunk.provider = last_provider.get("name")
yield chunk
if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
last_provider = get_last_provider(True) if last_provider is None else last_provider
chunk.model = last_provider.get("model")
chunk.provider = last_provider.get("name")
yield chunk
async def async_iter_response(
response: AsyncIterator[str],
response: AsyncIterator[Union[str, ResponseType]],
stream: bool,
response_format: dict = None,
max_tokens: int = None,
stop: list = None
) -> AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]:
response_format: Optional[dict] = None,
max_tokens: Optional[int] = None,
stop: Optional[list[str]] = None
) -> AsyncChatCompletionResponseType:
content = ""
finish_reason = None
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
@ -149,13 +153,16 @@ async def async_iter_response(
if hasattr(response, 'aclose'):
await safe_aclose(response)
async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompletionChunk]) -> AsyncIterator:
async def async_iter_append_model_and_provider(
response: AsyncChatCompletionResponseType
) -> AsyncChatCompletionResponseType:
last_provider = None
try:
async for chunk in response:
last_provider = get_last_provider(True) if last_provider is None else last_provider
chunk.model = last_provider.get("model")
chunk.provider = last_provider.get("name")
if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
last_provider = get_last_provider(True) if last_provider is None else last_provider
chunk.model = last_provider.get("model")
chunk.provider = last_provider.get("name")
yield chunk
finally:
if hasattr(response, 'aclose'):
@ -164,8 +171,8 @@ async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompl
class Client(BaseClient):
def __init__(
self,
provider: ProviderType = None,
image_provider: ImageProvider = None,
provider: Optional[ProviderType] = None,
image_provider: Optional[ImageProvider] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
@ -173,7 +180,7 @@ class Client(BaseClient):
self.images: Images = Images(self, image_provider)
class Completions:
def __init__(self, client: Client, provider: ProviderType = None):
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
self.client: Client = client
self.provider: ProviderType = provider
@ -181,16 +188,16 @@ class Completions:
self,
messages: Messages,
model: str,
provider: ProviderType = None,
stream: bool = False,
proxy: str = None,
response_format: dict = None,
max_tokens: int = None,
stop: Union[list[str], str] = None,
api_key: str = None,
ignored: list[str] = None,
ignore_working: bool = False,
ignore_stream: bool = False,
provider: Optional[ProviderType] = None,
stream: Optional[bool] = False,
proxy: Optional[str] = None,
response_format: Optional[dict] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[list[str], str]] = None,
api_key: Optional[str] = None,
ignored: Optional[list[str]] = None,
ignore_working: Optional[bool] = False,
ignore_stream: Optional[bool] = False,
**kwargs
) -> IterResponse:
model, provider = get_model_and_provider(
@ -234,22 +241,38 @@ class Completions:
class Chat:
completions: Completions
def __init__(self, client: Client, provider: ProviderType = None):
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
self.completions = Completions(client, provider)
class Images:
def __init__(self, client: Client, provider: ProviderType = None):
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
self.client: Client = client
self.provider: ProviderType = provider
self.provider: Optional[ProviderType] = provider
self.models: ImageModels = ImageModels(client)
def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
def generate(
self,
prompt: str,
model: str = None,
provider: Optional[ProviderType] = None,
response_format: str = "url",
proxy: Optional[str] = None,
**kwargs
) -> ImagesResponse:
"""
Synchronous generate method that runs the async_generate method in an event loop.
"""
return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy, **kwargs))
return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
async def async_generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
async def async_generate(
self,
prompt: str,
model: Optional[str] = None,
provider: Optional[ProviderType] = None,
response_format: Optional[str] = "url",
proxy: Optional[str] = None,
**kwargs
) -> ImagesResponse:
if provider is None:
provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
elif isinstance(provider, str):
@ -257,97 +280,73 @@ class Images:
else:
provider_handler = provider
if provider_handler is None:
raise ValueError(f"Unknown model: {model}")
if proxy is None:
proxy = self.client.proxy
raise ModelNotFoundError(f"Unknown model: {model}")
if isinstance(provider_handler, IterListProvider):
if provider_handler.providers:
provider_handler = provider_handler.providers[0]
else:
raise ValueError(f"IterListProvider for model {model} has no providers")
raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
if proxy is None:
proxy = self.client.proxy
response = None
if hasattr(provider_handler, "create_async_generator"):
messages = [{"role": "user", "content": prompt}]
async for item in provider_handler.create_async_generator(model, messages, **kwargs):
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs):
if isinstance(item, ImageResponse):
response = item
break
elif hasattr(provider, 'create'):
elif hasattr(provider_handler, 'create'):
if asyncio.iscoroutinefunction(provider_handler.create):
response = await provider_handler.create(prompt)
else:
response = provider_handler.create(prompt)
if isinstance(response, str):
response = ImageResponse([response], prompt)
elif hasattr(provider_handler, "create_completion"):
get_running_loop(check_nested=True)
messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs):
if isinstance(item, ImageResponse):
response = item
break
else:
raise ValueError(f"Provider {provider} does not support image generation")
if isinstance(response, ImageResponse):
return await self._process_image_response(response, response_format, proxy, model=model, provider=provider)
return await self._process_image_response(
response,
response_format,
proxy,
model,
getattr(provider_handler, "__name__", None)
)
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse:
async def process_image_item(session: aiohttp.ClientSession, image_data: str):
image_data_bytes = None
if image_data.startswith("http://") or image_data.startswith("https://"):
if response_format == "url":
return Image(url=image_data, revised_prompt=response.alt)
elif response_format == "b64_json":
# Fetch the image data and convert it to base64
image_data_bytes = await self._fetch_image(session, image_data)
b64_json = base64.b64encode(image_data_bytes).decode("utf-8")
return Image(b64_json=b64_json, url=image_data, revised_prompt=response.alt)
else:
# Assume image_data is base64 data or binary
if response_format == "url":
if image_data.startswith("data:image"):
# Remove the data URL scheme and get the base64 data
base64_data = image_data.split(",", 1)[-1]
else:
base64_data = image_data
# Decode the base64 data
image_data_bytes = base64.b64decode(base64_data)
if image_data_bytes:
file_name = self._save_image(image_data_bytes)
return Image(url=file_name, revised_prompt=response.alt)
else:
raise ValueError("Unable to process image data")
last_provider = get_last_provider(True)
async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session:
return ImagesResponse(
await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]),
model=last_provider.get("model") if model is None else model,
provider=last_provider.get("name") if provider is None else provider
)
async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes:
# Asynchronously fetch image data from the URL
async with session.get(url) as resp:
if resp.status == 200:
return await resp.read()
else:
raise RuntimeError(f"Failed to fetch image from {url}, status code {resp.status}")
def _save_image(self, image_data_bytes: bytes) -> str:
os.makedirs('generated_images', exist_ok=True)
image = to_image(image_data_bytes)
file_name = f"generated_images/image_{int(time.time())}_{random.randint(0, 10000)}.{EXTENSIONS_MAP[is_accepted_format(image_data_bytes)]}"
image.save(file_name)
return file_name
def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
def create_variation(
self,
image: Union[str, bytes],
model: str = None,
provider: Optional[ProviderType] = None,
response_format: str = "url",
**kwargs
) -> ImagesResponse:
return asyncio.run(self.async_create_variation(
image, model, provider, response_format, **kwargs
))
async def async_create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
async def async_create_variation(
self,
image: ImageType,
model: Optional[str] = None,
provider: Optional[ProviderType] = None,
response_format: str = "url",
proxy: Optional[str] = None,
**kwargs
) -> ImagesResponse:
if provider is None:
provider = self.models.get(model, provider or self.provider or BingCreateImages)
if provider is None:
raise ValueError(f"Unknown model: {model}")
raise ModelNotFoundError(f"Unknown model: {model}")
if isinstance(provider, str):
provider = convert_to_provider(provider)
if proxy is None:
@ -355,38 +354,61 @@ class Images:
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
messages = [{"role": "user", "content": "create a variation of this image"}]
image_data = to_data_uri(image)
generator = None
try:
generator = provider.create_async_generator(model, messages, image=image_data, response_format=response_format, proxy=proxy, **kwargs)
async for response in generator:
if isinstance(response, ImageResponse):
return self._process_image_response(response)
except RuntimeError as e:
if "async generator ignored GeneratorExit" in str(e):
logging.warning("Generator ignored GeneratorExit in create_variation, handling gracefully")
else:
raise
generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
async for chunk in generator:
if isinstance(chunk, ImageResponse):
response = chunk
break
finally:
if generator and hasattr(generator, 'aclose'):
await safe_aclose(generator)
logging.info("AsyncGeneratorProvider processing completed in create_variation")
elif hasattr(provider, 'create_variation'):
if asyncio.iscoroutinefunction(provider.create_variation):
response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
else:
response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
if isinstance(response, str):
response = ImageResponse([response])
return self._process_image_response(response)
else:
raise ValueError(f"Provider {provider} does not support image variation")
raise NoImageResponseError(f"Provider {provider} does not support image variation")
if isinstance(response, str):
response = ImageResponse([response])
if isinstance(response, ImageResponse):
return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None))
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
async def _process_image_response(
self,
response: ImageResponse,
response_format: str,
proxy: str = None,
model: Optional[str] = None,
provider: Optional[str] = None
) -> list[Image]:
if response_format in ("url", "b64_json"):
images = await copy_images(response.get_list(), response.options.get("cookies"), proxy)
async def process_image_item(image_file: str) -> Image:
if response_format == "b64_json":
with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file:
image_data = base64.b64encode(file.read()).decode()
return Image(url=image_file, b64_json=image_data, revised_prompt=response.alt)
return Image(url=image_file, revised_prompt=response.alt)
images = await asyncio.gather(*[process_image_item(image) for image in images])
else:
images = [Image(url=image, revised_prompt=response.alt) for image in response.get_list()]
last_provider = get_last_provider(True)
return ImagesResponse(
images,
model=last_provider.get("model") if model is None else model,
provider=last_provider.get("name") if provider is None else provider
)
class AsyncClient(BaseClient):
def __init__(
self,
provider: ProviderType = None,
image_provider: ImageProvider = None,
provider: Optional[ProviderType] = None,
image_provider: Optional[ImageProvider] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
@ -396,11 +418,11 @@ class AsyncClient(BaseClient):
class AsyncChat:
completions: AsyncCompletions
def __init__(self, client: AsyncClient, provider: ProviderType = None):
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.completions = AsyncCompletions(client, provider)
class AsyncCompletions:
def __init__(self, client: AsyncClient, provider: ProviderType = None):
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client
self.provider: ProviderType = provider
@ -408,18 +430,18 @@ class AsyncCompletions:
self,
messages: Messages,
model: str,
provider: ProviderType = None,
stream: bool = False,
proxy: str = None,
response_format: dict = None,
max_tokens: int = None,
stop: Union[list[str], str] = None,
api_key: str = None,
ignored: list[str] = None,
ignore_working: bool = False,
ignore_stream: bool = False,
provider: Optional[ProviderType] = None,
stream: Optional[bool] = False,
proxy: Optional[str] = None,
response_format: Optional[dict] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[list[str], str]] = None,
api_key: Optional[str] = None,
ignored: Optional[list[str]] = None,
ignore_working: Optional[bool] = False,
ignore_stream: Optional[bool] = False,
**kwargs
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk]]:
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]:
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
@ -450,15 +472,29 @@ class AsyncCompletions:
return response if stream else anext(response)
class AsyncImages(Images):
def __init__(self, client: AsyncClient, provider: ImageProvider = None):
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client
self.provider: ImageProvider = provider
self.provider: Optional[ProviderType] = provider
self.models: ImageModels = ImageModels(client)
async def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
async def generate(
self,
prompt: str,
model: Optional[str] = None,
provider: Optional[ProviderType] = None,
response_format: str = "url",
**kwargs
) -> ImagesResponse:
return await self.async_generate(prompt, model, provider, response_format, **kwargs)
async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
async def create_variation(
self,
image: ImageType,
model: str = None,
provider: ProviderType = None,
response_format: str = "url",
**kwargs
) -> ImagesResponse:
return await self.async_create_variation(
image, model, provider, response_format, **kwargs
)
)

View File

@ -6,7 +6,7 @@ import threading
import logging
import asyncio
from typing import AsyncIterator, Iterator, AsyncGenerator
from typing import AsyncIterator, Iterator, AsyncGenerator, Optional
def filter_json(text: str) -> str:
"""
@ -23,7 +23,7 @@ def filter_json(text: str) -> str:
return match.group("code")
return text
def find_stop(stop, content: str, chunk: str = None):
def find_stop(stop: Optional[list[str]], content: str, chunk: str = None):
first = -1
word = None
if stop is not None:

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from .types import Client, ImageProvider
from ..models import ModelUtils
class ImageModels():

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import os
from .stubs import ChatCompletion, ChatCompletionChunk
from ..providers.types import BaseProvider, ProviderType, FinishReason
from ..providers.types import BaseProvider
from typing import Union, Iterator, AsyncIterator
ImageProvider = Union[BaseProvider, object]

View File

@ -111,6 +111,11 @@
<input type="checkbox" id="hide-systemPrompt" />
<label for="hide-systemPrompt" class="toogle" title="For more space on phones"></label>
</div>
<div class="field">
<span class="label">Download generated images</span>
<input type="checkbox" id="download_images" checked/>
<label for="download_images" class="toogle" title="Download and save generated images to /generated_images"></label>
</div>
<div class="field">
<span class="label">Auto continue in ChatGPT</span>
<input id="auto_continue" type="checkbox" name="auto_continue" checked/>

View File

@ -498,6 +498,7 @@ body {
gap: 12px;
cursor: pointer;
animation: show_popup 0.4s;
height: 28px;
}
.toolbar .regenerate {

View File

@ -428,10 +428,14 @@ async function add_message_chunk(message, message_id) {
p.innerText = message.log;
log_storage.appendChild(p);
}
window.scrollTo(0, 0);
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
let scroll_down = ()=>{
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
window.scrollTo(0, 0);
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
}
}
setTimeout(scroll_down, 200);
setTimeout(scroll_down, 1000);
}
cameraInput?.addEventListener("click", (e) => {
@ -492,6 +496,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
const file = input && input.files.length > 0 ? input.files[0] : null;
const provider = providerSelect.options[providerSelect.selectedIndex].value;
const auto_continue = document.getElementById("auto_continue")?.checked;
const download_images = document.getElementById("download_images")?.checked;
let api_key = get_api_key_by_provider(provider);
await api("conversation", {
id: message_id,
@ -501,13 +506,13 @@ const ask_gpt = async (message_index = -1, message_id) => {
provider: provider,
messages: messages,
auto_continue: auto_continue,
download_images: download_images,
api_key: api_key,
}, file, message_id);
if (!error_storage[message_id]) {
html = markdown_render(message_storage[message_id]);
content_map.inner.innerHTML = html;
highlight(content_map.inner);
if (imageInput) imageInput.value = "";
if (cameraInput) cameraInput.value = "";
if (fileInput) fileInput.value = "";
@ -1239,8 +1244,7 @@ async function load_version() {
if (versions["version"] != versions["latest_version"]) {
let release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"];
let title = `New version: ${versions["latest_version"]}`;
text += `<a href="${release_url}" target="_blank" title="${title}">${versions["version"]}</a> `;
text += `<i class="fa-solid fa-rotate"></i>`
text += `<a href="${release_url}" target="_blank" title="${title}">${versions["version"]}</a> 🆕`;
} else {
text += versions["version"];
}

View File

@ -2,34 +2,22 @@ from __future__ import annotations
import logging
import os
import uuid
import asyncio
import time
from aiohttp import ClientSession
from typing import Iterator, Optional
from typing import Iterator
from flask import send_from_directory
from inspect import signature
from g4f import version, models
from g4f import get_last_provider, ChatCompletion
from g4f.errors import VersionNotFoundError
from g4f.typing import Cookies
from g4f.image import ImagePreview, ImageResponse, is_accepted_format, extract_data_uri
from g4f.requests.aiohttp import get_connector
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin, FinishReason
from g4f.providers.conversation import BaseConversation
from g4f.providers.base_provider import ProviderModelMixin
from g4f.providers.response import BaseConversation, FinishReason
from g4f.client.service import convert_to_provider
from g4f import debug
logger = logging.getLogger(__name__)
# Define the directory for generated images
images_dir = "./generated_images"
# Function to ensure the images directory exists
def ensure_images_dir():
if not os.path.exists(images_dir):
os.makedirs(images_dir)
conversations: dict[dict[str, BaseConversation]] = {}
class Api:
@ -42,7 +30,10 @@ class Api:
if provider in __map__:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key)
if api_key is not None and "api_key" in signature(provider.get_models).parameters:
models = provider.get_models(api_key=api_key)
else:
models = provider.get_models()
return [
{
"model": model,
@ -90,7 +81,7 @@ class Api:
def get_providers() -> list[str]:
return {
provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__)
+ (" (Image Generation)" if hasattr(provider, "image_models") else "")
+ (" (Image Generation)" if getattr(provider, "image_models", None) else "")
+ (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "")
+ (" (WebDriver)" if "webdriver" in provider.get_parameters() else "")
+ (" (Auth)" if provider.needs_auth else "")
@ -120,16 +111,23 @@ class Api:
api_key = json_data.get("api_key")
if api_key is not None:
kwargs["api_key"] = api_key
if json_data.get('web_search'):
if provider:
kwargs['web_search'] = True
else:
from .internet import get_search_message
messages[-1]["content"] = get_search_message(messages[-1]["content"])
do_web_search = json_data.get('web_search')
if do_web_search and provider:
provider_handler = convert_to_provider(provider)
if hasattr(provider_handler, "get_parameters"):
if "web_search" in provider_handler.get_parameters():
kwargs['web_search'] = True
do_web_search = False
if do_web_search:
from .internet import get_search_message
messages[-1]["content"] = get_search_message(messages[-1]["content"])
if json_data.get("auto_continue"):
kwargs['auto_continue'] = True
conversation_id = json_data.get("conversation_id")
if conversation_id and provider in conversations and conversation_id in conversations[provider]:
kwargs["conversation"] = conversations[provider][conversation_id]
if conversation_id and provider:
if provider in conversations and conversation_id in conversations[provider]:
kwargs["conversation"] = conversations[provider][conversation_id]
return {
"model": model,
@ -141,7 +139,7 @@ class Api:
**kwargs
}
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator:
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
if debug.logging:
debug.logs = []
print_callback = debug.log_handler
@ -163,18 +161,22 @@ class Api:
first = False
yield self._format_json("provider", get_last_provider(True))
if isinstance(chunk, BaseConversation):
if provider not in conversations:
conversations[provider] = {}
conversations[provider][conversation_id] = chunk
yield self._format_json("conversation", conversation_id)
if provider:
if provider not in conversations:
conversations[provider] = {}
conversations[provider][conversation_id] = chunk
yield self._format_json("conversation", conversation_id)
elif isinstance(chunk, Exception):
logger.exception(chunk)
yield self._format_json("message", get_error_message(chunk))
elif isinstance(chunk, ImagePreview):
yield self._format_json("preview", chunk.to_string())
elif isinstance(chunk, ImageResponse):
images = asyncio.run(self._copy_images(chunk.get_list(), chunk.options.get("cookies")))
yield self._format_json("content", str(ImageResponse(images, chunk.alt)))
images = chunk
if download_images:
images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies")))
images = ImageResponse(images, chunk.alt)
yield self._format_json("content", str(images))
elif not isinstance(chunk, FinishReason):
yield self._format_json("content", str(chunk))
if debug.logs:
@ -185,31 +187,6 @@ class Api:
logger.exception(e)
yield self._format_json('error', get_error_message(e))
async def _copy_images(self, images: list[str], cookies: Optional[Cookies] = None):
ensure_images_dir()
async with ClientSession(
connector=get_connector(None, os.environ.get("G4F_PROXY")),
cookies=cookies
) as session:
async def copy_image(image: str) -> str:
target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}")
if image.startswith("data:"):
with open(target, "wb") as f:
f.write(extract_data_uri(image))
else:
async with session.get(image) as response:
with open(target, "wb") as f:
async for chunk in response.content.iter_any():
f.write(chunk)
with open(target, "rb") as f:
extension = is_accepted_format(f.read(12)).split("/")[-1]
extension = "jpg" if extension == "jpeg" else extension
new_target = f"{target}.{extension}"
os.rename(target, new_target)
return f"/images/{os.path.basename(new_target)}"
return await asyncio.gather(*[copy_image(image) for image in images])
def _format_json(self, response_type: str, content):
return {
'type': response_type,
@ -221,4 +198,4 @@ def get_error_message(exception: Exception) -> str:
provider = get_last_provider()
if provider is None:
return message
return f"{provider.__name__}: {message}"
return f"{provider.__name__}: {message}"

View File

@ -89,7 +89,12 @@ class Backend_Api(Api):
kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
return self.app.response_class(
self._create_response_stream(kwargs, json_data.get("conversation_id"), json_data.get("provider")),
self._create_response_stream(
kwargs,
json_data.get("conversation_id"),
json_data.get("provider"),
json_data.get("download_images", True),
),
mimetype='text/event-stream'
)

View File

@ -8,12 +8,14 @@ try:
except ImportError:
has_requirements = False
from ...errors import MissingRequirementsError
from ... import debug
import asyncio
class SearchResults():
def __init__(self, results: list):
def __init__(self, results: list, used_words: int):
self.results = results
self.used_words = used_words
def __iter__(self):
yield from self.results
@ -104,7 +106,8 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text
region="wt-wt",
safesearch="moderate",
timelimit="y",
max_results=n_results
max_results=n_results,
backend="html"
):
results.append(SearchResultEntry(
result["title"],
@ -120,6 +123,7 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text
texts = await asyncio.gather(*requests)
formatted_results = []
used_words = 0
left_words = max_words
for i, entry in enumerate(results):
if add_text:
@ -132,13 +136,14 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text
left_words -= entry.snippet.count(" ")
if 0 > left_words:
break
used_words = max_words - left_words
formatted_results.append(entry)
return SearchResults(formatted_results)
return SearchResults(formatted_results, used_words)
def get_search_message(prompt) -> str:
def get_search_message(prompt, n_results: int = 5, max_words: int = 2500) -> str:
try:
search_results = asyncio.run(search(prompt))
search_results = asyncio.run(search(prompt, n_results, max_words))
message = f"""
{search_results}
@ -149,7 +154,8 @@ Make sure to add the sources of cites using [[Number]](Url) notation after the r
User request:
{prompt}
"""
debug.log(f"Web search: '{prompt.strip()[:50]}...' {search_results.used_words} Words")
return message
except Exception as e:
print("Couldn't do web search:", e)
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
return prompt

View File

@ -1,9 +1,13 @@
from __future__ import annotations
import os
import re
import time
import uuid
from io import BytesIO
import base64
from .typing import ImageType, Union, Image
import asyncio
from aiohttp import ClientSession
try:
from PIL.Image import open as open_image, new as new_image
@ -12,7 +16,10 @@ try:
except ImportError:
has_requirements = False
from .typing import ImageType, Union, Image, Optional, Cookies
from .errors import MissingRequirementsError
from .providers.response import ResponseType
from .requests.aiohttp import get_connector
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
@ -23,10 +30,12 @@ EXTENSIONS_MAP: dict[str, str] = {
"image/webp": "webp",
}
# Define the directory for generated images
images_dir = "./generated_images"
def fix_url(url:str) -> str:
""" replace ' ' by '+' (to be markdown compliant)"""
return url.replace(" ","+")
def to_image(image: ImageType, is_svg: bool = False) -> Image:
"""
@ -223,7 +232,6 @@ def format_images_markdown(images: Union[str, list], alt: str, preview: Union[st
preview = [preview.replace('{image}', image) if preview else image for image in images]
result = "\n".join(
f"[![#{idx+1} {alt}]({fix_url(preview[idx])})]({fix_url(image)})"
#f'[<img src="{preview[idx]}" width="200" alt="#{idx+1} {alt}">]({image})'
for idx, image in enumerate(images)
)
start_flag = "<!-- generated images start -->\n"
@ -260,7 +268,39 @@ def to_data_uri(image: ImageType) -> str:
return f"data:{is_accepted_format(data)};base64,{data_base64}"
return image
class ImageResponse:
# Function to ensure the images directory exists
def ensure_images_dir():
if not os.path.exists(images_dir):
os.makedirs(images_dir)
async def copy_images(images: list[str], cookies: Optional[Cookies] = None, proxy: Optional[str] = None):
ensure_images_dir()
async with ClientSession(
connector=get_connector(
proxy=os.environ.get("G4F_PROXY") if proxy is None else proxy
),
cookies=cookies
) as session:
async def copy_image(image: str) -> str:
target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}")
if image.startswith("data:"):
with open(target, "wb") as f:
f.write(extract_data_uri(image))
else:
async with session.get(image) as response:
with open(target, "wb") as f:
async for chunk in response.content.iter_chunked(4096):
f.write(chunk)
with open(target, "rb") as f:
extension = is_accepted_format(f.read(12)).split("/")[-1]
extension = "jpg" if extension == "jpeg" else extension
new_target = f"{target}.{extension}"
os.rename(target, new_target)
return f"/images/{os.path.basename(new_target)}"
return await asyncio.gather(*[copy_image(image) for image in images])
class ImageResponse(ResponseType):
def __init__(
self,
images: Union[str, list],

View File

@ -10,7 +10,8 @@ from inspect import signature, Parameter
from typing import Callable, Union
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider, FinishReason
from .types import BaseProvider
from .response import FinishReason, BaseConversation
from ..errors import NestAsyncioError, ModelNotSupportedError
from .. import debug
@ -100,7 +101,7 @@ class AbstractProvider(BaseProvider):
)
@classmethod
def get_parameters(cls) -> dict:
def get_parameters(cls) -> dict[str, Parameter]:
return signature(
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
cls.create_async if issubclass(cls, AsyncProvider) else
@ -258,7 +259,7 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
return "".join([
str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, (Exception, FinishReason))
if not isinstance(chunk, (Exception, FinishReason, BaseConversation))
])
@staticmethod
@ -307,4 +308,4 @@ class ProviderModelMixin:
elif model not in cls.get_models() and cls.models:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
debug.last_model = model
return model
return model

View File

@ -1,2 +0,0 @@
class BaseConversation:
...

26
g4f/providers/response.py Normal file
View File

@ -0,0 +1,26 @@
from __future__ import annotations
from abc import abstractmethod
class ResponseType:
@abstractmethod
def __str__(self) -> str:
pass
class FinishReason():
def __init__(self, reason: str):
self.reason = reason
def __str__(self) -> str:
return ""
class Sources(ResponseType):
def __init__(self, sources: list[dict[str, str]]) -> None:
self.list = sources
def __str__(self) -> str:
return "\n\n" + ("\n".join([f"{idx+1}. [{link['title']}]({link['url']})" for idx, link in enumerate(self.list)]))
class BaseConversation(ResponseType):
def __str__(self) -> str:
return ""

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Union, Dict, Type
from ..typing import Messages, CreateResult
from .conversation import BaseConversation
class BaseProvider(ABC):
"""
@ -98,10 +97,6 @@ class BaseRetryProvider(BaseProvider):
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
class FinishReason():
def __init__(self, reason: str):
self.reason = reason
class Streaming():
def __init__(self, data: str) -> None:
self.data = data

View File

@ -10,11 +10,13 @@ if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict
from .providers.response import ResponseType
SHA256 = NewType('sha_256_hash', str)
CreateResult = Iterator[str]
AsyncResult = AsyncIterator[str]
Messages = List[Dict[str, Union[str,List[Dict[str,Union[str,Dict[str,str]]]]]]]
CreateResult = Iterator[Union[str, ResponseType]]
AsyncResult = AsyncIterator[Union[str, ResponseType]]
Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]]
Cookies = Dict[str, str]
ImageType = Union[str, bytes, IO, Image, None]