mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-12-25 12:16:17 +03:00
Merge pull request #2389 from hlohaus/info
Add Cerebras and HuggingFace2 provider, Fix RubiksAI provider
This commit is contained in:
commit
419264f966
@ -28,6 +28,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
image_models = [default_image_model, 'repomap']
|
||||
text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
|
||||
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'blackboxai-pro']
|
||||
model_aliases = {
|
||||
"claude-3.5-sonnet": "claude-sonnet-3.5",
|
||||
}
|
||||
agentMode = {
|
||||
default_image_model: {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
|
||||
}
|
||||
@ -198,6 +201,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
async with ClientSession(headers=headers) as session:
|
||||
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
is_first = False
|
||||
async for chunk in response.content.iter_any():
|
||||
text_chunk = chunk.decode(errors="ignore")
|
||||
if model in cls.image_models:
|
||||
@ -217,5 +221,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
for i, result in enumerate(search_results, 1):
|
||||
formatted_response += f"\n{i}. {result['title']}: {result['link']}"
|
||||
yield formatted_response
|
||||
else:
|
||||
yield text_chunk.strip()
|
||||
elif text_chunk:
|
||||
if is_first:
|
||||
is_first = False
|
||||
yield text_chunk.lstrip()
|
||||
else:
|
||||
yield text_chunk
|
@ -21,8 +21,9 @@ from .helper import format_prompt
|
||||
from ..typing import CreateResult, Messages, ImageType
|
||||
from ..errors import MissingRequirementsError
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from ..providers.helper import format_cookies
|
||||
from ..requests import get_nodriver
|
||||
from ..image import to_bytes, is_accepted_format
|
||||
from ..image import ImageResponse, to_bytes, is_accepted_format
|
||||
from .. import debug
|
||||
|
||||
class Conversation(BaseConversation):
|
||||
@ -70,18 +71,21 @@ class Copilot(AbstractProvider):
|
||||
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy))
|
||||
else:
|
||||
access_token = conversation.access_token
|
||||
websocket_url = f"{websocket_url}&acessToken={quote(access_token)}"
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
debug.log(f"Copilot: Access token: {access_token[:7]}...{access_token[-5:]}")
|
||||
debug.log(f"Copilot: Cookies: {';'.join([*cookies])}")
|
||||
websocket_url = f"{websocket_url}&accessToken={quote(access_token)}"
|
||||
headers = {"authorization": f"Bearer {access_token}", "cookie": format_cookies(cookies)}
|
||||
|
||||
with Session(
|
||||
timeout=timeout,
|
||||
proxy=proxy,
|
||||
impersonate="chrome",
|
||||
headers=headers,
|
||||
cookies=cookies
|
||||
cookies=cookies,
|
||||
) as session:
|
||||
response = session.get(f"{cls.url}/")
|
||||
response = session.get("https://copilot.microsoft.com/c/api/user")
|
||||
raise_for_status(response)
|
||||
debug.log(f"Copilot: User: {response.json().get('firstName', 'null')}")
|
||||
if conversation is None:
|
||||
response = session.post(cls.conversation_url)
|
||||
raise_for_status(response)
|
||||
@ -119,6 +123,7 @@ class Copilot(AbstractProvider):
|
||||
|
||||
is_started = False
|
||||
msg = None
|
||||
image_prompt: str = None
|
||||
while True:
|
||||
try:
|
||||
msg = wss.recv()[0]
|
||||
@ -128,7 +133,11 @@ class Copilot(AbstractProvider):
|
||||
if msg.get("event") == "appendText":
|
||||
is_started = True
|
||||
yield msg.get("text")
|
||||
elif msg.get("event") in ["done", "partCompleted"]:
|
||||
elif msg.get("event") == "generatingImage":
|
||||
image_prompt = msg.get("prompt")
|
||||
elif msg.get("event") == "imageGenerated":
|
||||
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
|
||||
elif msg.get("event") == "done":
|
||||
break
|
||||
if not is_started:
|
||||
raise RuntimeError(f"Last message: {msg}")
|
||||
@ -152,7 +161,7 @@ class Copilot(AbstractProvider):
|
||||
})()
|
||||
""")
|
||||
if access_token is None:
|
||||
asyncio.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
cookies = {}
|
||||
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
|
||||
cookies[c.name] = c.value
|
||||
|
@ -1,7 +1,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import random
|
||||
import string
|
||||
import json
|
||||
@ -11,34 +10,24 @@ from aiohttp import ClientSession
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .helper import format_prompt
|
||||
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
|
||||
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "Rubiks AI"
|
||||
url = "https://rubiks.ai"
|
||||
api_endpoint = "https://rubiks.ai/search/api.php"
|
||||
api_endpoint = "https://rubiks.ai/search/api/"
|
||||
working = True
|
||||
supports_stream = True
|
||||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
|
||||
default_model = 'llama-3.1-70b-versatile'
|
||||
models = [default_model, 'gpt-4o-mini']
|
||||
default_model = 'gpt-4o-mini'
|
||||
models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
|
||||
|
||||
model_aliases = {
|
||||
"llama-3.1-70b": "llama-3.1-70b-versatile",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str) -> str:
|
||||
if model in cls.models:
|
||||
return model
|
||||
elif model in cls.model_aliases:
|
||||
return cls.model_aliases[model]
|
||||
else:
|
||||
return cls.default_model
|
||||
|
||||
@staticmethod
|
||||
def generate_mid() -> str:
|
||||
"""
|
||||
@ -70,7 +59,8 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
websearch: bool = False,
|
||||
web_search: bool = False,
|
||||
temperature: float = 0.6,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
"""
|
||||
@ -80,20 +70,18 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
- model (str): The model to use in the request.
|
||||
- messages (Messages): The messages to send as a prompt.
|
||||
- proxy (str, optional): Proxy URL, if needed.
|
||||
- websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
|
||||
- web_search (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
|
||||
"""
|
||||
model = cls.get_model(model)
|
||||
prompt = format_prompt(messages)
|
||||
q_value = prompt
|
||||
mid_value = cls.generate_mid()
|
||||
referer = cls.create_referer(q=q_value, mid=mid_value, model=model)
|
||||
referer = cls.create_referer(q=messages[-1]["content"], mid=mid_value, model=model)
|
||||
|
||||
url = cls.api_endpoint
|
||||
params = {
|
||||
'q': q_value,
|
||||
'model': model,
|
||||
'id': '',
|
||||
'mid': mid_value
|
||||
data = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"search": web_search,
|
||||
"stream": True,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
headers = {
|
||||
@ -111,52 +99,34 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
'sec-ch-ua-mobile': '?0',
|
||||
'sec-ch-ua-platform': '"Linux"'
|
||||
}
|
||||
async with ClientSession() as session:
|
||||
async with session.post(cls.api_endpoint, headers=headers, json=data, proxy=proxy) as response:
|
||||
await raise_for_status(response)
|
||||
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=None)
|
||||
async with ClientSession(timeout=timeout) as session:
|
||||
async with session.get(url, headers=headers, params=params, proxy=proxy) as response:
|
||||
if response.status != 200:
|
||||
yield f"Request ended with status code {response.status}"
|
||||
return
|
||||
sources = []
|
||||
async for line in response.content:
|
||||
decoded_line = line.decode('utf-8').strip()
|
||||
if not decoded_line.startswith('data: '):
|
||||
continue
|
||||
data = decoded_line[6:]
|
||||
if data in ('[DONE]', '{"done": ""}'):
|
||||
break
|
||||
try:
|
||||
json_data = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
assistant_text = ''
|
||||
sources = []
|
||||
if 'url' in json_data and 'title' in json_data:
|
||||
if web_search:
|
||||
sources.append({'title': json_data['title'], 'url': json_data['url']})
|
||||
|
||||
async for line in response.content:
|
||||
decoded_line = line.decode('utf-8').strip()
|
||||
if not decoded_line.startswith('data: '):
|
||||
continue
|
||||
data = decoded_line[6:]
|
||||
if data in ('[DONE]', '{"done": ""}'):
|
||||
break
|
||||
try:
|
||||
json_data = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
elif 'choices' in json_data:
|
||||
for choice in json_data['choices']:
|
||||
delta = choice.get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
if content:
|
||||
yield content
|
||||
|
||||
if 'url' in json_data and 'title' in json_data:
|
||||
if websearch:
|
||||
sources.append({'title': json_data['title'], 'url': json_data['url']})
|
||||
|
||||
elif 'choices' in json_data:
|
||||
for choice in json_data['choices']:
|
||||
delta = choice.get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
role = delta.get('role', '')
|
||||
if role == 'assistant':
|
||||
continue
|
||||
assistant_text += content
|
||||
|
||||
if websearch and sources:
|
||||
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
|
||||
assistant_text += f"\n\n**Source:**\n{sources_text}"
|
||||
|
||||
yield assistant_text
|
||||
|
||||
except asyncio.CancelledError:
|
||||
yield "The request was cancelled."
|
||||
except aiohttp.ClientError as e:
|
||||
yield f"An error occurred during the request: {e}"
|
||||
except Exception as e:
|
||||
yield f"An unexpected error occurred: {e}"
|
||||
if web_search and sources:
|
||||
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
|
||||
yield f"\n\n**Source:**\n{sources_text}"
|
65
g4f/Provider/needs_auth/Cerebras.py
Normal file
65
g4f/Provider/needs_auth/Cerebras.py
Normal file
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
from ...typing import AsyncResult, Messages, Cookies
|
||||
from ...requests.raise_for_status import raise_for_status
|
||||
from ...cookies import get_cookies
|
||||
|
||||
class Cerebras(OpenaiAPI):
|
||||
label = "Cerebras Inference"
|
||||
url = "https://inference.cerebras.ai/"
|
||||
working = True
|
||||
default_model = "llama3.1-70b"
|
||||
fallback_models = [
|
||||
"llama3.1-70b",
|
||||
"llama3.1-8b",
|
||||
]
|
||||
model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, api_key: str = None):
|
||||
if not cls.models:
|
||||
try:
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["authorization"] = f"Bearer ${api_key}"
|
||||
response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers)
|
||||
raise_for_status(response)
|
||||
data = response.json()
|
||||
cls.models = [model.get("model") for model in data.get("models")]
|
||||
except Exception:
|
||||
cls.models = cls.fallback_models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_base: str = "https://api.cerebras.ai/v1",
|
||||
api_key: str = None,
|
||||
cookies: Cookies = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if api_key is None and cookies is None:
|
||||
cookies = get_cookies(".cerebras.ai")
|
||||
async with ClientSession(cookies=cookies) as session:
|
||||
async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
|
||||
raise_for_status(response)
|
||||
data = await response.json()
|
||||
if data:
|
||||
api_key = data.get("user", {}).get("demoApiKey")
|
||||
async for chunk in super().create_async_generator(
|
||||
model, messages,
|
||||
api_base=api_base,
|
||||
impersonate="chrome",
|
||||
api_key=api_key,
|
||||
headers={
|
||||
"User-Agent": "ex/JS 1.5.0",
|
||||
},
|
||||
**kwargs
|
||||
):
|
||||
yield chunk
|
@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..base_provider import ProviderModelMixin
|
||||
from ..Copilot import Copilot
|
||||
|
||||
class CopilotAccount(Copilot):
|
||||
class CopilotAccount(Copilot, ProviderModelMixin):
|
||||
needs_auth = True
|
||||
parent = "Copilot"
|
||||
default_model = "Copilot"
|
||||
default_vision_model = default_model
|
||||
default_vision_model = default_model
|
||||
models = [default_model]
|
||||
image_models = models
|
28
g4f/Provider/needs_auth/HuggingFace2.py
Normal file
28
g4f/Provider/needs_auth/HuggingFace2.py
Normal file
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
from ..HuggingChat import HuggingChat
|
||||
from ...typing import AsyncResult, Messages
|
||||
|
||||
class HuggingFace2(OpenaiAPI):
|
||||
label = "HuggingFace (Inference API)"
|
||||
url = "https://huggingface.co"
|
||||
working = True
|
||||
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
default_vision_model = default_model
|
||||
models = [
|
||||
*HuggingChat.models
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_base: str = "https://api-inference.huggingface.co/v1",
|
||||
max_tokens: int = 500,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
return super().create_async_generator(
|
||||
model, messages, api_base=api_base, max_tokens=max_tokens, **kwargs
|
||||
)
|
@ -34,6 +34,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
stop: Union[str, list[str]] = None,
|
||||
stream: bool = False,
|
||||
headers: dict = None,
|
||||
impersonate: str = None,
|
||||
extra_data: dict = {},
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
@ -55,7 +56,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
async with StreamSession(
|
||||
proxies={"all": proxy},
|
||||
headers=cls.get_headers(stream, api_key, headers),
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
impersonate=impersonate,
|
||||
) as session:
|
||||
data = filter_none(
|
||||
messages=messages,
|
||||
|
@ -1,6 +1,7 @@
|
||||
from .gigachat import *
|
||||
|
||||
from .BingCreateImages import BingCreateImages
|
||||
from .Cerebras import Cerebras
|
||||
from .CopilotAccount import CopilotAccount
|
||||
from .DeepInfra import DeepInfra
|
||||
from .DeepInfraImage import DeepInfraImage
|
||||
@ -8,6 +9,7 @@ from .Gemini import Gemini
|
||||
from .GeminiPro import GeminiPro
|
||||
from .Groq import Groq
|
||||
from .HuggingFace import HuggingFace
|
||||
from .HuggingFace2 import HuggingFace2
|
||||
from .MetaAI import MetaAI
|
||||
from .MetaAIAccount import MetaAIAccount
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
|
@ -12,7 +12,7 @@ from typing import Union, AsyncIterator, Iterator, Coroutine
|
||||
|
||||
from ..providers.base_provider import AsyncGeneratorProvider
|
||||
from ..image import ImageResponse, to_image, to_data_uri, is_accepted_format, EXTENSIONS_MAP
|
||||
from ..typing import Messages, Cookies, Image
|
||||
from ..typing import Messages, Image
|
||||
from ..providers.types import ProviderType, FinishReason, BaseConversation
|
||||
from ..errors import NoImageResponseError
|
||||
from ..providers.retry_provider import IterListProvider
|
||||
@ -254,6 +254,8 @@ class Images:
|
||||
provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
|
||||
elif isinstance(provider, str):
|
||||
provider_handler = convert_to_provider(provider)
|
||||
else:
|
||||
provider_handler = provider
|
||||
if provider_handler is None:
|
||||
raise ValueError(f"Unknown model: {model}")
|
||||
if proxy is None:
|
||||
|
@ -128,6 +128,10 @@
|
||||
<label for="BingCreateImages-api_key" class="label" title="">Microsoft Designer in Bing:</label>
|
||||
<textarea id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder=""_U" cookie"></textarea>
|
||||
</div>
|
||||
<div class="field box">
|
||||
<label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
|
||||
<textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
|
||||
</div>
|
||||
<div class="field box">
|
||||
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
|
||||
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
|
||||
@ -142,7 +146,7 @@
|
||||
</div>
|
||||
<div class="field box">
|
||||
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
|
||||
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" placeholder="api_key"></textarea>
|
||||
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" class="HuggingFace2-api_key" placeholder="api_key"></textarea>
|
||||
</div>
|
||||
<div class="field box">
|
||||
<label for="Openai-api_key" class="label" title="">OpenAI API:</label>
|
||||
@ -192,7 +196,7 @@
|
||||
<div class="stop_generating stop_generating-hidden">
|
||||
<button id="cancelButton">
|
||||
<span>Stop Generating</span>
|
||||
<i class="fa-regular fa-stop"></i>
|
||||
<i class="fa-solid fa-stop"></i>
|
||||
</button>
|
||||
</div>
|
||||
<div class="regenerate">
|
||||
|
@ -512,9 +512,7 @@ body {
|
||||
|
||||
@media only screen and (min-width: 40em) {
|
||||
.stop_generating {
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
right: auto;
|
||||
right: 4px;
|
||||
}
|
||||
.toolbar .regenerate span {
|
||||
display: block;
|
||||
|
@ -215,7 +215,6 @@ const register_message_buttons = async () => {
|
||||
const message_el = el.parentElement.parentElement.parentElement;
|
||||
el.classList.add("clicked");
|
||||
setTimeout(() => el.classList.remove("clicked"), 1000);
|
||||
await hide_message(window.conversation_id, message_el.dataset.index);
|
||||
await ask_gpt(message_el.dataset.index, get_message_id());
|
||||
})
|
||||
}
|
||||
@ -317,6 +316,7 @@ async function remove_cancel_button() {
|
||||
|
||||
regenerate.addEventListener("click", async () => {
|
||||
regenerate.classList.add("regenerate-hidden");
|
||||
setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000);
|
||||
stop_generating.classList.remove("stop_generating-hidden");
|
||||
await hide_message(window.conversation_id);
|
||||
await ask_gpt(-1, get_message_id());
|
||||
@ -383,12 +383,12 @@ const prepare_messages = (messages, message_index = -1) => {
|
||||
return new_messages;
|
||||
}
|
||||
|
||||
async function add_message_chunk(message, message_index) {
|
||||
content_map = content_storage[message_index];
|
||||
async function add_message_chunk(message, message_id) {
|
||||
content_map = content_storage[message_id];
|
||||
if (message.type == "conversation") {
|
||||
console.info("Conversation used:", message.conversation)
|
||||
} else if (message.type == "provider") {
|
||||
provider_storage[message_index] = message.provider;
|
||||
provider_storage[message_id] = message.provider;
|
||||
content_map.content.querySelector('.provider').innerHTML = `
|
||||
<a href="${message.provider.url}" target="_blank">
|
||||
${message.provider.label ? message.provider.label : message.provider.name}
|
||||
@ -398,7 +398,7 @@ async function add_message_chunk(message, message_index) {
|
||||
} else if (message.type == "message") {
|
||||
console.error(message.message)
|
||||
} else if (message.type == "error") {
|
||||
error_storage[message_index] = message.error
|
||||
error_storage[message_id] = message.error
|
||||
console.error(message.error);
|
||||
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${message.error}</p>`;
|
||||
let p = document.createElement("p");
|
||||
@ -407,8 +407,8 @@ async function add_message_chunk(message, message_index) {
|
||||
} else if (message.type == "preview") {
|
||||
content_map.inner.innerHTML = markdown_render(message.preview);
|
||||
} else if (message.type == "content") {
|
||||
message_storage[message_index] += message.content;
|
||||
html = markdown_render(message_storage[message_index]);
|
||||
message_storage[message_id] += message.content;
|
||||
html = markdown_render(message_storage[message_id]);
|
||||
let lastElement, lastIndex = null;
|
||||
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
|
||||
const index = html.lastIndexOf(element)
|
||||
@ -421,7 +421,7 @@ async function add_message_chunk(message, message_index) {
|
||||
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
|
||||
}
|
||||
content_map.inner.innerHTML = html;
|
||||
content_map.count.innerText = count_words_and_tokens(message_storage[message_index], provider_storage[message_index]?.model);
|
||||
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
|
||||
highlight(content_map.inner);
|
||||
} else if (message.type == "log") {
|
||||
let p = document.createElement("p");
|
||||
@ -453,7 +453,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
||||
let total_messages = messages.length;
|
||||
messages = prepare_messages(messages, message_index);
|
||||
message_index = total_messages
|
||||
message_storage[message_index] = "";
|
||||
message_storage[message_id] = "";
|
||||
stop_generating.classList.remove(".stop_generating-hidden");
|
||||
|
||||
message_box.scrollTop = message_box.scrollHeight;
|
||||
@ -477,10 +477,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
||||
</div>
|
||||
`;
|
||||
|
||||
controller_storage[message_index] = new AbortController();
|
||||
controller_storage[message_id] = new AbortController();
|
||||
|
||||
let content_el = document.getElementById(`gpt_${message_id}`)
|
||||
let content_map = content_storage[message_index] = {
|
||||
let content_map = content_storage[message_id] = {
|
||||
content: content_el,
|
||||
inner: content_el.querySelector('.content_inner'),
|
||||
count: content_el.querySelector('.count'),
|
||||
@ -492,12 +492,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
||||
const file = input && input.files.length > 0 ? input.files[0] : null;
|
||||
const provider = providerSelect.options[providerSelect.selectedIndex].value;
|
||||
const auto_continue = document.getElementById("auto_continue")?.checked;
|
||||
let api_key = null;
|
||||
if (provider) {
|
||||
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
|
||||
if (api_key == null)
|
||||
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
|
||||
}
|
||||
let api_key = get_api_key_by_provider(provider);
|
||||
await api("conversation", {
|
||||
id: message_id,
|
||||
conversation_id: window.conversation_id,
|
||||
@ -506,10 +501,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
||||
provider: provider,
|
||||
messages: messages,
|
||||
auto_continue: auto_continue,
|
||||
api_key: api_key
|
||||
}, file, message_index);
|
||||
if (!error_storage[message_index]) {
|
||||
html = markdown_render(message_storage[message_index]);
|
||||
api_key: api_key,
|
||||
}, file, message_id);
|
||||
if (!error_storage[message_id]) {
|
||||
html = markdown_render(message_storage[message_id]);
|
||||
content_map.inner.innerHTML = html;
|
||||
highlight(content_map.inner);
|
||||
|
||||
@ -520,14 +515,14 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
if (e.name != "AbortError") {
|
||||
error_storage[message_index] = true;
|
||||
error_storage[message_id] = true;
|
||||
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${e}</p>`;
|
||||
}
|
||||
}
|
||||
delete controller_storage[message_index];
|
||||
if (!error_storage[message_index] && message_storage[message_index]) {
|
||||
const message_provider = message_index in provider_storage ? provider_storage[message_index] : null;
|
||||
await add_message(window.conversation_id, "assistant", message_storage[message_index], message_provider);
|
||||
delete controller_storage[message_id];
|
||||
if (!error_storage[message_id] && message_storage[message_id]) {
|
||||
const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
|
||||
await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider);
|
||||
await safe_load_conversation(window.conversation_id);
|
||||
} else {
|
||||
let cursorDiv = message_box.querySelector(".cursor");
|
||||
@ -1156,7 +1151,7 @@ async function on_api() {
|
||||
evt.preventDefault();
|
||||
console.log("pressed enter");
|
||||
prompt_lock = true;
|
||||
setTimeout(()=>prompt_lock=false, 3);
|
||||
setTimeout(()=>prompt_lock=false, 3000);
|
||||
await handle_ask();
|
||||
} else {
|
||||
messageInput.style.removeProperty("height");
|
||||
@ -1167,7 +1162,7 @@ async function on_api() {
|
||||
console.log("clicked send");
|
||||
if (prompt_lock) return;
|
||||
prompt_lock = true;
|
||||
setTimeout(()=>prompt_lock=false, 3);
|
||||
setTimeout(()=>prompt_lock=false, 3000);
|
||||
await handle_ask();
|
||||
});
|
||||
messageInput.focus();
|
||||
@ -1189,8 +1184,8 @@ async function on_api() {
|
||||
providerSelect.appendChild(option);
|
||||
})
|
||||
|
||||
await load_provider_models(appStorage.getItem("provider"));
|
||||
await load_settings_storage()
|
||||
await load_provider_models(appStorage.getItem("provider"));
|
||||
|
||||
const hide_systemPrompt = document.getElementById("hide-systemPrompt")
|
||||
const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i");
|
||||
@ -1316,7 +1311,7 @@ function get_selected_model() {
|
||||
}
|
||||
}
|
||||
|
||||
async function api(ressource, args=null, file=null, message_index=null) {
|
||||
async function api(ressource, args=null, file=null, message_id=null) {
|
||||
if (window?.pywebview) {
|
||||
if (args !== null) {
|
||||
if (ressource == "models") {
|
||||
@ -1326,15 +1321,19 @@ async function api(ressource, args=null, file=null, message_index=null) {
|
||||
}
|
||||
return pywebview.api[`get_${ressource}`]();
|
||||
}
|
||||
let api_key;
|
||||
if (ressource == "models" && args) {
|
||||
api_key = get_api_key_by_provider(args);
|
||||
ressource = `${ressource}/${args}`;
|
||||
}
|
||||
const url = `/backend-api/v2/${ressource}`;
|
||||
const headers = {};
|
||||
if (api_key) {
|
||||
headers.authorization = `Bearer ${api_key}`;
|
||||
}
|
||||
if (ressource == "conversation") {
|
||||
let body = JSON.stringify(args);
|
||||
const headers = {
|
||||
accept: 'text/event-stream'
|
||||
}
|
||||
headers.accept = 'text/event-stream';
|
||||
if (file !== null) {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
@ -1345,17 +1344,17 @@ async function api(ressource, args=null, file=null, message_index=null) {
|
||||
}
|
||||
response = await fetch(url, {
|
||||
method: 'POST',
|
||||
signal: controller_storage[message_index].signal,
|
||||
signal: controller_storage[message_id].signal,
|
||||
headers: headers,
|
||||
body: body
|
||||
body: body,
|
||||
});
|
||||
return read_response(response, message_index);
|
||||
return read_response(response, message_id);
|
||||
}
|
||||
response = await fetch(url);
|
||||
response = await fetch(url, {headers: headers});
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
async function read_response(response, message_index) {
|
||||
async function read_response(response, message_id) {
|
||||
const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
|
||||
let buffer = ""
|
||||
while (true) {
|
||||
@ -1368,7 +1367,7 @@ async function read_response(response, message_index) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
add_message_chunk(JSON.parse(buffer + line), message_index);
|
||||
add_message_chunk(JSON.parse(buffer + line), message_id);
|
||||
buffer = "";
|
||||
} catch {
|
||||
buffer += line
|
||||
@ -1377,6 +1376,16 @@ async function read_response(response, message_index) {
|
||||
}
|
||||
}
|
||||
|
||||
function get_api_key_by_provider(provider) {
|
||||
let api_key = null;
|
||||
if (provider) {
|
||||
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
|
||||
if (api_key == null)
|
||||
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
|
||||
}
|
||||
return api_key;
|
||||
}
|
||||
|
||||
async function load_provider_models(providerIndex=null) {
|
||||
if (!providerIndex) {
|
||||
providerIndex = providerSelect.selectedIndex;
|
||||
|
@ -38,10 +38,11 @@ class Api:
|
||||
return models._all_models
|
||||
|
||||
@staticmethod
|
||||
def get_provider_models(provider: str) -> list[dict]:
|
||||
def get_provider_models(provider: str, api_key: str = None) -> list[dict]:
|
||||
if provider in __map__:
|
||||
provider: ProviderType = __map__[provider]
|
||||
if issubclass(provider, ProviderModelMixin):
|
||||
models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key)
|
||||
return [
|
||||
{
|
||||
"model": model,
|
||||
@ -49,7 +50,7 @@ class Api:
|
||||
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
|
||||
"image": model in getattr(provider, "image_models", []),
|
||||
}
|
||||
for model in provider.get_models()
|
||||
for model in models
|
||||
]
|
||||
return []
|
||||
|
||||
|
@ -94,7 +94,8 @@ class Backend_Api(Api):
|
||||
)
|
||||
|
||||
def get_provider_models(self, provider: str):
|
||||
models = super().get_provider_models(provider)
|
||||
api_key = None if request.authorization is None else request.authorization.token
|
||||
models = super().get_provider_models(provider, api_key)
|
||||
if models is None:
|
||||
return 404, "Provider not found"
|
||||
return models
|
||||
|
@ -11,7 +11,9 @@ class CloudflareError(ResponseStatusError):
|
||||
...
|
||||
|
||||
def is_cloudflare(text: str) -> bool:
|
||||
if "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
|
||||
if "Generated by cloudfront" in text:
|
||||
return True
|
||||
elif "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
|
||||
return True
|
||||
return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user