diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 1d79b7d0..b036603e 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -276,16 +276,12 @@ class Api: HTTP_200_OK: {"model": List[ModelResponseModel]}, }) async def models(): - model_list = dict( - (model, g4f.models.ModelUtils.convert[model]) - for model in g4f.Model.__all__() - ) return [{ 'id': model_id, 'object': 'model', 'created': 0, 'owned_by': model.base_provider - } for model_id, model in model_list.items()] + } for model_id, model in g4f.models.ModelUtils.convert.items()] @self.app.get("/v1/models/{model_name}", responses={ HTTP_200_OK: {"model": ModelResponseModel}, diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index b6e9c9da..e3c6dc27 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -778,7 +778,7 @@ select:hover, background-color: var(--button-hover); } -#provider option:disabled[value] { +#provider option:disabled[value], #model option:disabled[value] { display: none; } diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index 8d20f7b2..c08f72f0 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -1296,6 +1296,13 @@ async function on_load() { const load_provider_option = (input, provider_name) => { if (input.checked) { + modelSelect.querySelectorAll(`option[data-disabled_providers*="${provider_name}"]`).forEach( + (el) => { + el.dataset.disabled_providers = el.dataset.disabled_providers ? el.dataset.disabled_providers.split(" ").filter((provider) => provider!=provider_name).join(" ") : ""; + el.dataset.providers = (el.dataset.providers ? el.dataset.providers + " " : "") + provider_name; + modelSelect.querySelectorAll(`option[value="${el.value}"]`).forEach((o)=>o.removeAttribute("disabled", "disabled")) + } + ); providerSelect.querySelectorAll(`option[value="${provider_name}"]`).forEach( (el) => el.removeAttribute("disabled") ); @@ -1303,6 +1310,13 @@ const load_provider_option = (input, provider_name) => { (el) => el.removeAttribute("disabled") ); } else { + modelSelect.querySelectorAll(`option[data-providers*="${provider_name}"]`).forEach( + (el) => { + el.dataset.providers = el.dataset.providers ? el.dataset.providers.split(" ").filter((provider) => provider!=provider_name).join(" ") : ""; + el.dataset.disabled_providers = (el.dataset.disabled_providers ? el.dataset.disabled_providers + " " : "") + provider_name; + if (!el.dataset.providers) modelSelect.querySelectorAll(`option[value="${el.value}"]`).forEach((o)=>o.setAttribute("disabled", "disabled")) + } + ); providerSelect.querySelectorAll(`option[value="${provider_name}"]`).forEach( (el) => el.setAttribute("disabled", "disabled") ); @@ -1342,7 +1356,9 @@ async function on_api() { models = await api("models"); models.forEach((model) => { let option = document.createElement("option"); - option.value = option.text = option.dataset.label = model; + option.value = model.name; + option.text = model.name + (model.image ? " (Image Generation)" : ""); + option.dataset.providers = model.providers.join(" "); modelSelect.appendChild(option); }); providers = await api("providers") diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 48028dba..d96edc2c 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -13,7 +13,7 @@ from g4f.errors import VersionNotFoundError from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir from g4f.Provider import ProviderType, __providers__, __map__ from g4f.providers.base_provider import ProviderModelMixin -from g4f.providers.retry_provider import BaseRetryProvider +from g4f.providers.retry_provider import IterListProvider from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData from g4f.client.service import convert_to_provider from g4f import debug @@ -24,7 +24,15 @@ conversations: dict[dict[str, BaseConversation]] = {} class Api: @staticmethod def get_models(): - return models._all_models + return [{ + "name": model.name, + "image": isinstance(model, models.ImageModel), + "providers": [ + getattr(provider, "parent", provider.__name__) + for provider in providers + ] + } + for model, providers in models.__models__.values()] @staticmethod def get_provider_models(provider: str, api_key: str = None): @@ -126,7 +134,7 @@ class Api: for chunk in result: if first: first = False - if isinstance(provider, BaseRetryProvider): + if isinstance(provider, IterListProvider): provider = provider.last_provider yield self._format_json("provider", {**provider.get_dict(), "model": model}) if isinstance(chunk, BaseConversation): diff --git a/g4f/models.py b/g4f/models.py index 9cc4e901..58d5067b 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -59,6 +59,9 @@ class Model: """Returns a list of all model names.""" return _all_models +class ImageModel(Model): + pass + ### Default ### default = Model( name = "", @@ -559,100 +562,98 @@ any_uncensored = Model( ############# ### Stability AI ### -sdxl = Model( +sdxl = ImageModel( name = 'sdxl', base_provider = 'Stability AI', best_provider = IterListProvider([ReplicateHome, Airforce]) ) -sd_3 = Model( +sd_3 = ImageModel( name = 'sd-3', base_provider = 'Stability AI', best_provider = ReplicateHome - ) ### Playground ### -playground_v2_5 = Model( +playground_v2_5 = ImageModel( name = 'playground-v2.5', base_provider = 'Playground AI', best_provider = ReplicateHome - ) ### Flux AI ### -flux = Model( +flux = ImageModel( name = 'flux', base_provider = 'Flux AI', best_provider = IterListProvider([Blackbox, Airforce]) ) -flux_pro = Model( +flux_pro = ImageModel( name = 'flux-pro', base_provider = 'Flux AI', best_provider = Airforce ) -flux_dev = Model( +flux_dev = ImageModel( name = 'flux-dev', base_provider = 'Flux AI', best_provider = AmigoChat ) -flux_realism = Model( +flux_realism = ImageModel( name = 'flux-realism', base_provider = 'Flux AI', best_provider = IterListProvider([Airforce, AmigoChat]) ) -flux_anime = Model( +flux_anime = ImageModel( name = 'flux-anime', base_provider = 'Flux AI', best_provider = Airforce ) -flux_3d = Model( +flux_3d = ImageModel( name = 'flux-3d', base_provider = 'Flux AI', best_provider = Airforce ) -flux_disney = Model( +flux_disney = ImageModel( name = 'flux-disney', base_provider = 'Flux AI', best_provider = Airforce ) -flux_pixel = Model( +flux_pixel = ImageModel( name = 'flux-pixel', base_provider = 'Flux AI', best_provider = Airforce ) -flux_4o = Model( +flux_4o = ImageModel( name = 'flux-4o', base_provider = 'Flux AI', best_provider = Airforce ) ### OpenAI ### -dall_e_3 = Model( +dall_e_3 = ImageModel( name = 'dall-e-3', base_provider = 'OpenAI', best_provider = IterListProvider([Airforce, CopilotAccount, OpenaiAccount, MicrosoftDesigner, BingCreateImages]) ) ### Recraft ### -recraft_v3 = Model( +recraft_v3 = ImageModel( name = 'recraft-v3', base_provider = 'Recraft', best_provider = AmigoChat ) ### Other ### -any_dark = Model( +any_dark = ImageModel( name = 'any-dark', base_provider = 'Other', best_provider = Airforce @@ -863,4 +864,17 @@ class ModelUtils: 'any-dark': any_dark, } -_all_models = list(ModelUtils.convert.keys()) +# Create a list of all working models +__models__ = {model.name: (model, providers) for model, providers in [ + (model, [provider for provider in providers if provider.working]) + for model, providers in [ + (model, model.best_provider.providers + if isinstance(model.best_provider, IterListProvider) + else [model.best_provider] + if model.best_provider is not None + else []) + for model in ModelUtils.convert.values()] + ] if providers} +# Update the ModelUtils.convert with the working models +ModelUtils.convert = {model.name: model for model, _ in __models__.values()} +_all_models = list(ModelUtils.convert.keys()) \ No newline at end of file