mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-12-23 11:02:40 +03:00
Improved ignored providers support,
Add get_models to OpenaiAPI, HuggingFace and Groq Add xAI provider
This commit is contained in:
parent
3b7b79f5ba
commit
ff66df1486
@ -34,6 +34,7 @@ class Blackbox2(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
model: str,
|
||||
messages: Messages,
|
||||
proxy: str = None,
|
||||
prompt: str = None,
|
||||
max_retries: int = 3,
|
||||
delay: int = 1,
|
||||
**kwargs
|
||||
|
@ -123,11 +123,12 @@ class Copilot(AbstractProvider, ProviderModelMixin):
|
||||
prompt = format_prompt(messages)
|
||||
if len(prompt) > 10000:
|
||||
if len(messages) > 6:
|
||||
prompt = format_prompt(messages[:3]+messages[-3:])
|
||||
elif len(messages) > 2:
|
||||
prompt = format_prompt(messages[:2]+messages[-1:])
|
||||
prompt = format_prompt(messages[:3] + messages[-3:])
|
||||
if len(prompt) > 10000:
|
||||
prompt = messages[-1]["content"]
|
||||
if len(messages) > 2:
|
||||
prompt = format_prompt(messages[:2] + messages[-1:])
|
||||
if len(prompt) > 10000:
|
||||
prompt = messages[-1]["content"]
|
||||
debug.log(f"Copilot: Trim messages to: {len(prompt)}")
|
||||
debug.log(f"Copilot: Created conversation: {conversation_id}")
|
||||
else:
|
||||
|
@ -4,38 +4,26 @@ import json
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..requests.raise_for_status import raise_for_status
|
||||
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from .helper import format_prompt
|
||||
|
||||
|
||||
class DarkAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
url = "https://darkai.foundation/chat"
|
||||
api_endpoint = "https://darkai.foundation/chat"
|
||||
working = True
|
||||
supports_stream = True
|
||||
supports_system_message = True
|
||||
supports_message_history = True
|
||||
|
||||
|
||||
default_model = 'llama-3-70b'
|
||||
models = [
|
||||
'gpt-4o', # Uncensored
|
||||
'gpt-3.5-turbo', # Uncensored
|
||||
default_model,
|
||||
]
|
||||
|
||||
model_aliases = {
|
||||
"llama-3.1-70b": "llama-3-70b",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, model: str) -> str:
|
||||
if model in cls.models:
|
||||
return model
|
||||
elif model in cls.model_aliases:
|
||||
return cls.model_aliases[model]
|
||||
else:
|
||||
return cls.default_model
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
@ -45,7 +33,7 @@ class DarkAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
model = cls.get_model(model)
|
||||
|
||||
|
||||
headers = {
|
||||
"accept": "text/event-stream",
|
||||
"content-type": "application/json",
|
||||
@ -58,24 +46,24 @@ class DarkAI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
"model": model,
|
||||
}
|
||||
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
|
||||
response.raise_for_status()
|
||||
full_text = ""
|
||||
async for chunk in response.content:
|
||||
if chunk:
|
||||
await raise_for_status(response)
|
||||
first = True
|
||||
async for line in response.content:
|
||||
if line:
|
||||
try:
|
||||
chunk_str = chunk.decode().strip()
|
||||
if chunk_str.startswith('data: '):
|
||||
chunk_data = json.loads(chunk_str[6:])
|
||||
line_str = line.decode().strip()
|
||||
if line_str.startswith('data: '):
|
||||
chunk_data = json.loads(line_str[6:])
|
||||
if chunk_data['event'] == 'text-chunk':
|
||||
full_text += chunk_data['data']['text']
|
||||
chunk = chunk_data['data']['text']
|
||||
if first:
|
||||
chunk = chunk.lstrip()
|
||||
if chunk:
|
||||
first = False
|
||||
yield chunk
|
||||
elif chunk_data['event'] == 'stream-end':
|
||||
if full_text:
|
||||
yield full_text.strip()
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if full_text:
|
||||
yield full_text.strip()
|
||||
pass
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
@ -11,35 +10,21 @@ from ...cookies import get_cookies
|
||||
class Cerebras(OpenaiAPI):
|
||||
label = "Cerebras Inference"
|
||||
url = "https://inference.cerebras.ai/"
|
||||
api_base = "https://api.cerebras.ai/v1"
|
||||
working = True
|
||||
default_model = "llama3.1-70b"
|
||||
fallback_models = [
|
||||
models = [
|
||||
"llama3.1-70b",
|
||||
"llama3.1-8b",
|
||||
]
|
||||
model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, api_key: str = None):
|
||||
if not cls.models:
|
||||
try:
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers["authorization"] = f"Bearer ${api_key}"
|
||||
response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers)
|
||||
raise_for_status(response)
|
||||
data = response.json()
|
||||
cls.models = [model.get("model") for model in data.get("models")]
|
||||
except Exception:
|
||||
cls.models = cls.fallback_models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_base: str = "https://api.cerebras.ai/v1",
|
||||
api_base: str = api_base,
|
||||
api_key: str = None,
|
||||
cookies: Cookies = None,
|
||||
**kwargs
|
||||
@ -62,4 +47,4 @@ class Cerebras(OpenaiAPI):
|
||||
},
|
||||
**kwargs
|
||||
):
|
||||
yield chunk
|
||||
yield chunk
|
@ -6,9 +6,10 @@ from ...typing import AsyncResult, Messages
|
||||
class Groq(OpenaiAPI):
|
||||
label = "Groq"
|
||||
url = "https://console.groq.com/playground"
|
||||
api_base = "https://api.groq.com/openai/v1"
|
||||
working = True
|
||||
default_model = "mixtral-8x7b-32768"
|
||||
models = [
|
||||
fallback_models = [
|
||||
"distil-whisper-large-v3-en",
|
||||
"gemma2-9b-it",
|
||||
"gemma-7b-it",
|
||||
@ -35,9 +36,9 @@ class Groq(OpenaiAPI):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_base: str = "https://api.groq.com/openai/v1",
|
||||
api_base: str = api_base,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
return super().create_async_generator(
|
||||
model, messages, api_base=api_base, **kwargs
|
||||
)
|
||||
)
|
@ -6,8 +6,8 @@ import random
|
||||
import requests
|
||||
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
|
||||
from ...errors import ModelNotFoundError, ModelNotSupportedError
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
|
||||
from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...image import ImageResponse
|
||||
|
||||
@ -28,9 +28,11 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
cls.models = [model["id"] for model in requests.get(url).json()]
|
||||
cls.models.append("meta-llama/Llama-3.2-11B-Vision-Instruct")
|
||||
cls.models.append("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
|
||||
cls.models.sort()
|
||||
if not cls.image_models:
|
||||
url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
|
||||
cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
|
||||
cls.image_models.sort()
|
||||
cls.models.extend(cls.image_models)
|
||||
return cls.models
|
||||
|
||||
@ -89,19 +91,27 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
) as session:
|
||||
if payload is None:
|
||||
async with session.get(f"https://huggingface.co/api/models/{model}") as response:
|
||||
await raise_for_status(response)
|
||||
model_data = await response.json()
|
||||
if "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
|
||||
model_type = None
|
||||
if "config" in model_data and "model_type" in model_data["config"]:
|
||||
model_type = model_data["config"]["model_type"]
|
||||
if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
|
||||
inputs = format_prompt(messages)
|
||||
elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
|
||||
eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
|
||||
if eos_token == "</s>":
|
||||
inputs = format_prompt_mistral(messages)
|
||||
if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
|
||||
inputs = format_prompt_custom(messages, eos_token)
|
||||
elif eos_token == "<|im_end|>":
|
||||
inputs = format_prompt_qwen(messages)
|
||||
elif eos_token == "<|eot_id|>":
|
||||
inputs = format_prompt_llama(messages)
|
||||
else:
|
||||
inputs = format_prompt(messages)
|
||||
inputs = format_prompt_default(messages)
|
||||
else:
|
||||
inputs = format_prompt(messages)
|
||||
inputs = format_prompt_default(messages)
|
||||
if model_type == "gpt2" and max_new_tokens >= 1024:
|
||||
params["max_new_tokens"] = 512
|
||||
payload = {"inputs": inputs, "parameters": params, "stream": stream}
|
||||
|
||||
async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
|
||||
@ -113,6 +123,8 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
async for line in response.iter_lines():
|
||||
if line.startswith(b"data:"):
|
||||
data = json.loads(line[5:])
|
||||
if "error" in data:
|
||||
raise ResponseError(data["error"])
|
||||
if not data["token"]["special"]:
|
||||
chunk = data["token"]["text"]
|
||||
if first:
|
||||
@ -128,7 +140,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
else:
|
||||
yield (await response.json())[0]["generated_text"].strip()
|
||||
|
||||
def format_prompt(messages: Messages) -> str:
|
||||
def format_prompt_default(messages: Messages) -> str:
|
||||
system_messages = [message["content"] for message in messages if message["role"] == "system"]
|
||||
question = " ".join([messages[-1]["content"], *system_messages])
|
||||
history = "".join([
|
||||
@ -146,9 +158,9 @@ def format_prompt_qwen(messages: Messages) -> str:
|
||||
def format_prompt_llama(messages: Messages) -> str:
|
||||
return "<|begin_of_text|>" + "".join([
|
||||
f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}\n<|eot_id|>\n" for message in messages
|
||||
]) + "<|start_header_id|>assistant<|end_header_id|>\\n\\n"
|
||||
|
||||
def format_prompt_mistral(messages: Messages) -> str:
|
||||
]) + "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
def format_prompt_custom(messages: Messages, end_token: str = "</s>") -> str:
|
||||
return "".join([
|
||||
f"<|{message['role']}|>\n{message['content']}'</s>\n" for message in messages
|
||||
f"<|{message['role']}|>\n{message['content']}{end_token}\n" for message in messages
|
||||
]) + "<|assistant|>\n"
|
@ -7,6 +7,7 @@ from ...typing import AsyncResult, Messages
|
||||
class HuggingFaceAPI(OpenaiAPI):
|
||||
label = "HuggingFace (Inference API)"
|
||||
url = "https://api-inference.huggingface.co"
|
||||
api_base = "https://api-inference.huggingface.co/v1"
|
||||
working = True
|
||||
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
default_vision_model = default_model
|
||||
@ -19,7 +20,7 @@ class HuggingFaceAPI(OpenaiAPI):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_base: str = "https://api-inference.huggingface.co/v1",
|
||||
api_base: str = api_base,
|
||||
max_tokens: int = 500,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import requests
|
||||
|
||||
from ..helper import filter_none
|
||||
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
|
||||
@ -8,15 +9,35 @@ from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
from ...image import to_data_uri
|
||||
from ... import debug
|
||||
|
||||
class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
label = "OpenAI API"
|
||||
url = "https://platform.openai.com"
|
||||
api_base = "https://api.openai.com/v1"
|
||||
working = True
|
||||
needs_auth = True
|
||||
supports_message_history = True
|
||||
supports_system_message = True
|
||||
default_model = ""
|
||||
fallback_models = []
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, api_key: str = None):
|
||||
if not cls.models:
|
||||
try:
|
||||
headers = {}
|
||||
if api_key is not None:
|
||||
headers["authorization"] = f"Bearer {api_key}"
|
||||
response = requests.get(f"{cls.api_base}/models", headers=headers)
|
||||
raise_for_status(response)
|
||||
data = response.json()
|
||||
cls.models = [model.get("id") for model in data.get("data")]
|
||||
cls.models.sort()
|
||||
except Exception as e:
|
||||
debug.log(e)
|
||||
cls.models = cls.fallback_models
|
||||
return cls.models
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
@ -27,7 +48,7 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
timeout: int = 120,
|
||||
images: ImagesType = None,
|
||||
api_key: str = None,
|
||||
api_base: str = "https://api.openai.com/v1",
|
||||
api_base: str = api_base,
|
||||
temperature: float = None,
|
||||
max_tokens: int = None,
|
||||
top_p: float = None,
|
||||
@ -47,14 +68,14 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
*[{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": to_data_uri(image)}
|
||||
} for image, image_name in images],
|
||||
} for image, _ in images],
|
||||
{
|
||||
"type": "text",
|
||||
"text": messages[-1]["content"]
|
||||
}
|
||||
]
|
||||
async with StreamSession(
|
||||
proxies={"all": proxy},
|
||||
proxy=proxy,
|
||||
headers=cls.get_headers(stream, api_key, headers),
|
||||
timeout=timeout,
|
||||
impersonate=impersonate,
|
||||
@ -111,7 +132,10 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
|
||||
if "error_message" in data:
|
||||
raise ResponseError(data["error_message"])
|
||||
elif "error" in data:
|
||||
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
|
||||
if "code" in data["error"]:
|
||||
raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
|
||||
else:
|
||||
raise ResponseError(data["error"]["message"])
|
||||
|
||||
@classmethod
|
||||
def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
|
||||
|
@ -26,3 +26,4 @@ from .Replicate import Replicate
|
||||
from .Theb import Theb
|
||||
from .ThebApi import ThebApi
|
||||
from .WhiteRabbitNeo import WhiteRabbitNeo
|
||||
from .xAI import xAI
|
22
g4f/Provider/needs_auth/xAI.py
Normal file
22
g4f/Provider/needs_auth/xAI.py
Normal file
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .OpenaiAPI import OpenaiAPI
|
||||
from ...typing import AsyncResult, Messages
|
||||
|
||||
class xAI(OpenaiAPI):
|
||||
label = "xAI"
|
||||
url = "https://console.x.ai"
|
||||
api_base = "https://api.x.ai/v1"
|
||||
working = True
|
||||
|
||||
@classmethod
|
||||
def create_async_generator(
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
api_base: str = api_base,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
return super().create_async_generator(
|
||||
model, messages, api_base=api_base, **kwargs
|
||||
)
|
@ -2,16 +2,15 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, Coroutine
|
||||
|
||||
from . import debug, version
|
||||
from .models import Model
|
||||
from .client import Client, AsyncClient
|
||||
from .typing import Messages, CreateResult, AsyncResult, ImageType
|
||||
from .errors import StreamNotSupportedError, ModelNotAllowedError
|
||||
from .errors import StreamNotSupportedError
|
||||
from .cookies import get_cookies, set_cookies
|
||||
from .providers.types import ProviderType
|
||||
from .providers.base_provider import AsyncGeneratorProvider
|
||||
from .client.service import get_model_and_provider, get_last_provider
|
||||
|
||||
#Configure "g4f" logger
|
||||
@ -30,14 +29,13 @@ class ChatCompletion:
|
||||
stream : bool = False,
|
||||
image : ImageType = None,
|
||||
image_name: Optional[str] = None,
|
||||
ignored: list[str] = None,
|
||||
ignore_working: bool = False,
|
||||
ignore_stream: bool = False,
|
||||
**kwargs) -> Union[CreateResult, str]:
|
||||
model, provider = get_model_and_provider(
|
||||
model, provider, stream,
|
||||
ignored, ignore_working,
|
||||
ignore_stream or kwargs.get("ignore_stream_and_auth")
|
||||
ignore_working,
|
||||
ignore_stream
|
||||
)
|
||||
if image is not None:
|
||||
kwargs["images"] = [(image, image_name)]
|
||||
@ -55,10 +53,9 @@ class ChatCompletion:
|
||||
messages : Messages,
|
||||
provider : Union[ProviderType, str, None] = None,
|
||||
stream : bool = False,
|
||||
ignored : list[str] = None,
|
||||
ignore_working: bool = False,
|
||||
**kwargs) -> Union[AsyncResult, str]:
|
||||
model, provider = get_model_and_provider(model, provider, False, ignored, ignore_working)
|
||||
**kwargs) -> Union[AsyncResult, Coroutine[str]]:
|
||||
model, provider = get_model_and_provider(model, provider, False, ignore_working)
|
||||
|
||||
if stream:
|
||||
if hasattr(provider, "create_async_generator"):
|
||||
|
@ -84,6 +84,11 @@ def create_app():
|
||||
if not AppConfig.ignore_cookie_files:
|
||||
read_cookie_files()
|
||||
|
||||
if AppConfig.ignored_providers:
|
||||
for provider in AppConfig.ignored_providers:
|
||||
if provider in ProviderUtils.convert:
|
||||
ProviderUtils.convert[provider].working = False
|
||||
|
||||
return app
|
||||
|
||||
def create_app_debug():
|
||||
@ -151,7 +156,7 @@ class ErrorResponseMessageModel(BaseModel):
|
||||
|
||||
class FileResponseModel(BaseModel):
|
||||
filename: str
|
||||
|
||||
|
||||
class ErrorResponse(Response):
|
||||
media_type = "application/json"
|
||||
|
||||
@ -183,12 +188,6 @@ class AppConfig:
|
||||
for key, value in data.items():
|
||||
setattr(cls, key, value)
|
||||
|
||||
list_ignored_providers: list[str] = None
|
||||
|
||||
def set_list_ignored_providers(ignored: list[str]):
|
||||
global list_ignored_providers
|
||||
list_ignored_providers = ignored
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI) -> None:
|
||||
self.app = app
|
||||
|
@ -24,7 +24,6 @@ def convert_to_provider(provider: str) -> ProviderType:
|
||||
def get_model_and_provider(model : Union[Model, str],
|
||||
provider : Union[ProviderType, str, None],
|
||||
stream : bool,
|
||||
ignored : list[str] = None,
|
||||
ignore_working: bool = False,
|
||||
ignore_stream: bool = False,
|
||||
logging: bool = True) -> tuple[str, ProviderType]:
|
||||
@ -58,7 +57,7 @@ def get_model_and_provider(model : Union[Model, str],
|
||||
if isinstance(model, str):
|
||||
if model in ModelUtils.convert:
|
||||
model = ModelUtils.convert[model]
|
||||
|
||||
|
||||
if not provider:
|
||||
if not model:
|
||||
model = default
|
||||
@ -66,7 +65,7 @@ def get_model_and_provider(model : Union[Model, str],
|
||||
elif isinstance(model, str):
|
||||
if model in ProviderUtils.convert:
|
||||
provider = ProviderUtils.convert[model]
|
||||
model = provider.default_model if hasattr(provider, "default_model") else ""
|
||||
model = getattr(provider, "default_model", "")
|
||||
else:
|
||||
raise ModelNotFoundError(f'Model not found: {model}')
|
||||
elif isinstance(model, Model):
|
||||
@ -87,8 +86,6 @@ def get_model_and_provider(model : Union[Model, str],
|
||||
if isinstance(provider, BaseRetryProvider):
|
||||
if not ignore_working:
|
||||
provider.providers = [p for p in provider.providers if p.working]
|
||||
if ignored:
|
||||
provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
|
||||
|
||||
if not ignore_stream and not provider.supports_stream and stream:
|
||||
raise StreamNotSupportedError(f'{provider_name} does not support "stream" argument')
|
||||
|
@ -16,7 +16,7 @@ try:
|
||||
_LinuxPasswordManager, BrowserCookieError
|
||||
)
|
||||
|
||||
def _g4f(domain_name: str) -> list:
|
||||
def g4f(domain_name: str) -> list:
|
||||
"""
|
||||
Load cookies from the 'g4f' browser (if exists).
|
||||
|
||||
@ -33,7 +33,7 @@ try:
|
||||
return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name)
|
||||
|
||||
browsers = [
|
||||
_g4f,
|
||||
g4f,
|
||||
chrome, chromium, firefox, opera, opera_gx,
|
||||
brave, edge, vivaldi,
|
||||
]
|
||||
@ -57,7 +57,8 @@ DOMAINS = [
|
||||
"www.whiterabbitneo.com",
|
||||
"huggingface.co",
|
||||
"chat.reka.ai",
|
||||
"chatgpt.com"
|
||||
"chatgpt.com",
|
||||
".cerebras.ai",
|
||||
]
|
||||
|
||||
if has_browser_cookie3 and os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
|
||||
@ -126,6 +127,11 @@ def get_cookies_dir() -> str:
|
||||
return CookiesConfig.cookies_dir
|
||||
|
||||
def read_cookie_files(dirPath: str = None):
|
||||
dirPath = CookiesConfig.cookies_dir if dirPath is None else dirPath
|
||||
if not os.access(dirPath, os.R_OK):
|
||||
debug.log(f"Read cookies: {dirPath} dir is not readable")
|
||||
return
|
||||
|
||||
def get_domain(v: dict) -> str:
|
||||
host = [h["value"] for h in v['request']['headers'] if h["name"].lower() in ("host", ":authority")]
|
||||
if not host:
|
||||
@ -137,7 +143,7 @@ def read_cookie_files(dirPath: str = None):
|
||||
|
||||
harFiles = []
|
||||
cookieFiles = []
|
||||
for root, _, files in os.walk(CookiesConfig.cookies_dir if dirPath is None else dirPath):
|
||||
for root, _, files in os.walk(dirPath):
|
||||
for file in files:
|
||||
if file.endswith(".har"):
|
||||
harFiles.append(os.path.join(root, file))
|
||||
@ -152,8 +158,7 @@ def read_cookie_files(dirPath: str = None):
|
||||
except json.JSONDecodeError:
|
||||
# Error: not a HAR file!
|
||||
continue
|
||||
if debug.logging:
|
||||
print("Read .har file:", path)
|
||||
debug.log(f"Read .har file: {path}")
|
||||
new_cookies = {}
|
||||
for v in harFile['log']['entries']:
|
||||
domain = get_domain(v)
|
||||
@ -165,9 +170,8 @@ def read_cookie_files(dirPath: str = None):
|
||||
if len(v_cookies) > 0:
|
||||
CookiesConfig.cookies[domain] = v_cookies
|
||||
new_cookies[domain] = len(v_cookies)
|
||||
if debug.logging:
|
||||
for domain, new_values in new_cookies.items():
|
||||
print(f"Cookies added: {new_values} from {domain}")
|
||||
for domain, new_values in new_cookies.items():
|
||||
debug.log(f"Cookies added: {new_values} from {domain}")
|
||||
for path in cookieFiles:
|
||||
with open(path, 'rb') as file:
|
||||
try:
|
||||
@ -177,8 +181,7 @@ def read_cookie_files(dirPath: str = None):
|
||||
continue
|
||||
if not isinstance(cookieFile, list):
|
||||
continue
|
||||
if debug.logging:
|
||||
print("Read cookie file:", path)
|
||||
debug.log(f"Read cookie file: {path}")
|
||||
new_cookies = {}
|
||||
for c in cookieFile:
|
||||
if isinstance(c, dict) and "domain" in c:
|
||||
@ -186,6 +189,5 @@ def read_cookie_files(dirPath: str = None):
|
||||
new_cookies[c["domain"]] = {}
|
||||
new_cookies[c["domain"]][c["name"]] = c["value"]
|
||||
for domain, new_values in new_cookies.items():
|
||||
if debug.logging:
|
||||
print(f"Cookies added: {len(new_values)} from {domain}")
|
||||
debug.log(f"Cookies added: {len(new_values)} from {domain}")
|
||||
CookiesConfig.cookies[domain] = new_values
|
@ -172,7 +172,7 @@
|
||||
</div>
|
||||
<div class="field box">
|
||||
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
|
||||
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" class="HuggingFace2-api_key" placeholder="api_key"></textarea>
|
||||
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" class="HuggingFaceAPI-api_key" placeholder="api_key"></textarea>
|
||||
</div>
|
||||
<div class="field box">
|
||||
<label for="Openai-api_key" class="label" title="">OpenAI API:</label>
|
||||
@ -190,6 +190,10 @@
|
||||
<label for="Replicate-api_key" class="label" title="">Replicate:</label>
|
||||
<textarea id="Replicate-api_key" name="Replicate[api_key]" class="ReplicateImage-api_key" placeholder="api_key"></textarea>
|
||||
</div>
|
||||
<div class="field box">
|
||||
<label for="xAI-api_key" class="label" title="">xAI:</label>
|
||||
<textarea id="xAI-api_key" name="xAI[api_key]" placeholder="api_key"></textarea>
|
||||
</div>
|
||||
</div>
|
||||
<div class="bottom_buttons">
|
||||
<button onclick="delete_conversations()">
|
||||
|
@ -34,6 +34,7 @@ class IterListProvider(BaseRetryProvider):
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
ignored: list[str] = [],
|
||||
**kwargs,
|
||||
) -> CreateResult:
|
||||
"""
|
||||
@ -50,7 +51,7 @@ class IterListProvider(BaseRetryProvider):
|
||||
exceptions = {}
|
||||
started: bool = False
|
||||
|
||||
for provider in self.get_providers(stream):
|
||||
for provider in self.get_providers(stream, ignored):
|
||||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
@ -62,8 +63,7 @@ class IterListProvider(BaseRetryProvider):
|
||||
return
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
if started:
|
||||
raise e
|
||||
|
||||
@ -73,6 +73,7 @@ class IterListProvider(BaseRetryProvider):
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
ignored: list[str] = [],
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
@ -87,7 +88,7 @@ class IterListProvider(BaseRetryProvider):
|
||||
"""
|
||||
exceptions = {}
|
||||
|
||||
for provider in self.get_providers(False):
|
||||
for provider in self.get_providers(False, ignored):
|
||||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
@ -99,28 +100,22 @@ class IterListProvider(BaseRetryProvider):
|
||||
return chunk
|
||||
except Exception as e:
|
||||
exceptions[provider.__name__] = e
|
||||
if debug.logging:
|
||||
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
def get_providers(self, stream: bool) -> list[ProviderType]:
|
||||
providers = [p for p in self.providers if p.supports_stream] if stream else self.providers
|
||||
if self.shuffle:
|
||||
random.shuffle(providers)
|
||||
return providers
|
||||
|
||||
async def create_async_generator(
|
||||
self,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
ignored: list[str] = [],
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
exceptions = {}
|
||||
started: bool = False
|
||||
|
||||
for provider in self.get_providers(stream):
|
||||
for provider in self.get_providers(stream, ignored):
|
||||
self.last_provider = provider
|
||||
debug.log(f"Using {provider.__name__} provider")
|
||||
try:
|
||||
@ -151,6 +146,12 @@ class IterListProvider(BaseRetryProvider):
|
||||
|
||||
raise_exceptions(exceptions)
|
||||
|
||||
def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
|
||||
providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
|
||||
if self.shuffle:
|
||||
random.shuffle(providers)
|
||||
return providers
|
||||
|
||||
class RetryProvider(IterListProvider):
|
||||
def __init__(
|
||||
self,
|
||||
@ -304,7 +305,7 @@ def raise_exceptions(exceptions: dict) -> None:
|
||||
"""
|
||||
if exceptions:
|
||||
raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
|
||||
f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in exceptions.items()
|
||||
f"{p}: {type(exception).__name__}: {exception}" for p, exception in exceptions.items()
|
||||
]))
|
||||
|
||||
raise RetryNoProviderError("No provider found")
|
Loading…
Reference in New Issue
Block a user