Show only enabled models in gui

This commit is contained in:
Heiner Lohaus 2024-12-07 22:50:28 +01:00
parent 20ad08021a
commit eff97ed344
5 changed files with 62 additions and 28 deletions

View File

@ -276,16 +276,12 @@ class Api:
HTTP_200_OK: {"model": List[ModelResponseModel]},
})
async def models():
model_list = dict(
(model, g4f.models.ModelUtils.convert[model])
for model in g4f.Model.__all__()
)
return [{
'id': model_id,
'object': 'model',
'created': 0,
'owned_by': model.base_provider
} for model_id, model in model_list.items()]
} for model_id, model in g4f.models.ModelUtils.convert.items()]
@self.app.get("/v1/models/{model_name}", responses={
HTTP_200_OK: {"model": ModelResponseModel},

View File

@ -778,7 +778,7 @@ select:hover,
background-color: var(--button-hover);
}
#provider option:disabled[value] {
#provider option:disabled[value], #model option:disabled[value] {
display: none;
}

View File

@ -1296,6 +1296,13 @@ async function on_load() {
const load_provider_option = (input, provider_name) => {
if (input.checked) {
modelSelect.querySelectorAll(`option[data-disabled_providers*="${provider_name}"]`).forEach(
(el) => {
el.dataset.disabled_providers = el.dataset.disabled_providers ? el.dataset.disabled_providers.split(" ").filter((provider) => provider!=provider_name).join(" ") : "";
el.dataset.providers = (el.dataset.providers ? el.dataset.providers + " " : "") + provider_name;
modelSelect.querySelectorAll(`option[value="${el.value}"]`).forEach((o)=>o.removeAttribute("disabled", "disabled"))
}
);
providerSelect.querySelectorAll(`option[value="${provider_name}"]`).forEach(
(el) => el.removeAttribute("disabled")
);
@ -1303,6 +1310,13 @@ const load_provider_option = (input, provider_name) => {
(el) => el.removeAttribute("disabled")
);
} else {
modelSelect.querySelectorAll(`option[data-providers*="${provider_name}"]`).forEach(
(el) => {
el.dataset.providers = el.dataset.providers ? el.dataset.providers.split(" ").filter((provider) => provider!=provider_name).join(" ") : "";
el.dataset.disabled_providers = (el.dataset.disabled_providers ? el.dataset.disabled_providers + " " : "") + provider_name;
if (!el.dataset.providers) modelSelect.querySelectorAll(`option[value="${el.value}"]`).forEach((o)=>o.setAttribute("disabled", "disabled"))
}
);
providerSelect.querySelectorAll(`option[value="${provider_name}"]`).forEach(
(el) => el.setAttribute("disabled", "disabled")
);
@ -1342,7 +1356,9 @@ async function on_api() {
models = await api("models");
models.forEach((model) => {
let option = document.createElement("option");
option.value = option.text = option.dataset.label = model;
option.value = model.name;
option.text = model.name + (model.image ? " (Image Generation)" : "");
option.dataset.providers = model.providers.join(" ");
modelSelect.appendChild(option);
});
providers = await api("providers")

View File

@ -13,7 +13,7 @@ from g4f.errors import VersionNotFoundError
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin
from g4f.providers.retry_provider import BaseRetryProvider
from g4f.providers.retry_provider import IterListProvider
from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData
from g4f.client.service import convert_to_provider
from g4f import debug
@ -24,7 +24,15 @@ conversations: dict[dict[str, BaseConversation]] = {}
class Api:
@staticmethod
def get_models():
return models._all_models
return [{
"name": model.name,
"image": isinstance(model, models.ImageModel),
"providers": [
getattr(provider, "parent", provider.__name__)
for provider in providers
]
}
for model, providers in models.__models__.values()]
@staticmethod
def get_provider_models(provider: str, api_key: str = None):
@ -126,7 +134,7 @@ class Api:
for chunk in result:
if first:
first = False
if isinstance(provider, BaseRetryProvider):
if isinstance(provider, IterListProvider):
provider = provider.last_provider
yield self._format_json("provider", {**provider.get_dict(), "model": model})
if isinstance(chunk, BaseConversation):

View File

@ -59,6 +59,9 @@ class Model:
"""Returns a list of all model names."""
return _all_models
class ImageModel(Model):
pass
### Default ###
default = Model(
name = "",
@ -559,100 +562,98 @@ any_uncensored = Model(
#############
### Stability AI ###
sdxl = Model(
sdxl = ImageModel(
name = 'sdxl',
base_provider = 'Stability AI',
best_provider = IterListProvider([ReplicateHome, Airforce])
)
sd_3 = Model(
sd_3 = ImageModel(
name = 'sd-3',
base_provider = 'Stability AI',
best_provider = ReplicateHome
)
### Playground ###
playground_v2_5 = Model(
playground_v2_5 = ImageModel(
name = 'playground-v2.5',
base_provider = 'Playground AI',
best_provider = ReplicateHome
)
### Flux AI ###
flux = Model(
flux = ImageModel(
name = 'flux',
base_provider = 'Flux AI',
best_provider = IterListProvider([Blackbox, Airforce])
)
flux_pro = Model(
flux_pro = ImageModel(
name = 'flux-pro',
base_provider = 'Flux AI',
best_provider = Airforce
)
flux_dev = Model(
flux_dev = ImageModel(
name = 'flux-dev',
base_provider = 'Flux AI',
best_provider = AmigoChat
)
flux_realism = Model(
flux_realism = ImageModel(
name = 'flux-realism',
base_provider = 'Flux AI',
best_provider = IterListProvider([Airforce, AmigoChat])
)
flux_anime = Model(
flux_anime = ImageModel(
name = 'flux-anime',
base_provider = 'Flux AI',
best_provider = Airforce
)
flux_3d = Model(
flux_3d = ImageModel(
name = 'flux-3d',
base_provider = 'Flux AI',
best_provider = Airforce
)
flux_disney = Model(
flux_disney = ImageModel(
name = 'flux-disney',
base_provider = 'Flux AI',
best_provider = Airforce
)
flux_pixel = Model(
flux_pixel = ImageModel(
name = 'flux-pixel',
base_provider = 'Flux AI',
best_provider = Airforce
)
flux_4o = Model(
flux_4o = ImageModel(
name = 'flux-4o',
base_provider = 'Flux AI',
best_provider = Airforce
)
### OpenAI ###
dall_e_3 = Model(
dall_e_3 = ImageModel(
name = 'dall-e-3',
base_provider = 'OpenAI',
best_provider = IterListProvider([Airforce, CopilotAccount, OpenaiAccount, MicrosoftDesigner, BingCreateImages])
)
### Recraft ###
recraft_v3 = Model(
recraft_v3 = ImageModel(
name = 'recraft-v3',
base_provider = 'Recraft',
best_provider = AmigoChat
)
### Other ###
any_dark = Model(
any_dark = ImageModel(
name = 'any-dark',
base_provider = 'Other',
best_provider = Airforce
@ -863,4 +864,17 @@ class ModelUtils:
'any-dark': any_dark,
}
# Create a list of all working models
__models__ = {model.name: (model, providers) for model, providers in [
(model, [provider for provider in providers if provider.working])
for model, providers in [
(model, model.best_provider.providers
if isinstance(model.best_provider, IterListProvider)
else [model.best_provider]
if model.best_provider is not None
else [])
for model in ModelUtils.convert.values()]
] if providers}
# Update the ModelUtils.convert with the working models
ModelUtils.convert = {model.name: model for model, _ in __models__.values()}
_all_models = list(ModelUtils.convert.keys())