mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-12-23 11:02:40 +03:00
Fix get_models in Airforce provider
This commit is contained in:
parent
e2c269cc63
commit
d9ddc70394
@ -4,7 +4,6 @@ import random
|
||||
import re
|
||||
import requests
|
||||
from requests.packages.urllib3.exceptions import InsecureRequestWarning
|
||||
from urllib.parse import quote
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
@ -70,21 +69,23 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
def fetch_imagine_models(cls):
|
||||
response = requests.get(
|
||||
'https://api.airforce/imagine/models',
|
||||
'https://api.airforce/v1/imagine2/models',
|
||||
verify=False
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
image_models = fetch_imagine_models.__get__(None, object)() + additional_models_imagine
|
||||
|
||||
@classmethod
|
||||
def is_image_model(cls, model: str) -> bool:
|
||||
return model in cls.image_models
|
||||
|
||||
models = list(dict.fromkeys([default_model] +
|
||||
fetch_completions_models.__get__(None, object)() +
|
||||
image_models))
|
||||
@classmethod
|
||||
def get_models(cls):
|
||||
if not cls.models:
|
||||
cls.image_models = cls.fetch_imagine_models() + cls.additional_models_imagine
|
||||
cls.models = list(dict.fromkeys([cls.default_model] +
|
||||
cls.fetch_completions_models() +
|
||||
cls.image_models))
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def check_api_key(cls, api_key: str) -> bool:
|
||||
@ -133,7 +134,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
async with session.get(cls.api_endpoint_imagine2, params=params, proxy=proxy) as response:
|
||||
if response.status == 200:
|
||||
image_url = str(response.url)
|
||||
yield ImageResponse(images=image_url, alt=f"Generated image: {prompt}")
|
||||
yield ImageResponse(images=image_url, alt=prompt)
|
||||
else:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Image generation failed: {response.status} - {error_text}")
|
||||
@ -215,12 +216,12 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if not await cls.check_api_key(api_key):
|
||||
pass
|
||||
|
||||
model = cls.get_model(model)
|
||||
if cls.is_image_model(model):
|
||||
if prompt is None:
|
||||
prompt = messages[-1]['content']
|
||||
if seed is None:
|
||||
seed = random.randint(0, 10000)
|
||||
|
||||
async for result in cls.generate_image(model, prompt, api_key, size, seed, proxy):
|
||||
yield result
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user