Show only enabled models in gui

This commit is contained in:
Heiner Lohaus 2024-12-07 22:50:28 +01:00
parent 20ad08021a
commit eff97ed344
5 changed files with 62 additions and 28 deletions

View File

@ -276,16 +276,12 @@ class Api:
HTTP_200_OK: {"model": List[ModelResponseModel]}, HTTP_200_OK: {"model": List[ModelResponseModel]},
}) })
async def models(): async def models():
model_list = dict(
(model, g4f.models.ModelUtils.convert[model])
for model in g4f.Model.__all__()
)
return [{ return [{
'id': model_id, 'id': model_id,
'object': 'model', 'object': 'model',
'created': 0, 'created': 0,
'owned_by': model.base_provider 'owned_by': model.base_provider
} for model_id, model in model_list.items()] } for model_id, model in g4f.models.ModelUtils.convert.items()]
@self.app.get("/v1/models/{model_name}", responses={ @self.app.get("/v1/models/{model_name}", responses={
HTTP_200_OK: {"model": ModelResponseModel}, HTTP_200_OK: {"model": ModelResponseModel},

View File

@ -778,7 +778,7 @@ select:hover,
background-color: var(--button-hover); background-color: var(--button-hover);
} }
#provider option:disabled[value] { #provider option:disabled[value], #model option:disabled[value] {
display: none; display: none;
} }

View File

