Add ReplicateImage Provider, Fix BingCreateImages Provider

This commit is contained in:
Heiner Lohaus 2024-04-11 02:40:30 +02:00
parent 13a09033fb
commit 009a67239a
7 changed files with 254 additions and 60 deletions

View File

@ -7,16 +7,33 @@ from typing import Iterator, Union
from ..cookies import get_cookies
from ..image import ImageResponse
from ..errors import MissingRequirementsError, MissingAuthError
from ..typing import Cookies
from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .bing.create_images import create_images, create_session, get_cookies_from_browser
class BingCreateImages:
"""A class for creating images using Bing."""
class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://www.bing.com/images/create"
working = True
def __init__(self, cookies: Cookies = None, proxy: str = None) -> None:
self.cookies: Cookies = cookies
self.proxy: str = proxy
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_key: str = None,
cookies: Cookies = None,
proxy: str = None,
**kwargs
) -> AsyncResult:
if api_key is not None:
cookies = {"_U": api_key}
session = BingCreateImages(cookies, proxy)
yield await session.create_async(messages[-1]["content"])
def create(self, prompt: str) -> Iterator[Union[ImageResponse, str]]:
"""
Generator for creating imagecompletion based on a prompt.

View File

@ -0,0 +1,96 @@
from __future__ import annotations
import random
import asyncio
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..typing import AsyncResult, Messages
from ..requests import StreamSession, raise_for_status
from ..image import ImageResponse
from ..errors import ResponseError
class ReplicateImage(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://replicate.com"
working = True
default_model = 'stability-ai/sdxl'
default_versions = [
"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
"2b017d9b67edd2ee1401238df49d75da53c523f36e363881e057f5dc3ed3c5b2"
]
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
**kwargs
) -> AsyncResult:
yield await cls.create_async(messages[-1]["content"], model, **kwargs)
@classmethod
async def create_async(
cls,
prompt: str,
model: str,
api_key: str = None,
proxy: str = None,
timeout: int = 180,
version: str = None,
extra_data: dict = {},
**kwargs
) -> ImageResponse:
headers = {
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US',
'Connection': 'keep-alive',
'Origin': cls.url,
'Referer': f'{cls.url}/',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-site',
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
'sec-ch-ua': '"Google Chrome";v="119", "Chromium";v="119", "Not?A_Brand";v="24"',
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"macOS"',
}
if version is None:
version = random.choice(cls.default_versions)
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
async with StreamSession(
proxies={"all": proxy},
headers=headers,
timeout=timeout
) as session:
data = {
"input": {
"prompt": prompt,
**extra_data
},
"version": version
}
if api_key is None:
data["model"] = cls.get_model(model)
url = "https://homepage.replicate.com/api/prediction"
else:
url = "https://api.replicate.com/v1/predictions"
async with session.post(url, json=data) as response:
await raise_for_status(response)
result = await response.json()
if "id" not in result:
raise ResponseError(f"Invalid response: {result}")
while True:
if api_key is None:
url = f"https://homepage.replicate.com/api/poll?id={result['id']}"
else:
url = f"https://api.replicate.com/v1/predictions/{result['id']}"
async with session.get(url) as response:
await raise_for_status(response)
result = await response.json()
if "status" not in result:
raise ResponseError(f"Invalid response: {result}")
if result["status"] == "succeeded":
images = result['output']
images = images[0] if len(images) == 1 else images
return ImageResponse(images, prompt)
await asyncio.sleep(0.5)

View File

@ -151,7 +151,7 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None,
if response.status != 200:
raise RuntimeError(f"Polling images faild. Code: {response.status}")
text = await response.text()
if not text:
if not text or "GenerativeImagesStatusPage" in text:
await asyncio.sleep(1)
else:
break

View File

@ -0,0 +1,78 @@
from __future__ import annotations
import asyncio
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, filter_none
from ...typing import AsyncResult, Messages
from ...requests import StreamSession, raise_for_status
from ...image import ImageResponse
from ...errors import ResponseError, MissingAuthError
class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://replicate.com"
working = True
default_model = "mistralai/mixtral-8x7b-instruct-v0.1"
api_base = "https://api.replicate.com/v1/models/"
@classmethod
async def create_async_generator(
cls,
model: str,
messages: Messages,
api_key: str = None,
proxy: str = None,
timeout: int = 180,
system_prompt: str = None,
max_new_tokens: int = None,
temperature: float = None,
top_p: float = None,
top_k: float = None,
stop: list = None,
extra_data: dict = {},
headers: dict = {},
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
if api_key is None:
raise MissingAuthError("api_key is missing")
headers["Authorization"] = f"Bearer {api_key}"
async with StreamSession(
proxies={"all": proxy},
headers=headers,
timeout=timeout
) as session:
data = {
"stream": True,
"input": {
"prompt": format_prompt(messages),
**filter_none(
system_prompt=system_prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stop_sequences=",".join(stop) if stop else None
),
**extra_data
},
}
url = f"{cls.api_base.rstrip('/')}/{model}/predictions"
async with session.post(url, json=data) as response:
await raise_for_status(response)
result = await response.json()
if "id" not in result:
raise ResponseError(f"Invalid response: {result}")
async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response:
await raise_for_status(response)
event = None
async for line in response.iter_lines():
if line.startswith(b"event: "):
event = line[7:]
elif event == b"output":
if line.startswith(b"data: "):
yield line[6:].decode()
elif not line.startswith(b"id: "):
continue#yield "+"+line.decode()
elif event == b"done":
break

View File

@ -58,6 +58,12 @@
</button>
</div>
<div class="bottom_buttons">
<!--
<button onclick="open_album();">
<i class="fa-solid fa-toolbox"></i>
<span>Images Album</span>
</button>
-->
<button onclick="open_settings();">
<i class="fa-solid fa-toolbox"></i>
<span>Open Settings</span>
@ -77,6 +83,9 @@
<span id="version_text" class="convo-title"></span>
</div>
</div>
</div>
<div class="images hidden">
</div>
<div class="settings hidden">
<div class="paper">
@ -101,7 +110,7 @@
<label for="auto_continue" class="toogle" title="Continue large responses in OpenaiChat"></label>
</div>
<div class="field box">
<label for="message-input-height" class="label" title="">Input max. grow height</label>
<label for="message-input-height" class="label" title="">Input max. height</label>
<input type="number" id="message-input-height" value="200"/>
</div>
<div class="field box">
@ -109,40 +118,40 @@
<input type="text" id="recognition-language" value="" placeholder="navigator.language"/>
</div>
<div class="field box">
<label for="OpenaiChat-api_key" class="label" title="">OpenaiChat: api_key</label>
<textarea id="OpenaiChat-api_key" name="OpenaiChat[api_key]" placeholder="..."></textarea>
<label for="Bing-api_key" class="label" title="">Bing:</label>
<textarea id="Bing-api_key" name="Bing[api_key]" class="BingCreateImages-api_key" placeholder="&quot;_U&quot; cookie"></textarea>
</div>
<div class="field box">
<label for="Bing-api_key" class="label" title="">Bing: "_U" cookie</label>
<textarea id="Bing-api_key" name="Bing[api_key]" placeholder="..."></textarea>
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="Gemini-api_key" class="label" title="">Gemini: Cookies</label>
<textarea id="Gemini-api_key" name="Gemini[api_key]" placeholder="..."></textarea>
<label for="Gemini-api_key" class="label" title="">Gemini:</label>
<textarea id="Gemini-api_key" name="Gemini[api_key]" placeholder="Cookies"></textarea>
</div>
<div class="field box">
<label for="Openai-api_key" class="label" title="">Openai: api_key</label>
<textarea id="Openai-api_key" name="Openai[api_key]" placeholder="..."></textarea>
<label for="GeminiPro-api_key" class="label" title="">GeminiPro:</label>
<textarea id="GeminiPro-api_key" name="GeminiPro[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="Groq-api_key" class="label" title="">Groq: api_key</label>
<textarea id="Groq-api_key" name="Groq[api_key]" placeholder="..."></textarea>
<label for="Groq-api_key" class="label" title="">Groq:</label>
<textarea id="Groq-api_key" name="Groq[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="GeminiPro-api_key" class="label" title="">GeminiPro: api_key</label>
<textarea id="GeminiPro-api_key" name="GeminiPro[api_key]" placeholder="..."></textarea>
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="OpenRouter-api_key" class="label" title="">OpenRouter: api_key</label>
<textarea id="OpenRouter-api_key" name="OpenRouter[api_key]" placeholder="..."></textarea>
<label for="Openai-api_key" class="label" title="">Openai:</label>
<textarea id="Openai-api_key" name="Openai[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="HuggingFace-api_key" class="label" title="">HuggingFace: api_key</label>
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" placeholder="..."></textarea>
<label for="OpenaiChat-api_key" class="label" title="">OpenaiChat:</label>
<textarea id="OpenaiChat-api_key" name="OpenaiChat[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="DeepInfra-api_key" class="label" title="">DeepInfra: api_key</label>
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" placeholder="..."></textarea>
<label for="OpenRouter-api_key" class="label" title="">OpenRouter:</label>
<textarea id="OpenRouter-api_key" name="OpenRouter[api_key]" placeholder="api_key"></textarea>
</div>
</div>
<div class="bottom_buttons">

View File

@ -653,7 +653,7 @@ select {
font-size: 15px;
width: 100%;
color: var(--colour-3);
min-height: 50px;
min-height: 49px;
height: 59px;
outline: none;
padding: var(--inner-gap) var(--section-gap);
@ -809,7 +809,7 @@ ul {
}
.mobile-sidebar {
display: none !important;
display: none;
position: absolute;
z-index: 100000;
top: 0;
@ -850,12 +850,8 @@ ul {
gap: 15px;
}
.field {
width: fit-content;
}
.mobile-sidebar {
display: flex !important;
display: flex;
}
#systemPrompt {
@ -1090,7 +1086,13 @@ a:-webkit-any-link {
}
.settings textarea {
height: 51px;
height: 19px;
min-height: 19px;
padding: 0;
}
.settings .field.box {
padding: var(--inner-gap) var(--inner-gap) var(--inner-gap) 0;
}
.settings, .images {
@ -1112,7 +1114,6 @@ a:-webkit-any-link {
.settings textarea {
background-color: transparent;
border: none;
padding: var(--inner-gap) 0;
}
.settings input {
@ -1130,10 +1131,7 @@ a:-webkit-any-link {
.settings .label {
font-size: 15px;
padding: var(--inner-gap) 0;
width: fit-content;
min-width: 190px;
margin-left: var(--section-gap);
margin-left: var(--inner-gap);
white-space:nowrap;
}

View File

@ -179,12 +179,14 @@ const register_message_buttons = async () => {
}
const delete_conversations = async () => {
const remove_keys = [];
for (let i = 0; i < appStorage.length; i++){
let key = appStorage.key(i);
if (key.startsWith("conversation:")) {
appStorage.removeItem(key);
remove_keys.push(key);
}
}
remove_keys.forEach((key)=>appStorage.removeItem(key));
hide_sidebar();
await new_conversation();
};
@ -274,31 +276,21 @@ const prepare_messages = (messages, filter_last_message=true) => {
}
let new_messages = [];
if (messages) {
for (i in messages) {
new_message = messages[i];
// Remove generated images from history
new_message.content = new_message.content.replaceAll(
/<!-- generated images start -->[\s\S]+<!-- generated images end -->/gm,
""
)
delete new_message["provider"];
// Remove regenerated messages
if (!new_message.regenerate) {
new_messages.push(new_message)
}
}
}
// Add system message
system_content = systemPrompt?.value;
if (system_content) {
new_messages.unshift({
if (systemPrompt?.value) {
new_messages.push({
"role": "system",
"content": system_content
"content": systemPrompt.value
});
}
messages.forEach((new_message) => {
// Include only not regenerated messages
if (!new_message.regenerate) {
// Remove generated images from history
new_message.content = filter_message(new_message.content);
delete new_message.provider;
new_messages.push(new_message)
}
});
return new_messages;
}
@ -413,8 +405,11 @@ const ask_gpt = async () => {
if (file && !provider)
provider = "Bing";
let api_key = null;
if (provider)
if (provider) {
api_key = document.getElementById(`${provider}-api_key`)?.value || null;
if (api_key == null)
api_key = document.querySelector(`.${provider}-api_key`)?.value || null;
}
await api("conversation", {
id: window.token,
conversation_id: window.conversation_id,
@ -949,6 +944,7 @@ function count_chars(text) {
}
function count_words_and_tokens(text, model) {
text = filter_message(text);
return `(${count_words(text)} words, ${count_chars(text)} chars, ${count_tokens(model, text)} tokens)`;
}