Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
ZeroCool940711 2022-10-01 19:19:04 -07:00
commit bbedcc8e84
7 changed files with 334 additions and 122 deletions

1
.gitignore vendored
View File

@ -64,3 +64,4 @@ condaenv.*.requirements.txt
/gfpgan/* /gfpgan/*
/models/* /models/*
z_version_env.tmp z_version_env.tmp
/user_data/*

View File

@ -294,7 +294,7 @@ img2img:
write_info_files: True write_info_files: True
img2txt: img2txt:
batch_size: 100 batch_size: 420
blip_image_eval_size: 512 blip_image_eval_size: 512
concepts_library: concepts_library:

View File

@ -24,7 +24,21 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd $SCRIPT_DIR cd $SCRIPT_DIR
export PYTHONPATH=$SCRIPT_DIR export PYTHONPATH=$SCRIPT_DIR
MODEL_DIR="${SCRIPT_DIR}/model_cache" if [[ $PUBLIC_KEY ]]
then
mkdir -p ~/.ssh
chmod 700 ~/.ssh
cd ~/.ssh
echo $PUBLIC_KEY >> authorized_keys
chmod 700 -R ~/.ssh
cd /
service ssh start
echo "SSH Service Started"
fi
MODEL_DIR="${SCRIPT_DIR}/user_data/model_cache"
mkdir -p $MODEL_DIR
# Array of model files to pre-download # Array of model files to pre-download
# local filename # local filename
# local path in container (no trailing slash) # local path in container (no trailing slash)
@ -37,6 +51,17 @@ MODEL_FILES=(
'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da' 'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba' 'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba'
'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13' 'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13'
'waifu-diffusion.ckpt models/custom https://huggingface.co/crumb/pruned-waifu-diffusion/resolve/main/model-pruned.ckpt 9b31355f90fea9933847175d4731a033f49f861395addc7e153f480551a24c25'
'trinart.ckpt models/custom https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step95000.ckpt c1799d22a355ba25c9ceeb6e3c91fc61788c8e274b73508ae8a15877c5dbcf63'
'model__base_caption.pth models/blip https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth 96ac8749bd0a568c274ebe302b3a3748ab9be614c737f3d8c529697139174086'
'pytorch_model.bin models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin f1a17cdbe0f36fec524f5cafb1c261ea3bbbc13e346e0f74fc9eb0460dedd0d3'
'config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/config.json 8a09b467700c58138c29d53c605b34ebc69beaadd13274a8a2af8ad2c2f4032a'
'merges.txt models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/merges.txt 9fd691f7c8039210e0fced15865466c65820d09b63988b0174bfe25de299051a'
'preprocessor_config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/preprocessor_config.json 910e70b3956ac9879ebc90b22fb3bc8a75b6a0677814500101a4c072bd7857bd'
'special_tokens_map.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/special_tokens_map.json f8c0d6c39aee3f8431078ef6646567b0aba7f2246e9c54b8b99d55c22b707cbf'
'tokenizer.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/tokenizer.json a83e0809aa4c3af7208b2df632a7a69668c6d48775b3c3fe4e1b1199d1f8b8f4'
'tokenizer_config.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/tokenizer_config.json deef455e52fa5e8151e339add0582e4235f066009601360999d3a9cda83b1129'
'vocab.json models/clip-vit-large-patch14 https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/vocab.json 3f0c4f7d2086b61b38487075278ea9ed04edb53a03cbb045b86c27190fa8fb69'
) )
@ -83,33 +108,30 @@ else
validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]} validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]}
fi fi
done done
mkdir -p ${MODEL_DIR}/stable-diffusion-v1-4
mkdir -p ${MODEL_DIR}/waifu-diffusion
ln -fs ${SCRIPT_DIR}/models/clip-vit-large-patch14/ ${MODEL_DIR}/stable-diffusion-v1-4/tokenizer
ln -fs ${SCRIPT_DIR}/models/clip-vit-large-patch14/ ${MODEL_DIR}/waifu-diffusion/tokenizer
fi fi
# Determine which webserver interface to launch (Streamlit vs Default: Gradio) if [[ -e "${MODEL_DIR}/sd-concepts-library" ]]; then
if [[ ! -z $WEBUI_SCRIPT && $WEBUI_SCRIPT == "webui_streamlit.py" ]]; then cd ${MODEL_DIR}/sd-concepts-library
launch_command="streamlit run scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS" git pull
else else
launch_command="python scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS" cd ${MODEL_DIR}
git clone https://github.com/sd-webui/sd-concepts-library
fi fi
mkdir -p ${SCRIPT_DIR}/models/custom
ln -fs ${MODEL_DIR}/sd-concepts-library/sd-concepts-library ${SCRIPT_DIR}/models/custom
# Start webserver interface echo "export HF_HOME=${MODEL_DIR}" >> ~/.bashrc
launch_message="Starting Stable Diffusion WebUI... ${launch_command}..." echo "export XDG_CACHE_HOME=${MODEL_DIR}" >> ~/.bashrc
if [[ -z $WEBUI_RELAUNCH || $WEBUI_RELAUNCH == "true" ]]; then echo "export TRANSFORMERS_CACHE=${MODEL_DIR}" >> ~/.bashrc
n=0 source ~/.bashrc
while true; do cd $SCRIPT_DIR
echo $launch_message launch_command="streamlit run ${SCRIPT_DIR}/scripts/webui_streamlit.py"
if (( $n > 0 )); then $launch_command
echo "Relaunch count: ${n}"
fi
$launch_command sleep infinity
echo "entrypoint.sh: Process is ending. Relaunching in 0.5s..."
((n++))
sleep 0.5
done
else
echo $launch_message
$launch_command
fi

196
frontend/js/index.js Normal file
View File

@ -0,0 +1,196 @@
window.SD = (() => {
/*
* Painterro is made a field of the SD global object
* To provide convinience when using w() method in css_and_js.py
*/
class PainterroClass {
static isOpen = false;
static async init ({ x, toId }) {
console.log(x)
const originalImage = x[2] === 'Mask' ? x[1]?.image : x[0];
if (window.Painterro === undefined) {
try {
await this.load();
} catch (e) {
SDClass.error(e);
return this.fallback(originalImage);
}
}
if (this.isOpen) {
return this.fallback(originalImage);
}
this.isOpen = true;
let resolveResult;
const paintClient = Painterro({
hiddenTools: ['arrow'],
onHide: () => {
resolveResult?.(null);
},
saveHandler: (image, done) => {
const data = image.asDataURL();
// ensures stable performance even
// when the editor is in interactive mode
SD.clearImageInput(SD.el.get(`#${toId}`));
resolveResult(data);
done(true);
paintClient.hide();
},
});
const result = await new Promise((resolve) => {
resolveResult = resolve;
paintClient.show(originalImage);
});
this.isOpen = false;
return result ? this.success(result) : this.fallback(originalImage);
}
static success (result) { return [result, { image: result, mask: result }] };
static fallback (image) { return [image, { image: image, mask: image }] };
static load () {
return new Promise((resolve, reject) => {
const scriptId = '__painterro-script';
if (document.getElementById(scriptId)) {
reject(new Error('Tried to load painterro script, but script tag already exists.'));
return;
}
const styleId = '__painterro-css-override';
if (!document.getElementById(styleId)) {
/* Ensure Painterro window is always on top */
const style = document.createElement('style');
style.id = styleId;
style.setAttribute('type', 'text/css');
style.appendChild(document.createTextNode(`
.ptro-holder-wrapper {
z-index: 100;
}
`));
document.head.appendChild(style);
}
const script = document.createElement('script');
script.id = scriptId;
script.src = 'https://unpkg.com/painterro@1.2.78/build/painterro.min.js';
script.onload = () => resolve(true);
script.onerror = (e) => {
// remove self on error to enable reattempting load
document.head.removeChild(script);
reject(e);
};
document.head.appendChild(script);
});
}
}
/*
* Turns out caching elements doesn't actually work in gradio
* As elements in tabs might get recreated
*/
class ElementCache {
#el;
constructor () {
this.root = document.querySelector('gradio-app').shadowRoot;
}
get (selector) {
return this.root.querySelector(selector);
}
}
/*
* The main helper class to incapsulate functions
* that change gradio ui functionality
*/
class SDClass {
el = new ElementCache();
Painterro = PainterroClass;
moveImageFromGallery ({ x, fromId, toId }) {
x = x[0];
if (!Array.isArray(x) || x.length === 0) return;
this.clearImageInput(this.el.get(`#${toId}`));
const i = this.#getGallerySelectedIndex(this.el.get(`#${fromId}`));
return [x[i].replace('data:;','data:image/png;')];
}
async copyImageFromGalleryToClipboard ({ x, fromId }) {
x = x[0];
if (!Array.isArray(x) || x.length === 0) return;
const i = this.#getGallerySelectedIndex(this.el.get(`#${fromId}`));
const data = x[i];
const blob = await (await fetch(data.replace('data:;','data:image/png;'))).blob();
const item = new ClipboardItem({'image/png': blob});
await this.copyToClipboard([item]);
}
async copyFullOutput ({ fromId }) {
const textField = this.el.get(`#${fromId} .textfield`);
if (!textField) {
SDclass.error(new Error(`Can't find textfield with the output!`));
}
const value = textField.textContent.replace(/\s+/g,' ').replace(/: /g,':');
await this.copyToClipboard(value)
}
clickFirstVisibleButton({ rowId }) {
const generateButtons = this.el.get(`#${rowId}`).querySelectorAll('.gr-button-primary');
if (!generateButtons) return;
for (let i = 0, arr = [...generateButtons]; i < arr.length; i++) {
const cs = window.getComputedStyle(arr[i]);
if (cs.display !== 'none' && cs.visibility !== 'hidden') {
console.log(arr[i]);
arr[i].click();
break;
}
}
}
async gradioInputToClipboard ({ x }) { return this.copyToClipboard(x[0]); }
async copyToClipboard (value) {
if (!value || typeof value === 'boolean') return;
try {
if (Array.isArray(value) &&
value.length &&
value[0] instanceof ClipboardItem) {
await navigator.clipboard.write(value);
} else {
await navigator.clipboard.writeText(value);
}
} catch (e) {
SDClass.error(e);
}
}
static error (e) {
console.error(e);
if (typeof e === 'string') {
alert(e);
} else if(typeof e === 'object' && Object.hasOwn(e, 'message')) {
alert(e.message);
}
}
clearImageInput (imageEditor) {
imageEditor?.querySelector('.modify-upload button:last-child')?.click();
}
#getGallerySelectedIndex (gallery) {
const selected = gallery.querySelector(`.\\!ring-2`);
return selected ? [...selected.parentNode.children].indexOf(selected) : 0;
}
}
return new SDClass();
})();

