Add Cerebras and HuggingFace2 provider, Fix RubiksAI provider

Add support for image generation in Copilot provider
This commit is contained in:
Heiner Lohaus 2024-11-20 02:34:47 +01:00
parent 275404373f
commit 58fa409eef
14 changed files with 234 additions and 132 deletions

View File

@ -28,6 +28,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
image_models = [default_image_model, 'repomap'] image_models = [default_image_model, 'repomap']
text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro'] text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'blackboxai-pro'] vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'blackboxai-pro']
model_aliases = {
"claude-3.5-sonnet": "claude-sonnet-3.5",
}
agentMode = { agentMode = {
default_image_model: {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"}, default_image_model: {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
} }
@ -198,6 +201,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
async with ClientSession(headers=headers) as session: async with ClientSession(headers=headers) as session:
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
response.raise_for_status() response.raise_for_status()
is_first = False
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
text_chunk = chunk.decode(errors="ignore") text_chunk = chunk.decode(errors="ignore")
if model in cls.image_models: if model in cls.image_models:
@ -217,5 +221,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
for i, result in enumerate(search_results, 1): for i, result in enumerate(search_results, 1):
formatted_response += f"\n{i}. {result['title']}: {result['link']}" formatted_response += f"\n{i}. {result['title']}: {result['link']}"
yield formatted_response yield formatted_response
else: elif text_chunk:
yield text_chunk.strip() if is_first:
is_first = False
yield text_chunk.lstrip()
else:
yield text_chunk

View File

@ -21,8 +21,9 @@ from .helper import format_prompt
from ..typing import CreateResult, Messages, ImageType from ..typing import CreateResult, Messages, ImageType
from ..errors import MissingRequirementsError from ..errors import MissingRequirementsError
from ..requests.raise_for_status import raise_for_status from ..requests.raise_for_status import raise_for_status
from ..providers.helper import format_cookies
from ..requests import get_nodriver from ..requests import get_nodriver
from ..image import to_bytes, is_accepted_format from ..image import ImageResponse, to_bytes, is_accepted_format
from .. import debug from .. import debug
class Conversation(BaseConversation): class Conversation(BaseConversation):
@ -70,18 +71,21 @@ class Copilot(AbstractProvider):
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy)) access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy))
else: else:
access_token = conversation.access_token access_token = conversation.access_token
websocket_url = f"{websocket_url}&acessToken={quote(access_token)}" debug.log(f"Copilot: Access token: {access_token[:7]}...{access_token[-5:]}")
headers = {"Authorization": f"Bearer {access_token}"} debug.log(f"Copilot: Cookies: {';'.join([*cookies])}")
websocket_url = f"{websocket_url}&accessToken={quote(access_token)}"
headers = {"authorization": f"Bearer {access_token}", "cookie": format_cookies(cookies)}
with Session( with Session(
timeout=timeout, timeout=timeout,
proxy=proxy, proxy=proxy,
impersonate="chrome", impersonate="chrome",
headers=headers, headers=headers,
cookies=cookies cookies=cookies,
) as session: ) as session:
response = session.get(f"{cls.url}/") response = session.get("https://copilot.microsoft.com/c/api/user")
raise_for_status(response) raise_for_status(response)
debug.log(f"Copilot: User: {response.json().get('firstName', 'null')}")
if conversation is None: if conversation is None:
response = session.post(cls.conversation_url) response = session.post(cls.conversation_url)
raise_for_status(response) raise_for_status(response)
@ -119,6 +123,7 @@ class Copilot(AbstractProvider):
is_started = False is_started = False
msg = None msg = None
image_prompt: str = None
while True: while True:
try: try:
msg = wss.recv()[0] msg = wss.recv()[0]
@ -128,7 +133,11 @@ class Copilot(AbstractProvider):
if msg.get("event") == "appendText": if msg.get("event") == "appendText":
is_started = True is_started = True
yield msg.get("text") yield msg.get("text")
elif msg.get("event") in ["done", "partCompleted"]: elif msg.get("event") == "generatingImage":
image_prompt = msg.get("prompt")
elif msg.get("event") == "imageGenerated":
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
elif msg.get("event") == "done":
break break
if not is_started: if not is_started:
raise RuntimeError(f"Last message: {msg}") raise RuntimeError(f"Last message: {msg}")
@ -152,7 +161,7 @@ class Copilot(AbstractProvider):
})() })()
""") """)
if access_token is None: if access_token is None:
asyncio.sleep(1) await asyncio.sleep(1)
cookies = {} cookies = {}
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])): for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
cookies[c.name] = c.value cookies[c.name] = c.value

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import aiohttp
import random import random
import string import string
import json import json
@ -11,34 +10,24 @@ from aiohttp import ClientSession
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt from ..requests.raise_for_status import raise_for_status
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin): class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Rubiks AI" label = "Rubiks AI"
url = "https://rubiks.ai" url = "https://rubiks.ai"
api_endpoint = "https://rubiks.ai/search/api.php" api_endpoint = "https://rubiks.ai/search/api/"
working = True working = True
supports_stream = True supports_stream = True
supports_system_message = True supports_system_message = True
supports_message_history = True supports_message_history = True
default_model = 'llama-3.1-70b-versatile' default_model = 'gpt-4o-mini'
models = [default_model, 'gpt-4o-mini'] models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
model_aliases = { model_aliases = {
"llama-3.1-70b": "llama-3.1-70b-versatile", "llama-3.1-70b": "llama-3.1-70b-versatile",
} }
@classmethod
def get_model(cls, model: str) -> str:
if model in cls.models:
return model
elif model in cls.model_aliases:
return cls.model_aliases[model]
else:
return cls.default_model
@staticmethod @staticmethod
def generate_mid() -> str: def generate_mid() -> str:
""" """
@ -70,7 +59,8 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
model: str, model: str,
messages: Messages, messages: Messages,
proxy: str = None, proxy: str = None,
websearch: bool = False, web_search: bool = False,
temperature: float = 0.6,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
""" """
@ -80,20 +70,18 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
- model (str): The model to use in the request. - model (str): The model to use in the request.
- messages (Messages): The messages to send as a prompt. - messages (Messages): The messages to send as a prompt.
- proxy (str, optional): Proxy URL, if needed. - proxy (str, optional): Proxy URL, if needed.
- websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False. - web_search (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
""" """
model = cls.get_model(model) model = cls.get_model(model)
prompt = format_prompt(messages)
q_value = prompt
mid_value = cls.generate_mid() mid_value = cls.generate_mid()
referer = cls.create_referer(q=q_value, mid=mid_value, model=model) referer = cls.create_referer(q=messages[-1]["content"], mid=mid_value, model=model)
url = cls.api_endpoint data = {
params = { "messages": messages,
'q': q_value, "model": model,
'model': model, "search": web_search,
'id': '', "stream": True,
'mid': mid_value "temperature": temperature
} }
headers = { headers = {
@ -111,52 +99,34 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
'sec-ch-ua-mobile': '?0', 'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Linux"' 'sec-ch-ua-platform': '"Linux"'
} }
async with ClientSession() as session:
async with session.post(cls.api_endpoint, headers=headers, json=data, proxy=proxy) as response:
await raise_for_status(response)
try: sources = []
timeout = aiohttp.ClientTimeout(total=None) async for line in response.content:
async with ClientSession(timeout=timeout) as session: decoded_line = line.decode('utf-8').strip()
async with session.get(url, headers=headers, params=params, proxy=proxy) as response: if not decoded_line.startswith('data: '):
if response.status != 200: continue
yield f"Request ended with status code {response.status}" data = decoded_line[6:]
return if data in ('[DONE]', '{"done": ""}'):
break
try:
json_data = json.loads(data)
except json.JSONDecodeError:
continue
assistant_text = '' if 'url' in json_data and 'title' in json_data:
sources = [] if web_search:
sources.append({'title': json_data['title'], 'url': json_data['url']})
async for line in response.content: elif 'choices' in json_data:
decoded_line = line.decode('utf-8').strip() for choice in json_data['choices']:
if not decoded_line.startswith('data: '): delta = choice.get('delta', {})
continue content = delta.get('content', '')
data = decoded_line[6:] if content:
if data in ('[DONE]', '{"done": ""}'): yield content
break
try:
json_data = json.loads(data)
except json.JSONDecodeError:
continue
if 'url' in json_data and 'title' in json_data: if web_search and sources:
if websearch: sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
sources.append({'title': json_data['title'], 'url': json_data['url']}) yield f"\n\n**Source:**\n{sources_text}"
elif 'choices' in json_data:
for choice in json_data['choices']:
delta = choice.get('delta', {})
content = delta.get('content', '')
role = delta.get('role', '')
if role == 'assistant':
continue
assistant_text += content
if websearch and sources:
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
assistant_text += f"\n\n**Source:**\n{sources_text}"
yield assistant_text
except asyncio.CancelledError:
yield "The request was cancelled."
except aiohttp.ClientError as e:
yield f"An error occurred during the request: {e}"
except Exception as e:
yield f"An unexpected error occurred: {e}"

View File

@ -0,0 +1,65 @@
from __future__ import annotations
import requests
from aiohttp import ClientSession
from .OpenaiAPI import OpenaiAPI
from ...typing import AsyncResult, Messages, Cookies
from ...requests.raise_for_status import raise_for_status
from ...cookies import get_cookies
class Cerebras(OpenaiAPI):
label = "Cerebras Inference"
url = "https://inference.cerebras.ai/"
working = True
default_model = "llama3.1-70b"
fallback_models = [
"llama3.1-70b",
"llama3.1-8b",
]
model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}
@classmethod
def get_models(cls, api_key: str = None):
if not cls.models:
try:
headers = {}
if api_key:
headers["authorization"] = f"Bearer ${api_key}"
response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers)
raise_for_status(response)
data = response.json()
cls.models = [model.get("model") for model in data.get("models")]
except Exception:
cls.models = cls.fallback_models
return cls.models
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "https://api.cerebras.ai/v1",
api_key: str = None,
cookies: Cookies = None,
**kwargs
) -> AsyncResult:
if api_key is None and cookies is None:
cookies = get_cookies(".cerebras.ai")
async with ClientSession(cookies=cookies) as session:
async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
raise_for_status(response)
data = await response.json()
if data:
api_key = data.get("user", {}).get("demoApiKey")
async for chunk in super().create_async_generator(
model, messages,
api_base=api_base,
impersonate="chrome",
api_key=api_key,
headers={
"User-Agent": "ex/JS 1.5.0",
},
**kwargs
):
yield chunk

