mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-12-25 20:22:47 +03:00
Add Cerebras and HuggingFace2 provider, Fix RubiksAI provider
Add support for image generation in Copilot provider
This commit is contained in:
parent
275404373f
commit
58fa409eef
@ -28,6 +28,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
image_models = [default_image_model, 'repomap']
|
image_models = [default_image_model, 'repomap']
|
||||||
text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
|
text_models = [default_model, 'gpt-4o', 'gemini-pro', 'claude-sonnet-3.5', 'blackboxai-pro']
|
||||||
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'blackboxai-pro']
|
vision_models = [default_model, 'gpt-4o', 'gemini-pro', 'blackboxai-pro']
|
||||||
|
model_aliases = {
|
||||||
|
"claude-3.5-sonnet": "claude-sonnet-3.5",
|
||||||
|
}
|
||||||
agentMode = {
|
agentMode = {
|
||||||
default_image_model: {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
|
default_image_model: {'mode': True, 'id': "ImageGenerationLV45LJp", 'name': "Image Generation"},
|
||||||
}
|
}
|
||||||
@ -198,6 +201,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
async with ClientSession(headers=headers) as session:
|
async with ClientSession(headers=headers) as session:
|
||||||
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
|
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
is_first = False
|
||||||
async for chunk in response.content.iter_any():
|
async for chunk in response.content.iter_any():
|
||||||
text_chunk = chunk.decode(errors="ignore")
|
text_chunk = chunk.decode(errors="ignore")
|
||||||
if model in cls.image_models:
|
if model in cls.image_models:
|
||||||
@ -217,5 +221,9 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
for i, result in enumerate(search_results, 1):
|
for i, result in enumerate(search_results, 1):
|
||||||
formatted_response += f"\n{i}. {result['title']}: {result['link']}"
|
formatted_response += f"\n{i}. {result['title']}: {result['link']}"
|
||||||
yield formatted_response
|
yield formatted_response
|
||||||
else:
|
elif text_chunk:
|
||||||
yield text_chunk.strip()
|
if is_first:
|
||||||
|
is_first = False
|
||||||
|
yield text_chunk.lstrip()
|
||||||
|
else:
|
||||||
|
yield text_chunk
|
@ -21,8 +21,9 @@ from .helper import format_prompt
|
|||||||
from ..typing import CreateResult, Messages, ImageType
|
from ..typing import CreateResult, Messages, ImageType
|
||||||
from ..errors import MissingRequirementsError
|
from ..errors import MissingRequirementsError
|
||||||
from ..requests.raise_for_status import raise_for_status
|
from ..requests.raise_for_status import raise_for_status
|
||||||
|
from ..providers.helper import format_cookies
|
||||||
from ..requests import get_nodriver
|
from ..requests import get_nodriver
|
||||||
from ..image import to_bytes, is_accepted_format
|
from ..image import ImageResponse, to_bytes, is_accepted_format
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|
||||||
class Conversation(BaseConversation):
|
class Conversation(BaseConversation):
|
||||||
@ -70,18 +71,21 @@ class Copilot(AbstractProvider):
|
|||||||
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy))
|
access_token, cookies = asyncio.run(cls.get_access_token_and_cookies(proxy))
|
||||||
else:
|
else:
|
||||||
access_token = conversation.access_token
|
access_token = conversation.access_token
|
||||||
websocket_url = f"{websocket_url}&acessToken={quote(access_token)}"
|
debug.log(f"Copilot: Access token: {access_token[:7]}...{access_token[-5:]}")
|
||||||
headers = {"Authorization": f"Bearer {access_token}"}
|
debug.log(f"Copilot: Cookies: {';'.join([*cookies])}")
|
||||||
|
websocket_url = f"{websocket_url}&accessToken={quote(access_token)}"
|
||||||
|
headers = {"authorization": f"Bearer {access_token}", "cookie": format_cookies(cookies)}
|
||||||
|
|
||||||
with Session(
|
with Session(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
proxy=proxy,
|
proxy=proxy,
|
||||||
impersonate="chrome",
|
impersonate="chrome",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
cookies=cookies
|
cookies=cookies,
|
||||||
) as session:
|
) as session:
|
||||||
response = session.get(f"{cls.url}/")
|
response = session.get("https://copilot.microsoft.com/c/api/user")
|
||||||
raise_for_status(response)
|
raise_for_status(response)
|
||||||
|
debug.log(f"Copilot: User: {response.json().get('firstName', 'null')}")
|
||||||
if conversation is None:
|
if conversation is None:
|
||||||
response = session.post(cls.conversation_url)
|
response = session.post(cls.conversation_url)
|
||||||
raise_for_status(response)
|
raise_for_status(response)
|
||||||
@ -119,6 +123,7 @@ class Copilot(AbstractProvider):
|
|||||||
|
|
||||||
is_started = False
|
is_started = False
|
||||||
msg = None
|
msg = None
|
||||||
|
image_prompt: str = None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = wss.recv()[0]
|
msg = wss.recv()[0]
|
||||||
@ -128,7 +133,11 @@ class Copilot(AbstractProvider):
|
|||||||
if msg.get("event") == "appendText":
|
if msg.get("event") == "appendText":
|
||||||
is_started = True
|
is_started = True
|
||||||
yield msg.get("text")
|
yield msg.get("text")
|
||||||
elif msg.get("event") in ["done", "partCompleted"]:
|
elif msg.get("event") == "generatingImage":
|
||||||
|
image_prompt = msg.get("prompt")
|
||||||
|
elif msg.get("event") == "imageGenerated":
|
||||||
|
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
|
||||||
|
elif msg.get("event") == "done":
|
||||||
break
|
break
|
||||||
if not is_started:
|
if not is_started:
|
||||||
raise RuntimeError(f"Last message: {msg}")
|
raise RuntimeError(f"Last message: {msg}")
|
||||||
@ -152,7 +161,7 @@ class Copilot(AbstractProvider):
|
|||||||
})()
|
})()
|
||||||
""")
|
""")
|
||||||
if access_token is None:
|
if access_token is None:
|
||||||
asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
cookies = {}
|
cookies = {}
|
||||||
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
|
for c in await page.send(nodriver.cdp.network.get_cookies([cls.url])):
|
||||||
cookies[c.name] = c.value
|
cookies[c.name] = c.value
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import aiohttp
|
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import json
|
import json
|
||||||
@ -11,34 +10,24 @@ from aiohttp import ClientSession
|
|||||||
|
|
||||||
from ..typing import AsyncResult, Messages
|
from ..typing import AsyncResult, Messages
|
||||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||||
from .helper import format_prompt
|
from ..requests.raise_for_status import raise_for_status
|
||||||
|
|
||||||
|
|
||||||
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||||
label = "Rubiks AI"
|
label = "Rubiks AI"
|
||||||
url = "https://rubiks.ai"
|
url = "https://rubiks.ai"
|
||||||
api_endpoint = "https://rubiks.ai/search/api.php"
|
api_endpoint = "https://rubiks.ai/search/api/"
|
||||||
working = True
|
working = True
|
||||||
supports_stream = True
|
supports_stream = True
|
||||||
supports_system_message = True
|
supports_system_message = True
|
||||||
supports_message_history = True
|
supports_message_history = True
|
||||||
|
|
||||||
default_model = 'llama-3.1-70b-versatile'
|
default_model = 'gpt-4o-mini'
|
||||||
models = [default_model, 'gpt-4o-mini']
|
models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
|
||||||
|
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
"llama-3.1-70b": "llama-3.1-70b-versatile",
|
"llama-3.1-70b": "llama-3.1-70b-versatile",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_model(cls, model: str) -> str:
|
|
||||||
if model in cls.models:
|
|
||||||
return model
|
|
||||||
elif model in cls.model_aliases:
|
|
||||||
return cls.model_aliases[model]
|
|
||||||
else:
|
|
||||||
return cls.default_model
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_mid() -> str:
|
def generate_mid() -> str:
|
||||||
"""
|
"""
|
||||||
@ -70,7 +59,8 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
model: str,
|
model: str,
|
||||||
messages: Messages,
|
messages: Messages,
|
||||||
proxy: str = None,
|
proxy: str = None,
|
||||||
websearch: bool = False,
|
web_search: bool = False,
|
||||||
|
temperature: float = 0.6,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
"""
|
"""
|
||||||
@ -80,20 +70,18 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
- model (str): The model to use in the request.
|
- model (str): The model to use in the request.
|
||||||
- messages (Messages): The messages to send as a prompt.
|
- messages (Messages): The messages to send as a prompt.
|
||||||
- proxy (str, optional): Proxy URL, if needed.
|
- proxy (str, optional): Proxy URL, if needed.
|
||||||
- websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
|
- web_search (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
|
||||||
"""
|
"""
|
||||||
model = cls.get_model(model)
|
model = cls.get_model(model)
|
||||||
prompt = format_prompt(messages)
|
|
||||||
q_value = prompt
|
|
||||||
mid_value = cls.generate_mid()
|
mid_value = cls.generate_mid()
|
||||||
referer = cls.create_referer(q=q_value, mid=mid_value, model=model)
|
referer = cls.create_referer(q=messages[-1]["content"], mid=mid_value, model=model)
|
||||||
|
|
||||||
url = cls.api_endpoint
|
data = {
|
||||||
params = {
|
"messages": messages,
|
||||||
'q': q_value,
|
"model": model,
|
||||||
'model': model,
|
"search": web_search,
|
||||||
'id': '',
|
"stream": True,
|
||||||
'mid': mid_value
|
"temperature": temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
@ -111,52 +99,34 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
'sec-ch-ua-mobile': '?0',
|
'sec-ch-ua-mobile': '?0',
|
||||||
'sec-ch-ua-platform': '"Linux"'
|
'sec-ch-ua-platform': '"Linux"'
|
||||||
}
|
}
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.post(cls.api_endpoint, headers=headers, json=data, proxy=proxy) as response:
|
||||||
|
await raise_for_status(response)
|
||||||
|
|
||||||
try:
|
sources = []
|
||||||
timeout = aiohttp.ClientTimeout(total=None)
|
async for line in response.content:
|
||||||
async with ClientSession(timeout=timeout) as session:
|
decoded_line = line.decode('utf-8').strip()
|
||||||
async with session.get(url, headers=headers, params=params, proxy=proxy) as response:
|
if not decoded_line.startswith('data: '):
|
||||||
if response.status != 200:
|
continue
|
||||||
yield f"Request ended with status code {response.status}"
|
data = decoded_line[6:]
|
||||||
return
|
if data in ('[DONE]', '{"done": ""}'):
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
json_data = json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
assistant_text = ''
|
if 'url' in json_data and 'title' in json_data:
|
||||||
sources = []
|
if web_search:
|
||||||
|
sources.append({'title': json_data['title'], 'url': json_data['url']})
|
||||||
|
|
||||||
async for line in response.content:
|
elif 'choices' in json_data:
|
||||||
decoded_line = line.decode('utf-8').strip()
|
for choice in json_data['choices']:
|
||||||
if not decoded_line.startswith('data: '):
|
delta = choice.get('delta', {})
|
||||||
continue
|
content = delta.get('content', '')
|
||||||
data = decoded_line[6:]
|
if content:
|
||||||
if data in ('[DONE]', '{"done": ""}'):
|
yield content
|
||||||
break
|
|
||||||
try:
|
|
||||||
json_data = json.loads(data)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if 'url' in json_data and 'title' in json_data:
|
if web_search and sources:
|
||||||
if websearch:
|
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
|
||||||
sources.append({'title': json_data['title'], 'url': json_data['url']})
|
yield f"\n\n**Source:**\n{sources_text}"
|
||||||
|
|
||||||
elif 'choices' in json_data:
|
|
||||||
for choice in json_data['choices']:
|
|
||||||
delta = choice.get('delta', {})
|
|
||||||
content = delta.get('content', '')
|
|
||||||
role = delta.get('role', '')
|
|
||||||
if role == 'assistant':
|
|
||||||
continue
|
|
||||||
assistant_text += content
|
|
||||||
|
|
||||||
if websearch and sources:
|
|
||||||
sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
|
|
||||||
assistant_text += f"\n\n**Source:**\n{sources_text}"
|
|
||||||
|
|
||||||
yield assistant_text
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
yield "The request was cancelled."
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
yield f"An error occurred during the request: {e}"
|
|
||||||
except Exception as e:
|
|
||||||
yield f"An unexpected error occurred: {e}"
|
|
65
g4f/Provider/needs_auth/Cerebras.py
Normal file
65
g4f/Provider/needs_auth/Cerebras.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
from .OpenaiAPI import OpenaiAPI
|
||||||
|
from ...typing import AsyncResult, Messages, Cookies
|
||||||
|
from ...requests.raise_for_status import raise_for_status
|
||||||
|
from ...cookies import get_cookies
|
||||||
|
|
||||||
|
class Cerebras(OpenaiAPI):
|
||||||
|
label = "Cerebras Inference"
|
||||||
|
url = "https://inference.cerebras.ai/"
|
||||||
|
working = True
|
||||||
|
default_model = "llama3.1-70b"
|
||||||
|
fallback_models = [
|
||||||
|
"llama3.1-70b",
|
||||||
|
"llama3.1-8b",
|
||||||
|
]
|
||||||
|
model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_models(cls, api_key: str = None):
|
||||||
|
if not cls.models:
|
||||||
|
try:
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers["authorization"] = f"Bearer ${api_key}"
|
||||||
|
response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers)
|
||||||
|
raise_for_status(response)
|
||||||
|
data = response.json()
|
||||||
|
cls.models = [model.get("model") for model in data.get("models")]
|
||||||
|
except Exception:
|
||||||
|
cls.models = cls.fallback_models
|
||||||
|
return cls.models
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_async_generator(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
api_base: str = "https://api.cerebras.ai/v1",
|
||||||
|
api_key: str = None,
|
||||||
|
cookies: Cookies = None,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
if api_key is None and cookies is None:
|
||||||
|
cookies = get_cookies(".cerebras.ai")
|
||||||
|
async with ClientSession(cookies=cookies) as session:
|
||||||
|
async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
|
||||||
|
raise_for_status(response)
|
||||||
|
data = await response.json()
|
||||||
|
if data:
|
||||||
|
api_key = data.get("user", {}).get("demoApiKey")
|
||||||
|
async for chunk in super().create_async_generator(
|
||||||
|
model, messages,
|
||||||
|
api_base=api_base,
|
||||||
|
impersonate="chrome",
|
||||||
|
api_key=api_key,
|
||||||
|
headers={
|
||||||
|
"User-Agent": "ex/JS 1.5.0",
|
||||||
|
},
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
yield chunk
|
@ -1,9 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ..base_provider import ProviderModelMixin
|
||||||
from ..Copilot import Copilot
|
from ..Copilot import Copilot
|
||||||
|
|
||||||
class CopilotAccount(Copilot):
|
class CopilotAccount(Copilot, ProviderModelMixin):
|
||||||
needs_auth = True
|
needs_auth = True
|
||||||
parent = "Copilot"
|
parent = "Copilot"
|
||||||
default_model = "Copilot"
|
default_model = "Copilot"
|
||||||
default_vision_model = default_model
|
default_vision_model = default_model
|
||||||
|
models = [default_model]
|
||||||
|
image_models = models
|
28
g4f/Provider/needs_auth/HuggingFace2.py
Normal file
28
g4f/Provider/needs_auth/HuggingFace2.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .OpenaiAPI import OpenaiAPI
|
||||||
|
from ..HuggingChat import HuggingChat
|
||||||
|
from ...typing import AsyncResult, Messages
|
||||||
|
|
||||||
|
class HuggingFace2(OpenaiAPI):
|
||||||
|
label = "HuggingFace (Inference API)"
|
||||||
|
url = "https://huggingface.co"
|
||||||
|
working = True
|
||||||
|
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||||
|
default_vision_model = default_model
|
||||||
|
models = [
|
||||||
|
*HuggingChat.models
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_async_generator(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
messages: Messages,
|
||||||
|
api_base: str = "https://api-inference.huggingface.co/v1",
|
||||||
|
max_tokens: int = 500,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncResult:
|
||||||
|
return super().create_async_generator(
|
||||||
|
model, messages, api_base=api_base, max_tokens=max_tokens, **kwargs
|
||||||
|
)
|
@ -34,6 +34,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
stop: Union[str, list[str]] = None,
|
stop: Union[str, list[str]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
headers: dict = None,
|
headers: dict = None,
|
||||||
|
impersonate: str = None,
|
||||||
extra_data: dict = {},
|
extra_data: dict = {},
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
@ -55,7 +56,8 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
|||||||
async with StreamSession(
|
async with StreamSession(
|
||||||
proxies={"all": proxy},
|
proxies={"all": proxy},
|
||||||
headers=cls.get_headers(stream, api_key, headers),
|
headers=cls.get_headers(stream, api_key, headers),
|
||||||
timeout=timeout
|
timeout=timeout,
|
||||||
|
impersonate=impersonate,
|
||||||
) as session:
|
) as session:
|
||||||
data = filter_none(
|
data = filter_none(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from .gigachat import *
|
from .gigachat import *
|
||||||
|
|
||||||
from .BingCreateImages import BingCreateImages
|
from .BingCreateImages import BingCreateImages
|
||||||
|
from .Cerebras import Cerebras
|
||||||
from .CopilotAccount import CopilotAccount
|
from .CopilotAccount import CopilotAccount
|
||||||
from .DeepInfra import DeepInfra
|
from .DeepInfra import DeepInfra
|
||||||
from .DeepInfraImage import DeepInfraImage
|
from .DeepInfraImage import DeepInfraImage
|
||||||
@ -8,6 +9,7 @@ from .Gemini import Gemini
|
|||||||
from .GeminiPro import GeminiPro
|
from .GeminiPro import GeminiPro
|
||||||
from .Groq import Groq
|
from .Groq import Groq
|
||||||
from .HuggingFace import HuggingFace
|
from .HuggingFace import HuggingFace
|
||||||
|
from .HuggingFace2 import HuggingFace2
|
||||||
from .MetaAI import MetaAI
|
from .MetaAI import MetaAI
|
||||||
from .MetaAIAccount import MetaAIAccount
|
from .MetaAIAccount import MetaAIAccount
|
||||||
from .OpenaiAPI import OpenaiAPI
|
from .OpenaiAPI import OpenaiAPI
|
||||||
|
@ -128,6 +128,10 @@
|
|||||||
<label for="BingCreateImages-api_key" class="label" title="">Microsoft Designer in Bing:</label>
|
<label for="BingCreateImages-api_key" class="label" title="">Microsoft Designer in Bing:</label>
|
||||||
<textarea id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder=""_U" cookie"></textarea>
|
<textarea id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder=""_U" cookie"></textarea>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="field box">
|
||||||
|
<label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
|
||||||
|
<textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
|
||||||
|
</div>
|
||||||
<div class="field box">
|
<div class="field box">
|
||||||
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
|
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
|
||||||
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
|
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
|
||||||
@ -142,7 +146,7 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="field box">
|
<div class="field box">
|
||||||
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
|
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
|
||||||
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" placeholder="api_key"></textarea>
|
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" class="HuggingFace2-api_key" placeholder="api_key"></textarea>
|
||||||
</div>
|
</div>
|
||||||
<div class="field box">
|
<div class="field box">
|
||||||
<label for="Openai-api_key" class="label" title="">OpenAI API:</label>
|
<label for="Openai-api_key" class="label" title="">OpenAI API:</label>
|
||||||
@ -192,7 +196,7 @@
|
|||||||
<div class="stop_generating stop_generating-hidden">
|
<div class="stop_generating stop_generating-hidden">
|
||||||
<button id="cancelButton">
|
<button id="cancelButton">
|
||||||
<span>Stop Generating</span>
|
<span>Stop Generating</span>
|
||||||
<i class="fa-regular fa-stop"></i>
|
<i class="fa-solid fa-stop"></i>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<div class="regenerate">
|
<div class="regenerate">
|
||||||
|
@ -512,9 +512,7 @@ body {
|
|||||||
|
|
||||||
@media only screen and (min-width: 40em) {
|
@media only screen and (min-width: 40em) {
|
||||||
.stop_generating {
|
.stop_generating {
|
||||||
left: 50%;
|
right: 4px;
|
||||||
transform: translateX(-50%);
|
|
||||||
right: auto;
|
|
||||||
}
|
}
|
||||||
.toolbar .regenerate span {
|
.toolbar .regenerate span {
|
||||||
display: block;
|
display: block;
|
||||||
|
@ -215,7 +215,6 @@ const register_message_buttons = async () => {
|
|||||||
const message_el = el.parentElement.parentElement.parentElement;
|
const message_el = el.parentElement.parentElement.parentElement;
|
||||||
el.classList.add("clicked");
|
el.classList.add("clicked");
|
||||||
setTimeout(() => el.classList.remove("clicked"), 1000);
|
setTimeout(() => el.classList.remove("clicked"), 1000);
|
||||||
await hide_message(window.conversation_id, message_el.dataset.index);
|
|
||||||
await ask_gpt(message_el.dataset.index, get_message_id());
|
await ask_gpt(message_el.dataset.index, get_message_id());
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -317,6 +316,7 @@ async function remove_cancel_button() {
|
|||||||
|
|
||||||
regenerate.addEventListener("click", async () => {
|
regenerate.addEventListener("click", async () => {
|
||||||
regenerate.classList.add("regenerate-hidden");
|
regenerate.classList.add("regenerate-hidden");
|
||||||
|
setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000);
|
||||||
stop_generating.classList.remove("stop_generating-hidden");
|
stop_generating.classList.remove("stop_generating-hidden");
|
||||||
await hide_message(window.conversation_id);
|
await hide_message(window.conversation_id);
|
||||||
await ask_gpt(-1, get_message_id());
|
await ask_gpt(-1, get_message_id());
|
||||||
@ -383,12 +383,12 @@ const prepare_messages = (messages, message_index = -1) => {
|
|||||||
return new_messages;
|
return new_messages;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function add_message_chunk(message, message_index) {
|
async function add_message_chunk(message, message_id) {
|
||||||
content_map = content_storage[message_index];
|
content_map = content_storage[message_id];
|
||||||
if (message.type == "conversation") {
|
if (message.type == "conversation") {
|
||||||
console.info("Conversation used:", message.conversation)
|
console.info("Conversation used:", message.conversation)
|
||||||
} else if (message.type == "provider") {
|
} else if (message.type == "provider") {
|
||||||
provider_storage[message_index] = message.provider;
|
provider_storage[message_id] = message.provider;
|
||||||
content_map.content.querySelector('.provider').innerHTML = `
|
content_map.content.querySelector('.provider').innerHTML = `
|
||||||
<a href="${message.provider.url}" target="_blank">
|
<a href="${message.provider.url}" target="_blank">
|
||||||
${message.provider.label ? message.provider.label : message.provider.name}
|
${message.provider.label ? message.provider.label : message.provider.name}
|
||||||
@ -398,7 +398,7 @@ async function add_message_chunk(message, message_index) {
|
|||||||
} else if (message.type == "message") {
|
} else if (message.type == "message") {
|
||||||
console.error(message.message)
|
console.error(message.message)
|
||||||
} else if (message.type == "error") {
|
} else if (message.type == "error") {
|
||||||
error_storage[message_index] = message.error
|
error_storage[message_id] = message.error
|
||||||
console.error(message.error);
|
console.error(message.error);
|
||||||
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${message.error}</p>`;
|
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${message.error}</p>`;
|
||||||
let p = document.createElement("p");
|
let p = document.createElement("p");
|
||||||
@ -407,8 +407,8 @@ async function add_message_chunk(message, message_index) {
|
|||||||
} else if (message.type == "preview") {
|
} else if (message.type == "preview") {
|
||||||
content_map.inner.innerHTML = markdown_render(message.preview);
|
content_map.inner.innerHTML = markdown_render(message.preview);
|
||||||
} else if (message.type == "content") {
|
} else if (message.type == "content") {
|
||||||
message_storage[message_index] += message.content;
|
message_storage[message_id] += message.content;
|
||||||
html = markdown_render(message_storage[message_index]);
|
html = markdown_render(message_storage[message_id]);
|
||||||
let lastElement, lastIndex = null;
|
let lastElement, lastIndex = null;
|
||||||
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
|
for (element of ['</p>', '</code></pre>', '</p>\n</li>\n</ol>', '</li>\n</ol>', '</li>\n</ul>']) {
|
||||||
const index = html.lastIndexOf(element)
|
const index = html.lastIndexOf(element)
|
||||||
@ -421,7 +421,7 @@ async function add_message_chunk(message, message_index) {
|
|||||||
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
|
html = html.substring(0, lastIndex) + '<span class="cursor"></span>' + lastElement;
|
||||||
}
|
}
|
||||||
content_map.inner.innerHTML = html;
|
content_map.inner.innerHTML = html;
|
||||||
content_map.count.innerText = count_words_and_tokens(message_storage[message_index], provider_storage[message_index]?.model);
|
content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model);
|
||||||
highlight(content_map.inner);
|
highlight(content_map.inner);
|
||||||
} else if (message.type == "log") {
|
} else if (message.type == "log") {
|
||||||
let p = document.createElement("p");
|
let p = document.createElement("p");
|
||||||
@ -453,7 +453,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
|||||||
let total_messages = messages.length;
|
let total_messages = messages.length;
|
||||||
messages = prepare_messages(messages, message_index);
|
messages = prepare_messages(messages, message_index);
|
||||||
message_index = total_messages
|
message_index = total_messages
|
||||||
message_storage[message_index] = "";
|
message_storage[message_id] = "";
|
||||||
stop_generating.classList.remove(".stop_generating-hidden");
|
stop_generating.classList.remove(".stop_generating-hidden");
|
||||||
|
|
||||||
message_box.scrollTop = message_box.scrollHeight;
|
message_box.scrollTop = message_box.scrollHeight;
|
||||||
@ -477,10 +477,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
|||||||
</div>
|
</div>
|
||||||
`;
|
`;
|
||||||
|
|
||||||
controller_storage[message_index] = new AbortController();
|
controller_storage[message_id] = new AbortController();
|
||||||
|
|
||||||
let content_el = document.getElementById(`gpt_${message_id}`)
|
let content_el = document.getElementById(`gpt_${message_id}`)
|
||||||
let content_map = content_storage[message_index] = {
|
let content_map = content_storage[message_id] = {
|
||||||
content: content_el,
|
content: content_el,
|
||||||
inner: content_el.querySelector('.content_inner'),
|
inner: content_el.querySelector('.content_inner'),
|
||||||
count: content_el.querySelector('.count'),
|
count: content_el.querySelector('.count'),
|
||||||
@ -492,12 +492,7 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
|||||||
const file = input && input.files.length > 0 ? input.files[0] : null;
|
const file = input && input.files.length > 0 ? input.files[0] : null;
|
||||||
const provider = providerSelect.options[providerSelect.selectedIndex].value;
|
const provider = providerSelect.options[providerSelect.selectedIndex].value;
|
||||||
const auto_continue = document.getElementById("auto_continue")?.checked;
|
const auto_continue = document.getElementById("auto_continue")?.checked;
|
||||||
let api_key = null;
|
let api_key = get_api_key_by_provider(provider);
|
||||||
if (provider) {
|
|
||||||
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
|
|
||||||
if (api_key == null)
|
|
||||||
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
|
|
||||||
}
|
|
||||||
await api("conversation", {
|
await api("conversation", {
|
||||||
id: message_id,
|
id: message_id,
|
||||||
conversation_id: window.conversation_id,
|
conversation_id: window.conversation_id,
|
||||||
@ -506,10 +501,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
|||||||
provider: provider,
|
provider: provider,
|
||||||
messages: messages,
|
messages: messages,
|
||||||
auto_continue: auto_continue,
|
auto_continue: auto_continue,
|
||||||
api_key: api_key
|
api_key: api_key,
|
||||||
}, file, message_index);
|
}, file, message_id);
|
||||||
if (!error_storage[message_index]) {
|
if (!error_storage[message_id]) {
|
||||||
html = markdown_render(message_storage[message_index]);
|
html = markdown_render(message_storage[message_id]);
|
||||||
content_map.inner.innerHTML = html;
|
content_map.inner.innerHTML = html;
|
||||||
highlight(content_map.inner);
|
highlight(content_map.inner);
|
||||||
|
|
||||||
@ -520,14 +515,14 @@ const ask_gpt = async (message_index = -1, message_id) => {
|
|||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error(e);
|
console.error(e);
|
||||||
if (e.name != "AbortError") {
|
if (e.name != "AbortError") {
|
||||||
error_storage[message_index] = true;
|
error_storage[message_id] = true;
|
||||||
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${e}</p>`;
|
content_map.inner.innerHTML += `<p><strong>An error occured:</strong> ${e}</p>`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete controller_storage[message_index];
|
delete controller_storage[message_id];
|
||||||
if (!error_storage[message_index] && message_storage[message_index]) {
|
if (!error_storage[message_id] && message_storage[message_id]) {
|
||||||
const message_provider = message_index in provider_storage ? provider_storage[message_index] : null;
|
const message_provider = message_id in provider_storage ? provider_storage[message_id] : null;
|
||||||
await add_message(window.conversation_id, "assistant", message_storage[message_index], message_provider);
|
await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider);
|
||||||
await safe_load_conversation(window.conversation_id);
|
await safe_load_conversation(window.conversation_id);
|
||||||
} else {
|
} else {
|
||||||
let cursorDiv = message_box.querySelector(".cursor");
|
let cursorDiv = message_box.querySelector(".cursor");
|
||||||
@ -1156,7 +1151,7 @@ async function on_api() {
|
|||||||
evt.preventDefault();
|
evt.preventDefault();
|
||||||
console.log("pressed enter");
|
console.log("pressed enter");
|
||||||
prompt_lock = true;
|
prompt_lock = true;
|
||||||
setTimeout(()=>prompt_lock=false, 3);
|
setTimeout(()=>prompt_lock=false, 3000);
|
||||||
await handle_ask();
|
await handle_ask();
|
||||||
} else {
|
} else {
|
||||||
messageInput.style.removeProperty("height");
|
messageInput.style.removeProperty("height");
|
||||||
@ -1167,7 +1162,7 @@ async function on_api() {
|
|||||||
console.log("clicked send");
|
console.log("clicked send");
|
||||||
if (prompt_lock) return;
|
if (prompt_lock) return;
|
||||||
prompt_lock = true;
|
prompt_lock = true;
|
||||||
setTimeout(()=>prompt_lock=false, 3);
|
setTimeout(()=>prompt_lock=false, 3000);
|
||||||
await handle_ask();
|
await handle_ask();
|
||||||
});
|
});
|
||||||
messageInput.focus();
|
messageInput.focus();
|
||||||
@ -1189,8 +1184,8 @@ async function on_api() {
|
|||||||
providerSelect.appendChild(option);
|
providerSelect.appendChild(option);
|
||||||
})
|
})
|
||||||
|
|
||||||
await load_provider_models(appStorage.getItem("provider"));
|
|
||||||
await load_settings_storage()
|
await load_settings_storage()
|
||||||
|
await load_provider_models(appStorage.getItem("provider"));
|
||||||
|
|
||||||
const hide_systemPrompt = document.getElementById("hide-systemPrompt")
|
const hide_systemPrompt = document.getElementById("hide-systemPrompt")
|
||||||
const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i");
|
const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i");
|
||||||
@ -1316,7 +1311,7 @@ function get_selected_model() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function api(ressource, args=null, file=null, message_index=null) {
|
async function api(ressource, args=null, file=null, message_id=null) {
|
||||||
if (window?.pywebview) {
|
if (window?.pywebview) {
|
||||||
if (args !== null) {
|
if (args !== null) {
|
||||||
if (ressource == "models") {
|
if (ressource == "models") {
|
||||||
@ -1326,15 +1321,19 @@ async function api(ressource, args=null, file=null, message_index=null) {
|
|||||||
}
|
}
|
||||||
return pywebview.api[`get_${ressource}`]();
|
return pywebview.api[`get_${ressource}`]();
|
||||||
}
|
}
|
||||||
|
let api_key;
|
||||||
if (ressource == "models" && args) {
|
if (ressource == "models" && args) {
|
||||||
|
api_key = get_api_key_by_provider(args);
|
||||||
ressource = `${ressource}/${args}`;
|
ressource = `${ressource}/${args}`;
|
||||||
}
|
}
|
||||||
const url = `/backend-api/v2/${ressource}`;
|
const url = `/backend-api/v2/${ressource}`;
|
||||||
|
const headers = {};
|
||||||
|
if (api_key) {
|
||||||
|
headers.authorization = `Bearer ${api_key}`;
|
||||||
|
}
|
||||||
if (ressource == "conversation") {
|
if (ressource == "conversation") {
|
||||||
let body = JSON.stringify(args);
|
let body = JSON.stringify(args);
|
||||||
const headers = {
|
headers.accept = 'text/event-stream';
|
||||||
accept: 'text/event-stream'
|
|
||||||
}
|
|
||||||
if (file !== null) {
|
if (file !== null) {
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
formData.append('file', file);
|
formData.append('file', file);
|
||||||
@ -1345,17 +1344,17 @@ async function api(ressource, args=null, file=null, message_index=null) {
|
|||||||
}
|
}
|
||||||
response = await fetch(url, {
|
response = await fetch(url, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
signal: controller_storage[message_index].signal,
|
signal: controller_storage[message_id].signal,
|
||||||
headers: headers,
|
headers: headers,
|
||||||
body: body
|
body: body,
|
||||||
});
|
});
|
||||||
return read_response(response, message_index);
|
return read_response(response, message_id);
|
||||||
}
|
}
|
||||||
response = await fetch(url);
|
response = await fetch(url, {headers: headers});
|
||||||
return await response.json();
|
return await response.json();
|
||||||
}
|
}
|
||||||
|
|
||||||
async function read_response(response, message_index) {
|
async function read_response(response, message_id) {
|
||||||
const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
|
const reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
|
||||||
let buffer = ""
|
let buffer = ""
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -1368,7 +1367,7 @@ async function read_response(response, message_index) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
add_message_chunk(JSON.parse(buffer + line), message_index);
|
add_message_chunk(JSON.parse(buffer + line), message_id);
|
||||||
buffer = "";
|
buffer = "";
|
||||||
} catch {
|
} catch {
|
||||||
buffer += line
|
buffer += line
|
||||||
@ -1377,6 +1376,16 @@ async function read_response(response, message_index) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function get_api_key_by_provider(provider) {
|
||||||
|
let api_key = null;
|
||||||
|
if (provider) {
|
||||||
|
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
|
||||||
|
if (api_key == null)
|
||||||
|
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
|
||||||
|
}
|
||||||
|
return api_key;
|
||||||
|
}
|
||||||
|
|
||||||
async function load_provider_models(providerIndex=null) {
|
async function load_provider_models(providerIndex=null) {
|
||||||
if (!providerIndex) {
|
if (!providerIndex) {
|
||||||
providerIndex = providerSelect.selectedIndex;
|
providerIndex = providerSelect.selectedIndex;
|
||||||
|
@ -38,10 +38,11 @@ class Api:
|
|||||||
return models._all_models
|
return models._all_models
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_models(provider: str) -> list[dict]:
|
def get_provider_models(provider: str, api_key: str = None) -> list[dict]:
|
||||||
if provider in __map__:
|
if provider in __map__:
|
||||||
provider: ProviderType = __map__[provider]
|
provider: ProviderType = __map__[provider]
|
||||||
if issubclass(provider, ProviderModelMixin):
|
if issubclass(provider, ProviderModelMixin):
|
||||||
|
models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"model": model,
|
"model": model,
|
||||||
@ -49,7 +50,7 @@ class Api:
|
|||||||
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
|
"vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
|
||||||
"image": model in getattr(provider, "image_models", []),
|
"image": model in getattr(provider, "image_models", []),
|
||||||
}
|
}
|
||||||
for model in provider.get_models()
|
for model in models
|
||||||
]
|
]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -94,7 +94,8 @@ class Backend_Api(Api):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_provider_models(self, provider: str):
|
def get_provider_models(self, provider: str):
|
||||||
models = super().get_provider_models(provider)
|
api_key = None if request.authorization is None else request.authorization.token
|
||||||
|
models = super().get_provider_models(provider, api_key)
|
||||||
if models is None:
|
if models is None:
|
||||||
return 404, "Provider not found"
|
return 404, "Provider not found"
|
||||||
return models
|
return models
|
||||||
|
@ -11,7 +11,9 @@ class CloudflareError(ResponseStatusError):
|
|||||||
...
|
...
|
||||||
|
|
||||||
def is_cloudflare(text: str) -> bool:
|
def is_cloudflare(text: str) -> bool:
|
||||||
if "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
|
if "Generated by cloudfront" in text:
|
||||||
|
return True
|
||||||
|
elif "<title>Attention Required! | Cloudflare</title>" in text or 'id="cf-cloudflare-status"' in text:
|
||||||
return True
|
return True
|
||||||
return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text
|
return '<div id="cf-please-wait">' in text or "<title>Just a moment...</title>" in text
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user