View File

@ -5,7 +5,7 @@ import clip
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
import kornia import kornia
import os
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
@ -138,8 +138,12 @@ class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" """Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version) if os.path.exists("models/clip-vit-large-patch14"):
self.transformer = CLIPTextModel.from_pretrained(version) self.tokenizer = CLIPTokenizer.from_pretrained("models/clip-vit-large-patch14")
self.transformer = CLIPTextModel.from_pretrained("models/clip-vit-large-patch14")
else:
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
self.freeze() self.freeze()

View File

@ -61,7 +61,8 @@ from ldm.models.blip import blip_decoder
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 512 blip_image_eval_size = 512
#blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' #blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
server_state["clip_models"] = {}
server_state["preprocesses"] = {}
def load_blip_model(): def load_blip_model():
print("Loading BLIP Model") print("Loading BLIP Model")
@ -219,61 +220,60 @@ def interrogate(image, models):
#print (st.session_state["log_message"]) #print (st.session_state["log_message"])
for model_name in models: for model_name in models:
print(f"Interrogating with {model_name}...") with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='') print(f"Interrogating with {model_name}...")
st.session_state["log_message"].code(f"Interrogating with {model_name}...", language='')
if "clip_model" not in server_state:
#with server_state_lock[server_state["clip_model"]]: if model_name not in server_state["clip_models"]:
if model_name == 'ViT-H-14': if model_name == 'ViT-H-14':
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k') server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='user_data/model_cache/clip')
elif model_name == 'ViT-g-14': elif model_name == 'ViT-g-14':
server_state["clip_model"], _, server_state["preprocess"] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k') server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='user_data/model_cache/clip')
else: else:
server_state["clip_model"], server_state["preprocess"] = clip.load(model_name, device=device) server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = clip.load(model_name, device=device, download_root='user_data/model_cache/clip')
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
server_state["clip_model"] = server_state["clip_model"].cuda().eval()
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
images = server_state["preprocess"](image).unsqueeze(0).cuda()
image_features = server_state["clip_models"][model_name].encode_image(images).float()
with torch.no_grad(): image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = server_state["clip_model"].encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
if st.session_state["defaults"].general.optimized: if st.session_state["defaults"].general.optimized:
clear_cuda() clear_cuda()
ranks = [] ranks = []
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["mediums"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["mediums"]))
ranks.append(batch_rank(server_state["clip_model"], image_features, ["by "+artist for artist in server_state["artists"]])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, ["by "+artist for artist in server_state["artists"]]))
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["trending_list"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["movements"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["flavors"])) ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["genres"])) # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["styles"])) # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["techniques"])) # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["subjects"])) # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["colors"])) # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["moods"])) # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["themes"])) # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"]))
# ranks.append(batch_rank(server_state["clip_model"], image_features, server_state["keywords"])) # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"]))
for i in range(len(ranks)): for i in range(len(ranks)):
confidence_sum = 0 confidence_sum = 0
for ci in range(len(ranks[i])): for ci in range(len(ranks[i])):
confidence_sum += ranks[i][ci][1] confidence_sum += ranks[i][ci][1]
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))): if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
bests[i] = ranks[i] bests[i] = ranks[i]
row = [model_name] row = [model_name]
for r in ranks: for r in ranks:
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r])) row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
table.append(row) table.append(row)
if st.session_state["defaults"].general.optimized: if st.session_state["defaults"].general.optimized:
del server_state["clip_model"] del server_state["clip_models"][model_name]
gc.collect() gc.collect()
# for i in range(len(st.session_state["uploaded_image"])): # for i in range(len(st.session_state["uploaded_image"])):
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame( st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(

View File

@ -230,58 +230,47 @@ def load_diffusers_model(weights_path,torch_device):
try: try:
with server_state_lock["pipe"]: with server_state_lock["pipe"]:
try: if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path: if ("weights_path" in st.session_state) and st.session_state["weights_path"] != weights_path:
if st.session_state["weights_path"] != weights_path: del st.session_state["weights_path"]
del st.session_state["weights_path"]
st.session_state["weights_path"] = weights_path
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
server_state["pipe"].unet.to(torch_device)
server_state["pipe"].vae.to(torch_device)
server_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing:
server_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage:
server_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded")
else:
print("Tx2Vid Model already Loaded")
except:
#del st.session_state["weights_path"]
#del server_state["pipe"]
st.session_state["weights_path"] = weights_path st.session_state["weights_path"] = weights_path
server_state["pipe"] = StableDiffusionPipeline.from_pretrained( # if folder "user_data/model_cache/stable-diffusion-v1-4" exists, load the model from there
weights_path, if weights_path == "CompVis/stable-diffusion-v1-4":
use_local_file=True, model_path = os.path.join("user_data", "model_cache", "stable-diffusion-v1-4")
use_auth_token=st.session_state["defaults"].general.huggingface_token, elif weights_path == "hakurei/waifu-diffusion":
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, model_path = os.path.join("user_data", "model_cache", "waifu-diffusion")
revision="fp16" if not st.session_state['defaults'].general.no_half else None
) if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
StableDiffusionPipeline.save_pretrained(server_state["pipe"], model_path)
else:
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
model_path,
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None
)
server_state["pipe"].unet.to(torch_device) server_state["pipe"].unet.to(torch_device)
server_state["pipe"].vae.to(torch_device) server_state["pipe"].vae.to(torch_device)
server_state["pipe"].text_encoder.to(torch_device) server_state["pipe"].text_encoder.to(torch_device)
if st.session_state.defaults.general.enable_attention_slicing: if st.session_state.defaults.general.enable_attention_slicing:
server_state["pipe"].enable_attention_slicing() server_state["pipe"].enable_attention_slicing()
if st.session_state.defaults.general.enable_minimal_memory_usage: if st.session_state.defaults.general.enable_minimal_memory_usage:
server_state["pipe"].enable_minimal_memory_usage() server_state["pipe"].enable_minimal_memory_usage()
print("Tx2Vid Model Loaded") print("Tx2Vid Model Loaded")
else:
print("Tx2Vid Model already Loaded")
except (EnvironmentError, OSError): except (EnvironmentError, OSError):
st.session_state["progress_bar_text"].error( st.session_state["progress_bar_text"].error(
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token." "You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."