Fix get_models in Airforce provider

This commit is contained in:
Heiner Lohaus 2024-12-05 11:59:32 +01:00
parent e2c269cc63
commit d9ddc70394

View File

@ -4,7 +4,6 @@ import random
import re
import requests
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from urllib.parse import quote
from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
@ -70,21 +69,23 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
def fetch_imagine_models(cls):
response = requests.get(
'https://api.airforce/imagine/models',
'https://api.airforce/v1/imagine2/models',
verify=False
)
response.raise_for_status()
return response.json()
image_models = fetch_imagine_models.__get__(None, object)() + additional_models_imagine
@classmethod
def is_image_model(cls, model: str) -> bool:
return model in cls.image_models
models = list(dict.fromkeys([default_model] +
fetch_completions_models.__get__(None, object)() +
image_models))
@classmethod
def get_models(cls):
if not cls.models:
cls.image_models = cls.fetch_imagine_models() + cls.additional_models_imagine
cls.models = list(dict.fromkeys([cls.default_model] +
cls.fetch_completions_models() +
cls.image_models))
return cls.models
@classmethod
async def check_api_key(cls, api_key: str) -> bool:
@ -133,7 +134,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
async with session.get(cls.api_endpoint_imagine2, params=params, proxy=proxy) as response:
if response.status == 200:
image_url = str(response.url)
yield ImageResponse(images=image_url, alt=f"Generated image: {prompt}")
yield ImageResponse(images=image_url, alt=prompt)
else:
error_text = await response.text()
raise RuntimeError(f"Image generation failed: {response.status} - {error_text}")
@ -215,12 +216,12 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
if not await cls.check_api_key(api_key):
pass
model = cls.get_model(model)
if cls.is_image_model(model):
if prompt is None:
prompt = messages[-1]['content']
if seed is None:
seed = random.randint(0, 10000)
async for result in cls.generate_image(model, prompt, api_key, size, seed, proxy):
yield result
else: