mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-28 11:07:24 +03:00
Improve download of generated images, serve images in the api (#2391)
* Improve download of generated images, serve images in the api Add support for conversation handling in the api * Add orginal prompt to image response * Add download images option in gui, fix loading model list in Airforce * Add download images option in gui, fix loading model list in Airforce
This commit is contained in:
parent
c959d9b469
commit
ffb4b0d162
@ -1,13 +1,17 @@
|
||||
import requests
|
||||
import json
|
||||
import uuid
|
||||
|
||||
url = "http://localhost:1337/v1/chat/completions"
|
||||
conversation_id = str(uuid.uuid4())
|
||||
body = {
|
||||
"model": "",
|
||||
"provider": "",
|
||||
"provider": "Copilot",
|
||||
"stream": True,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What can you do? Who are you?"}
|
||||
]
|
||||
{"role": "user", "content": "Hello, i am Heiner. How are you?"}
|
||||
],
|
||||
"conversation_id": conversation_id
|
||||
}
|
||||
response = requests.post(url, json=body, stream=True)
|
||||
response.raise_for_status()
|
||||
@ -21,4 +25,27 @@ for line in response.iter_lines():
|
||||
print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
print()
|
||||
print()
|
||||
print()
|
||||
print()
|
||||
body = {
|
||||
"model": "",
|
||||
"provider": "Copilot",
|
||||
"stream": True,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Tell me somethings about my name"}
|
||||
],
|
||||
"conversation_id": conversation_id
|
||||
}
|
||||
response = requests.post(url, json=body, stream=True)
|
||||
response.raise_for_status()
|
||||
for line in response.iter_lines():
|
||||
if line.startswith(b"data: "):
|
||||
try:
|
||||
json_data = json.loads(line[6:])
|
||||
if json_data.get("error"):
|
||||
print(json_data)
|
||||
break
|
||||
print(json_data.get("choices", [{"delta": {}}])[0]["delta"].get("content", ""), end="")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
@ -20,7 +20,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
working = True
|
||||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
|
||||
|
||||
@classmethod
|
||||
def fetch_completions_models(cls):
|
||||
response = requests.get('https://api.airforce/models', verify=False)
|
||||
@ -34,19 +34,20 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
completions_models = fetch_completions_models.__func__(None)
|
||||
imagine_models = fetch_imagine_models.__func__(None)
|
||||
|
||||
default_model = "gpt-4o-mini"
|
||||
default_image_model = "flux"
|
||||
additional_models_imagine = ["stable-diffusion-xl-base", "stable-diffusion-xl-lightning", "Flux-1.1-Pro"]
|
||||
text_models = completions_models
|
||||
image_models = [*imagine_models, *additional_models_imagine]
|
||||
models = [
|
||||
*text_models,
|
||||
*image_models,
|
||||
]
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
if not cls.models:
|
||||
cls.image_models = [*cls.fetch_imagine_models(), *cls.additional_models_imagine]
|
||||
cls.models = [
|
||||
*cls.fetch_completions_models(),
|
||||
*cls.image_models
|
||||
]
|
||||
return cls.models
|
||||
|
||||
model_aliases = {
|
||||
### completions ###
|
||||
# openchat
|
||||
@ -100,7 +101,6 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
|
||||
if model in cls.image_models:
|
||||
return cls._generate_image(model, messages, proxy, seed, size)
|
||||
else:
|
||||
|
@ -20,15 +20,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
_last_validated_value = None
|
||||
|
||||
|
||||
default_model = 'blackboxai'
|
||||
default_vision_model = default_model
|
||||
default_image_model = 'Image Generation'
|
||||
image_models = ['Image Generation', 'repomap']
|
||||
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'gemini-1.5-flash', 'llama-3.1-8b', 'llama-3.1-70b', 'llama-3.1-405b']
|
||||
|
||||
|
||||
userSelectedModel = ['gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
|
||||
|
||||
|
||||
agentMode = {
|
||||
'Image Generation': {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
|
||||
}
|
||||
@ -77,22 +77,21 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
}
|
||||
|
||||
additional_prefixes = {
|
||||
'gpt-4o': '@gpt-4o',
|
||||
'gemini-pro': '@gemini-pro',
|
||||
'claude-sonnet-3.5': '@claude-sonnet'
|
||||
}
|
||||
'gpt-4o': '@gpt-4o',
|
||||
'gemini-pro': '@gemini-pro',
|
||||
'claude-sonnet-3.5': '@claude-sonnet'
|
||||
}
|
||||
|
||||
model_prefixes = {
|
||||
**{mode: f"@{value['id']}" for mode, value in trendingAgentMode.items()
|
||||
if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]},
|
||||
**additional_prefixes
|
||||
}
|
||||
**{
|
||||
mode: f"@{value['id']}" for mode, value in trendingAgentMode.items()
|
||||
if mode not in ["gemini-1.5-flash", "llama-3.1-8b", "llama-3.1-70b", "llama-3.1-405b", "repomap"]
|
||||
},
|
||||
**additional_prefixes
|
||||
}
|
||||
|
||||
|
||||
models = list(dict.fromkeys([default_model, *userSelectedModel, *list(agentMode.keys()), *list(trendingAgentMode.keys())]))
|
||||
|
||||
|
||||
|
||||
model_aliases = {
|
||||
"gemini-flash": "gemini-1.5-flash",
|
||||
"claude-3.5-sonnet": "claude-sonnet-3.5",
|
||||
@ -131,12 +130,11 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
||||
return cls._last_validated_value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_id(length=7):
|
||||
characters = string.ascii_letters + string.digits
|
||||
return ''.join(random.choice(characters) for _ in range(length))
|
||||
|
||||
|
||||
@classmethod
|
||||
def add_prefix_to_messages(cls, messages: Messages, model: str) -> Messages:
|
||||
prefix = cls.model_prefixes.get(model, "")
|
||||
@ -157,6 +155,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
web_search: bool = False,
|
||||
image: ImageType = None,
|
||||
@ -191,7 +190,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
'sec-fetch-site': 'same-origin',
|
||||
'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36'
|
||||
}
|
||||
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
"id": message_id,
|
||||
@ -221,26 +220,25 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
response_text = await response.text()
|
||||
|
||||
|
||||
if model in cls.image_models:
|
||||
image_matches = re.findall(r'!\[.*?\]\((https?://[^\)]+)\)', response_text)
|
||||
if image_matches:
|
||||
image_url = image_matches[0]
|
||||
image_response = ImageResponse(images=[image_url], alt="Generated Image")
|
||||
yield image_response
|
||||
yield ImageResponse(image_url, prompt)
|
||||
return
|
||||
|
||||
response_text = re.sub(r'Generated by BLACKBOX.AI, try unlimited chat https://www.blackbox.ai', '', response_text, flags=re.DOTALL)
|
||||
|
||||
|
||||
json_match = re.search(r'\$~~~\$(.*?)\$~~~\$', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
search_results = json.loads(json_match.group(1))
|
||||
answer = response_text.split('$~~~$')[-1].strip()
|
||||
|
||||
|
||||
formatted_response = f"{answer}\n\n**Source:**"
|
||||
for i, result in enumerate(search_results, 1):
|
||||
formatted_response += f"\n{i}. {result['title']}: {result['link']}"
|
||||
|
||||
|
||||
yield formatted_response
|
||||
else:
|
||||
yield response_text.strip()
|
||||
|
@ -57,6 +57,7 @@ class Copilot(AbstractProvider):
|
||||
image: ImageType = None,
|
||||
conversation: Conversation = None,
|
||||
return_conversation: bool = False,
|
||||
web_search: bool = True,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
if not has_curl_cffi:
|
||||
@ -124,12 +125,14 @@ class Copilot(AbstractProvider):
|
||||
is_started = False
|
||||
msg = None
|
||||
image_prompt: str = None
|
||||
last_msg = None
|
||||
while True:
|
||||
try:
|
||||
msg = wss.recv()[0]
|
||||
msg = json.loads(msg)
|
||||
except:
|
||||
break
|
||||
last_msg = msg
|
||||
if msg.get("event") == "appendText":
|
||||
is_started = True
|
||||
yield msg.get("text")
|
||||
@ -139,8 +142,12 @@ class Copilot(AbstractProvider):
|
||||
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
|
||||
elif msg.get("event") == "done":
|
||||
break
|
||||
elif msg.get("event") == "error":
|
||||
raise RuntimeError(f"Error: {msg}")
|
||||
elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]:
|
||||
debug.log(f"Copilot Message: {msg}")
|
||||
if not is_started:
|
||||
raise RuntimeError(f"Last message: {msg}")
|
||||
raise RuntimeError(f"Invalid response: {last_msg}")
|
||||
|
||||
@classmethod
|
||||
async def get_access_token_and_cookies(cls, proxy: str = None):
|
||||
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from urllib.parse import quote
|
||||
import random
|
||||
import requests
|
||||
from sys import maxsize
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
@ -40,6 +39,7 @@ class PollinationsAI(OpenaiAPI):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
api_base: str = "https://text.pollinations.ai/openai",
|
||||
api_key: str = None,
|
||||
proxy: str = None,
|
||||
@ -49,9 +49,10 @@ class PollinationsAI(OpenaiAPI):
|
||||
if model:
|
||||
model = cls.get_model(model)
|
||||
if model in cls.image_models:
|
||||
prompt = messages[-1]["content"]
|
||||
if prompt is None:
|
||||
prompt = messages[-1]["content"]
|
||||
if seed is None:
|
||||
seed = random.randint(0, maxsize)
|
||||
seed = random.randint(0, 100000)
|
||||
image = f"https://image.pollinations.ai/prompt/{quote(prompt)}?width=1024&height=1024&seed={int(seed)}&nofeed=true&nologo=true&model={quote(model)}"
|
||||
yield ImageResponse(image, prompt)
|
||||
return
|
||||
|
@ -6,6 +6,8 @@ from aiohttp import ClientSession, ContentTypeError
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ..requests.aiohttp import get_connector
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from .helper import format_prompt
|
||||
from ..image import ImageResponse
|
||||
|
||||
@ -32,10 +34,8 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
'yorickvp/llava-13b',
|
||||
]
|
||||
|
||||
|
||||
|
||||
models = text_models + image_models
|
||||
|
||||
|
||||
model_aliases = {
|
||||
# image_models
|
||||
"sd-3": "stability-ai/stable-diffusion-3",
|
||||
@ -56,23 +56,14 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
# text_models
|
||||
"google-deepmind/gemma-2b-it": "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626",
|
||||
"yorickvp/llava-13b": "80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
|
||||
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str) -> str:
|
||||
if model in cls.models:
|
||||
return model
|
||||
elif model in cls.model_aliases:
|
||||
return cls.model_aliases[model]
|
||||
else:
|
||||
return cls.default_model
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
proxy: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
@ -96,29 +87,30 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
async with ClientSession(headers=headers) as session:
|
||||
if model in cls.image_models:
|
||||
prompt = messages[-1]['content'] if messages else ""
|
||||
else:
|
||||
prompt = format_prompt(messages)
|
||||
|
||||
async with ClientSession(headers=headers, connector=get_connector(proxy=proxy)) as session:
|
||||
if prompt is None:
|
||||
if model in cls.image_models:
|
||||
prompt = messages[-1]['content']
|
||||
else:
|
||||
prompt = format_prompt(messages)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"version": cls.model_versions[model],
|
||||
"input": {"prompt": prompt},
|
||||
}
|
||||
|
||||
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async with session.post(cls.api_endpoint, json=data) as response:
|
||||
await raise_for_status(response)
|
||||
result = await response.json()
|
||||
prediction_id = result['id']
|
||||
|
||||
|
||||
poll_url = f"https://homepage.replicate.com/api/poll?id={prediction_id}"
|
||||
max_attempts = 30
|
||||
delay = 5
|
||||
for _ in range(max_attempts):
|
||||
async with session.get(poll_url, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
async with session.get(poll_url) as response:
|
||||
await raise_for_status(response)
|
||||
try:
|
||||
result = await response.json()
|
||||
except ContentTypeError:
|
||||
@ -131,7 +123,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if result['status'] == 'succeeded':
|
||||
if model in cls.image_models:
|
||||
image_url = result['output'][0]
|
||||
yield ImageResponse(image_url, "Generated image")
|
||||
yield ImageResponse(image_url, prompt)
|
||||
return
|
||||
else:
|
||||
for chunk in result['output']:
|
||||
@ -140,6 +132,6 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
elif result['status'] == 'failed':
|
||||
raise Exception(f"Prediction failed: {result.get('error')}")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
|
||||
if result['status'] != 'succeeded':
|
||||
raise Exception("Prediction timed out")
|
||||
|
@ -9,7 +9,7 @@ from urllib.parse import urlencode
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, Sources
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
|
||||
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
@ -23,7 +23,6 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
||||
default_model = 'gpt-4o-mini'
|
||||
models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
|
||||
|
||||
model_aliases = {
|
||||
"llama-3.1-70b": "llama-3.1-70b-versatile",
|
||||
}
|
||||
@ -118,7 +117,7 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
|
||||
if 'url' in json_data and 'title' in json_data:
|
||||
if web_search:
|
||||
sources.append({'title': json_data['title'], 'url': json_data['url']})
|
||||
sources.append(json_data)
|
||||
|
||||
elif 'choices' in json_data:
|
||||
for choice in json_data['choices']:
|
||||
@ -128,5 +127,4 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
yield content
|
||||
|
||||
if web_search and sources:
|
||||
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
|
||||
yield f"\n\n**Source:**\n{sources_text}"
|
||||
yield Sources(sources)
|
@ -1,4 +1,4 @@
|
||||
from ..providers.base_provider import *
|
||||
from ..providers.types import FinishReason, Streaming
|
||||
from ..providers.conversation import BaseConversation
|
||||
from ..providers.types import Streaming
|
||||
from ..providers.response import BaseConversation, Sources, FinishReason
|
||||
from .helper import get_cookies, format_prompt
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import RateLimitError
|
||||
from ...providers.conversation import BaseConversation
|
||||
from ...providers.response import BaseConversation
|
||||
|
||||
class Conversation(BaseConversation):
|
||||
"""
|
||||
|
@ -28,13 +28,14 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
api_key: str = None,
|
||||
cookies: Cookies = None,
|
||||
proxy: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
session = BingCreateImages(cookies, proxy, api_key)
|
||||
yield await session.generate(messages[-1]["content"])
|
||||
yield await session.generate(messages[-1]["content"] if prompt is None else prompt)
|
||||
|
||||
async def generate(self, prompt: str) -> ImageResponse:
|
||||
"""
|
||||
|
@ -29,9 +29,10 @@ class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
prompt: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
yield await cls.create_async(messages[-1]["content"], model, **kwargs)
|
||||
yield await cls.create_async(messages[-1]["content"] if prompt is None else prompt, model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
|
@ -28,7 +28,7 @@ from ...requests.raise_for_status import raise_for_status
|
||||
from ...requests.aiohttp import StreamSession
|
||||
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
from ...providers.conversation import BaseConversation
|
||||
from ...providers.response import BaseConversation
|
||||
from ..helper import format_cookies
|
||||
from ..openai.har_file import get_request_config, NoValidHarFileError
|
||||
from ..openai.har_file import RequestConfig, arkReq, arkose_url, start_url, conversation_url, backend_url, backend_anon_url
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import json
|
||||
import uvicorn
|
||||
import secrets
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI, Response, Request
|
||||
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
|
||||
@ -13,13 +14,16 @@ from starlette.exceptions import HTTPException
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Optional
|
||||
|
||||
import g4f
|
||||
import g4f.debug
|
||||
from g4f.client import AsyncClient, ChatCompletion
|
||||
from g4f.providers.response import BaseConversation
|
||||
from g4f.client.helper import filter_none
|
||||
from g4f.image import is_accepted_format, images_dir
|
||||
from g4f.typing import Messages
|
||||
from g4f.cookies import read_cookie_files
|
||||
|
||||
@ -63,6 +67,7 @@ class ChatCompletionsConfig(BaseModel):
|
||||
api_key: Optional[str] = None
|
||||
web_search: Optional[bool] = None
|
||||
proxy: Optional[str] = None
|
||||
conversation_id: str = None
|
||||
|
||||
class ImageGenerationConfig(BaseModel):
|
||||
prompt: str
|
||||
@ -98,6 +103,7 @@ class Api:
|
||||
self.client = AsyncClient()
|
||||
self.g4f_api_key = g4f_api_key
|
||||
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
|
||||
self.conversations: dict[str, dict[str, BaseConversation]] = {}
|
||||
|
||||
def register_authorization(self):
|
||||
@self.app.middleware("http")
|
||||
@ -179,12 +185,21 @@ class Api:
|
||||
async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
|
||||
try:
|
||||
config.provider = provider if config.provider is None else config.provider
|
||||
if config.provider is None:
|
||||
config.provider = AppConfig.provider
|
||||
if config.api_key is None and request is not None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is not None:
|
||||
auth_header = auth_header.split(None, 1)[-1]
|
||||
if auth_header and auth_header != "Bearer":
|
||||
config.api_key = auth_header
|
||||
api_key = auth_header.split(None, 1)[-1]
|
||||
if api_key and api_key != "Bearer":
|
||||
config.api_key = api_key
|
||||
|
||||
conversation = return_conversation = None
|
||||
if config.conversation_id is not None and config.provider is not None:
|
||||
return_conversation = True
|
||||
if config.conversation_id in self.conversations:
|
||||
if config.provider in self.conversations[config.conversation_id]:
|
||||
conversation = self.conversations[config.conversation_id][config.provider]
|
||||
|
||||
# Create the completion response
|
||||
response = self.client.chat.completions.create(
|
||||
@ -194,6 +209,11 @@ class Api:
|
||||
"provider": AppConfig.provider,
|
||||
"proxy": AppConfig.proxy,
|
||||
**config.dict(exclude_none=True),
|
||||
**{
|
||||
"conversation_id": None,
|
||||
"return_conversation": return_conversation,
|
||||
"conversation": conversation
|
||||
}
|
||||
},
|
||||
ignored=AppConfig.ignored_providers
|
||||
),
|
||||
@ -206,7 +226,13 @@ class Api:
|
||||
async def streaming():
|
||||
try:
|
||||
async for chunk in response:
|
||||
yield f"data: {json.dumps(chunk.to_json())}\n\n"
|
||||
if isinstance(chunk, BaseConversation):
|
||||
if config.conversation_id is not None and config.provider is not None:
|
||||
if config.conversation_id not in self.conversations:
|
||||
self.conversations[config.conversation_id] = {}
|
||||
self.conversations[config.conversation_id][config.provider] = chunk
|
||||
else:
|
||||
yield f"data: {json.dumps(chunk.to_json())}\n\n"
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except Exception as e:
|
||||
@ -222,7 +248,13 @@ class Api:
|
||||
|
||||
@self.app.post("/v1/images/generate")
|
||||
@self.app.post("/v1/images/generations")
|
||||
async def generate_image(config: ImageGenerationConfig):
|
||||
async def generate_image(config: ImageGenerationConfig, request: Request):
|
||||
if config.api_key is None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is not None:
|
||||
api_key = auth_header.split(None, 1)[-1]
|
||||
if api_key and api_key != "Bearer":
|
||||
config.api_key = api_key
|
||||
try:
|
||||
response = await self.client.images.generate(
|
||||
prompt=config.prompt,
|
||||
@ -234,6 +266,9 @@ class Api:
|
||||
proxy = config.proxy
|
||||
)
|
||||
)
|
||||
for image in response.data:
|
||||
if hasattr(image, "url") and image.url.startswith("/"):
|
||||
image.url = f"{request.base_url}{image.url.lstrip('/')}"
|
||||
return JSONResponse(response.to_json())
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
@ -243,6 +278,18 @@ class Api:
|
||||
async def completions():
|
||||
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
|
||||
|
||||
@self.app.get("/images/{filename}")
|
||||
async def get_image(filename):
|
||||
target = os.path.join(images_dir, filename)
|
||||
|
||||
if not os.path.isfile(target):
|
||||
return Response(status_code=404)
|
||||
|
||||
with open(target, "rb") as f:
|
||||
content_type = is_accepted_format(f.read(12))
|
||||
|
||||
return FileResponse(target, media_type=content_type)
|
||||
|
||||
def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str:
|
||||
last_provider = {} if not image else g4f.get_last_provider(True)
|
||||
provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider
|
||||
|
@ -6,24 +6,27 @@ import random
|
||||
import string
|
||||
import asyncio
|
||||
import base64
|
||||
import aiohttp
|
||||
import logging
|
||||
from typing import Union, AsyncIterator, Iterator, Coroutine
|
||||
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
|
||||
|
||||
from ..providers.base_provider import AsyncGeneratorProvider
|
||||
from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP
|
||||
from ..typing import Messages, Image
|
||||
from ..providers.types import ProviderType, FinishReason, BaseConversation
|
||||
from ..errors import NoImageResponseError
|
||||
from ..image import ImageResponse, copy_images, images_dir
|
||||
from ..typing import Messages, Image, ImageType
|
||||
from ..providers.types import ProviderType
|
||||
from ..providers.response import ResponseType, FinishReason, BaseConversation
|
||||
from ..errors import NoImageResponseError, ModelNotFoundError
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
from ..providers.base_provider import get_running_loop
|
||||
from ..Provider.needs_auth.BingCreateImages import BingCreateImages
|
||||
from ..requests.aiohttp import get_connector
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
|
||||
from .image_models import ImageModels
|
||||
from .types import IterResponse, ImageProvider, Client as BaseClient
|
||||
from .service import get_model_and_provider, get_last_provider, convert_to_provider
|
||||
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_sync_iter, to_async_iterator
|
||||
|
||||
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
|
||||
AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
|
||||
|
||||
try:
|
||||
anext # Python 3.8+
|
||||
except NameError:
|
||||
@ -35,12 +38,12 @@ except NameError:
|
||||
|
||||
# Synchronous iter_response function
|
||||
def iter_response(
|
||||
response: Union[Iterator[str], AsyncIterator[str]],
|
||||
response: Union[Iterator[Union[str, ResponseType]]],
|
||||
stream: bool,
|
||||
response_format: dict = None,
|
||||
max_tokens: int = None,
|
||||
stop: list = None
|
||||
) -> Iterator[Union[ChatCompletion, ChatCompletionChunk]]:
|
||||
response_format: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[list[str]] = None
|
||||
) -> ChatCompletionResponseType:
|
||||
content = ""
|
||||
finish_reason = None
|
||||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
||||
@ -88,22 +91,23 @@ def iter_response(
|
||||
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
|
||||
|
||||
# Synchronous iter_append_model_and_provider function
|
||||
def iter_append_model_and_provider(response: Iterator[ChatCompletionChunk]) -> Iterator[ChatCompletionChunk]:
|
||||
def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType:
|
||||
last_provider = None
|
||||
|
||||
for chunk in response:
|
||||
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
||||
chunk.model = last_provider.get("model")
|
||||
chunk.provider = last_provider.get("name")
|
||||
yield chunk
|
||||
if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
|
||||
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
||||
chunk.model = last_provider.get("model")
|
||||
chunk.provider = last_provider.get("name")
|
||||
yield chunk
|
||||
|
||||
async def async_iter_response(
|
||||
response: AsyncIterator[str],
|
||||
response: AsyncIterator[Union[str, ResponseType]],
|
||||
stream: bool,
|
||||
response_format: dict = None,
|
||||
max_tokens: int = None,
|
||||
stop: list = None
|
||||
) -> AsyncIterator[Union[ChatCompletion, ChatCompletionChunk]]:
|
||||
response_format: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[list[str]] = None
|
||||
) -> AsyncChatCompletionResponseType:
|
||||
content = ""
|
||||
finish_reason = None
|
||||
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
|
||||
@ -149,13 +153,16 @@ async def async_iter_response(
|
||||
if hasattr(response, 'aclose'):
|
||||
await safe_aclose(response)
|
||||
|
||||
async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompletionChunk]) -> AsyncIterator:
|
||||
async def async_iter_append_model_and_provider(
|
||||
response: AsyncChatCompletionResponseType
|
||||
) -> AsyncChatCompletionResponseType:
|
||||
last_provider = None
|
||||
try:
|
||||
async for chunk in response:
|
||||
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
||||
chunk.model = last_provider.get("model")
|
||||
chunk.provider = last_provider.get("name")
|
||||
if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
|
||||
last_provider = get_last_provider(True) if last_provider is None else last_provider
|
||||
chunk.model = last_provider.get("model")
|
||||
chunk.provider = last_provider.get("name")
|
||||
yield chunk
|
||||
finally:
|
||||
if hasattr(response, 'aclose'):
|
||||
@ -164,8 +171,8 @@ async def async_iter_append_model_and_provider(response: AsyncIterator[ChatCompl
|
||||
class Client(BaseClient):
|
||||
def __init__(
|
||||
self,
|
||||
provider: ProviderType = None,
|
||||
image_provider: ImageProvider = None,
|
||||
provider: Optional[ProviderType] = None,
|
||||
image_provider: Optional[ImageProvider] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -173,7 +180,7 @@ class Client(BaseClient):
|
||||
self.images: Images = Images(self, image_provider)
|
||||
|
||||
class Completions:
|
||||
def __init__(self, client: Client, provider: ProviderType = None):
|
||||
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
||||
self.client: Client = client
|
||||
self.provider: ProviderType = provider
|
||||
|
||||
@ -181,16 +188,16 @@ class Completions:
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
provider: ProviderType = None,
|
||||
stream: bool = False,
|
||||
proxy: str = None,
|
||||
response_format: dict = None,
|
||||
max_tokens: int = None,
|
||||
stop: Union[list[str], str] = None,
|
||||
api_key: str = None,
|
||||
ignored: list[str] = None,
|
||||
ignore_working: bool = False,
|
||||
ignore_stream: bool = False,
|
||||
provider: Optional[ProviderType] = None,
|
||||
stream: Optional[bool] = False,
|
||||
proxy: Optional[str] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[list[str], str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
ignored: Optional[list[str]] = None,
|
||||
ignore_working: Optional[bool] = False,
|
||||
ignore_stream: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> IterResponse:
|
||||
model, provider = get_model_and_provider(
|
||||
@ -234,22 +241,38 @@ class Completions:
|
||||
class Chat:
|
||||
completions: Completions
|
||||
|
||||
def __init__(self, client: Client, provider: ProviderType = None):
|
||||
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
||||
self.completions = Completions(client, provider)
|
||||
|
||||
class Images:
|
||||
def __init__(self, client: Client, provider: ProviderType = None):
|
||||
def __init__(self, client: Client, provider: Optional[ProviderType] = None):
|
||||
self.client: Client = client
|
||||
self.provider: ProviderType = provider
|
||||
self.provider: Optional[ProviderType] = provider
|
||||
self.models: ImageModels = ImageModels(client)
|
||||
|
||||
def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = None,
|
||||
provider: Optional[ProviderType] = None,
|
||||
response_format: str = "url",
|
||||
proxy: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
"""
|
||||
Synchronous generate method that runs the async_generate method in an event loop.
|
||||
"""
|
||||
return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy, **kwargs))
|
||||
return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
|
||||
|
||||
async def async_generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
|
||||
async def async_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[ProviderType] = None,
|
||||
response_format: Optional[str] = "url",
|
||||
proxy: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
if provider is None:
|
||||
provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
|
||||
elif isinstance(provider, str):
|
||||
@ -257,97 +280,73 @@ class Images:
|
||||
else:
|
||||
provider_handler = provider
|
||||
if provider_handler is None:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
if proxy is None:
|
||||
proxy = self.client.proxy
|
||||
|
||||
raise ModelNotFoundError(f"Unknown model: {model}")
|
||||
if isinstance(provider_handler, IterListProvider):
|
||||
if provider_handler.providers:
|
||||
provider_handler = provider_handler.providers[0]
|
||||
else:
|
||||
raise ValueError(f"IterListProvider for model {model} has no providers")
|
||||
raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
|
||||
if proxy is None:
|
||||
proxy = self.client.proxy
|
||||
|
||||
response = None
|
||||
if hasattr(provider_handler, "create_async_generator"):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for item in provider_handler.create_async_generator(model, messages, **kwargs):
|
||||
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
|
||||
messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
|
||||
async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs):
|
||||
if isinstance(item, ImageResponse):
|
||||
response = item
|
||||
break
|
||||
elif hasattr(provider, 'create'):
|
||||
elif hasattr(provider_handler, 'create'):
|
||||
if asyncio.iscoroutinefunction(provider_handler.create):
|
||||
response = await provider_handler.create(prompt)
|
||||
else:
|
||||
response = provider_handler.create(prompt)
|
||||
if isinstance(response, str):
|
||||
response = ImageResponse([response], prompt)
|
||||
elif hasattr(provider_handler, "create_completion"):
|
||||
get_running_loop(check_nested=True)
|
||||
messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
|
||||
for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs):
|
||||
if isinstance(item, ImageResponse):
|
||||
response = item
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Provider {provider} does not support image generation")
|
||||
if isinstance(response, ImageResponse):
|
||||
return await self._process_image_response(response, response_format, proxy, model=model, provider=provider)
|
||||
|
||||
return await self._process_image_response(
|
||||
response,
|
||||
response_format,
|
||||
proxy,
|
||||
model,
|
||||
getattr(provider_handler, "__name__", None)
|
||||
)
|
||||
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
|
||||
|
||||
async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse:
|
||||
async def process_image_item(session: aiohttp.ClientSession, image_data: str):
|
||||
image_data_bytes = None
|
||||
if image_data.startswith("http://") or image_data.startswith("https://"):
|
||||
if response_format == "url":
|
||||
return Image(url=image_data, revised_prompt=response.alt)
|
||||
elif response_format == "b64_json":
|
||||
# Fetch the image data and convert it to base64
|
||||
image_data_bytes = await self._fetch_image(session, image_data)
|
||||
b64_json = base64.b64encode(image_data_bytes).decode("utf-8")
|
||||
return Image(b64_json=b64_json, url=image_data, revised_prompt=response.alt)
|
||||
else:
|
||||
# Assume image_data is base64 data or binary
|
||||
if response_format == "url":
|
||||
if image_data.startswith("data:image"):
|
||||
# Remove the data URL scheme and get the base64 data
|
||||
base64_data = image_data.split(",", 1)[-1]
|
||||
else:
|
||||
base64_data = image_data
|
||||
# Decode the base64 data
|
||||
image_data_bytes = base64.b64decode(base64_data)
|
||||
if image_data_bytes:
|
||||
file_name = self._save_image(image_data_bytes)
|
||||
return Image(url=file_name, revised_prompt=response.alt)
|
||||
else:
|
||||
raise ValueError("Unable to process image data")
|
||||
|
||||
last_provider = get_last_provider(True)
|
||||
async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session:
|
||||
return ImagesResponse(
|
||||
await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]),
|
||||
model=last_provider.get("model") if model is None else model,
|
||||
provider=last_provider.get("name") if provider is None else provider
|
||||
)
|
||||
|
||||
async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes:
|
||||
# Asynchronously fetch image data from the URL
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.read()
|
||||
else:
|
||||
raise RuntimeError(f"Failed to fetch image from {url}, status code {resp.status}")
|
||||
|
||||
def _save_image(self, image_data_bytes: bytes) -> str:
|
||||
os.makedirs('generated_images', exist_ok=True)
|
||||
image = to_image(image_data_bytes)
|
||||
file_name = f"generated_images/image_{int(time.time())}_{random.randint(0, 10000)}.{EXTENSIONS_MAP[is_accepted_format(image_data_bytes)]}"
|
||||
image.save(file_name)
|
||||
return file_name
|
||||
|
||||
def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
|
||||
def create_variation(
|
||||
self,
|
||||
image: Union[str, bytes],
|
||||
model: str = None,
|
||||
provider: Optional[ProviderType] = None,
|
||||
response_format: str = "url",
|
||||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
return asyncio.run(self.async_create_variation(
|
||||
image, model, provider, response_format, **kwargs
|
||||
))
|
||||
|
||||
async def async_create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse:
|
||||
async def async_create_variation(
|
||||
self,
|
||||
image: ImageType,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[ProviderType] = None,
|
||||
response_format: str = "url",
|
||||
proxy: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
if provider is None:
|
||||
provider = self.models.get(model, provider or self.provider or BingCreateImages)
|
||||
if provider is None:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
raise ModelNotFoundError(f"Unknown model: {model}")
|
||||
if isinstance(provider, str):
|
||||
provider = convert_to_provider(provider)
|
||||
if proxy is None:
|
||||
@ -355,38 +354,61 @@ class Images:
|
||||
|
||||
if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
|
||||
messages = [{"role": "user", "content": "create a variation of this image"}]
|
||||
image_data = to_data_uri(image)
|
||||
generator = None
|
||||
try:
|
||||
generator = provider.create_async_generator(model, messages, image=image_data, response_format=response_format, proxy=proxy, **kwargs)
|
||||
async for response in generator:
|
||||
if isinstance(response, ImageResponse):
|
||||
return self._process_image_response(response)
|
||||
except RuntimeError as e:
|
||||
if "async generator ignored GeneratorExit" in str(e):
|
||||
logging.warning("Generator ignored GeneratorExit in create_variation, handling gracefully")
|
||||
else:
|
||||
raise
|
||||
generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
|
||||
async for chunk in generator:
|
||||
if isinstance(chunk, ImageResponse):
|
||||
response = chunk
|
||||
break
|
||||
finally:
|
||||
if generator and hasattr(generator, 'aclose'):
|
||||
await safe_aclose(generator)
|
||||
logging.info("AsyncGeneratorProvider processing completed in create_variation")
|
||||
elif hasattr(provider, 'create_variation'):
|
||||
if asyncio.iscoroutinefunction(provider.create_variation):
|
||||
response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
|
||||
else:
|
||||
response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
|
||||
if isinstance(response, str):
|
||||
response = ImageResponse([response])
|
||||
return self._process_image_response(response)
|
||||
else:
|
||||
raise ValueError(f"Provider {provider} does not support image variation")
|
||||
raise NoImageResponseError(f"Provider {provider} does not support image variation")
|
||||
|
||||
if isinstance(response, str):
|
||||
response = ImageResponse([response])
|
||||
if isinstance(response, ImageResponse):
|
||||
return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None))
|
||||
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
|
||||
|
||||
async def _process_image_response(
|
||||
self,
|
||||
response: ImageResponse,
|
||||
response_format: str,
|
||||
proxy: str = None,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[str] = None
|
||||
) -> list[Image]:
|
||||
if response_format in ("url", "b64_json"):
|
||||
images = await copy_images(response.get_list(), response.options.get("cookies"), proxy)
|
||||
async def process_image_item(image_file: str) -> Image:
|
||||
if response_format == "b64_json":
|
||||
with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file:
|
||||
image_data = base64.b64encode(file.read()).decode()
|
||||
return Image(url=image_file, b64_json=image_data, revised_prompt=response.alt)
|
||||
return Image(url=image_file, revised_prompt=response.alt)
|
||||
images = await asyncio.gather(*[process_image_item(image) for image in images])
|
||||
else:
|
||||
images = [Image(url=image, revised_prompt=response.alt) for image in response.get_list()]
|
||||
last_provider = get_last_provider(True)
|
||||
return ImagesResponse(
|
||||
images,
|
||||
model=last_provider.get("model") if model is None else model,
|
||||
provider=last_provider.get("name") if provider is None else provider
|
||||
)
|
||||
|
||||
class AsyncClient(BaseClient):
|
||||
def __init__(
|
||||
self,
|
||||
provider: ProviderType = None,
|
||||
image_provider: ImageProvider = None,
|
||||
provider: Optional[ProviderType] = None,
|
||||
image_provider: Optional[ImageProvider] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -396,11 +418,11 @@ class AsyncClient(BaseClient):
|
||||
class AsyncChat:
|
||||
completions: AsyncCompletions
|
||||
|
||||
def __init__(self, client: AsyncClient, provider: ProviderType = None):
|
||||
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
||||
self.completions = AsyncCompletions(client, provider)
|
||||
|
||||
class AsyncCompletions:
|
||||
def __init__(self, client: AsyncClient, provider: ProviderType = None):
|
||||
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
||||
self.client: AsyncClient = client
|
||||
self.provider: ProviderType = provider
|
||||
|
||||
@ -408,18 +430,18 @@ class AsyncCompletions:
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
provider: ProviderType = None,
|
||||
stream: bool = False,
|
||||
proxy: str = None,
|
||||
response_format: dict = None,
|
||||
max_tokens: int = None,
|
||||
stop: Union[list[str], str] = None,
|
||||
api_key: str = None,
|
||||
ignored: list[str] = None,
|
||||
ignore_working: bool = False,
|
||||
ignore_stream: bool = False,
|
||||
provider: Optional[ProviderType] = None,
|
||||
stream: Optional[bool] = False,
|
||||
proxy: Optional[str] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[list[str], str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
ignored: Optional[list[str]] = None,
|
||||
ignore_working: Optional[bool] = False,
|
||||
ignore_stream: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk]]:
|
||||
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]:
|
||||
model, provider = get_model_and_provider(
|
||||
model,
|
||||
self.provider if provider is None else provider,
|
||||
@ -450,15 +472,29 @@ class AsyncCompletions:
|
||||
return response if stream else anext(response)
|
||||
|
||||
class AsyncImages(Images):
|
||||
def __init__(self, client: AsyncClient, provider: ImageProvider = None):
|
||||
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
|
||||
self.client: AsyncClient = client
|
||||
self.provider: ImageProvider = provider
|
||||
self.provider: Optional[ProviderType] = provider
|
||||
self.models: ImageModels = ImageModels(client)
|
||||
|
||||
async def generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[ProviderType] = None,
|
||||
response_format: str = "url",
|
||||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
return await self.async_generate(prompt, model, provider, response_format, **kwargs)
|
||||
|
||||
async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
|
||||
async def create_variation(
|
||||
self,
|
||||
image: ImageType,
|
||||
model: str = None,
|
||||
provider: ProviderType = None,
|
||||
response_format: str = "url",
|
||||
**kwargs
|
||||
) -> ImagesResponse:
|
||||
return await self.async_create_variation(
|
||||
image, model, provider, response_format, **kwargs
|
||||
)
|
||||
)
|
@ -6,7 +6,7 @@ import threading
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, Iterator, AsyncGenerator
|
||||
from typing import AsyncIterator, Iterator, AsyncGenerator, Optional
|
||||
|
||||
def filter_json(text: str) -> str:
|
||||
"""
|
||||
@ -23,7 +23,7 @@ def filter_json(text: str) -> str:
|
||||
return match.group("code")
|
||||
return text
|
||||
|
||||
def find_stop(stop, content: str, chunk: str = None):
|
||||
def find_stop(stop: Optional[list[str]], content: str, chunk: str = None):
|
||||
first = -1
|
||||
word = None
|
||||
if stop is not None:
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .types import Client, ImageProvider
|
||||
|
||||
from ..models import ModelUtils
|
||||
|
||||
class ImageModels():
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import os
|
||||
|
||||
from .stubs import ChatCompletion, ChatCompletionChunk
|
||||
from ..providers.types import BaseProvider, ProviderType, FinishReason
|
||||
from ..providers.types import BaseProvider
|
||||
from typing import Union, Iterator, AsyncIterator
|
||||
|
||||
ImageProvider = Union[BaseProvider, object]
|
||||
|
@ -111,6 +111,11 @@
|
||||
<input type="checkbox" id="hide-systemPrompt" />
|
||||
<label for="hide-systemPrompt" class="toogle" title="For more space on phones"></label>
|
||||
</div>
|
||||
<div class="field">
|
||||
<span class="label">Download generated images</span>
|
||||
<input type="checkbox" id="download_images" checked/>
|
||||
<label for="download_images" class="toogle" title="Download and save generated images to /generated_images"></label>
|
||||
</div>
|
||||
<div class="field">
|
||||
<span class="label">Auto continue in ChatGPT</span>
|
||||
<input id="auto_continue" type="checkbox" name="auto_continue" checked/>
|
||||
|
@ -498,6 +498,7 @@ body {
|
||||
gap: 12px;
|
||||
cursor: pointer;
|
||||
animation: show_popup 0.4s;
|
||||
height: 28px;
|
||||
}
|
||||
|
||||
.toolbar .regenerate {
|
||||
|
@ -428,10 +428,14 @@ async function add_message_chunk(message, message_id) {
|
||||
p.innerText = message.log;
|
||||
log_storage.appendChild(p);
|
||||
}
|
||||
window.scrollTo(0, 0);
|
||||
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
|
||||
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
|
||||
let scroll_down = ()=>{
|
||||
if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) {
|
||||
window.scrollTo(0, 0);
|
||||
message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" });
|
||||
}
|
||||
}
|
||||
setTimeout(scroll_down, 200);
|
||||
setTimeout(scroll_down, 1000);
|
||||
}
|
||||
|
||||
cameraInput?.addEventListener("click", (e) => {
|
||||
@ -492,6 +496,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
||||
const file = input && input.files.length > 0 ? input.files[0] : null;
|
||||
const provider = providerSelect.options[providerSelect.selectedIndex].value;
|
||||
const auto_continue = document.getElementById("auto_continue")?.checked;
|
||||
const download_images = document.getElementById("download_images")?.checked;
|
||||
let api_key = get_api_key_by_provider(provider);
|
||||
await api("conversation", {
|
||||
id: message_id,
|
||||
@ -501,13 +506,13 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
||||
provider: provider,
|
||||
messages: messages,
|
||||
auto_continue: auto_continue,
|
||||
download_images: download_images,
|
||||
api_key: api_key,
|
||||
}, file, message_id);
|
||||
if (!error_storage[message_id]) {
|
||||
html = markdown_render(message_storage[message_id]);
|
||||
content_map.inner.innerHTML = html;
|
||||
highlight(content_map.inner);
|
||||
|
||||
if (imageInput) imageInput.value = "";
|
||||
if (cameraInput) cameraInput.value = "";
|
||||
if (fileInput) fileInput.value = "";
|
||||
@ -1239,8 +1244,7 @@ async function load_version() {
|
||||
if (versions["version"] != versions["latest_version"]) {
|
||||
let release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"];
|
||||
let title = `New version: ${versions["latest_version"]}`;
|
||||
text += `<a href="${release_url}" target="_blank" title="${title}">${versions["version"]}</a> `;
|
||||
text += `<i class="fa-solid fa-rotate"></i>`
|
||||
text += `<a href="${release_url}" target="_blank" title="${title}">${versions["version"]}</a> 🆕`;
|
||||
} else {
|
||||
text += versions["version"];
|
||||
}
|
||||
|
@ -2,34 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import time
|
||||
from aiohttp import ClientSession
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterator
|
||||
from flask import send_from_directory
|
||||
from inspect import signature
|
||||
|
||||
from g4f import version, models
|
||||
from g4f import get_last_provider, ChatCompletion
|
||||
from g4f.errors import VersionNotFoundError
|
||||
from g4f.typing import Cookies
|
||||
from g4f.image import ImagePreview, ImageResponse, is_accepted_format, extract_data_uri
|
||||
from g4f.requests.aiohttp import get_connector
|
||||
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
|
||||
from g4f.Provider import ProviderType, __providers__, __map__
|
||||
from g4f.providers.base_provider import ProviderModelMixin, FinishReason
|
||||
from g4f.providers.conversation import BaseConversation
|
||||
from g4f.providers.base_provider import ProviderModelMixin
|
||||
from g4f.providers.response import BaseConversation, FinishReason
|
||||
from g4f.client.service import convert_to_provider
|
||||
from g4f import debug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define the directory for generated images
|
||||
images_dir = "./generated_images"
|
||||
|
||||
# Function to ensure the images directory exists
|
||||
def ensure_images_dir():
|
||||
if not os.path.exists(images_dir):
|
||||
os.makedirs(images_dir)
|
||||
|
||||
conversations: dict[dict[str, BaseConversation]] = {}
|
||||
|
||||
class Api:
|
||||
@ -42,7 +30,10 @@ class Api:
|
||||
if provider in __map__:
|
||||
provider: ProviderType = __map__[provider]
|
||||
if issubclass(provider, ProviderModelMixin):
|
||||
models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key)
|
||||
if api_key is not None and "api_key" in signature(provider.get_models).parameters:
|
||||
models = provider.get_models(api_key=api_key)
|
||||
else:
|
||||
models = provider.get_models()
|
||||
return [
|
||||
{
|
||||
"model": model,
|
||||
@ -90,7 +81,7 @@ class Api:
|
||||
def get_providers() -> list[str]:
|
||||
return {
|
||||
provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__)
|
||||
+ (" (Image Generation)" if hasattr(provider, "image_models") else "")
|
||||
+ (" (Image Generation)" if getattr(provider, "image_models", None) else "")
|
||||
+ (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "")
|
||||
+ (" (WebDriver)" if "webdriver" in provider.get_parameters() else "")
|
||||
+ (" (Auth)" if provider.needs_auth else "")
|
||||
@ -120,16 +111,23 @@ class Api:
|
||||
api_key = json_data.get("api_key")
|
||||
if api_key is not None:
|
||||
kwargs["api_key"] = api_key
|
||||
if json_data.get('web_search'):
|
||||
if provider:
|
||||
kwargs['web_search'] = True
|
||||
else:
|
||||
from .internet import get_search_message
|
||||
messages[-1]["content"] = get_search_message(messages[-1]["content"])
|
||||
do_web_search = json_data.get('web_search')
|
||||
if do_web_search and provider:
|
||||
provider_handler = convert_to_provider(provider)
|
||||
if hasattr(provider_handler, "get_parameters"):
|
||||
if "web_search" in provider_handler.get_parameters():
|
||||
kwargs['web_search'] = True
|
||||
do_web_search = False
|
||||
if do_web_search:
|
||||
from .internet import get_search_message
|
||||
messages[-1]["content"] = get_search_message(messages[-1]["content"])
|
||||
if json_data.get("auto_continue"):
|
||||
kwargs['auto_continue'] = True
|
||||
|
||||
conversation_id = json_data.get("conversation_id")
|
||||
if conversation_id and provider in conversations and conversation_id in conversations[provider]:
|
||||
kwargs["conversation"] = conversations[provider][conversation_id]
|
||||
if conversation_id and provider:
|
||||
if provider in conversations and conversation_id in conversations[provider]:
|
||||
kwargs["conversation"] = conversations[provider][conversation_id]
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
@ -141,7 +139,7 @@ class Api:
|
||||
**kwargs
|
||||
}
|
||||
|
||||
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator:
|
||||
def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
|
||||
if debug.logging:
|
||||
debug.logs = []
|
||||
print_callback = debug.log_handler
|
||||
@ -163,18 +161,22 @@ class Api:
|
||||
first = False
|
||||
yield self._format_json("provider", get_last_provider(True))
|
||||
if isinstance(chunk, BaseConversation):
|
||||
if provider not in conversations:
|
||||
conversations[provider] = {}
|
||||
conversations[provider][conversation_id] = chunk
|
||||
yield self._format_json("conversation", conversation_id)
|
||||
if provider:
|
||||
if provider not in conversations:
|
||||
conversations[provider] = {}
|
||||
conversations[provider][conversation_id] = chunk
|
||||
yield self._format_json("conversation", conversation_id)
|
||||
elif isinstance(chunk, Exception):
|
||||
logger.exception(chunk)
|
||||
yield self._format_json("message", get_error_message(chunk))
|
||||
elif isinstance(chunk, ImagePreview):
|
||||
yield self._format_json("preview", chunk.to_string())
|
||||
elif isinstance(chunk, ImageResponse):
|
||||
images = asyncio.run(self._copy_images(chunk.get_list(), chunk.options.get("cookies")))
|
||||
yield self._format_json("content", str(ImageResponse(images, chunk.alt)))
|
||||
images = chunk
|
||||
if download_images:
|
||||
images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies")))
|
||||
images = ImageResponse(images, chunk.alt)
|
||||
yield self._format_json("content", str(images))
|
||||
elif not isinstance(chunk, FinishReason):
|
||||
yield self._format_json("content", str(chunk))
|
||||
if debug.logs:
|
||||
@ -185,31 +187,6 @@ class Api:
|
||||
logger.exception(e)
|
||||
yield self._format_json('error', get_error_message(e))
|
||||
|
||||
async def _copy_images(self, images: list[str], cookies: Optional[Cookies] = None):
|
||||
ensure_images_dir()
|
||||
async with ClientSession(
|
||||
connector=get_connector(None, os.environ.get("G4F_PROXY")),
|
||||
cookies=cookies
|
||||
) as session:
|
||||
async def copy_image(image: str) -> str:
|
||||
target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}")
|
||||
if image.startswith("data:"):
|
||||
with open(target, "wb") as f:
|
||||
f.write(extract_data_uri(image))
|
||||
else:
|
||||
async with session.get(image) as response:
|
||||
with open(target, "wb") as f:
|
||||
async for chunk in response.content.iter_any():
|
||||
f.write(chunk)
|
||||
with open(target, "rb") as f:
|
||||
extension = is_accepted_format(f.read(12)).split("/")[-1]
|
||||
extension = "jpg" if extension == "jpeg" else extension
|
||||
new_target = f"{target}.{extension}"
|
||||
os.rename(target, new_target)
|
||||
return f"/images/{os.path.basename(new_target)}"
|
||||
|
||||
return await asyncio.gather(*[copy_image(image) for image in images])
|
||||
|
||||
def _format_json(self, response_type: str, content):
|
||||
return {
|
||||
'type': response_type,
|
||||
@ -221,4 +198,4 @@ def get_error_message(exception: Exception) -> str:
|
||||
provider = get_last_provider()
|
||||
if provider is None:
|
||||
return message
|
||||
return f"{provider.__name__}: {message}"
|
||||
return f"{provider.__name__}: {message}"
|
@ -89,7 +89,12 @@ class Backend_Api(Api):
|
||||
kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
|
||||
|
||||
return self.app.response_class(
|
||||
self._create_response_stream(kwargs, json_data.get("conversation_id"), json_data.get("provider")),
|
||||
self._create_response_stream(
|
||||
kwargs,
|
||||
json_data.get("conversation_id"),
|
||||
json_data.get("provider"),
|
||||
json_data.get("download_images", True),
|
||||
),
|
||||
mimetype='text/event-stream'
|
||||
)
|
||||
|
||||
|
@ -8,12 +8,14 @@ try:
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
from ...errors import MissingRequirementsError
|
||||
|
||||
from ... import debug
|
||||
|
||||
import asyncio
|
||||
|
||||
class SearchResults():
|
||||
def __init__(self, results: list):
|
||||
def __init__(self, results: list, used_words: int):
|
||||
self.results = results
|
||||
self.used_words = used_words
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.results
|
||||
@ -104,7 +106,8 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text
|
||||
region="wt-wt",
|
||||
safesearch="moderate",
|
||||
timelimit="y",
|
||||
max_results=n_results
|
||||
max_results=n_results,
|
||||
backend="html"
|
||||
):
|
||||
results.append(SearchResultEntry(
|
||||
result["title"],
|
||||
@ -120,6 +123,7 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text
|
||||
texts = await asyncio.gather(*requests)
|
||||
|
||||
formatted_results = []
|
||||
used_words = 0
|
||||
left_words = max_words
|
||||
for i, entry in enumerate(results):
|
||||
if add_text:
|
||||
@ -132,13 +136,14 @@ async def search(query: str, n_results: int = 5, max_words: int = 2500, add_text
|
||||
left_words -= entry.snippet.count(" ")
|
||||
if 0 > left_words:
|
||||
break
|
||||
used_words = max_words - left_words
|
||||
formatted_results.append(entry)
|
||||
|
||||
return SearchResults(formatted_results)
|
||||
return SearchResults(formatted_results, used_words)
|
||||
|
||||
def get_search_message(prompt) -> str:
|
||||
def get_search_message(prompt, n_results: int = 5, max_words: int = 2500) -> str:
|
||||
try:
|
||||
search_results = asyncio.run(search(prompt))
|
||||
search_results = asyncio.run(search(prompt, n_results, max_words))
|
||||
message = f"""
|
||||
{search_results}
|
||||
|
||||
@ -149,7 +154,8 @@ Make sure to add the sources of cites using [[Number]](Url) notation after the r
|
||||
User request:
|
||||
{prompt}
|
||||
"""
|
||||
debug.log(f"Web search: '{prompt.strip()[:50]}...' {search_results.used_words} Words")
|
||||
return message
|
||||
except Exception as e:
|
||||
print("Couldn't do web search:", e)
|
||||
debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
|
||||
return prompt
|
48
g4f/image.py
48
g4f/image.py
@ -1,9 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import base64
|
||||
from .typing import ImageType, Union, Image
|
||||
import asyncio
|
||||
from aiohttp import ClientSession
|
||||
|
||||
try:
|
||||
from PIL.Image import open as open_image, new as new_image
|
||||
@ -12,7 +16,10 @@ try:
|
||||
except ImportError:
|
||||
has_requirements = False
|
||||
|
||||
from .typing import ImageType, Union, Image, Optional, Cookies
|
||||
from .errors import MissingRequirementsError
|
||||
from .providers.response import ResponseType
|
||||
from .requests.aiohttp import get_connector
|
||||
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
|
||||
|
||||
@ -23,10 +30,12 @@ EXTENSIONS_MAP: dict[str, str] = {
|
||||
"image/webp": "webp",
|
||||
}
|
||||
|
||||
# Define the directory for generated images
|
||||
images_dir = "./generated_images"
|
||||
|
||||
def fix_url(url:str) -> str:
|
||||
""" replace ' ' by '+' (to be markdown compliant)"""
|
||||
return url.replace(" ","+")
|
||||
|
||||
|
||||
def to_image(image: ImageType, is_svg: bool = False) -> Image:
|
||||
"""
|
||||
@ -223,7 +232,6 @@ def format_images_markdown(images: Union[str, list], alt: str, preview: Union[st
|
||||
preview = [preview.replace('{image}', image) if preview else image for image in images]
|
||||
result = "\n".join(
|
||||
f"[![#{idx+1} {alt}]({fix_url(preview[idx])})]({fix_url(image)})"
|
||||
#f'[<img src="{preview[idx]}" width="200" alt="#{idx+1} {alt}">]({image})'
|
||||
for idx, image in enumerate(images)
|
||||
)
|
||||
start_flag = "<!-- generated images start -->\n"
|
||||
@ -260,7 +268,39 @@ def to_data_uri(image: ImageType) -> str:
|
||||
return f"data:{is_accepted_format(data)};base64,{data_base64}"
|
||||
return image
|
||||
|
||||
class ImageResponse:
|
||||
# Function to ensure the images directory exists
|
||||
def ensure_images_dir():
|
||||
if not os.path.exists(images_dir):
|
||||
os.makedirs(images_dir)
|
||||
|
||||
async def copy_images(images: list[str], cookies: Optional[Cookies] = None, proxy: Optional[str] = None):
|
||||
ensure_images_dir()
|
||||
async with ClientSession(
|
||||
connector=get_connector(
|
||||
proxy=os.environ.get("G4F_PROXY") if proxy is None else proxy
|
||||
),
|
||||
cookies=cookies
|
||||
) as session:
|
||||
async def copy_image(image: str) -> str:
|
||||
target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}")
|
||||
if image.startswith("data:"):
|
||||
with open(target, "wb") as f:
|
||||
f.write(extract_data_uri(image))
|
||||
else:
|
||||
async with session.get(image) as response:
|
||||
with open(target, "wb") as f:
|
||||
async for chunk in response.content.iter_chunked(4096):
|
||||
f.write(chunk)
|
||||
with open(target, "rb") as f:
|
||||
extension = is_accepted_format(f.read(12)).split("/")[-1]
|
||||
extension = "jpg" if extension == "jpeg" else extension
|
||||
new_target = f"{target}.{extension}"
|
||||
os.rename(target, new_target)
|
||||
return f"/images/{os.path.basename(new_target)}"
|
||||
|
||||
return await asyncio.gather(*[copy_image(image) for image in images])
|
||||
|
||||
class ImageResponse(ResponseType):
|
||||
def __init__(
|
||||
self,
|
||||
images: Union[str, list],
|
||||
|
@ -10,7 +10,8 @@ from inspect import signature, Parameter
|
||||
from typing import Callable, Union
|
||||
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from .types import BaseProvider, FinishReason
|
||||
from .types import BaseProvider
|
||||
from .response import FinishReason, BaseConversation
|
||||
from ..errors import NestAsyncioError, ModelNotSupportedError
|
||||
from .. import debug
|
||||
|
||||
@ -100,7 +101,7 @@ class AbstractProvider(BaseProvider):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_parameters(cls) -> dict:
|
||||
def get_parameters(cls) -> dict[str, Parameter]:
|
||||
return signature(
|
||||
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
|
||||
cls.create_async if issubclass(cls, AsyncProvider) else
|
||||
@ -258,7 +259,7 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
"""
|
||||
return "".join([
|
||||
str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
|
||||
if not isinstance(chunk, (Exception, FinishReason))
|
||||
if not isinstance(chunk, (Exception, FinishReason, BaseConversation))
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
@ -307,4 +308,4 @@ class ProviderModelMixin:
|
||||
elif model not in cls.get_models() and cls.models:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
debug.last_model = model
|
||||
return model
|
||||
return model
|
@ -1,2 +0,0 @@
|
||||
class BaseConversation:
|
||||
...
|
26
g4f/providers/response.py
Normal file
26
g4f/providers/response.py
Normal file
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
class ResponseType:
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
pass
|
||||
|
||||
class FinishReason():
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
class Sources(ResponseType):
|
||||
def __init__(self, sources: list[dict[str, str]]) -> None:
|
||||
self.list = sources
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "\n\n" + ("\n".join([f"{idx+1}. [{link['title']}]({link['url']})" for idx, link in enumerate(self.list)]))
|
||||
|
||||
class BaseConversation(ResponseType):
|
||||
def __str__(self) -> str:
|
||||
return ""
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union, Dict, Type
|
||||
from ..typing import Messages, CreateResult
|
||||
from .conversation import BaseConversation
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""
|
||||
@ -98,10 +97,6 @@ class BaseRetryProvider(BaseProvider):
|
||||
|
||||
ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
|
||||
|
||||
class FinishReason():
|
||||
def __init__(self, reason: str):
|
||||
self.reason = reason
|
||||
|
||||
class Streaming():
|
||||
def __init__(self, data: str) -> None:
|
||||
self.data = data
|
||||
|
@ -10,11 +10,13 @@ if sys.version_info >= (3, 8):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from .providers.response import ResponseType
|
||||
|
||||
SHA256 = NewType('sha_256_hash', str)
|
||||
CreateResult = Iterator[str]
|
||||
AsyncResult = AsyncIterator[str]
|
||||
Messages = List[Dict[str, Union[str,List[Dict[str,Union[str,Dict[str,str]]]]]]]
|
||||
CreateResult = Iterator[Union[str, ResponseType]]
|
||||
AsyncResult = AsyncIterator[Union[str, ResponseType]]
|
||||
Messages = List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]]
|
||||
Cookies = Dict[str, str]
|
||||
ImageType = Union[str, bytes, IO, Image, None]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user