mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-12-23 19:11:48 +03:00
Fix get_models in Airforce provider
This commit is contained in:
parent
e2c269cc63
commit
d9ddc70394
@ -4,7 +4,6 @@ import random
|
|||||||
import re
|
import re
|
||||||
import requests
|
import requests
|
||||||
from requests.packages.urllib3.exceptions import InsecureRequestWarning
|
from requests.packages.urllib3.exceptions import InsecureRequestWarning
|
||||||
from urllib.parse import quote
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages
|
||||||
@ -70,21 +69,23 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
def fetch_imagine_models(cls):
|
def fetch_imagine_models(cls):
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
'https://api.airforce/imagine/models',
|
'https://api.airforce/imagine/models',
|
||||||
'https://api.airforce/v1/imagine2/models',
|
|
||||||
verify=False
|
verify=False
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
image_models = fetch_imagine_models.__get__(None, object)() + additional_models_imagine
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_image_model(cls, model: str) -> bool:
|
def is_image_model(cls, model: str) -> bool:
|
||||||
return model in cls.image_models
|
return model in cls.image_models
|
||||||
|
|
||||||
models = list(dict.fromkeys([default_model] +
|
@classmethod
|
||||||
fetch_completions_models.__get__(None, object)() +
|
def get_models(cls):
|
||||||
image_models))
|
if not cls.models:
|
||||||
|
cls.image_models = cls.fetch_imagine_models() + cls.additional_models_imagine
|
||||||
|
cls.models = list(dict.fromkeys([cls.default_model] +
|
||||||
|
cls.fetch_completions_models() +
|
||||||
|
cls.image_models))
|
||||||
|
return cls.models
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def check_api_key(cls, api_key: str) -> bool:
|
async def check_api_key(cls, api_key: str) -> bool:
|
||||||
@ -133,7 +134,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
async with session.get(cls.api_endpoint_imagine2, params=params, proxy=proxy) as response:
|
async with session.get(cls.api_endpoint_imagine2, params=params, proxy=proxy) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
image_url = str(response.url)
|
image_url = str(response.url)
|
||||||
yield ImageResponse(images=image_url, alt=f"Generated image: {prompt}")
|
yield ImageResponse(images=image_url, alt=prompt)
|
||||||
else:
|
else:
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
raise RuntimeError(f"Image generation failed: {response.status} - {error_text}")
|
raise RuntimeError(f"Image generation failed: {response.status} - {error_text}")
|
||||||
@ -215,12 +216,12 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
if not await cls.check_api_key(api_key):
|
if not await cls.check_api_key(api_key):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
model = cls.get_model(model)
|
||||||
if cls.is_image_model(model):
|
if cls.is_image_model(model):
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
prompt = messages[-1]['content']
|
prompt = messages[-1]['content']
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = random.randint(0, 10000)
|
seed = random.randint(0, 10000)
|
||||||
|
|
||||||
async for result in cls.generate_image(model, prompt, api_key, size, seed, proxy):
|
async for result in cls.generate_image(model, prompt, api_key, size, seed, proxy):
|
||||||
yield result
|
yield result
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user