mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-12-23 11:02:40 +03:00
Add more flux dev image providers
This commit is contained in:
parent
54a6d91cfc
commit
1bfb36176c
58
g4f/Provider/Flux.py
Normal file
58
g4f/Provider/Flux.py
Normal file
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..image import ImageResponse, ImagePreview
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
|
||||
class Flux(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Flux Provider"
|
||||
url = "https://black-forest-labs-flux-1-dev.hf.space"
|
||||
api_endpoint = "/gradio_api/call/infer"
|
||||
working = True
|
||||
default_model = 'flux-1-dev'
|
||||
models = [default_model]
|
||||
image_models = [default_model]
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls, model: str, messages: Messages, prompt: str = None, api_key: str = None, proxy: str = None, **kwargs
|
||||
) -> AsyncResult:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
async with ClientSession(headers=headers) as session:
|
||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
||||
data = {
|
||||
"data": [prompt, 0, True, 1024, 1024, 3.5, 28]
|
||||
}
|
||||
async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
event_id = (await response.json()).get("event_id")
|
||||
async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response:
|
||||
event_response.raise_for_status()
|
||||
event = None
|
||||
async for chunk in event_response.content:
|
||||
if chunk.startswith(b"event: "):
|
||||
event = chunk[7:].decode(errors="replace").strip()
|
||||
if chunk.startswith(b"data: "):
|
||||
if event == "error":
|
||||
raise RuntimeError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}")
|
||||
if event in ("complete", "generating"):
|
||||
try:
|
||||
data = json.loads(chunk[6:])
|
||||
if data is None:
|
||||
continue
|
||||
url = data[0]["url"]
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e)
|
||||
if event == "generating":
|
||||
yield ImagePreview(url, prompt)
|
||||
else:
|
||||
yield ImageResponse(url, prompt)
|
||||
break
|
@ -39,6 +39,7 @@ from .TeachAnything import TeachAnything
|
||||
from .Upstage import Upstage
|
||||
from .You import You
|
||||
from .Mhystical import Mhystical
|
||||
from .Flux import Flux
|
||||
|
||||
import sys
|
||||
|
||||
@ -59,4 +60,4 @@ __map__: dict[str, ProviderType] = dict([
|
||||
])
|
||||
|
||||
class ProviderUtils:
|
||||
convert: dict[str, ProviderType] = __map__
|
||||
convert: dict[str, ProviderType] = __map__
|
@ -12,6 +12,7 @@ from ...typing import CreateResult, Messages, Cookies
|
||||
from ...errors import MissingRequirementsError
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...cookies import get_cookies
|
||||
from ...image import ImageResponse
|
||||
from ..base_provider import ProviderModelMixin, AbstractProvider, BaseConversation
|
||||
from ..helper import format_prompt
|
||||
from ... import debug
|
||||
@ -26,10 +27,12 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
||||
working = True
|
||||
supports_stream = True
|
||||
needs_auth = True
|
||||
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
|
||||
|
||||
default_model = "Qwen/Qwen2.5-72B-Instruct"
|
||||
image_models = [
|
||||
"black-forest-labs/FLUX.1-dev"
|
||||
]
|
||||
models = [
|
||||
'Qwen/Qwen2.5-72B-Instruct',
|
||||
default_model,
|
||||
'meta-llama/Meta-Llama-3.1-70B-Instruct',
|
||||
'CohereForAI/c4ai-command-r-plus-08-2024',
|
||||
'Qwen/QwQ-32B-Preview',
|
||||
@ -39,8 +42,8 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
||||
'NousResearch/Hermes-3-Llama-3.1-8B',
|
||||
'mistralai/Mistral-Nemo-Instruct-2407',
|
||||
'microsoft/Phi-3.5-mini-instruct',
|
||||
*image_models
|
||||
]
|
||||
|
||||
model_aliases = {
|
||||
"qwen-2.5-72b": "Qwen/Qwen2.5-72B-Instruct",
|
||||
"llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
@ -52,6 +55,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
||||
"hermes-3": "NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
|
||||
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
|
||||
"flux-dev": "black-forest-labs/FLUX.1-dev",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -109,7 +113,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
||||
"is_retry": False,
|
||||
"is_continue": False,
|
||||
"web_search": web_search,
|
||||
"tools": []
|
||||
"tools": ["000000000000000000000001"] if model in cls.image_models else [],
|
||||
}
|
||||
|
||||
headers = {
|
||||
@ -162,14 +166,18 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
|
||||
|
||||
elif line["type"] == "finalAnswer":
|
||||
break
|
||||
|
||||
full_response = full_response.replace('<|im_end|', '').replace('\u0000', '').strip()
|
||||
elif line["type"] == "file":
|
||||
url = f"https://huggingface.co/chat/conversation/{conversation.conversation_id}/output/{line['sha']}"
|
||||
yield ImageResponse(url, alt=messages[-1]["content"], options={"cookies": cookies})
|
||||
|
||||
full_response = full_response.replace('<|im_end|', '').replace('\u0000', '').strip()
|
||||
if not stream:
|
||||
yield full_response
|
||||
|
||||
@classmethod
|
||||
def create_conversation(cls, session: Session, model: str):
|
||||
if model in cls.image_models:
|
||||
model = cls.default_model
|
||||
json_data = {
|
||||
'model': model,
|
||||
}
|
||||
|
@ -1,21 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import base64
|
||||
import random
|
||||
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ...errors import ModelNotFoundError
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...image import ImageResponse
|
||||
|
||||
from .HuggingChat import HuggingChat
|
||||
|
||||
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
url = "https://huggingface.co/chat"
|
||||
working = True
|
||||
needs_auth = True
|
||||
supports_message_history = True
|
||||
default_model = HuggingChat.default_model
|
||||
models = HuggingChat.models
|
||||
default_image_model = "black-forest-labs/FLUX.1-dev"
|
||||
models = [*HuggingChat.models, default_image_model]
|
||||
image_models = [default_image_model]
|
||||
model_aliases = HuggingChat.model_aliases
|
||||
|
||||
@classmethod
|
||||
@ -29,6 +33,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
api_key: str = None,
|
||||
max_new_tokens: int = 1024,
|
||||
temperature: float = 0.7,
|
||||
prompt: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
@ -50,16 +55,22 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
params = {
|
||||
"return_full_text": False,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
**kwargs
|
||||
}
|
||||
payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream}
|
||||
if model in cls.image_models:
|
||||
stream = False
|
||||
prompt = messages[-1]["content"] if prompt is None else prompt
|
||||
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32)}}
|
||||
else:
|
||||
params = {
|
||||
"return_full_text": False,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
**kwargs
|
||||
}
|
||||
payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream}
|
||||
async with StreamSession(
|
||||
headers=headers,
|
||||
proxy=proxy
|
||||
proxy=proxy,
|
||||
timeout=600
|
||||
) as session:
|
||||
async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
|
||||
if response.status == 404:
|
||||
@ -78,7 +89,12 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if chunk:
|
||||
yield chunk
|
||||
else:
|
||||
yield (await response.json())[0]["generated_text"].strip()
|
||||
if response.headers["content-type"].startswith("image/"):
|
||||
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
|
||||
url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}"
|
||||
yield ImageResponse(url, prompt)
|
||||
else:
|
||||
yield (await response.json())[0]["generated_text"].strip()
|
||||
|
||||
def format_prompt(messages: Messages) -> str:
|
||||
system_messages = [message["content"] for message in messages if message["role"] == "system"]
|
||||
|
@ -34,8 +34,8 @@ try:
|
||||
|
||||
browsers = [
|
||||
_g4f,
|
||||
chrome, chromium, opera, opera_gx,
|
||||
brave, edge, vivaldi, firefox,
|
||||
chrome, chromium, firefox, opera, opera_gx,
|
||||
brave, edge, vivaldi,
|
||||
]
|
||||
has_browser_cookie3 = True
|
||||
except ImportError:
|
||||
|
@ -504,6 +504,8 @@ async function add_message_chunk(message, message_id) {
|
||||
p.innerText = message.error;
|
||||
log_storage.appendChild(p);
|
||||
} else if (message.type == "preview") {
|
||||
if (content_map.inner.clientHeight > 200)
|
||||
content_map.inner.style.height = content_map.inner.clientHeight + "px";
|
||||
content_map.inner.innerHTML = markdown_render(message.preview);
|
||||
} else if (message.type == "content") {
|
||||
message_storage[message_id] += message.content;
|
||||
@ -522,6 +524,7 @@ async function add_message_chunk(message, message_id) {
|
||||
content_map.inner.innerHTML = html;
|
||||
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
|
||||
highlight(content_map.inner);
|
||||
content_map.inner.style.height = "";
|
||||
} else if (message.type == "log") {
|
||||
let p = document.createElement("p");
|
||||
p.innerText = message.log;
|
||||
|
@ -123,22 +123,21 @@ class Api:
|
||||
print(text)
|
||||
debug.log_handler = log_handler
|
||||
proxy = os.environ.get("G4F_PROXY")
|
||||
provider = kwargs.get("provider")
|
||||
model, provider_handler = get_model_and_provider(
|
||||
kwargs.get("model"), provider,
|
||||
stream=True,
|
||||
ignore_stream=True
|
||||
)
|
||||
first = True
|
||||
try:
|
||||
model, provider = get_model_and_provider(
|
||||
kwargs.get("model"), kwargs.get("provider"),
|
||||
stream=True,
|
||||
ignore_stream=True
|
||||
)
|
||||
result = ChatCompletion.create(**{**kwargs, "model": model, "provider": provider})
|
||||
first = True
|
||||
result = ChatCompletion.create(**{**kwargs, "model": model, "provider": provider_handler})
|
||||
for chunk in result:
|
||||
if first:
|
||||
first = False
|
||||
if isinstance(provider, IterListProvider):
|
||||
provider = provider.last_provider
|
||||
yield self._format_json("provider", {**provider.get_dict(), "model": model})
|
||||
yield self.handle_provider(provider_handler, model)
|
||||
if isinstance(chunk, BaseConversation):
|
||||
if provider:
|
||||
if provider is not None:
|
||||
if provider not in conversations:
|
||||
conversations[provider] = {}
|
||||
conversations[provider][conversation_id] = chunk
|
||||
@ -165,6 +164,8 @@ class Api:
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
yield self._format_json('error', get_error_message(e))
|
||||
if first:
|
||||
yield self.handle_provider(provider_handler, model)
|
||||
|
||||
def _format_json(self, response_type: str, content):
|
||||
return {
|
||||
@ -172,9 +173,12 @@ class Api:
|
||||
response_type: content
|
||||
}
|
||||
|
||||
def handle_provider(self, provider_handler, model):
|
||||
if isinstance(provider_handler, IterListProvider):
|
||||
provider_handler = provider_handler.last_provider
|
||||
if issubclass(provider_handler, ProviderModelMixin) and provider_handler.last_model is not None:
|
||||
model = provider_handler.last_model
|
||||
return self._format_json("provider", {**provider_handler.get_dict(), "model": model})
|
||||
|
||||
def get_error_message(exception: Exception) -> str:
|
||||
message = f"{type(exception).__name__}: {exception}"
|
||||
provider = get_last_provider()
|
||||
if provider is None:
|
||||
return message
|
||||
return f"{provider.__name__}: {message}"
|
||||
return f"{type(exception).__name__}: {exception}"
|
@ -38,6 +38,7 @@ from .Provider import (
|
||||
RubiksAI,
|
||||
TeachAnything,
|
||||
Upstage,
|
||||
Flux,
|
||||
)
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
@ -599,7 +600,7 @@ flux_pro = ImageModel(
|
||||
flux_dev = ImageModel(
|
||||
name = 'flux-dev',
|
||||
base_provider = 'Flux AI',
|
||||
best_provider = AmigoChat
|
||||
best_provider = IterListProvider([Flux, AmigoChat, HuggingChat, HuggingFace])
|
||||
)
|
||||
|
||||
flux_realism = ImageModel(
|
||||
|
@ -98,7 +98,7 @@ class AbstractProvider(BaseProvider):
|
||||
default_value = f'"{param.default}"' if isinstance(param.default, str) else param.default
|
||||
args += f" = {default_value}" if param.default is not Parameter.empty else ""
|
||||
args += ","
|
||||
|
||||
|
||||
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
|
||||
|
||||
class AsyncProvider(AbstractProvider):
|
||||
@ -240,6 +240,7 @@ class ProviderModelMixin:
|
||||
models: list[str] = []
|
||||
model_aliases: dict[str, str] = {}
|
||||
image_models: list = None
|
||||
last_model: str = None
|
||||
|
||||
@classmethod
|
||||
def get_models(cls) -> list[str]:
|
||||
@ -255,5 +256,6 @@ class ProviderModelMixin:
|
||||
model = cls.model_aliases[model]
|
||||
elif model not in cls.get_models() and cls.models:
|
||||
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
|
||||
cls.last_model = model
|
||||
debug.last_model = model
|
||||
return model
|
Loading…
Reference in New Issue
Block a user