@ -1296,6 +1296,13 @@ async function on_load() {
const load_provider_option = (input, provider_name) => { const load_provider_option = (input, provider_name) => {
if (input.checked) { if (input.checked) {
modelSelect.querySelectorAll(`option[data-disabled_providers*="${provider_name}"]`).forEach(
(el) => {
el.dataset.disabled_providers = el.dataset.disabled_providers ? el.dataset.disabled_providers.split(" ").filter((provider) => provider!=provider_name).join(" ") : "";
el.dataset.providers = (el.dataset.providers ? el.dataset.providers + " " : "") + provider_name;
modelSelect.querySelectorAll(`option[value="${el.value}"]`).forEach((o)=>o.removeAttribute("disabled", "disabled"))
}
);
providerSelect.querySelectorAll(`option[value="${provider_name}"]`).forEach( providerSelect.querySelectorAll(`option[value="${provider_name}"]`).forEach(
(el) => el.removeAttribute("disabled") (el) => el.removeAttribute("disabled")
); );
@ -1303,6 +1310,13 @@ const load_provider_option = (input, provider_name) => {
(el) => el.removeAttribute("disabled") (el) => el.removeAttribute("disabled")
); );
} else { } else {
modelSelect.querySelectorAll(`option[data-providers*="${provider_name}"]`).forEach(
(el) => {
el.dataset.providers = el.dataset.providers ? el.dataset.providers.split(" ").filter((provider) => provider!=provider_name).join(" ") : "";
el.dataset.disabled_providers = (el.dataset.disabled_providers ? el.dataset.disabled_providers + " " : "") + provider_name;
if (!el.dataset.providers) modelSelect.querySelectorAll(`option[value="${el.value}"]`).forEach((o)=>o.setAttribute("disabled", "disabled"))
}
);
providerSelect.querySelectorAll(`option[value="${provider_name}"]`).forEach( providerSelect.querySelectorAll(`option[value="${provider_name}"]`).forEach(
(el) => el.setAttribute("disabled", "disabled") (el) => el.setAttribute("disabled", "disabled")
); );
@ -1342,7 +1356,9 @@ async function on_api() {
models = await api("models"); models = await api("models");
models.forEach((model) => { models.forEach((model) => {
let option = document.createElement("option"); let option = document.createElement("option");
option.value = option.text = option.dataset.label = model; option.value = model.name;
option.text = model.name + (model.image ? " (Image Generation)" : "");
option.dataset.providers = model.providers.join(" ");
modelSelect.appendChild(option); modelSelect.appendChild(option);
}); });
providers = await api("providers") providers = await api("providers")

View File

@ -13,7 +13,7 @@ from g4f.errors import VersionNotFoundError
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from g4f.Provider import ProviderType, __providers__, __map__ from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin from g4f.providers.base_provider import ProviderModelMixin
from g4f.providers.retry_provider import BaseRetryProvider from g4f.providers.retry_provider import IterListProvider
from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData
from g4f.client.service import convert_to_provider from g4f.client.service import convert_to_provider
from g4f import debug from g4f import debug
@ -24,7 +24,15 @@ conversations: dict[dict[str, BaseConversation]] = {}
class Api: class Api:
@staticmethod @staticmethod
def get_models(): def get_models():
return models._all_models return [{
"name": model.name,
"image": isinstance(model, models.ImageModel),
"providers": [
getattr(provider, "parent", provider.__name__)
for provider in providers
]
}
for model, providers in models.__models__.values()]
@staticmethod @staticmethod
def get_provider_models(provider: str, api_key: str = None): def get_provider_models(provider: str, api_key: str = None):
@ -126,7 +134,7 @@ class Api:
for chunk in result: for chunk in result:
if first: if first:
first = False first = False
if isinstance(provider, BaseRetryProvider): if isinstance(provider, IterListProvider):
provider = provider.last_provider provider = provider.last_provider
yield self._format_json("provider", {**provider.get_dict(), "model": model}) yield self._format_json("provider", {**provider.get_dict(), "model": model})
if isinstance(chunk, BaseConversation): if isinstance(chunk, BaseConversation):

View File

@ -59,6 +59,9 @@ class Model:
"""Returns a list of all model names.""" """Returns a list of all model names."""
return _all_models return _all_models
class ImageModel(Model):
pass
### Default ### ### Default ###
default = Model( default = Model(
name = "", name = "",
@ -559,100 +562,98 @@ any_uncensored = Model(
############# #############
### Stability AI ### ### Stability AI ###
sdxl = Model( sdxl = ImageModel(
name = 'sdxl', name = 'sdxl',
base_provider = 'Stability AI', base_provider = 'Stability AI',
best_provider = IterListProvider([ReplicateHome, Airforce]) best_provider = IterListProvider([ReplicateHome, Airforce])
) )
sd_3 = Model( sd_3 = ImageModel(
name = 'sd-3', name = 'sd-3',
base_provider = 'Stability AI', base_provider = 'Stability AI',
best_provider = ReplicateHome best_provider = ReplicateHome
) )
### Playground ### ### Playground ###
playground_v2_5 = Model( playground_v2_5 = ImageModel(
name = 'playground-v2.5', name = 'playground-v2.5',
base_provider = 'Playground AI', base_provider = 'Playground AI',
best_provider = ReplicateHome best_provider = ReplicateHome
) )
### Flux AI ### ### Flux AI ###
flux = Model( flux = ImageModel(
name = 'flux', name = 'flux',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = IterListProvider([Blackbox, Airforce]) best_provider = IterListProvider([Blackbox, Airforce])
) )
flux_pro = Model( flux_pro = ImageModel(
name = 'flux-pro', name = 'flux-pro',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = Airforce best_provider = Airforce
) )
flux_dev = Model( flux_dev = ImageModel(
name = 'flux-dev', name = 'flux-dev',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = AmigoChat best_provider = AmigoChat
) )
flux_realism = Model( flux_realism = ImageModel(
name = 'flux-realism', name = 'flux-realism',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = IterListProvider([Airforce, AmigoChat]) best_provider = IterListProvider([Airforce, AmigoChat])
) )
flux_anime = Model( flux_anime = ImageModel(
name = 'flux-anime', name = 'flux-anime',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = Airforce best_provider = Airforce
) )
flux_3d = Model( flux_3d = ImageModel(
name = 'flux-3d', name = 'flux-3d',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = Airforce best_provider = Airforce
) )
flux_disney = Model( flux_disney = ImageModel(
name = 'flux-disney', name = 'flux-disney',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = Airforce best_provider = Airforce
) )
flux_pixel = Model( flux_pixel = ImageModel(
name = 'flux-pixel', name = 'flux-pixel',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = Airforce best_provider = Airforce
) )
flux_4o = Model( flux_4o = ImageModel(
name = 'flux-4o', name = 'flux-4o',
base_provider = 'Flux AI', base_provider = 'Flux AI',
best_provider = Airforce best_provider = Airforce
) )
### OpenAI ### ### OpenAI ###
dall_e_3 = Model( dall_e_3 = ImageModel(
name = 'dall-e-3', name = 'dall-e-3',
base_provider = 'OpenAI', base_provider = 'OpenAI',
best_provider = IterListProvider([Airforce, CopilotAccount, OpenaiAccount, MicrosoftDesigner, BingCreateImages]) best_provider = IterListProvider([Airforce, CopilotAccount, OpenaiAccount, MicrosoftDesigner, BingCreateImages])
) )
### Recraft ### ### Recraft ###
recraft_v3 = Model( recraft_v3 = ImageModel(
name = 'recraft-v3', name = 'recraft-v3',
base_provider = 'Recraft', base_provider = 'Recraft',
best_provider = AmigoChat best_provider = AmigoChat
) )
### Other ### ### Other ###
any_dark = Model( any_dark = ImageModel(
name = 'any-dark', name = 'any-dark',
base_provider = 'Other', base_provider = 'Other',
best_provider = Airforce best_provider = Airforce
@ -863,4 +864,17 @@ class ModelUtils:
'any-dark': any_dark, 'any-dark': any_dark,
} }
# Create a list of all working models
__models__ = {model.name: (model, providers) for model, providers in [
(model, [provider for provider in providers if provider.working])
for model, providers in [
(model, model.best_provider.providers
if isinstance(model.best_provider, IterListProvider)
else [model.best_provider]
if model.best_provider is not None
else [])
for model in ModelUtils.convert.values()]
] if providers}
# Update the ModelUtils.convert with the working models
ModelUtils.convert = {model.name: model for model, _ in __models__.values()}
_all_models = list(ModelUtils.convert.keys()) _all_models = list(ModelUtils.convert.keys())