Add more flux dev image providers

This commit is contained in:
Heiner Lohaus 2024-12-08 04:13:09 +01:00
parent 54a6d91cfc
commit 1bfb36176c
9 changed files with 132 additions and 39 deletions

58
g4f/Provider/Flux.py Normal file
View File

@ -0,0 +1,58 @@
from __future__ import annotations
import json
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from ..image import ImageResponse, ImagePreview
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
class Flux(AsyncGeneratorProvider, ProviderModelMixin):
label = "Flux Provider"
url = "https://black-forest-labs-flux-1-dev.hf.space"
api_endpoint = "/gradio_api/call/infer"
working = True
default_model = 'flux-1-dev'
models = [default_model]
image_models = [default_model]
@classmethod
async def create_async_generator(
cls, model: str, messages: Messages, prompt: str = None, api_key: str = None, proxy: str = None, **kwargs
) -> AsyncResult:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
async with ClientSession(headers=headers) as session:
prompt = messages[-1]["content"] if prompt is None else prompt
data = {
"data": [prompt, 0, True, 1024, 1024, 3.5, 28]
}
async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response:
response.raise_for_status()
event_id = (await response.json()).get("event_id")
async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response:
event_response.raise_for_status()
event = None
async for chunk in event_response.content:
if chunk.startswith(b"event: "):
event = chunk[7:].decode(errors="replace").strip()
if chunk.startswith(b"data: "):
if event == "error":
raise RuntimeError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}")
if event in ("complete", "generating"):
try:
data = json.loads(chunk[6:])
if data is None:
continue
url = data[0]["url"]
except (json.JSONDecodeError, KeyError, TypeError) as e:
raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e)
if event == "generating":
yield ImagePreview(url, prompt)
else:
yield ImageResponse(url, prompt)
break

View File

@ -39,6 +39,7 @@ from .TeachAnything import TeachAnything
from .Upstage import Upstage
from .You import You
from .Mhystical import Mhystical
from .Flux import Flux
import sys
@ -59,4 +60,4 @@ __map__: dict[str, ProviderType] = dict([
])
class ProviderUtils:
convert: dict[str, ProviderType] = __map__
convert: dict[str, ProviderType] = __map__

View File

@ -12,6 +12,7 @@ from ...typing import CreateResult, Messages, Cookies
from ...errors import MissingRequirementsError
from ...requests.raise_for_status import raise_for_status
from ...cookies import get_cookies
from ...image import ImageResponse
from ..base_provider import ProviderModelMixin, AbstractProvider, BaseConversation
from ..helper import format_prompt
from ... import debug
@ -26,10 +27,12 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
working = True
supports_stream = True
needs_auth = True
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
default_model = "Qwen/Qwen2.5-72B-Instruct"
image_models = [
"black-forest-labs/FLUX.1-dev"
]
models = [
'Qwen/Qwen2.5-72B-Instruct',
default_model,
'meta-llama/Meta-Llama-3.1-70B-Instruct',
'CohereForAI/c4ai-command-r-plus-08-2024',
'Qwen/QwQ-32B-Preview',
@ -39,8 +42,8 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
'NousResearch/Hermes-3-Llama-3.1-8B',
'mistralai/Mistral-Nemo-Instruct-2407',
'microsoft/Phi-3.5-mini-instruct',
*image_models
]
model_aliases = {
"qwen-2.5-72b": "Qwen/Qwen2.5-72B-Instruct",
"llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
@ -52,6 +55,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
"hermes-3": "NousResearch/Hermes-3-Llama-3.1-8B",
"mistral-nemo": "mistralai/Mistral-Nemo-Instruct-2407",
"phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
"flux-dev": "black-forest-labs/FLUX.1-dev",
}
@classmethod
@ -109,7 +113,7 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
"is_retry": False,
"is_continue": False,
"web_search": web_search,
"tools": []
"tools": ["000000000000000000000001"] if model in cls.image_models else [],
}
headers = {
@ -162,14 +166,18 @@ class HuggingChat(AbstractProvider, ProviderModelMixin):
elif line["type"] == "finalAnswer":
break
full_response = full_response.replace('<|im_end|', '').replace('\u0000', '').strip()
elif line["type"] == "file":
url = f"https://huggingface.co/chat/conversation/{conversation.conversation_id}/output/{line['sha']}"
yield ImageResponse(url, alt=messages[-1]["content"], options={"cookies": cookies})
full_response = full_response.replace('<|im_end|', '').replace('\u0000', '').strip()
if not stream:
yield full_response
@classmethod
def create_conversation(cls, session: Session, model: str):
if model in cls.image_models:
model = cls.default_model
json_data = {
'model': model,
}

View File

@ -1,21 +1,25 @@
from __future__ import annotations
import json
import base64
import random
from ...typing import AsyncResult, Messages
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...errors import ModelNotFoundError
from ...requests import StreamSession, raise_for_status
from ...image import ImageResponse
from .HuggingChat import HuggingChat
class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co/chat"
working = True
needs_auth = True
supports_message_history = True
default_model = HuggingChat.default_model
models = HuggingChat.models
default_image_model = "black-forest-labs/FLUX.1-dev"
models = [*HuggingChat.models, default_image_model]
image_models = [default_image_model]
model_aliases = HuggingChat.model_aliases
@classmethod
@ -29,6 +33,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
api_key: str = None,
max_new_tokens: int = 1024,
temperature: float = 0.7,
prompt: str = None,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
@ -50,16 +55,22 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
params = {
"return_full_text": False,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
**kwargs
}
payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream}
if model in cls.image_models:
stream = False
prompt = messages[-1]["content"] if prompt is None else prompt
payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32)}}
else:
params = {
"return_full_text": False,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
**kwargs
}
payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream}
async with StreamSession(
headers=headers,
proxy=proxy
proxy=proxy,
timeout=600
) as session:
async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
if response.status == 404:
@ -78,7 +89,12 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
if chunk:
yield chunk
else:
yield (await response.json())[0]["generated_text"].strip()
if response.headers["content-type"].startswith("image/"):
base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}"
yield ImageResponse(url, prompt)
else:
yield (await response.json())[0]["generated_text"].strip()
def format_prompt(messages: Messages) -> str:
system_messages = [message["content"] for message in messages if message["role"] == "system"]

