mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-15 06:21:34 +03:00
Merge remote-tracking branch 'origin/dev' into dev
This commit is contained in:
commit
bbedcc8e84
1
.gitignore
vendored
1
.gitignore
vendored
@ -64,3 +64,4 @@ condaenv.*.requirements.txt
|
||||
/gfpgan/*
|
||||
/models/*
|
||||
z_version_env.tmp
|
||||
/user_data/*
|
||||
|
@ -294,7 +294,7 @@ img2img:
|
||||
write_info_files: True
|
||||
|
||||
img2txt:
|
||||
batch_size: 100
|
||||
batch_size: 420
|
||||
blip_image_eval_size: 512
|
||||
|
||||
concepts_library:
|
||||
|
@ -24,7 +24,21 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
cd $SCRIPT_DIR
|
||||
export PYTHONPATH=$SCRIPT_DIR
|
||||
|
||||
MODEL_DIR="${SCRIPT_DIR}/model_cache"
|
||||
if [[ $PUBLIC_KEY ]]
|
||||
then
|
||||
mkdir -p ~/.ssh
|
||||
chmod 700 ~/.ssh
|
||||
cd ~/.ssh
|
||||
echo $PUBLIC_KEY >> authorized_keys
|
||||
chmod 700 -R ~/.ssh
|
||||
cd /
|
||||
service ssh start
|
||||
echo "SSH Service Started"
|
||||
fi
|
||||
|
||||
|
||||
MODEL_DIR="${SCRIPT_DIR}/user_data/model_cache"
|
||||
mkdir -p $MODEL_DIR
|
||||
# Array of model files to pre-download
|
||||
# local filename
|
||||
# local path in container (no trailing slash)
|
||||
@ -37,6 +51,17 @@ MODEL_FILES=(
|
||||
'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
|
||||
'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba'
|
||||
'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13'
|
||||
'waifu-diffusion.ckpt models/custom https://huggingface.co/crumb/pruned-waifu-diffusion/resolve/main/model-pruned.ckpt 9b31355f90fea9933847175d4731a033f49f861395addc7e153f480551a24c25'
|
||||
'trinart.ckpt models/custom https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step95000.ckpt c1799d22a355ba25c9ceeb6e3c91fc61788c8e274b73508ae8a15877c5dbcf63'
|
||||
'model__base_caption.pth models/blip https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth 96ac8749bd0a568c274ebe302b3a3748ab9be614c737f3d8c529697139174086'
|
||||
'pytorch_model.bin models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin f1a17cdbe0f36fec524f5cafb1c261ea3bbbc13e346e0f74fc9eb0460dedd0d3'
|
||||
'config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/config.json 8a09b467700c58138c29d53c605b34ebc69beaadd13274a8a2af8ad2c2f4032a'
|
||||
'merges.txt models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/merges.txt 9fd691f7c8039210e0fced15865466c65820d09b63988b0174bfe25de299051a'
|
||||
'preprocessor_config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/preprocessor_config.json 910e70b3956ac9879ebc90b22fb3bc8a75b6a0677814500101a4c072bd7857bd'
|
||||
'special_tokens_map.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/special_tokens_map.json f8c0d6c39aee3f8431078ef6646567b0aba7f2246e9c54b8b99d55c22b707cbf'
|
||||
'tokenizer.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/tokenizer.json a83e0809aa4c3af7208b2df632a7a69668c6d48775b3c3fe4e1b1199d1f8b8f4'
|
||||
'tokenizer_config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/tokenizer_config.json deef455e52fa5e8151e339add0582e4235f066009601360999d3a9cda83b1129'
|
||||
'vocab.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/vocab.json 3f0c4f7d2086b61b38487075278ea9ed04edb53a03cbb045b86c27190fa8fb69'
|
||||
)
|
||||
|
||||
|
||||
@ -83,33 +108,30 @@ else
|
||||
validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]}
|
||||
fi
|
||||
done
|
||||
mkdir -p ${MODEL_DIR}/stable-diffusion-v1-4
|
||||
mkdir -p ${MODEL_DIR}/waifu-diffusion
|
||||
|
||||
ln -fs ${SCRIPT_DIR}/models/clip-vit-large-patch14/ ${MODEL_DIR}/stable-diffusion-v1-4/tokenizer
|
||||
ln -fs ${SCRIPT_DIR}/models/clip-vit-large-patch14/ ${MODEL_DIR}/waifu-diffusion/tokenizer
|
||||
fi
|
||||
|
||||
# Determine which webserver interface to launch (Streamlit vs Default: Gradio)
|
||||
if [[ ! -z $WEBUI_SCRIPT && $WEBUI_SCRIPT == "webui_streamlit.py" ]]; then
|
||||
launch_command="streamlit run scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
|
||||
if [[ -e "${MODEL_DIR}/sd-concepts-library" ]]; then
|
||||
cd ${MODEL_DIR}/sd-concepts-library
|
||||
git pull
|
||||
else
|
||||
launch_command="python scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
|
||||
cd ${MODEL_DIR}
|
||||
git clone https://github.com/sd-webui/sd-concepts-library
|
||||
fi
|
||||
mkdir -p ${SCRIPT_DIR}/models/custom
|
||||
ln -fs ${MODEL_DIR}/sd-concepts-library/sd-concepts-library ${SCRIPT_DIR}/models/custom
|
||||
|
||||
# Start webserver interface
|
||||
launch_message="Starting Stable Diffusion WebUI... ${launch_command}..."
|
||||
if [[ -z $WEBUI_RELAUNCH || $WEBUI_RELAUNCH == "true" ]]; then
|
||||
n=0
|
||||
while true; do
|
||||
echo $launch_message
|
||||
echo "export HF_HOME=${MODEL_DIR}" >> ~/.bashrc
|
||||
echo "export XDG_CACHE_HOME=${MODEL_DIR}" >> ~/.bashrc
|
||||
echo "export TRANSFORMERS_CACHE=${MODEL_DIR}" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
cd $SCRIPT_DIR
|
||||
launch_command="streamlit run ${SCRIPT_DIR}/scripts/webui_streamlit.py"
|
||||
|
||||
if (( $n > 0 )); then
|
||||
echo "Relaunch count: ${n}"
|
||||
fi
|
||||
$launch_command
|
||||
|
||||
$launch_command
|
||||
|
||||
echo "entrypoint.sh: Process is ending. Relaunching in 0.5s..."
|
||||
((n++))
|
||||
sleep 0.5
|
||||
done
|
||||
else
|
||||
echo $launch_message
|
||||
$launch_command
|
||||
fi
|
||||
sleep infinity
|
||||
|
196
frontend/js/index.js
Normal file
196
frontend/js/index.js
Normal file
@ -0,0 +1,196 @@
|
||||
window.SD = (() => {
|
||||
/*
|
||||
* Painterro is made a field of the SD global object
|
||||
* To provide convinience when using w() method in css_and_js.py
|
||||
*/
|
||||
class PainterroClass {
|
||||
static isOpen = false;
|
||||
static async init ({ x, toId }) {
|
||||
console.log(x)
|
||||
|
||||
const originalImage = x[2] === 'Mask' ? x[1]?.image : x[0];
|
||||
|
||||
if (window.Painterro === undefined) {
|
||||
try {
|
||||
await this.load();
|
||||
} catch (e) {
|
||||
SDClass.error(e);
|
||||
|
||||
return this.fallback(originalImage);
|
||||
}
|
||||
}
|
||||
|
||||
if (this.isOpen) {
|
||||
return this.fallback(originalImage);
|
||||
}
|
||||
this.isOpen = true;
|
||||
|
||||
let resolveResult;
|
||||
const paintClient = Painterro({
|
||||
hiddenTools: ['arrow'],
|
||||
onHide: () => {
|
||||
resolveResult?.(null);
|
||||
},
|
||||
saveHandler: (image, done) => {
|
||||
const data = image.asDataURL();
|
||||
|
||||
// ensures stable performance even
|
||||
// when the editor is in interactive mode
|
||||
SD.clearImageInput(SD.el.get(`#${toId}`));
|
||||
|
||||
resolveResult(data);
|
||||
|
||||
done(true);
|
||||
paintClient.hide();
|
||||
},
|
||||
});
|
||||
|
||||
const result = await new Promise((resolve) => {
|
||||
resolveResult = resolve;
|
||||
paintClient.show(originalImage);
|
||||
});
|
||||
this.isOpen = false;
|
||||
|
||||
return result ? this.success(result) : this.fallback(originalImage);
|
||||
}
|
||||
static success (result) { return [result, { image: result, mask: result }] };
|
||||
static fallback (image) { return [image, { image: image, mask: image }] };
|
||||
static load () {
|
||||
return new Promise((resolve, reject) => {
|
||||
const scriptId = '__painterro-script';
|
||||
if (document.getElementById(scriptId)) {
|
||||
reject(new Error('Tried to load painterro script, but script tag already exists.'));
|
||||
return;
|
||||
}
|
||||
|
||||
const styleId = '__painterro-css-override';
|
||||
if (!document.getElementById(styleId)) {
|
||||
/* Ensure Painterro window is always on top */
|
||||
const style = document.createElement('style');
|
||||
style.id = styleId;
|
||||
style.setAttribute('type', 'text/css');
|
||||
style.appendChild(document.createTextNode(`
|
||||
.ptro-holder-wrapper {
|
||||
z-index: 100;
|
||||
}
|
||||
`));
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
const script = document.createElement('script');
|
||||
script.id = scriptId;
|
||||
script.src = 'https://unpkg.com/painterro@1.2.78/build/painterro.min.js';
|
||||
script.onload = () => resolve(true);
|
||||
script.onerror = (e) => {
|
||||
// remove self on error to enable reattempting load
|
||||
document.head.removeChild(script);
|
||||
reject(e);
|
||||
};
|
||||
document.head.appendChild(script);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Turns out caching elements doesn't actually work in gradio
|
||||
* As elements in tabs might get recreated
|
||||
*/
|
||||
class ElementCache {
|
||||
#el;
|
||||
constructor () {
|
||||
this.root = document.querySelector('gradio-app').shadowRoot;
|
||||
}
|
||||
get (selector) {
|
||||
return this.root.querySelector(selector);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* The main helper class to incapsulate functions
|
||||
* that change gradio ui functionality
|
||||
*/
|
||||
class SDClass {
|
||||
el = new ElementCache();
|
||||
Painterro = PainterroClass;
|
||||
moveImageFromGallery ({ x, fromId, toId }) {
|
||||
x = x[0];
|
||||
if (!Array.isArray(x) || x.length === 0) return;
|
||||
|
||||
this.clearImageInput(this.el.get(`#${toId}`));
|
||||
|
||||
const i = this.#getGallerySelectedIndex(this.el.get(`#${fromId}`));
|
||||
|
||||
return [x[i].replace('data:;','data:image/png;')];
|
||||
}
|
||||
async copyImageFromGalleryToClipboard ({ x, fromId }) {
|
||||
x = x[0];
|
||||
if (!Array.isArray(x) || x.length === 0) return;
|
||||
|
||||
const i = this.#getGallerySelectedIndex(this.el.get(`#${fromId}`));
|
||||
|
||||
const data = x[i];
|
||||
const blob = await (await fetch(data.replace('data:;','data:image/png;'))).blob();
|
||||
const item = new ClipboardItem({'image/png': blob});
|
||||
|
||||
await this.copyToClipboard([item]);
|
||||
}
|
||||
async copyFullOutput ({ fromId }) {
|
||||
const textField = this.el.get(`#${fromId} .textfield`);
|
||||
if (!textField) {
|
||||
SDclass.error(new Error(`Can't find textfield with the output!`));
|
||||
}
|
||||
|
||||
const value = textField.textContent.replace(/\s+/g,' ').replace(/: /g,':');
|
||||
|
||||
await this.copyToClipboard(value)
|
||||
}
|
||||
clickFirstVisibleButton({ rowId }) {
|
||||
const generateButtons = this.el.get(`#${rowId}`).querySelectorAll('.gr-button-primary');
|
||||
|
||||
if (!generateButtons) return;
|
||||
|
||||
for (let i = 0, arr = [...generateButtons]; i < arr.length; i++) {
|
||||
const cs = window.getComputedStyle(arr[i]);
|
||||
|
||||
if (cs.display !== 'none' && cs.visibility !== 'hidden') {
|
||||
console.log(arr[i]);
|
||||
|
||||
arr[i].click();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
async gradioInputToClipboard ({ x }) { return this.copyToClipboard(x[0]); }
|
||||
async copyToClipboard (value) {
|
||||
if (!value || typeof value === 'boolean') return;
|
||||
try {
|
||||
if (Array.isArray(value) &&
|
||||
value.length &&
|
||||
value[0] instanceof ClipboardItem) {
|
||||
await navigator.clipboard.write(value);
|
||||
} else {
|
||||
await navigator.clipboard.writeText(value);
|
||||
}
|
||||
} catch (e) {
|
||||
SDClass.error(e);
|
||||
}
|
||||
}
|
||||
static error (e) {
|
||||
console.error(e);
|
||||
if (typeof e === 'string') {
|
||||
alert(e);
|
||||
} else if(typeof e === 'object' && Object.hasOwn(e, 'message')) {
|
||||
alert(e.message);
|
||||
}
|
||||
}
|
||||
clearImageInput (imageEditor) {
|
||||
imageEditor?.querySelector('.modify-upload button:last-child')?.click();
|
||||
}
|
||||
#getGallerySelectedIndex (gallery) {
|
||||
const selected = gallery.querySelector(`.\\!ring-2`);
|
||||
return selected ? [...selected.parentNode.children].indexOf(selected) : 0;
|
||||
}
|
||||
}
|
||||
|
||||
return new SDClass();
|
||||
})();
|
@ -5,7 +5,7 @@ import clip
|
||||
from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
|
||||
import os
|
||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
|
||||
|
||||
@ -138,8 +138,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
if os.path.exists("models/clip-vit-large-patch14"):
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("models/clip-vit-large-patch14")
|
||||
self.transformer = CLIPTextModel.from_pretrained("models/clip-vit-large-patch14")
|
||||
else:
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.freeze()
|
||||
|
@ -61,7 +61,8 @@ from ldm.models.blip import blip_decoder
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
blip_image_eval_size = 512
|
||||
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
||||
|
||||
server_state["clip_models"] = {}
|
||||
server_state["preprocesses"] = {}
|
||||
|
||||
def load_blip_model():
|
||||
print("Loading BLIP Model")
|
||||
@ -219,61 +220,60 @@ def interrogate(image, models):
|
||||
#print (st.session_state["log_message"])
|
||||
|
||||
for model_name in models:
|
||||
print(f"Interrogating with {model_name}...")
|
||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||
|
||||
if "clip_model" not in server_state:
|
||||
#with server_state_lock[server_state["clip_model"]]:
|
||||
if model_name == 'ViT-H-14':
|
||||
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k')
|
||||
elif model_name == 'ViT-g-14':
|
||||
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k')
|
||||
else:
|
||||
server_state["clip_model"], server_state["preprocess"] = clip.load(model_name, device=device)
|
||||
|
||||
server_state["clip_model"] = server_state["clip_model"].cuda().eval()
|
||||
|
||||
images = server_state["preprocess"](image).unsqueeze(0).cuda()
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
|
||||
print(f"Interrogating with {model_name}...")
|
||||
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
|
||||
|
||||
if model_name not in server_state["clip_models"]:
|
||||
if model_name == 'ViT-H-14':
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='user_data/model_cache/clip')
|
||||
elif model_name == 'ViT-g-14':
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='user_data/model_cache/clip')
|
||||
else:
|
||||
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = clip.load(model_name, device=device, download_root='user_data/model_cache/clip')
|
||||
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
|
||||
|
||||
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
|
||||
|
||||
|
||||
image_features = server_state["clip_models"][model_name].encode_image(images).float()
|
||||
|
||||
with torch.no_grad():
|
||||
image_features = server_state["clip_model"].encode_image(images).float()
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
clear_cuda()
|
||||
|
||||
ranks = []
|
||||
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["mediums"]))
|
||||
ranks.append(batch_rank(server_state["clip_model"], image_features, ["by "+artist for artist in server_state["artists"]]))
|
||||
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["trending_list"]))
|
||||
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["movements"]))
|
||||
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["flavors"]))
|
||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["genres"]))
|
||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["styles"]))
|
||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["techniques"]))
|
||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["subjects"]))
|
||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["colors"]))
|
||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["moods"]))
|
||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["themes"]))
|
||||
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["keywords"]))
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
clear_cuda()
|
||||
|
||||
ranks = []
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["mediums"]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, ["by "+artist for artist in server_state["artists"]]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"]))
|
||||
|
||||
for i in range(len(ranks)):
|
||||
confidence_sum = 0
|
||||
for ci in range(len(ranks[i])):
|
||||
confidence_sum += ranks[i][ci][1]
|
||||
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
|
||||
bests[i] = ranks[i]
|
||||
for i in range(len(ranks)):
|
||||
confidence_sum = 0
|
||||
for ci in range(len(ranks[i])):
|
||||
confidence_sum += ranks[i][ci][1]
|
||||
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
|
||||
bests[i] = ranks[i]
|
||||
|
||||
row = [model_name]
|
||||
for r in ranks:
|
||||
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
|
||||
row = [model_name]
|
||||
for r in ranks:
|
||||
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
|
||||
|
||||
table.append(row)
|
||||
table.append(row)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
del server_state["clip_model"]
|
||||
gc.collect()
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
del server_state["clip_models"][model_name]
|
||||
gc.collect()
|
||||
|
||||
# for i in range(len(st.session_state["uploaded_image"])):
|
||||
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
|
||||
|
@ -230,58 +230,47 @@ def load_diffusers_model(weights_path,torch_device):
|
||||
|
||||
try:
|
||||
with server_state_lock["pipe"]:
|
||||
try:
|
||||
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
|
||||
if st.session_state["weights_path"] != weights_path:
|
||||
del st.session_state["weights_path"]
|
||||
|
||||
st.session_state["weights_path"] = weights_path
|
||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||
)
|
||||
|
||||
server_state["pipe"].unet.to(torch_device)
|
||||
server_state["pipe"].vae.to(torch_device)
|
||||
server_state["pipe"].text_encoder.to(torch_device)
|
||||
|
||||
if st.session_state.defaults.general.enable_attention_slicing:
|
||||
server_state["pipe"].enable_attention_slicing()
|
||||
|
||||
if st.session_state.defaults.general.enable_minimal_memory_usage:
|
||||
server_state["pipe"].enable_minimal_memory_usage()
|
||||
|
||||
print("Tx2Vid Model Loaded")
|
||||
else:
|
||||
print("Tx2Vid Model already Loaded")
|
||||
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
|
||||
if ("weights_path" in st.session_state) and st.session_state["weights_path"] != weights_path:
|
||||
del st.session_state["weights_path"]
|
||||
|
||||
except:
|
||||
#del st.session_state["weights_path"]
|
||||
#del server_state["pipe"]
|
||||
|
||||
st.session_state["weights_path"] = weights_path
|
||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||
)
|
||||
|
||||
# if folder "user_data/model_cache/stable-diffusion-v1-4" exists, load the model from there
|
||||
if weights_path == "CompVis/stable-diffusion-v1-4":
|
||||
model_path = os.path.join("user_data", "model_cache", "stable-diffusion-v1-4")
|
||||
elif weights_path == "hakurei/waifu-diffusion":
|
||||
model_path = os.path.join("user_data", "model_cache", "waifu-diffusion")
|
||||
|
||||
if not os.path.exists(model_path + "/model_index.json"):
|
||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||
)
|
||||
StableDiffusionPipeline.save_pretrained(server_state["pipe"], model_path)
|
||||
else:
|
||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
use_local_file=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||
)
|
||||
|
||||
server_state["pipe"].unet.to(torch_device)
|
||||
server_state["pipe"].vae.to(torch_device)
|
||||
server_state["pipe"].text_encoder.to(torch_device)
|
||||
|
||||
|
||||
if st.session_state.defaults.general.enable_attention_slicing:
|
||||
server_state["pipe"].enable_attention_slicing()
|
||||
|
||||
if st.session_state.defaults.general.enable_minimal_memory_usage:
|
||||
|
||||
if st.session_state.defaults.general.enable_minimal_memory_usage:
|
||||
server_state["pipe"].enable_minimal_memory_usage()
|
||||
|
||||
print("Tx2Vid Model Loaded")
|
||||
|
||||
print("Tx2Vid Model Loaded")
|
||||
else:
|
||||
print("Tx2Vid Model already Loaded")
|
||||
except (EnvironmentError, OSError):
|
||||
st.session_state["progress_bar_text"].error(
|
||||
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
|
||||
|
Loading…
Reference in New Issue
Block a user