View File

@ -1,9 +1,12 @@
from __future__ import annotations from __future__ import annotations
from ..base_provider import ProviderModelMixin
from ..Copilot import Copilot from ..Copilot import Copilot
class CopilotAccount(Copilot): class CopilotAccount(Copilot, ProviderModelMixin):
needs_auth = True needs_auth = True
parent = "Copilot" parent = "Copilot"
default_model = "Copilot" default_model = "Copilot"
default_vision_model = default_model default_vision_model = default_model
models = [default_model]
image_models = models

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from .OpenaiAPI import OpenaiAPI
from ..HuggingChat import HuggingChat
from ...typing import AsyncResult, Messages
class HuggingFace2(OpenaiAPI):
label = "HuggingFace (Inference API)"
url = "https://huggingface.co"
working = True
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
default_vision_model = default_model
models = [
*HuggingChat.models
]
@classmethod
def create_async_generator(
cls,
model: str,
messages: Messages,
api_base: str = "https://api-inference.huggingface.co/v1",
max_tokens: int = 500,
**kwargs
) -> AsyncResult:
return super().create_async_generator(
model, messages, api_base=api_base, max_tokens=max_tokens, **kwargs
)

View File

@ -34,6 +34,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
stop: Union[str, list[str]] = None, stop: Union[str, list[str]] = None,
stream: bool = False, stream: bool = False,
headers: dict = None, headers: dict = None,
impersonate: str = None,
extra_data: dict = {}, extra_data: dict = {},
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
@ -55,7 +56,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
async with StreamSession( async with StreamSession(
proxies={"all": proxy}, proxies={"all": proxy},
headers=cls.get_headers(stream, api_key, headers), headers=cls.get_headers(stream, api_key, headers),
timeout=timeout timeout=timeout,
impersonate=impersonate,
) as session: ) as session:
data = filter_none( data = filter_none(
messages=messages, messages=messages,

View File

@ -1,6 +1,7 @@
from .gigachat import * from .gigachat import *
from .BingCreateImages import BingCreateImages from .BingCreateImages import BingCreateImages
from .Cerebras import Cerebras
from .CopilotAccount import CopilotAccount from .CopilotAccount import CopilotAccount
from .DeepInfra import DeepInfra from .DeepInfra import DeepInfra
from .DeepInfraImage import DeepInfraImage from .DeepInfraImage import DeepInfraImage
@ -8,6 +9,7 @@ from .Gemini import Gemini
from .GeminiPro import GeminiPro from .GeminiPro import GeminiPro
from .Groq import Groq from .Groq import Groq
from .HuggingFace import HuggingFace from .HuggingFace import HuggingFace
from .HuggingFace2 import HuggingFace2
from .MetaAI import MetaAI from .MetaAI import MetaAI
from .MetaAIAccount import MetaAIAccount from .MetaAIAccount import MetaAIAccount
from .OpenaiAPI import OpenaiAPI from .OpenaiAPI import OpenaiAPI

View File

@ -128,6 +128,10 @@
<label for="BingCreateImages-api_key" class="label" title="">Microsoft Designer in Bing:</label> <label for="BingCreateImages-api_key" class="label" title="">Microsoft Designer in Bing:</label>
<textarea id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder="&quot;_U&quot; cookie"></textarea> <textarea id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder="&quot;_U&quot; cookie"></textarea>
</div> </div>
<div class="field box">
<label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
<textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box"> <div class="field box">
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label> <label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea> <textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
@ -142,7 +146,7 @@
</div> </div>
<div class="field box"> <div class="field box">
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label> <label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" placeholder="api_key"></textarea> <textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" class="HuggingFace2-api_key" placeholder="api_key"></textarea>
</div> </div>
<div class="field box"> <div class="field box">
<label for="Openai-api_key" class="label" title="">OpenAI API:</label> <label for="Openai-api_key" class="label" title="">OpenAI API:</label>
@ -192,7 +196,7 @@
<div class="stop_generating stop_generating-hidden"> <div class="stop_generating stop_generating-hidden">
<button id="cancelButton"> <button id="cancelButton">
<span>Stop Generating</span> <span>Stop Generating</span>
<i class="fa-regular fa-stop"></i> <i class="fa-solid fa-stop"></i>
</button> </button>
</div> </div>
<div class="regenerate"> <div class="regenerate">

View File

@ -512,9 +512,7 @@ body {
@media only screen and (min-width: 40em) { @media only screen and (min-width: 40em) {
.stop_generating { .stop_generating {
left: 50%; right: 4px;
transform: translateX(-50%);
right: auto;
} }
.toolbar .regenerate span { .toolbar .regenerate span {
display: block; display: block;

View File

@ -215,7 +215,6 @@ const register_message_buttons = async () => {
const message_el = el.parentElement.parentElement.parentElement; const message_el = el.parentElement.parentElement.parentElement;
el.classList.add("clicked"); el.classList.add("clicked");
setTimeout(() => el.classList.remove("clicked"), 1000); setTimeout(() => el.classList.remove("clicked"), 1000);
await hide_message(window.conversation_id, message_el.dataset.index);
await ask_gpt(message_el.dataset.index, get_message_id()); await ask_gpt(message_el.dataset.index, get_message_id());
}) })
} }
@ -317,6 +316,7 @@ async function remove_cancel_button() {
regenerate.addEventListener("click", async () => { regenerate.addEventListener("click", async () => {
regenerate.classList.add("regenerate-hidden"); regenerate.classList.add("regenerate-hidden");
setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000);
stop_generating.classList.remove("stop_generating-hidden"); stop_generating.classList.remove("stop_generating-hidden");
await hide_message(window.conversation_id); await hide_message(window.conversation_id);
await ask_gpt(-1, get_message_id()); await ask_gpt(-1, get_message_id());
@ -383,12 +383,12 @@ const prepare_messages = (messages, message_index = -1) => {
return new_messages; return new_messages;
} }
async function add_message_chunk(message, message_index) { async function add_message_chunk(message, message_id) {
content_map = content_storage[message_index]; content_map = content_storage[message_id];
if (message.type == "conversation") { if (message.type == "conversation") {
console.info("Conversation used:", message.conversation) console.info("Conversation used:", message.conversation)
} else if (message.type == "provider") { } else if (message.type == "provider") {
provider_storage[message_index] = message.provider; provider_storage[message_id] = message.provider;
content_map.content.querySelector('.provider').innerHTML = ` content_map.content.querySelector('.provider').innerHTML = `
<a href="${message.provider.url}" target="_blank"> <a href="${message.provider.url}" target="_blank">
${message.provider.label ? message.provider.label : message.provider.name} ${message.provider.label ? message.provider.label : message.provider.name}
@ -398,7 +398,7 @@ async function add_message_chunk(message, message_index) {
} else if (message.type == "message") { } else if (message.type == "message") {
console.error(message.message) console.error(message.message)
} else if (message.type == "error") { } else if (message.type == "error") {
error_storage[message_index] = message.error error_storage[message_id] = message.error
console.error(message.error); console.error(message.error);
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${message.error}</p>`; content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${message.error}</p>`;
let p = document.createElement("p"); let p = document.createElement("p");
@ -407,8 +407,8 @@ async function add_message_chunk(message, message_index) {
} else if (message.type == "preview") { } else if (message.type == "preview") {
content_map.inner.innerHTML = markdown_render(message.preview); content_map.inner.innerHTML = markdown_render(message.preview);
} else if (message.type == "content") { } else if (message.type == "content") {
message_storage[message_index] += message.content; message_storage[message_id] += message.content;
html = markdown_render(message_storage[message_index]); html = markdown_render(message_storage[message_id]);
let lastElement, lastIndex = null; let lastElement, lastIndex = null;
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) { for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
const index = html.lastIndexOf(element) const index = html.lastIndexOf(element)
@ -421,7 +421,7 @@ async function add_message_chunk(message, message_index) {
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement; html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
} }
content_map.inner.innerHTML = html; content_map.inner.innerHTML = html;
content_map.count.innerText = count_words_and_tokens(message_storage[message_index], provider_storage[message_index]?.model); content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
highlight(content_map.inner); highlight(content_map.inner);
} else if (message.type == "log") { } else if (message.type == "log") {
let p = document.createElement("p"); let p = document.createElement("p");
@ -453,7 +453,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
let total_messages = messages.length; let total_messages = messages.length;
messages = prepare_messages(messages, message_index); messages = prepare_messages(messages, message_index);
message_index = total_messages message_index = total_messages
message_storage[message_index] = ""; message_storage[message_id] = "";
stop_generating.classList.remove(".stop_generating-hidden"); stop_generating.classList.remove(".stop_generating-hidden");
message_box.scrollTop = message_box.scrollHeight; message_box.scrollTop = message_box.scrollHeight;
@ -477,10 +477,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
</div> </div>
`; `;
controller_storage[message_index] = new AbortController(); controller_storage[message_id] = new AbortController();
let content_el = document.getElementById(`gpt_${message_id}`) let content_el = document.getElementById(`gpt_${message_id}`)
let content_map = content_storage[message_index] = { let content_map = content_storage[message_id] = {
content: content_el, content: content_el,
inner: content_el.querySelector('.content_inner'), inner: content_el.querySelector('.content_inner'),
count: content_el.querySelector('.count'), count: content_el.querySelector('.count'),
@ -492,12 +492,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
const file = input && input.files.length > 0 ? input.files[0] : null; const file = input && input.files.length > 0 ? input.files[0] : null;
const provider = providerSelect.options[providerSelect.selectedIndex].value; const provider = providerSelect.options[providerSelect.selectedIndex].value;
const auto_continue = document.getElementById("auto_continue")?.checked; const auto_continue = document.getElementById("auto_continue")?.checked;
let api_key = null; let api_key = get_api_key_by_provider(provider);
if (provider) {
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
if (api_key == null)
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
}
await api("conversation", { await api("conversation", {
id: message_id, id: message_id,
conversation_id: window.conversation_id, conversation_id: window.conversation_id,
@ -506,10 +501,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
provider: provider, provider: provider,
messages: messages, messages: messages,
auto_continue: auto_continue, auto_continue: auto_continue,
api_key: api_key api_key: api_key,
}, file, message_index); }, file, message_id);
if (!error_storage[message_index]) { if (!error_storage[message_id]) {
html = markdown_render(message_storage[message_index]); html = markdown_render(message_storage[message_id]);
content_map.inner.innerHTML = html; content_map.inner.innerHTML = html;
highlight(content_map.inner); highlight(content_map.inner);
@ -520,14 +515,14 @@ const ask_gpt = async (message_index = -1, message_id) => {
} catch (e) { } catch (e) {
console.error(e); console.error(e);
if (e.name != "AbortError") { if (e.name != "AbortError") {
error_storage[message_index] = true; error_storage[message_id] = true;
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${e}</p>`; content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${e}</p>`;
} }
} }
delete controller_storage[message_index]; delete controller_storage[message_id];
if (!error_storage[message_index] && message_storage[message_index]) { if (!error_storage[message_id] && message_storage[message_id]) {
const message_provider = message_index in provider_storage ? provider_storage[message_index] : null; const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
await add_message(window.conversation_id, "assistant", message_storage[message_index], message_provider); await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider);
await safe_load_conversation(window.conversation_id); await safe_load_conversation(window.conversation_id);
} else { } else {
let cursorDiv = message_box.querySelector(".cursor"); let cursorDiv = message_box.querySelector(".cursor");
@ -1156,7 +1151,7 @@ async function on_api() {
evt.preventDefault(); evt.preventDefault();
console.log("pressed enter"); console.log("pressed enter");
prompt_lock = true; prompt_lock = true;
setTimeout(()=>prompt_lock=false, 3); setTimeout(()=>prompt_lock=false, 3000);
await handle_ask(); await handle_ask();
} else { } else {
messageInput.style.removeProperty("height"); messageInput.style.removeProperty("height");
@ -1167,7 +1162,7 @@ async function on_api() {
console.log("clicked send"); console.log("clicked send");
if (prompt_lock) return; if (prompt_lock) return;
prompt_lock = true; prompt_lock = true;
setTimeout(()=>prompt_lock=false, 3); setTimeout(()=>prompt_lock=false, 3000);
await handle_ask(); await handle_ask();
}); });
messageInput.focus(); messageInput.focus();
@ -1189,8 +1184,8 @@ async function on_api() {
providerSelect.appendChild(option); providerSelect.appendChild(option);
}) })
await load_provider_models(appStorage.getItem("provider"));
await load_settings_storage() await load_settings_storage()
await load_provider_models(appStorage.getItem("provider"));
const hide_systemPrompt = document.getElementById("hide-systemPrompt") const hide_systemPrompt = document.getElementById("hide-systemPrompt")
const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i"); const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i");
@ -1316,7 +1311,7 @@ function get_selected_model() {
} }
} }
async function api(ressource, args=null, file=null, message_index=null) { async function api(ressource, args=null, file=null, message_id=null) {
if (window?.pywebview) { if (window?.pywebview) {
if (args !== null) { if (args !== null) {
if (ressource == "models") { if (ressource == "models") {
@ -1326,15 +1321,19 @@ async function api(ressource, args=null, file=null, message_index=null) {
} }
return pywebview.api[`get_${ressource}`](); return pywebview.api[`get_${ressource}`]();
} }
let api_key;
if (ressource == "models" && args) { if (ressource == "models" && args) {
api_key = get_api_key_by_provider(args);
ressource = `${ressource}/${args}`; ressource = `${ressource}/${args}`;
} }
const url = `/backend-api/v2/${ressource}`; const url = `/backend-api/v2/${ressource}`;
const headers = {};
if (api_key) {
headers.authorization = `Bearer ${api_key}`;
}
if (ressource == "conversation") { if (ressource == "conversation") {
let body = JSON.stringify(args); let body = JSON.stringify(args);
const headers = { headers.accept = 'text/event-stream';
accept: 'text/event-stream'
}
if (file !== null) { if (file !== null) {
const formData = new FormData(); const formData = new FormData();
formData.append('file', file); formData.append('file', file);
@ -1345,17 +1344,17 @@ async function api(ressource, args=null, file=null, message_index=null) {
} }
response = await fetch(url, { response = await fetch(url, {
method: 'POST', method: 'POST',
signal: controller_storage[message_index].signal, signal: controller_storage[message_id].signal,
headers: headers, headers: headers,
body: body body: body,
}); });
return read_response(response, message_index); return read_response(response, message_id);
} }
response = await fetch(url); response = await fetch(url, {headers: headers});
return await response.json(); return await response.json();
} }
async function read_response(response, message_index) { async function read_response(response, message_id) {
const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
let buffer = "" let buffer = ""
while (true) { while (true) {
@ -1368,7 +1367,7 @@ async function read_response(response, message_index) {
continue; continue;
} }
try { try {
add_message_chunk(JSON.parse(buffer + line), message_index); add_message_chunk(JSON.parse(buffer + line), message_id);
buffer = ""; buffer = "";
} catch { } catch {
buffer += line buffer += line
@ -1377,6 +1376,16 @@ async function read_response(response, message_index) {
} }
} }
function get_api_key_by_provider(provider) {
let api_key = null;
if (provider) {
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
if (api_key == null)
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
}
return api_key;
}
async function load_provider_models(providerIndex=null) { async function load_provider_models(providerIndex=null) {
if (!providerIndex) { if (!providerIndex) {
providerIndex = providerSelect.selectedIndex; providerIndex = providerSelect.selectedIndex;

View File

@ -38,10 +38,11 @@ class Api:
return models._all_models return models._all_models
@staticmethod @staticmethod
def get_provider_models(provider: str) -> list[dict]: def get_provider_models(provider: str, api_key: str = None) -> list[dict]:
if provider in __map__: if provider in __map__:
provider: ProviderType = __map__[provider] provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin): if issubclass(provider, ProviderModelMixin):
models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key)
return [ return [
{ {
"model": model, "model": model,
@ -49,7 +50,7 @@ class Api:
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []), "vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
"image": model in getattr(provider, "image_models", []), "image": model in getattr(provider, "image_models", []),
} }
for model in provider.get_models() for model in models
] ]
return [] return []

View File

@ -94,7 +94,8 @@ class Backend_Api(Api):
) )
def get_provider_models(self, provider: str): def get_provider_models(self, provider: str):
models = super().get_provider_models(provider) api_key = None if request.authorization is None else request.authorization.token
models = super().get_provider_models(provider, api_key)
if models is None: if models is None:
return 404, "Provider not found" return 404, "Provider not found"
return models return models

View File

@ -11,7 +11,9 @@ class CloudflareError(ResponseStatusError):
... ...
def is_cloudflare(text: str) -> bool: def is_cloudflare(text: str) -> bool:
if "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text: if "Generated by cloudfront" in text:
return True
elif "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
return True return True
return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text