View File

@ -34,8 +34,8 @@ try:
browsers = [
_g4f,
chrome, chromium, opera, opera_gx,
brave, edge, vivaldi, firefox,
chrome, chromium, firefox, opera, opera_gx,
brave, edge, vivaldi,
]
has_browser_cookie3 = True
except ImportError:

View File

@ -504,6 +504,8 @@ async function add_message_chunk(message, message_id) {
p.innerText = message.error;
log_storage.appendChild(p);
} else if (message.type == "preview") {
if (content_map.inner.clientHeight > 200)
content_map.inner.style.height = content_map.inner.clientHeight + "px";
content_map.inner.innerHTML = markdown_render(message.preview);
} else if (message.type == "content") {
message_storage[message_id] += message.content;
@ -522,6 +524,7 @@ async function add_message_chunk(message, message_id) {
content_map.inner.innerHTML = html;
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner);
content_map.inner.style.height = "";
} else if (message.type == "log") {
let p = document.createElement("p");
p.innerText = message.log;

View File

@ -123,22 +123,21 @@ class Api:
print(text)
debug.log_handler = log_handler
proxy = os.environ.get("G4F_PROXY")
provider = kwargs.get("provider")
model, provider_handler = get_model_and_provider(
kwargs.get("model"), provider,
stream=True,
ignore_stream=True
)
first = True
try:
model, provider = get_model_and_provider(
kwargs.get("model"), kwargs.get("provider"),
stream=True,
ignore_stream=True
)
result = ChatCompletion.create(**{**kwargs, "model": model, "provider": provider})
first = True
result = ChatCompletion.create(**{**kwargs, "model": model, "provider": provider_handler})
for chunk in result:
if first:
first = False
if isinstance(provider, IterListProvider):
provider = provider.last_provider
yield self._format_json("provider", {**provider.get_dict(), "model": model})
yield self.handle_provider(provider_handler, model)
if isinstance(chunk, BaseConversation):
if provider:
if provider is not None:
if provider not in conversations:
conversations[provider] = {}
conversations[provider][conversation_id] = chunk
@ -165,6 +164,8 @@ class Api:
except Exception as e:
logger.exception(e)
yield self._format_json('error', get_error_message(e))
if first:
yield self.handle_provider(provider_handler, model)
def _format_json(self, response_type: str, content):
return {
@ -172,9 +173,12 @@ class Api:
response_type: content
}
def handle_provider(self, provider_handler, model):
if isinstance(provider_handler, IterListProvider):
provider_handler = provider_handler.last_provider
if issubclass(provider_handler, ProviderModelMixin) and provider_handler.last_model is not None:
model = provider_handler.last_model
return self._format_json("provider", {**provider_handler.get_dict(), "model": model})
def get_error_message(exception: Exception) -> str:
message = f"{type(exception).__name__}: {exception}"
provider = get_last_provider()
if provider is None:
return message
return f"{provider.__name__}: {message}"
return f"{type(exception).__name__}: {exception}"

View File

@ -38,6 +38,7 @@ from .Provider import (
RubiksAI,
TeachAnything,
Upstage,
Flux,
)
@dataclass(unsafe_hash=True)
@ -599,7 +600,7 @@ flux_pro = ImageModel(
flux_dev = ImageModel(
name = 'flux-dev',
base_provider = 'Flux AI',
best_provider = AmigoChat
best_provider = IterListProvider([Flux, AmigoChat, HuggingChat, HuggingFace])
)
flux_realism = ImageModel(

View File

@ -98,7 +98,7 @@ class AbstractProvider(BaseProvider):
default_value = f'"{param.default}"' if isinstance(param.default, str) else param.default
args += f" = {default_value}" if param.default is not Parameter.empty else ""
args += ","
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
class AsyncProvider(AbstractProvider):
@ -240,6 +240,7 @@ class ProviderModelMixin:
models: list[str] = []
model_aliases: dict[str, str] = {}
image_models: list = None
last_model: str = None
@classmethod
def get_models(cls) -> list[str]:
@ -255,5 +256,6 @@ class ProviderModelMixin:
model = cls.model_aliases[model]
elif model not in cls.get_models() and cls.models:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
cls.last_model = model
debug.last_model = model
return model