Fix get_models in Airforce provider

This commit is contained in:
Heiner Lohaus 2024-12-05 11:59:32 +01:00
parent e2c269cc63
commit d9ddc70394

View File

@ -4,7 +4,6 @@ import random
import re import re
import requests import requests
from requests.packages.urllib3.exceptions import InsecureRequestWarning from requests.packages.urllib3.exceptions import InsecureRequestWarning
from urllib.parse import quote
from aiohttp import ClientSession from aiohttp import ClientSession
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
@ -70,21 +69,23 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
def fetch_imagine_models(cls): def fetch_imagine_models(cls):
response = requests.get( response = requests.get(
'https://api.airforce/imagine/models', 'https://api.airforce/imagine/models',
'https://api.airforce/v1/imagine2/models',
verify=False verify=False
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
image_models = fetch_imagine_models.__get__(None, object)() + additional_models_imagine
@classmethod @classmethod
def is_image_model(cls, model: str) -> bool: def is_image_model(cls, model: str) -> bool:
return model in cls.image_models return model in cls.image_models
models = list(dict.fromkeys([default_model] + @classmethod
fetch_completions_models.__get__(None, object)() + def get_models(cls):
image_models)) if not cls.models:
cls.image_models = cls.fetch_imagine_models() + cls.additional_models_imagine
cls.models = list(dict.fromkeys([cls.default_model] +
cls.fetch_completions_models() +
cls.image_models))
return cls.models
@classmethod @classmethod
async def check_api_key(cls, api_key: str) -> bool: async def check_api_key(cls, api_key: str) -> bool:
@ -133,7 +134,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
async with session.get(cls.api_endpoint_imagine2, params=params, proxy=proxy) as response: async with session.get(cls.api_endpoint_imagine2, params=params, proxy=proxy) as response:
if response.status == 200: if response.status == 200:
image_url = str(response.url) image_url = str(response.url)
yield ImageResponse(images=image_url, alt=f"Generated image: {prompt}") yield ImageResponse(images=image_url, alt=prompt)
else: else:
error_text = await response.text() error_text = await response.text()
raise RuntimeError(f"Image generation failed: {response.status} - {error_text}") raise RuntimeError(f"Image generation failed: {response.status} - {error_text}")
@ -215,12 +216,12 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
if not await cls.check_api_key(api_key): if not await cls.check_api_key(api_key):
pass pass
model = cls.get_model(model)
if cls.is_image_model(model): if cls.is_image_model(model):
if prompt is None: if prompt is None:
prompt = messages[-1]['content'] prompt = messages[-1]['content']
if seed is None: if seed is None:
seed = random.randint(0, 10000) seed = random.randint(0, 10000)
async for result in cls.generate_image(model, prompt, api_key, size, seed, proxy): async for result in cls.generate_image(model, prompt, api_key, size, seed, proxy):
yield result yield result
else: else: