Merge pull request #2389 from hlohaus/info

Add Cerebras and HuggingFace2 provider, Fix RubiksAI provider
This commit is contained in:
H Lohaus 2024-11-20 02:42:15 +01:00 committed by GitHub
commit 419264f966
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 237 additions and 133 deletions

View File

@ -28,6 +28,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
image_models = [default_image_model, 'repomap']
text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'blackboxai-pro']
model_aliases = {
"claude-3.5-sonnet": "claude-sonnet-3.5",
}
agentMode = {
default_image_model: {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
}
@ -198,6 +201,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
async with ClientSession(headers=headers) as session:
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status()
is_first = False
async for chunk in response.content.iter_any():
text_chunk = chunk.decode(errors="ignore")
if model in cls.image_models:
@ -217,5 +221,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
for i, result in enumerate(search_results, 1):
formatted_response += f"\n{i}. {result['title']}: {result['link']}"
yield formatted_response
else:
yield text_chunk.strip()
elif text_chunk:
if is_first:
is_first = False
yield text_chunk.lstrip()
else:
yield text_chunk

View File

@ -21,8 +21,9 @@ from .helper import format_prompt
from ..typing import CreateResult, Messages, ImageType
from ..errors import MissingRequirementsError
from ..requests.raise_for_status import raise_for_status
from ..providers.helper import format_cookies
from ..requests import get_nodriver
from ..image import to_bytes, is_accepted_format
from ..image import ImageResponse, to_bytes, is_accepted_format
from .. import debug
class Conversation(BaseConversation):
@ -70,18 +71,21 @@ class Copilot(AbstractProvider):
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy))
else:
access_token = conversation.access_token
websocket_url = f"{websocket_url}&acessToken={quote(access_token)}"
headers = {"Authorization": f"Bearer {access_token}"}
debug.log(f"Copilot: Access token: {access_token[:7]}...{access_token[-5:]}")
debug.log(f"Copilot: Cookies: {';'.join([*cookies])}")
websocket_url = f"{websocket_url}&accessToken={quote(access_token)}"
headers = {"authorization": f"Bearer {access_token}", "cookie": format_cookies(cookies)}
with Session(
timeout=timeout,
proxy=proxy,
impersonate="chrome",
headers=headers,
cookies=cookies
cookies=cookies,
) as session:
response = session.get(f"{cls.url}/")
response = session.get("https://copilot.microsoft.com/c/api/user")
raise_for_status(response)
debug.log(f"Copilot: User: {response.json().get('firstName', 'null')}")
if conversation is None:
response = session.post(cls.conversation_url)
raise_for_status(response)
@ -119,6 +123,7 @@ class Copilot(AbstractProvider):
is_started = False
msg = None
image_prompt: str = None
while True:
try:
msg = wss.recv()[0]
@ -128,7 +133,11 @@ class Copilot(AbstractProvider):
if msg.get("event") == "appendText":
is_started = True
yield msg.get("text")
elif msg.get("event") in ["done", "partCompleted"]:
elif msg.get("event") == "generatingImage":
image_prompt = msg.get("prompt")
elif msg.get("event") == "imageGenerated":
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
elif msg.get("event") == "done":
break
if not is_started:
raise RuntimeError(f"Last message: {msg}")
@ -152,7 +161,7 @@ class Copilot(AbstractProvider):
})()
""")
if access_token is None:
asyncio.sleep(1)
await asyncio.sleep(1)
cookies = {}
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
cookies[c.name] = c.value

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import aiohttp
import random
import string
import json
@ -11,34 +10,24 @@ from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt
from ..requests.raise_for_status import raise_for_status
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Rubiks AI"
url = "https://rubiks.ai"
api_endpoint = "https://rubiks.ai/search/api.php"
api_endpoint = "https://rubiks.ai/search/api/"
working = True
supports_stream = True
supports_system_message = True
supports_message_history = True
default_model = 'llama-3.1-70b-versatile'
models = [default_model, 'gpt-4o-mini']
default_model = 'gpt-4o-mini'
models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
model_aliases = {
"llama-3.1-70b": "llama-3.1-70b-versatile",
}
@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
return model
elif model in cls.model_aliases:
return cls.model_aliases[model]
else:
return cls.default_model
@staticmethod
def generate_mid() -> str:
"""
@ -70,7 +59,8 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
proxy: str = None,
websearch: bool = False,
web_search: bool = False,
temperature: float = 0.6,
**kwargs
) -> AsyncResult:
"""
@ -80,20 +70,18 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
- model (str): The model to use in the request.
- messages (Messages): The messages to send as a prompt.
- proxy (str, optional): Proxy URL, if needed.
- websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
- web_search (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
"""
model = cls.get_model(model)
prompt = format_prompt(messages)
q_value = prompt
mid_value = cls.generate_mid()
referer = cls.create_referer(q=q_value, mid=mid_value, model=model)
referer = cls.create_referer(q=messages[-1]["content"], mid=mid_value, model=model)
url = cls.api_endpoint
params = {
'q': q_value,
'model': model,
'id': '',
'mid': mid_value
data = {
"messages": messages,
"model": model,
"search": web_search,
"stream": True,
"temperature": temperature
}
headers = {
@ -111,52 +99,34 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Linux"'
}
async with ClientSession() as session:
async with session.post(cls.api_endpoint, headers=headers, json=data, proxy=proxy) as response:
await raise_for_status(response)
try:
timeout = aiohttp.ClientTimeout(total=None)
async with ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers, params=params, proxy=proxy) as response:
if response.status != 200:
yield f"Request ended with status code {response.status}"
return
sources = []
async for line in response.content:
decoded_line = line.decode('utf-8').strip()
if not decoded_line.startswith('data: '):
continue
data = decoded_line[6:]
if data in ('[DONE]', '{"done": ""}'):
break
try:
json_data = json.loads(data)
except json.JSONDecodeError:
continue
assistant_text = ''
sources = []
if 'url' in json_data and 'title' in json_data:
if web_search:
sources.append({'title': json_data['title'], 'url': json_data['url']})
async for line in response.content:
decoded_line = line.decode('utf-8').strip()
if not decoded_line.startswith('data: '):
continue
data = decoded_line[6:]
if data in ('[DONE]', '{"done": ""}'):
break
try:
json_data = json.loads(data)
except json.JSONDecodeError:
continue
elif 'choices' in json_data:
for choice in json_data['choices']:
delta = choice.get('delta', {})
content = delta.get('content', '')
if content:
yield content
if 'url' in json_data and 'title' in json_data:
if websearch:
sources.append({'title': json_data['title'], 'url': json_data['url']})
elif 'choices' in json_data:
for choice in json_data['choices']:
delta = choice.get('delta', {})
content = delta.get('content', '')
role = delta.get('role', '')
if role == 'assistant':
continue
assistant_text += content
if websearch and sources:
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
assistant_text += f"\n\n**Source:**\n{sources_text}"
yield assistant_text
except asyncio.CancelledError:
yield "The request was cancelled."
except aiohttp.ClientError as e:
yield f"An error occurred during the request: {e}"
except Exception as e:
yield f"An unexpected error occurred: {e}"
if web_search and sources:
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
yield f"\n\n**Source:**\n{sources_text}"

View File

@ -0,0 +1,65 @@
from __future__ import annotations
import requests
from aiohttp import ClientSession
from .OpenaiAPI import OpenaiAPI
from ...typing import AsyncResult, Messages, Cookies
from ...requests.raise_for_status import raise_for_status
from ...cookies import get_cookies
class Cerebras(OpenaiAPI):
label = "Cerebras Inference"
url = "https://inference.cerebras.ai/"
working = True
default_model = "llama3.1-70b"
fallback_models = [
"llama3.1-70b",
"llama3.1-8b",
]
model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}
@classmethod
def get_models(cls, api_key: str = None):
if not cls.models:
try:
headers = {}
if api_key:
headers["authorization"] = f"Bearer ${api_key}"
response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers)
raise_for_status(response)
data = response.json()
cls.models = [model.get("model") for model in data.get("models")]
except Exception:
cls.models = cls.fallback_models
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "https://api.cerebras.ai/v1",
api_key: str = None,
cookies: Cookies = None,
**kwargs
) -> AsyncResult:
if api_key is None and cookies is None:
cookies = get_cookies(".cerebras.ai")
async with ClientSession(cookies=cookies) as session:
async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
raise_for_status(response)
data = await response.json()
if data:
api_key = data.get("user", {}).get("demoApiKey")
async for chunk in super().create_async_generator(
model, messages,
api_base=api_base,
impersonate="chrome",
api_key=api_key,
headers={
"User-Agent": "ex/JS 1.5.0",
},
**kwargs
):
yield chunk

