mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 22:13:41 +03:00
Merge remote-tracking branch 'origin/dev' into dev
This commit is contained in:
commit
bbedcc8e84
1
.gitignore
vendored
1
.gitignore
vendored
@ -64,3 +64,4 @@ condaenv.*.requirements.txt
|
|||||||
/gfpgan/*
|
/gfpgan/*
|
||||||
/models/*
|
/models/*
|
||||||
z_version_env.tmp
|
z_version_env.tmp
|
||||||
|
/user_data/*
|
||||||
|
@ -294,7 +294,7 @@ img2img:
|
|||||||
write_info_files: True
|
write_info_files: True
|
||||||
|
|
||||||
img2txt:
|
img2txt:
|
||||||
batch_size: 100
|
batch_size: 420
|
||||||
blip_image_eval_size: 512
|
blip_image_eval_size: 512
|
||||||
|
|
||||||
concepts_library:
|
concepts_library:
|
||||||
|
@ -24,7 +24,21 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
|||||||
cd $SCRIPT_DIR
|
cd $SCRIPT_DIR
|
||||||
export PYTHONPATH=$SCRIPT_DIR
|
export PYTHONPATH=$SCRIPT_DIR
|
||||||
|
|
||||||
MODEL_DIR="${SCRIPT_DIR}/model_cache"
|
if [[ $PUBLIC_KEY ]]
|
||||||
|
then
|
||||||
|
mkdir -p ~/.ssh
|
||||||
|
chmod 700 ~/.ssh
|
||||||
|
cd ~/.ssh
|
||||||
|
echo $PUBLIC_KEY >> authorized_keys
|
||||||
|
chmod 700 -R ~/.ssh
|
||||||
|
cd /
|
||||||
|
service ssh start
|
||||||
|
echo "SSH Service Started"
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_DIR="${SCRIPT_DIR}/user_data/model_cache"
|
||||||
|
mkdir -p $MODEL_DIR
|
||||||
# Array of model files to pre-download
|
# Array of model files to pre-download
|
||||||
# local filename
|
# local filename
|
||||||
# local path in container (no trailing slash)
|
# local path in container (no trailing slash)
|
||||||
@ -37,6 +51,17 @@ MODEL_FILES=(
|
|||||||
'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
|
'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
|
||||||
'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba'
|
'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba'
|
||||||
'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13'
|
'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13'
|
||||||
|
'waifu-diffusion.ckpt models/custom https://huggingface.co/crumb/pruned-waifu-diffusion/resolve/main/model-pruned.ckpt 9b31355f90fea9933847175d4731a033f49f861395addc7e153f480551a24c25'
|
||||||
|
'trinart.ckpt models/custom https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step95000.ckpt c1799d22a355ba25c9ceeb6e3c91fc61788c8e274b73508ae8a15877c5dbcf63'
|
||||||
|
'model__base_caption.pth models/blip https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth 96ac8749bd0a568c274ebe302b3a3748ab9be614c737f3d8c529697139174086'
|
||||||
|
'pytorch_model.bin models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin f1a17cdbe0f36fec524f5cafb1c261ea3bbbc13e346e0f74fc9eb0460dedd0d3'
|
||||||
|
'config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/config.json 8a09b467700c58138c29d53c605b34ebc69beaadd13274a8a2af8ad2c2f4032a'
|
||||||
|
'merges.txt models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/merges.txt 9fd691f7c8039210e0fced15865466c65820d09b63988b0174bfe25de299051a'
|
||||||
|
'preprocessor_config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/preprocessor_config.json 910e70b3956ac9879ebc90b22fb3bc8a75b6a0677814500101a4c072bd7857bd'
|
||||||
|
'special_tokens_map.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/special_tokens_map.json f8c0d6c39aee3f8431078ef6646567b0aba7f2246e9c54b8b99d55c22b707cbf'
|
||||||
|
'tokenizer.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/tokenizer.json a83e0809aa4c3af7208b2df632a7a69668c6d48775b3c3fe4e1b1199d1f8b8f4'
|
||||||
|
'tokenizer_config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/tokenizer_config.json deef455e52fa5e8151e339add0582e4235f066009601360999d3a9cda83b1129'
|
||||||
|
'vocab.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/vocab.json 3f0c4f7d2086b61b38487075278ea9ed04edb53a03cbb045b86c27190fa8fb69'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -83,33 +108,30 @@ else
|
|||||||
validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]}
|
validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]}
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
mkdir -p ${MODEL_DIR}/stable-diffusion-v1-4
|
||||||
|
mkdir -p ${MODEL_DIR}/waifu-diffusion
|
||||||
|
|
||||||
|
ln -fs ${SCRIPT_DIR}/models/clip-vit-large-patch14/ ${MODEL_DIR}/stable-diffusion-v1-4/tokenizer
|
||||||
|
ln -fs ${SCRIPT_DIR}/models/clip-vit-large-patch14/ ${MODEL_DIR}/waifu-diffusion/tokenizer
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Determine which webserver interface to launch (Streamlit vs Default: Gradio)
|
if [[ -e "${MODEL_DIR}/sd-concepts-library" ]]; then
|
||||||
if [[ ! -z $WEBUI_SCRIPT && $WEBUI_SCRIPT == "webui_streamlit.py" ]]; then
|
cd ${MODEL_DIR}/sd-concepts-library
|
||||||
launch_command="streamlit run scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
|
git pull
|
||||||
else
|
else
|
||||||
launch_command="python scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
|
cd ${MODEL_DIR}
|
||||||
|
git clone https://github.com/sd-webui/sd-concepts-library
|
||||||
fi
|
fi
|
||||||
|
mkdir -p ${SCRIPT_DIR}/models/custom
|
||||||
|
ln -fs ${MODEL_DIR}/sd-concepts-library/sd-concepts-library ${SCRIPT_DIR}/models/custom
|
||||||
|
|
||||||
# Start webserver interface
|
echo "export HF_HOME=${MODEL_DIR}" >> ~/.bashrc
|
||||||
launch_message="Starting Stable Diffusion WebUI... ${launch_command}..."
|
echo "export XDG_CACHE_HOME=${MODEL_DIR}" >> ~/.bashrc
|
||||||
if [[ -z $WEBUI_RELAUNCH || $WEBUI_RELAUNCH == "true" ]]; then
|
echo "export TRANSFORMERS_CACHE=${MODEL_DIR}" >> ~/.bashrc
|
||||||
n=0
|
source ~/.bashrc
|
||||||
while true; do
|
cd $SCRIPT_DIR
|
||||||
echo $launch_message
|
launch_command="streamlit run ${SCRIPT_DIR}/scripts/webui_streamlit.py"
|
||||||
|
|
||||||
if (( $n > 0 )); then
|
$launch_command
|
||||||
echo "Relaunch count: ${n}"
|
|
||||||
fi
|
|
||||||
|
|
||||||
$launch_command
|
sleep infinity
|
||||||
|
|
||||||
echo "entrypoint.sh: Process is ending. Relaunching in 0.5s..."
|
|
||||||
((n++))
|
|
||||||
sleep 0.5
|
|
||||||
done
|
|
||||||
else
|
|
||||||
echo $launch_message
|
|
||||||
$launch_command
|
|
||||||
fi
|
|
||||||
|
196
frontend/js/index.js
Normal file
196
frontend/js/index.js
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
window.SD = (() => {
|
||||||
|
/*
|
||||||
|
* Painterro is made a field of the SD global object
|
||||||
|
* To provide convinience when using w() method in css_and_js.py
|
||||||
|
*/
|
||||||
|
class PainterroClass {
|
||||||
|
static isOpen = false;
|
||||||
|
static async init ({ x, toId }) {
|
||||||
|
console.log(x)
|
||||||
|
|
||||||
|
const originalImage = x[2] === 'Mask' ? x[1]?.image : x[0];
|
||||||
|
|
||||||
|
if (window.Painterro === undefined) {
|
||||||
|
try {
|
||||||
|
await this.load();
|
||||||
|
} catch (e) {
|
||||||
|
SDClass.error(e);
|
||||||
|
|
||||||
|
return this.fallback(originalImage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.isOpen) {
|
||||||
|
return this.fallback(originalImage);
|
||||||
|
}
|
||||||
|
this.isOpen = true;
|
||||||
|
|
||||||
|
let resolveResult;
|
||||||
|
const paintClient = Painterro({
|
||||||
|
hiddenTools: ['arrow'],
|
||||||
|
onHide: () => {
|
||||||
|
resolveResult?.(null);
|
||||||
|
},
|
||||||
|
saveHandler: (image, done) => {
|
||||||
|
const data = image.asDataURL();
|
||||||
|
|
||||||
|
// ensures stable performance even
|
||||||
|
// when the editor is in interactive mode
|
||||||
|
SD.clearImageInput(SD.el.get(`#${toId}`));
|
||||||
|
|
||||||
|
resolveResult(data);
|
||||||
|
|
||||||
|
done(true);
|
||||||
|
paintClient.hide();
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await new Promise((resolve) => {
|
||||||
|
resolveResult = resolve;
|
||||||
|
paintClient.show(originalImage);
|
||||||
|
});
|
||||||
|
this.isOpen = false;
|
||||||
|
|
||||||
|
return result ? this.success(result) : this.fallback(originalImage);
|
||||||
|
}
|
||||||
|
static success (result) { return [result, { image: result, mask: result }] };
|
||||||
|
static fallback (image) { return [image, { image: image, mask: image }] };
|
||||||
|
static load () {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const scriptId = '__painterro-script';
|
||||||
|
if (document.getElementById(scriptId)) {
|
||||||
|
reject(new Error('Tried to load painterro script, but script tag already exists.'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const styleId = '__painterro-css-override';
|
||||||
|
if (!document.getElementById(styleId)) {
|
||||||
|
/* Ensure Painterro window is always on top */
|
||||||
|
const style = document.createElement('style');
|
||||||
|
style.id = styleId;
|
||||||
|
style.setAttribute('type', 'text/css');
|
||||||
|
style.appendChild(document.createTextNode(`
|
||||||
|
.ptro-holder-wrapper {
|
||||||
|
z-index: 100;
|
||||||
|
}
|
||||||
|
`));
|
||||||
|
document.head.appendChild(style);
|
||||||
|
}
|
||||||
|
|
||||||
|
const script = document.createElement('script');
|
||||||
|
script.id = scriptId;
|
||||||
|
script.src = 'https://unpkg.com/painterro@1.2.78/build/painterro.min.js';
|
||||||
|
script.onload = () => resolve(true);
|
||||||
|
script.onerror = (e) => {
|
||||||
|
// remove self on error to enable reattempting load
|
||||||
|
document.head.removeChild(script);
|
||||||
|
reject(e);
|
||||||
|
};
|
||||||
|
document.head.appendChild(script);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Turns out caching elements doesn't actually work in gradio
|
||||||
|
* As elements in tabs might get recreated
|
||||||
|
*/
|
||||||
|
class ElementCache {
|
||||||
|
#el;
|
||||||
|
constructor () {
|
||||||
|
this.root = document.querySelector('gradio-app').shadowRoot;
|
||||||
|
}
|
||||||
|
get (selector) {
|
||||||
|
return this.root.querySelector(selector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* The main helper class to incapsulate functions
|
||||||
|
* that change gradio ui functionality
|
||||||
|
*/
|
||||||
|
class SDClass {
|
||||||
|
el = new ElementCache();
|
||||||
|
Painterro = PainterroClass;
|
||||||
|
moveImageFromGallery ({ x, fromId, toId }) {
|
||||||
|
x = x[0];
|
||||||
|
if (!Array.isArray(x) || x.length === 0) return;
|
||||||
|
|
||||||
|
this.clearImageInput(this.el.get(`#${toId}`));
|
||||||
|
|
||||||
|
const i = this.#getGallerySelectedIndex(this.el.get(`#${fromId}`));
|
||||||
|
|
||||||
|
return [x[i].replace('data:;','data:image/png;')];
|
||||||
|
}
|
||||||
|
async copyImageFromGalleryToClipboard ({ x, fromId }) {
|
||||||
|
x = x[0];
|
||||||
|
if (!Array.isArray(x) || x.length === 0) return;
|
||||||
|
|
||||||
|
const i = this.#getGallerySelectedIndex(this.el.get(`#${fromId}`));
|
||||||
|
|
||||||
|
const data = x[i];
|
||||||
|
const blob = await (await fetch(data.replace('data:;','data:image/png;'))).blob();
|
||||||
|
const item = new ClipboardItem({'image/png': blob});
|
||||||
|
|
||||||
|
await this.copyToClipboard([item]);
|
||||||
|
}
|
||||||
|
async copyFullOutput ({ fromId }) {
|
||||||
|
const textField = this.el.get(`#${fromId} .textfield`);
|
||||||
|
if (!textField) {
|
||||||
|
SDclass.error(new Error(`Can't find textfield with the output!`));
|
||||||
|
}
|
||||||
|
|
||||||
|
const value = textField.textContent.replace(/\s+/g,' ').replace(/: /g,':');
|
||||||
|
|
||||||
|
await this.copyToClipboard(value)
|
||||||
|
}
|
||||||
|
clickFirstVisibleButton({ rowId }) {
|
||||||
|
const generateButtons = this.el.get(`#${rowId}`).querySelectorAll('.gr-button-primary');
|
||||||
|
|
||||||
|
if (!generateButtons) return;
|
||||||
|
|
||||||
|
for (let i = 0, arr = [...generateButtons]; i < arr.length; i++) {
|
||||||
|
const cs = window.getComputedStyle(arr[i]);
|
||||||
|
|
||||||
|
if (cs.display !== 'none' && cs.visibility !== 'hidden') {
|
||||||
|
console.log(arr[i]);
|
||||||
|
|
||||||
|
arr[i].click();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async gradioInputToClipboard ({ x }) { return this.copyToClipboard(x[0]); }
|
||||||
|
async copyToClipboard (value) {
|
||||||
|
if (!value || typeof value === 'boolean') return;
|
||||||
|
try {
|
||||||
|
if (Array.isArray(value) &&
|
||||||
|
value.length &&
|
||||||
|
value[0] instanceof ClipboardItem) {
|
||||||
|
await navigator.clipboard.write(value);
|
||||||
|
} else {
|
||||||
|
await navigator.clipboard.writeText(value);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
SDClass.error(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static error (e) {
|
||||||
|
console.error(e);
|
||||||
|
if (typeof e === 'string') {
|
||||||
|
alert(e);
|
||||||
|
} else if(typeof e === 'object' && Object.hasOwn(e, 'message')) {
|
||||||
|
alert(e.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clearImageInput (imageEditor) {
|
||||||
|
imageEditor?.querySelector('.modify-upload button:last-child')?.click();
|
||||||
|
}
|
||||||
|
#getGallerySelectedIndex (gallery) {
|
||||||
|
const selected = gallery.querySelector(`.\\!ring-2`);
|
||||||
|
return selected ? [...selected.parentNode.children].indexOf(selected) : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new SDClass();
|
||||||
|
})();
|
@ -5,7 +5,7 @@ import clip
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
import kornia
|
import kornia
|
||||||
|
import os
|
||||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||||
|
|
||||||
|
|
||||||
@ -138,8 +138,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
if os.path.exists("models/clip-vit-large-patch14"):
|
||||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
self.tokenizer = CLIPTokenizer.from_pretrained("models/clip-vit-large-patch14")
|
||||||
|
self.transformer = CLIPTextModel.from_pretrained("models/clip-vit-large-patch14")
|
||||||
|
else:
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
@ -61,7 +61,8 @@ from ldm.models.blip import blip_decoder
|
|||||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
blip_image_eval_size = 512
|
blip_image_eval_size = 512
|
||||||
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
||||||
|
server_state["clip_models"] = {}
|
||||||
|
server_state["preprocesses"] = {}
|
||||||
|
|
||||||
def load_blip_model():
|
def load_blip_model():
|
||||||
print("Loading BLIP Model")
|
print("Loading BLIP Model")
|
||||||
@ -219,61 +220,60 @@ def interrogate(image, models):
|
|||||||
#print (st.session_state["log_message"])
|
#print (st.session_state["log_message"])
|
||||||
|
|
||||||
for model_name in models:
|
for model_name in models:
|
||||||
print(f"Interrogating with {model_name}...")
|
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
|
||||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
print(f"Interrogating with {model_name}...")
|
||||||
|
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||||
|
|
||||||
if "clip_model" not in server_state:
|
if model_name not in server_state["clip_models"]:
|
||||||
#with server_state_lock[server_state["clip_model"]]:
|
if model_name == 'ViT-H-14':
|
||||||
if model_name == 'ViT-H-14':
|
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='user_data/model_cache/clip')
|
||||||
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k')
|
elif model_name == 'ViT-g-14':
|
||||||
elif model_name == 'ViT-g-14':
|
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='user_data/model_cache/clip')
|
||||||
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k')
|
else:
|
||||||
else:
|
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = clip.load(model_name, device=device, download_root='user_data/model_cache/clip')
|
||||||
server_state["clip_model"], server_state["preprocess"] = clip.load(model_name, device=device)
|
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
|
||||||
|
|
||||||
server_state["clip_model"] = server_state["clip_model"].cuda().eval()
|
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
|
||||||
|
|
||||||
images = server_state["preprocess"](image).unsqueeze(0).cuda()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
image_features = server_state["clip_models"][model_name].encode_image(images).float()
|
||||||
image_features = server_state["clip_model"].encode_image(images).float()
|
|
||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
if st.session_state["defaults"].general.optimized:
|
if st.session_state["defaults"].general.optimized:
|
||||||
clear_cuda()
|
clear_cuda()
|
||||||
|
|
||||||
ranks = []
|
ranks = []
|
||||||
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["mediums"]))
|
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["mediums"]))
|
||||||
ranks.append(batch_rank(server_state["clip_model"], image_features, ["by "+artist for artist in server_state["artists"]]))
|
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, ["by "+artist for artist in server_state["artists"]]))
|
||||||
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["trending_list"]))
|
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
|
||||||
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["movements"]))
|
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
|
||||||
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["flavors"]))
|
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
|
||||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["genres"]))
|
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"]))
|
||||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["styles"]))
|
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"]))
|
||||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["techniques"]))
|
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
|
||||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["subjects"]))
|
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"]))
|
||||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["colors"]))
|
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"]))
|
||||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["moods"]))
|
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"]))
|
||||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["themes"]))
|
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"]))
|
||||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["keywords"]))
|
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"]))
|
||||||
|
|
||||||
for i in range(len(ranks)):
|
for i in range(len(ranks)):
|
||||||
confidence_sum = 0
|
confidence_sum = 0
|
||||||
for ci in range(len(ranks[i])):
|
for ci in range(len(ranks[i])):
|
||||||
confidence_sum += ranks[i][ci][1]
|
confidence_sum += ranks[i][ci][1]
|
||||||
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
|
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
|
||||||
bests[i] = ranks[i]
|
bests[i] = ranks[i]
|
||||||
|
|
||||||
row = [model_name]
|
row = [model_name]
|
||||||
for r in ranks:
|
for r in ranks:
|
||||||
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
|
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
|
||||||
|
|
||||||
table.append(row)
|
table.append(row)
|
||||||
|
|
||||||
if st.session_state["defaults"].general.optimized:
|
if st.session_state["defaults"].general.optimized:
|
||||||
del server_state["clip_model"]
|
del server_state["clip_models"][model_name]
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# for i in range(len(st.session_state["uploaded_image"])):
|
# for i in range(len(st.session_state["uploaded_image"])):
|
||||||
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
|
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
|
||||||
|
@ -230,46 +230,33 @@ def load_diffusers_model(weights_path,torch_device):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with server_state_lock["pipe"]:
|
with server_state_lock["pipe"]:
|
||||||
try:
|
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
|
||||||
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
|
if ("weights_path" in st.session_state) and st.session_state["weights_path"] != weights_path:
|
||||||
if st.session_state["weights_path"] != weights_path:
|
del st.session_state["weights_path"]
|
||||||
del st.session_state["weights_path"]
|
|
||||||
|
|
||||||
st.session_state["weights_path"] = weights_path
|
|
||||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
|
||||||
weights_path,
|
|
||||||
use_local_file=True,
|
|
||||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
|
||||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
|
||||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
|
||||||
)
|
|
||||||
|
|
||||||
server_state["pipe"].unet.to(torch_device)
|
|
||||||
server_state["pipe"].vae.to(torch_device)
|
|
||||||
server_state["pipe"].text_encoder.to(torch_device)
|
|
||||||
|
|
||||||
if st.session_state.defaults.general.enable_attention_slicing:
|
|
||||||
server_state["pipe"].enable_attention_slicing()
|
|
||||||
|
|
||||||
if st.session_state.defaults.general.enable_minimal_memory_usage:
|
|
||||||
server_state["pipe"].enable_minimal_memory_usage()
|
|
||||||
|
|
||||||
print("Tx2Vid Model Loaded")
|
|
||||||
else:
|
|
||||||
print("Tx2Vid Model already Loaded")
|
|
||||||
|
|
||||||
except:
|
|
||||||
#del st.session_state["weights_path"]
|
|
||||||
#del server_state["pipe"]
|
|
||||||
|
|
||||||
st.session_state["weights_path"] = weights_path
|
st.session_state["weights_path"] = weights_path
|
||||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
# if folder "user_data/model_cache/stable-diffusion-v1-4" exists, load the model from there
|
||||||
weights_path,
|
if weights_path == "CompVis/stable-diffusion-v1-4":
|
||||||
use_local_file=True,
|
model_path = os.path.join("user_data", "model_cache", "stable-diffusion-v1-4")
|
||||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
elif weights_path == "hakurei/waifu-diffusion":
|
||||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
model_path = os.path.join("user_data", "model_cache", "waifu-diffusion")
|
||||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
|
||||||
)
|
if not os.path.exists(model_path + "/model_index.json"):
|
||||||
|
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||||
|
weights_path,
|
||||||
|
use_local_file=True,
|
||||||
|
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||||
|
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||||
|
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||||
|
)
|
||||||
|
StableDiffusionPipeline.save_pretrained(server_state["pipe"], model_path)
|
||||||
|
else:
|
||||||
|
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
use_local_file=True,
|
||||||
|
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||||
|
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||||
|
)
|
||||||
|
|
||||||
server_state["pipe"].unet.to(torch_device)
|
server_state["pipe"].unet.to(torch_device)
|
||||||
server_state["pipe"].vae.to(torch_device)
|
server_state["pipe"].vae.to(torch_device)
|
||||||
@ -282,6 +269,8 @@ def load_diffusers_model(weights_path,torch_device):
|
|||||||
server_state["pipe"].enable_minimal_memory_usage()
|
server_state["pipe"].enable_minimal_memory_usage()
|
||||||
|
|
||||||
print("Tx2Vid Model Loaded")
|
print("Tx2Vid Model Loaded")
|
||||||
|
else:
|
||||||
|
print("Tx2Vid Model already Loaded")
|
||||||
except (EnvironmentError, OSError):
|
except (EnvironmentError, OSError):
|
||||||
st.session_state["progress_bar_text"].error(
|
st.session_state["progress_bar_text"].error(
|
||||||
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
|
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
|
||||||
|
Loading…
Reference in New Issue
Block a user