View File

@ -1,9 +1,12 @@
from __future__ import annotations
from ..base_provider import ProviderModelMixin
from ..Copilot import Copilot
class CopilotAccount(Copilot):
class CopilotAccount(Copilot, ProviderModelMixin):
needs_auth = True
parent = "Copilot"
default_model = "Copilot"
default_vision_model = default_model
default_vision_model = default_model
models = [default_model]
image_models = models

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from .OpenaiAPI import OpenaiAPI
from ..HuggingChat import HuggingChat
from ...typing import AsyncResult, Messages
class HuggingFace2(OpenaiAPI):
label = "HuggingFace (Inference API)"
url = "https://huggingface.co"
working = True
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_vision_model = default_model
models = [
*HuggingChat.models
]
@classmethod
def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "https://api-inference.huggingface.co/v1",
max_tokens: int = 500,
**kwargs
) -> AsyncResult:
return super().create_async_generator(
model, messages, api_base=api_base, max_tokens=max_tokens, **kwargs
)

View File

@ -34,6 +34,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
stop: Union[str, list[str]] = None,
stream: bool = False,
headers: dict = None,
impersonate: str = None,
extra_data: dict = {},
**kwargs
) -> AsyncResult:
@ -55,7 +56,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
async with StreamSession(
proxies={"all": proxy},
headers=cls.get_headers(stream, api_key, headers),
timeout=timeout
timeout=timeout,
impersonate=impersonate,
) as session:
data = filter_none(
messages=messages,

View File

@ -1,6 +1,7 @@
from .gigachat import *
from .BingCreateImages import BingCreateImages
from .Cerebras import Cerebras
from .CopilotAccount import CopilotAccount
from .DeepInfra import DeepInfra
from .DeepInfraImage import DeepInfraImage
@ -8,6 +9,7 @@ from .Gemini import Gemini
from .GeminiPro import GeminiPro
from .Groq import Groq
from .HuggingFace import HuggingFace
from .HuggingFace2 import HuggingFace2
from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount
from .OpenaiAPI import OpenaiAPI

View File

@ -12,7 +12,7 @@ from typing import Union, AsyncIterator, Iterator, Coroutine
from ..providers.base_provider import AsyncGeneratorProvider
from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP
from ..typing import Messages, Cookies, Image
from ..typing import Messages, Image
from ..providers.types import ProviderType, FinishReason, BaseConversation
from ..errors import NoImageResponseError
from ..providers.retry_provider import IterListProvider
@ -254,6 +254,8 @@ class Images:
provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
elif isinstance(provider, str):
provider_handler = convert_to_provider(provider)
else:
provider_handler = provider
if provider_handler is None:
raise ValueError(f"Unknown model: {model}")
if proxy is None:

View File

@ -128,6 +128,10 @@
<label for="BingCreateImages-api_key" class="label" title="">Microsoft Designer in Bing:</label>
<textarea id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder="&quot;_U&quot; cookie"></textarea>
</div>
<div class="field box">
<label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
<textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
@ -142,7 +146,7 @@
</div>
<div class="field box">
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" placeholder="api_key"></textarea>
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" class="HuggingFace2-api_key" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="Openai-api_key" class="label" title="">OpenAI API:</label>
@ -192,7 +196,7 @@
<div class="stop_generating stop_generating-hidden">
<button id="cancelButton">
<span>Stop Generating</span>
<i class="fa-regular fa-stop"></i>
<i class="fa-solid fa-stop"></i>
</button>
</div>
<div class="regenerate">

View File

@ -512,9 +512,7 @@ body {
@media only screen and (min-width: 40em) {
.stop_generating {
left: 50%;
transform: translateX(-50%);
right: auto;
right: 4px;
}
.toolbar .regenerate span {
display: block;

View File

@ -215,7 +215,6 @@ const register_message_buttons = async () => {
const message_el = el.parentElement.parentElement.parentElement;
el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000);
await hide_message(window.conversation_id, message_el.dataset.index);
await ask_gpt(message_el.dataset.index, get_message_id());
})
}
@ -317,6 +316,7 @@ async function remove_cancel_button() {
regenerate.addEventListener("click", async () => {
regenerate.classList.add("regenerate-hidden");
setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000);
stop_generating.classList.remove("stop_generating-hidden");
await hide_message(window.conversation_id);
await ask_gpt(-1, get_message_id());
@ -383,12 +383,12 @@ const prepare_messages = (messages, message_index = -1) => {
return new_messages;
}
async function add_message_chunk(message, message_index) {
content_map = content_storage[message_index];
async function add_message_chunk(message, message_id) {
content_map = content_storage[message_id];
if (message.type == "conversation") {
console.info("Conversation used:", message.conversation)
} else if (message.type == "provider") {
provider_storage[message_index] = message.provider;
provider_storage[message_id] = message.provider;
content_map.content.querySelector('.provider').innerHTML = `
<a href="${message.provider.url}" target="_blank">
${message.provider.label ? message.provider.label : message.provider.name}
@ -398,7 +398,7 @@ async function add_message_chunk(message, message_index) {
} else if (message.type == "message") {
console.error(message.message)
} else if (message.type == "error") {
error_storage[message_index] = message.error
error_storage[message_id] = message.error
console.error(message.error);
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${message.error}</p>`;
let p = document.createElement("p");
@ -407,8 +407,8 @@ async function add_message_chunk(message, message_index) {
} else if (message.type == "preview") {
content_map.inner.innerHTML = markdown_render(message.preview);
} else if (message.type == "content") {
message_storage[message_index] += message.content;
html = markdown_render(message_storage[message_index]);
message_storage[message_id] += message.content;
html = markdown_render(message_storage[message_id]);
let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
const index = html.lastIndexOf(element)
@ -421,7 +421,7 @@ async function add_message_chunk(message, message_index) {
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
}
content_map.inner.innerHTML = html;
content_map.count.innerText = count_words_and_tokens(message_storage[message_index], provider_storage[message_index]?.model);
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner);
} else if (message.type == "log") {
let p = document.createElement("p");
@ -453,7 +453,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
let total_messages = messages.length;
messages = prepare_messages(messages, message_index);
message_index = total_messages
message_storage[message_index] = "";
message_storage[message_id] = "";
stop_generating.classList.remove(".stop_generating-hidden");
message_box.scrollTop = message_box.scrollHeight;
@ -477,10 +477,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
</div>
`;
controller_storage[message_index] = new AbortController();
controller_storage[message_id] = new AbortController();
let content_el = document.getElementById(`gpt_${message_id}`)
let content_map = content_storage[message_index] = {
let content_map = content_storage[message_id] = {
content: content_el,
inner: content_el.querySelector('.content_inner'),
count: content_el.querySelector('.count'),
@ -492,12 +492,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
const file = input && input.files.length > 0 ? input.files[0] : null;
const provider = providerSelect.options[providerSelect.selectedIndex].value;
const auto_continue = document.getElementById("auto_continue")?.checked;
let api_key = null;
if (provider) {
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
if (api_key == null)
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
}
let api_key = get_api_key_by_provider(provider);
await api("conversation", {
id: message_id,
conversation_id: window.conversation_id,
@ -506,10 +501,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
provider: provider,
messages: messages,
auto_continue: auto_continue,
api_key: api_key
}, file, message_index);
if (!error_storage[message_index]) {
html = markdown_render(message_storage[message_index]);
api_key: api_key,
}, file, message_id);
if (!error_storage[message_id]) {
html = markdown_render(message_storage[message_id]);
content_map.inner.innerHTML = html;
highlight(content_map.inner);
@ -520,14 +515,14 @@ const ask_gpt = async (message_index = -1, message_id) => {
} catch (e) {
console.error(e);
if (e.name != "AbortError") {
error_storage[message_index] = true;
error_storage[message_id] = true;
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${e}</p>`;
}
}
delete controller_storage[message_index];
if (!error_storage[message_index] && message_storage[message_index]) {
const message_provider = message_index in provider_storage ? provider_storage[message_index] : null;
await add_message(window.conversation_id, "assistant", message_storage[message_index], message_provider);
delete controller_storage[message_id];
if (!error_storage[message_id] && message_storage[message_id]) {
const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider);
await safe_load_conversation(window.conversation_id);
} else {
let cursorDiv = message_box.querySelector(".cursor");
@ -1156,7 +1151,7 @@ async function on_api() {
evt.preventDefault();
console.log("pressed enter");
prompt_lock = true;
setTimeout(()=>prompt_lock=false, 3);
setTimeout(()=>prompt_lock=false, 3000);
await handle_ask();
} else {
messageInput.style.removeProperty("height");
@ -1167,7 +1162,7 @@ async function on_api() {
console.log("clicked send");
if (prompt_lock) return;
prompt_lock = true;
setTimeout(()=>prompt_lock=false, 3);
setTimeout(()=>prompt_lock=false, 3000);
await handle_ask();
});
messageInput.focus();
@ -1189,8 +1184,8 @@ async function on_api() {
providerSelect.appendChild(option);
})
await load_provider_models(appStorage.getItem("provider"));
await load_settings_storage()
await load_provider_models(appStorage.getItem("provider"));
const hide_systemPrompt = document.getElementById("hide-systemPrompt")
const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i");
@ -1316,7 +1311,7 @@ function get_selected_model() {
}
}
async function api(ressource, args=null, file=null, message_index=null) {
async function api(ressource, args=null, file=null, message_id=null) {
if (window?.pywebview) {
if (args !== null) {
if (ressource == "models") {
@ -1326,15 +1321,19 @@ async function api(ressource, args=null, file=null, message_index=null) {
}
return pywebview.api[`get_${ressource}`]();
}
let api_key;
if (ressource == "models" && args) {
api_key = get_api_key_by_provider(args);
ressource = `${ressource}/${args}`;
}
const url = `/backend-api/v2/${ressource}`;
const headers = {};
if (api_key) {
headers.authorization = `Bearer ${api_key}`;
}
if (ressource == "conversation") {
let body = JSON.stringify(args);
const headers = {
accept: 'text/event-stream'
}
headers.accept = 'text/event-stream';
if (file !== null) {
const formData = new FormData();
formData.append('file', file);
@ -1345,17 +1344,17 @@ async function api(ressource, args=null, file=null, message_index=null) {
}
response = await fetch(url, {
method: 'POST',
signal: controller_storage[message_index].signal,
signal: controller_storage[message_id].signal,
headers: headers,
body: body
body: body,
});
return read_response(response, message_index);
return read_response(response, message_id);
}
response = await fetch(url);
response = await fetch(url, {headers: headers});
return await response.json();
}
async function read_response(response, message_index) {
async function read_response(response, message_id) {
const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
let buffer = ""
while (true) {
@ -1368,7 +1367,7 @@ async function read_response(response, message_index) {
continue;
}
try {
add_message_chunk(JSON.parse(buffer + line), message_index);
add_message_chunk(JSON.parse(buffer + line), message_id);
buffer = "";
} catch {
buffer += line
@ -1377,6 +1376,16 @@ async function read_response(response, message_index) {
}
}
function get_api_key_by_provider(provider) {
let api_key = null;
if (provider) {
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
if (api_key == null)
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
}
return api_key;
}
async function load_provider_models(providerIndex=null) {
if (!providerIndex) {
providerIndex = providerSelect.selectedIndex;

View File

@ -38,10 +38,11 @@ class Api:
return models._all_models
@staticmethod
def get_provider_models(provider: str) -> list[dict]:
def get_provider_models(provider: str, api_key: str = None) -> list[dict]:
if provider in __map__:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key)
return [
{
"model": model,
@ -49,7 +50,7 @@ class Api:
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
"image": model in getattr(provider, "image_models", []),
}
for model in provider.get_models()
for model in models
]
return []

View File

@ -94,7 +94,8 @@ class Backend_Api(Api):
)
def get_provider_models(self, provider: str):
models = super().get_provider_models(provider)
api_key = None if request.authorization is None else request.authorization.token
models = super().get_provider_models(provider, api_key)
if models is None:
return 404, "Provider not found"
return models

View File

@ -11,7 +11,9 @@ class CloudflareError(ResponseStatusError):
...
def is_cloudflare(text: str) -> bool:
if "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
if "Generated by cloudfront" in text:
return True
elif "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
return True
return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text