diff --git a/.eslintignore b/.eslintignore
new file mode 100644
index 00000000..1cfd9487
--- /dev/null
+++ b/.eslintignore
@@ -0,0 +1,4 @@
+extensions
+extensions-disabled
+repositories
+venv
\ No newline at end of file
diff --git a/.eslintrc.js b/.eslintrc.js
new file mode 100644
index 00000000..78275554
--- /dev/null
+++ b/.eslintrc.js
@@ -0,0 +1,89 @@
+module.exports = {
+ env: {
+ browser: true,
+ es2021: true,
+ },
+ extends: "eslint:recommended",
+ parserOptions: {
+ ecmaVersion: "latest",
+ },
+ rules: {
+ "arrow-spacing": "error",
+ "block-spacing": "error",
+ "brace-style": "error",
+ "comma-dangle": ["error", "only-multiline"],
+ "comma-spacing": "error",
+ "comma-style": ["error", "last"],
+ "curly": ["error", "multi-line", "consistent"],
+ "eol-last": "error",
+ "func-call-spacing": "error",
+ "function-call-argument-newline": ["error", "consistent"],
+ "function-paren-newline": ["error", "consistent"],
+ "indent": ["error", 4],
+ "key-spacing": "error",
+ "keyword-spacing": "error",
+ "linebreak-style": ["error", "unix"],
+ "no-extra-semi": "error",
+ "no-mixed-spaces-and-tabs": "error",
+ "no-trailing-spaces": "error",
+ "no-whitespace-before-property": "error",
+ "object-curly-newline": ["error", {consistent: true, multiline: true}],
+ "quote-props": ["error", "consistent-as-needed"],
+ "semi": ["error", "always"],
+ "semi-spacing": "error",
+ "semi-style": ["error", "last"],
+ "space-before-blocks": "error",
+ "space-before-function-paren": ["error", "never"],
+ "space-in-parens": ["error", "never"],
+ "space-infix-ops": "error",
+ "space-unary-ops": "error",
+ "switch-colon-spacing": "error",
+ "template-curly-spacing": ["error", "never"],
+ "unicode-bom": "error",
+ "no-multi-spaces": "error",
+ "object-curly-spacing": ["error", "never"],
+ "operator-linebreak": ["error", "after"],
+ "no-unused-vars": "off",
+ "no-redeclare": "off",
+ },
+ globals: {
+ // this file
+ module: "writable",
+ //script.js
+ gradioApp: "writable",
+ onUiLoaded: "writable",
+ onUiUpdate: "writable",
+ onOptionsChanged: "writable",
+ uiCurrentTab: "writable",
+ uiElementIsVisible: "writable",
+ executeCallbacks: "writable",
+ //ui.js
+ opts: "writable",
+ all_gallery_buttons: "writable",
+ selected_gallery_button: "writable",
+ selected_gallery_index: "writable",
+ args_to_array: "writable",
+ switch_to_txt2img: "writable",
+ switch_to_img2img_tab: "writable",
+ switch_to_img2img: "writable",
+ switch_to_sketch: "writable",
+ switch_to_inpaint: "writable",
+ switch_to_inpaint_sketch: "writable",
+ switch_to_extras: "writable",
+ get_tab_index: "writable",
+ create_submit_args: "writable",
+ restart_reload: "writable",
+ updateInput: "writable",
+ //extraNetworks.js
+ requestGet: "writable",
+ popup: "writable",
+ // from python
+ localization: "writable",
+ // progrssbar.js
+ randomId: "writable",
+ requestProgress: "writable",
+ // imageviewer.js
+ modalPrevImage: "writable",
+ modalNextImage: "writable",
+ }
+};
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index 7d435297..3a8b9953 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -47,6 +47,15 @@ body:
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations:
required: true
+ - type: dropdown
+ id: py-version
+ attributes:
+ label: What Python version are you running on ?
+ multiple: false
+ options:
+ - Python 3.10.x
+ - Python 3.11.x (above, no supported yet)
+ - Python 3.9.x (below, no recommended)
- type: dropdown
id: platforms
attributes:
@@ -59,6 +68,18 @@ body:
- iOS
- Android
- Other/Cloud
+ - type: dropdown
+ id: device
+ attributes:
+ label: What device are you running WebUI on?
+ multiple: true
+ options:
+ - Nvidia GPUs (RTX 20 above)
+ - Nvidia GPUs (GTX 16 below)
+ - AMD GPUs (RX 6000 above)
+ - AMD GPUs (RX 5000 below)
+ - CPU
+ - Other GPUs
- type: dropdown
id: browsers
attributes:
diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml
index a168be5b..7b7219fd 100644
--- a/.github/workflows/on_pull_request.yaml
+++ b/.github/workflows/on_pull_request.yaml
@@ -1,39 +1,34 @@
-# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
name: Run Linting/Formatting on Pull Requests
on:
- push
- pull_request
- # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
- # if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
- # pull_request:
- # branches:
- # - master
- # branches-ignore:
- # - development
jobs:
- lint:
+ lint-python:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- - name: Set up Python 3.10
- uses: actions/setup-python@v4
+ - uses: actions/setup-python@v4
with:
- python-version: 3.10.6
- cache: pip
- cache-dependency-path: |
- **/requirements*txt
- - name: Install PyLint
- run: |
- python -m pip install --upgrade pip
- pip install pylint
- # This lets PyLint check to see if it can resolve imports
- - name: Install dependencies
- run: |
- export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
- python launch.py
- - name: Analysing the code with pylint
- run: |
- pylint $(git ls-files '*.py')
+ python-version: 3.11
+ # NB: there's no cache: pip here since we're not installing anything
+ # from the requirements.txt file(s) in the repository; it's faster
+ # not to have GHA download an (at the time of writing) 4 GB cache
+ # of PyTorch and other dependencies.
+ - name: Install Ruff
+ run: pip install ruff==0.0.265
+ - name: Run Ruff
+ run: ruff .
+ lint-js:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v3
+ - name: Install Node.js
+ uses: actions/setup-node@v3
+ with:
+ node-version: 18
+ - run: npm i --ci
+ - run: npm run lint
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index 9a0b8d22..0708398b 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -17,8 +17,14 @@ jobs:
cache: pip
cache-dependency-path: |
**/requirements*txt
+ launch.py
- name: Run tests
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: "1"
+ PIP_PROGRESS_BAR: "off"
+ TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
+ WEBUI_LAUNCH_LIVE_OUTPUT: "1"
- name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3
if: always()
diff --git a/.gitignore b/.gitignore
index 7328401f..46654d83 100644
--- a/.gitignore
+++ b/.gitignore
@@ -34,3 +34,5 @@ notification.mp3
/test/stderr.txt
/cache.json*
/config_states/
+/node_modules
+/package-lock.json
\ No newline at end of file
diff --git a/README.md b/README.md
index 67a1a83a..79089e52 100644
--- a/README.md
+++ b/README.md
@@ -99,6 +99,12 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
+### Installation on Windows 10/11 with NVidia-GPUs using release package
+1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
+2. Run `update.bat`.
+3. Run `run.bat`.
+> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
+
### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win).
@@ -158,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
+- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index bc11cc6e..7f450086 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -88,7 +88,7 @@ class LDSR:
x_t = None
logs = None
- for n in range(n_runs):
+ for _ in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
@@ -110,7 +110,6 @@ class LDSR:
diffusion_steps = int(steps)
eta = 1.0
- down_sample_method = 'Lanczos'
gc.collect()
if torch.cuda.is_available:
@@ -131,11 +130,11 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
-
+
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
-
+
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
@@ -158,7 +157,7 @@ class LDSR:
def get_cond(selected_path):
- example = dict()
+ example = {}
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
- log = dict()
+ log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
- except:
+ except Exception:
pass
log["sample"] = x_sample
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index da19cff1..fbbe9005 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks
-import sd_hijack_autoencoder, sd_hijack_ddpm_v1
+import sd_hijack_autoencoder # noqa: F401
+import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler):
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
index 8e03c7f8..81c5101b 100644
--- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -1,16 +1,21 @@
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
-
+import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
+
+from torch.optim.lr_scheduler import LambdaLR
+
+from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
import ldm.models.autoencoder
+from packaging import version
class VQModel(pl.LightningModule):
def __init__(self,
@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed,
embed_dim,
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list()):
+ def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
- for ik in ignore_keys:
+ for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+ self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
- log = dict()
+ log = {}
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ if x.shape[1] > 3:
+ xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
@@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant)
return dec
-setattr(ldm.models.autoencoder, "VQModel", VQModel)
-setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
+ldm.models.autoencoder.VQModel = VQModel
+ldm.models.autoencoder.VQModelInterface = VQModelInterface
diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
index 5c0488e5..631a08ef 100644
--- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
+++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
- for ik in ignore_keys:
+ for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
+ log = {}
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
log["inputs"] = x
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
x_start = x[:n_row]
for t in range(self.num_timesteps):
@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
+ except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
@@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
- self.bbox_tokenizer = None
+ self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
@@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
+ if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
- def rescale_bbox(bbox):
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
- return x0, y0, w, h
-
- return [rescale_bbox(b) for b in bboxes]
-
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
@@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
- assert not return_ids
+ assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
return img, intermediates
@torch.no_grad()
@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
if return_intermediates:
return img, intermediates
@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
use_ddim = ddim_steps is not None
- log = dict()
+ log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
if plot_diffusion_rows:
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
if inpaint:
# make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+ super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+ logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img
return logs
-setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
-setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
-setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
-setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
+ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
+ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
+ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
+ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index b5d0c98f..1308c48b 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -1,4 +1,3 @@
-import glob
import os
import re
import torch
@@ -177,7 +176,7 @@ def load_lora(name, filename):
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+ raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -189,7 +188,7 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+ raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@@ -207,7 +206,7 @@ def load_loras(names, multipliers=None):
loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
- if any([x is None for x in loras_on_disk]):
+ if any(x is None for x in loras_on_disk):
list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
@@ -314,7 +313,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}')
- setattr(self, "lora_current_names", wanted_names)
+ self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward):
@@ -348,8 +347,8 @@ def lora_forward(module, input, original_forward):
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+ self.lora_current_names = ()
+ self.lora_weights_backup = None
def lora_Linear_forward(self, input):
@@ -428,7 +427,7 @@ def infotext_pasted(infotext, params):
added = []
- for k, v in params.items():
+ for k in params:
if not k.startswith("AddNet Model "):
continue
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 060bda05..728e0b86 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
}))
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index c7fd5739..cc2cbc6a 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -10,10 +10,9 @@ from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import devices, modelloader
+from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net
from modules.shared import opts
-from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
- for k, v in model.named_parameters():
+ for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
+
+
+def on_ui_settings():
+ import gradio as gr
+ from modules import shared
+
+ shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
+ shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py
index 43ca8d36..b51a8806 100644
--- a/extensions-builtin/ScuNET/scunet_model_arch.py
+++ b/extensions-builtin/ScuNET/scunet_model_arch.py
@@ -61,7 +61,9 @@ class WMSA(nn.Module):
Returns:
output: tensor shape [b h w c]
"""
- if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+ if self.type != 'W':
+ x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1)
w_windows = x.size(2)
@@ -85,8 +87,9 @@ class WMSA(nn.Module):
output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
- if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
- dims=(1, 2))
+ if self.type != 'W':
+ output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
+
return output
def relative_embedding(self):
@@ -262,4 +265,4 @@ class SCUNet(nn.Module):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
\ No newline at end of file
+ nn.init.constant_(m.weight, 1.0)
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index e8783bca..0ba50487 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,4 +1,3 @@
-import contextlib
import os
import numpy as np
@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts, state
+from modules.shared import opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
img = upscale(img, model)
try:
torch.cuda.empty_cache()
- except:
+ except Exception:
pass
return img
@@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
-
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
index 863f42db..93b93274 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch.py
@@ -644,7 +644,7 @@ class SwinIR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
@@ -805,7 +805,7 @@ class SwinIR(nn.Module):
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
-
+
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
@@ -844,7 +844,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
index 0e28ae6e..dad22cca 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
@@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
+ pretrained_window_size=(0, 0)):
super().__init__()
self.dim = dim
@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
-
+
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
- return attn_mask
+ return attn_mask
def forward(self, x, x_size):
H, W = x_size
@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
+
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
@@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
- return flops
+ return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
@@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
-
+
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
- return flops
+ return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
@@ -531,7 +531,7 @@ class RSTB(nn.Module):
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
@@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
-
+
class Upsample_hf(nn.Sequential):
"""Upsample module.
@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
+ super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
-
-
+
+
class Swin2SR(nn.Module):
r""" Swin2SR
@@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True,
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
@@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
)
self.layers.append(layer)
-
+
if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers):
@@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
)
self.layers_hf.append(layer)
-
+
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
@@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
+ nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
@@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
@@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
x = self.patch_unembed(x, x_size)
return x
-
+
def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
@@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
- return x
+ return x
def forward(self, x):
H, W = x.shape[2:]
@@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before))
-
+
x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf)
@@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
-
+
x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
+
elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
+
else:
return x[:, :, :H*self.upscale, :W*self.upscale]
@@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
@@ -1014,4 +1014,4 @@ if __name__ == '__main__':
x = torch.randn((1, 3, height, width))
x = model(x)
- print(x.shape)
\ No newline at end of file
+ print(x.shape)
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
index 5c7a836a..114cf94c 100644
--- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
+++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
@@ -4,39 +4,39 @@
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(textArea, counterElt) {
- var counts = {};
- (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
- counts[bracket] = (counts[bracket] || 0) + 1;
- });
- var errors = [];
+ var counts = {};
+ (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
+ counts[bracket] = (counts[bracket] || 0) + 1;
+ });
+ var errors = [];
- function checkPair(open, close, kind) {
- if (counts[open] !== counts[close]) {
- errors.push(
- `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
- );
+ function checkPair(open, close, kind) {
+ if (counts[open] !== counts[close]) {
+ errors.push(
+ `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
+ );
+ }
}
- }
- checkPair('(', ')', 'round brackets');
- checkPair('[', ']', 'square brackets');
- checkPair('{', '}', 'curly brackets');
- counterElt.title = errors.join('\n');
- counterElt.classList.toggle('error', errors.length !== 0);
+ checkPair('(', ')', 'round brackets');
+ checkPair('[', ']', 'square brackets');
+ checkPair('{', '}', 'curly brackets');
+ counterElt.title = errors.join('\n');
+ counterElt.classList.toggle('error', errors.length !== 0);
}
function setupBracketChecking(id_prompt, id_counter) {
- var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
- var counter = gradioApp().getElementById(id_counter)
+ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
+ var counter = gradioApp().getElementById(id_counter);
- if (textarea && counter) {
- textarea.addEventListener("input", () => checkBrackets(textarea, counter));
- }
+ if (textarea && counter) {
+ textarea.addEventListener("input", () => checkBrackets(textarea, counter));
+ }
}
-onUiLoaded(function () {
- setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
- setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
- setupBracketChecking('img2img_prompt', 'img2img_token_counter');
- setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
+onUiLoaded(function() {
+ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
+ setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
+ setupBracketChecking('img2img_prompt', 'img2img_token_counter');
+ setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
});
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index 1d546217..6853b14f 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -6,7 +6,7 @@
- {search_term}
+ {search_term}
{name}
{description}
diff --git a/html/licenses.html b/html/licenses.html
index bc995aa0..ef6f2c0a 100644
--- a/html/licenses.html
+++ b/html/licenses.html
@@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
+
+
+
+Tiny AutoEncoder for Stable Diffusion option for live previews
+
+MIT License
+
+Copyright (c) 2023 Ollin Boer Bohan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js
index 5160081d..1c08a1a9 100644
--- a/javascript/aspectRatioOverlay.js
+++ b/javascript/aspectRatioOverlay.js
@@ -1,111 +1,113 @@
-
-let currentWidth = null;
-let currentHeight = null;
-let arFrameTimeout = setTimeout(function(){},0);
-
-function dimensionChange(e, is_width, is_height){
-
- if(is_width){
- currentWidth = e.target.value*1.0
- }
- if(is_height){
- currentHeight = e.target.value*1.0
- }
-
- var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
-
- if(!inImg2img){
- return;
- }
-
- var targetElement = null;
-
- var tabIndex = get_tab_index('mode_img2img')
- if(tabIndex == 0){ // img2img
- targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
- } else if(tabIndex == 1){ //Sketch
- targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
- } else if(tabIndex == 2){ // Inpaint
- targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
- } else if(tabIndex == 3){ // Inpaint sketch
- targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
- }
-
-
- if(targetElement){
-
- var arPreviewRect = gradioApp().querySelector('#imageARPreview');
- if(!arPreviewRect){
- arPreviewRect = document.createElement('div')
- arPreviewRect.id = "imageARPreview";
- gradioApp().appendChild(arPreviewRect)
- }
-
-
-
- var viewportOffset = targetElement.getBoundingClientRect();
-
- var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
-
- var scaledx = targetElement.naturalWidth*viewportscale
- var scaledy = targetElement.naturalHeight*viewportscale
-
- var cleintRectTop = (viewportOffset.top+window.scrollY)
- var cleintRectLeft = (viewportOffset.left+window.scrollX)
- var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
- var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
-
- var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
- var arscaledx = currentWidth*arscale
- var arscaledy = currentHeight*arscale
-
- var arRectTop = cleintRectCentreY-(arscaledy/2)
- var arRectLeft = cleintRectCentreX-(arscaledx/2)
- var arRectWidth = arscaledx
- var arRectHeight = arscaledy
-
- arPreviewRect.style.top = arRectTop+'px';
- arPreviewRect.style.left = arRectLeft+'px';
- arPreviewRect.style.width = arRectWidth+'px';
- arPreviewRect.style.height = arRectHeight+'px';
-
- clearTimeout(arFrameTimeout);
- arFrameTimeout = setTimeout(function(){
- arPreviewRect.style.display = 'none';
- },2000);
-
- arPreviewRect.style.display = 'block';
-
- }
-
-}
-
-
-onUiUpdate(function(){
- var arPreviewRect = gradioApp().querySelector('#imageARPreview');
- if(arPreviewRect){
- arPreviewRect.style.display = 'none';
- }
- var tabImg2img = gradioApp().querySelector("#tab_img2img");
- if (tabImg2img) {
- var inImg2img = tabImg2img.style.display == "block";
- if(inImg2img){
- let inputs = gradioApp().querySelectorAll('input');
- inputs.forEach(function(e){
- var is_width = e.parentElement.id == "img2img_width"
- var is_height = e.parentElement.id == "img2img_height"
-
- if((is_width || is_height) && !e.classList.contains('scrollwatch')){
- e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
- e.classList.add('scrollwatch')
- }
- if(is_width){
- currentWidth = e.value*1.0
- }
- if(is_height){
- currentHeight = e.value*1.0
- }
- })
- }
- }
-});
+
+let currentWidth = null;
+let currentHeight = null;
+let arFrameTimeout = setTimeout(function() {}, 0);
+
+function dimensionChange(e, is_width, is_height) {
+
+ if (is_width) {
+ currentWidth = e.target.value * 1.0;
+ }
+ if (is_height) {
+ currentHeight = e.target.value * 1.0;
+ }
+
+ var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
+
+ if (!inImg2img) {
+ return;
+ }
+
+ var targetElement = null;
+
+ var tabIndex = get_tab_index('mode_img2img');
+ if (tabIndex == 0) { // img2img
+ targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
+ } else if (tabIndex == 1) { //Sketch
+ targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
+ } else if (tabIndex == 2) { // Inpaint
+ targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
+ } else if (tabIndex == 3) { // Inpaint sketch
+ targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
+ }
+
+
+ if (targetElement) {
+
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
+ if (!arPreviewRect) {
+ arPreviewRect = document.createElement('div');
+ arPreviewRect.id = "imageARPreview";
+ gradioApp().appendChild(arPreviewRect);
+ }
+
+
+
+ var viewportOffset = targetElement.getBoundingClientRect();
+
+ var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);
+
+ var scaledx = targetElement.naturalWidth * viewportscale;
+ var scaledy = targetElement.naturalHeight * viewportscale;
+
+ var cleintRectTop = (viewportOffset.top + window.scrollY);
+ var cleintRectLeft = (viewportOffset.left + window.scrollX);
+ var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
+ var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
+
+ var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
+ var arscaledx = currentWidth * arscale;
+ var arscaledy = currentHeight * arscale;
+
+ var arRectTop = cleintRectCentreY - (arscaledy / 2);
+ var arRectLeft = cleintRectCentreX - (arscaledx / 2);
+ var arRectWidth = arscaledx;
+ var arRectHeight = arscaledy;
+
+ arPreviewRect.style.top = arRectTop + 'px';
+ arPreviewRect.style.left = arRectLeft + 'px';
+ arPreviewRect.style.width = arRectWidth + 'px';
+ arPreviewRect.style.height = arRectHeight + 'px';
+
+ clearTimeout(arFrameTimeout);
+ arFrameTimeout = setTimeout(function() {
+ arPreviewRect.style.display = 'none';
+ }, 2000);
+
+ arPreviewRect.style.display = 'block';
+
+ }
+
+}
+
+
+onUiUpdate(function() {
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
+ if (arPreviewRect) {
+ arPreviewRect.style.display = 'none';
+ }
+ var tabImg2img = gradioApp().querySelector("#tab_img2img");
+ if (tabImg2img) {
+ var inImg2img = tabImg2img.style.display == "block";
+ if (inImg2img) {
+ let inputs = gradioApp().querySelectorAll('input');
+ inputs.forEach(function(e) {
+ var is_width = e.parentElement.id == "img2img_width";
+ var is_height = e.parentElement.id == "img2img_height";
+
+ if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {
+ e.addEventListener('input', function(e) {
+ dimensionChange(e, is_width, is_height);
+ });
+ e.classList.add('scrollwatch');
+ }
+ if (is_width) {
+ currentWidth = e.value * 1.0;
+ }
+ if (is_height) {
+ currentHeight = e.value * 1.0;
+ }
+ });
+ }
+ }
+});
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js
index b2bdf053..f14af1d4 100644
--- a/javascript/contextMenus.js
+++ b/javascript/contextMenus.js
@@ -1,166 +1,172 @@
-
-contextMenuInit = function(){
- let eventListenerApplied=false;
- let menuSpecs = new Map();
-
- const uid = function(){
- return Date.now().toString(36) + Math.random().toString(36).substring(2);
- }
-
- function showContextMenu(event,element,menuEntries){
- let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
- let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
-
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
-
- let baseStyle = window.getComputedStyle(uiCurrentTab)
-
- const contextMenu = document.createElement('nav')
- contextMenu.id = "context-menu"
- contextMenu.style.background = baseStyle.background
- contextMenu.style.color = baseStyle.color
- contextMenu.style.fontFamily = baseStyle.fontFamily
- contextMenu.style.top = posy+'px'
- contextMenu.style.left = posx+'px'
-
-
-
- const contextMenuList = document.createElement('ul')
- contextMenuList.className = 'context-menu-items';
- contextMenu.append(contextMenuList);
-
- menuEntries.forEach(function(entry){
- let contextMenuEntry = document.createElement('a')
- contextMenuEntry.innerHTML = entry['name']
- contextMenuEntry.addEventListener("click", function() {
- entry['func']();
- })
- contextMenuList.append(contextMenuEntry);
-
- })
-
- gradioApp().appendChild(contextMenu)
-
- let menuWidth = contextMenu.offsetWidth + 4;
- let menuHeight = contextMenu.offsetHeight + 4;
-
- let windowWidth = window.innerWidth;
- let windowHeight = window.innerHeight;
-
- if ( (windowWidth - posx) < menuWidth ) {
- contextMenu.style.left = windowWidth - menuWidth + "px";
- }
-
- if ( (windowHeight - posy) < menuHeight ) {
- contextMenu.style.top = windowHeight - menuHeight + "px";
- }
-
- }
-
- function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
-
- var currentItems = menuSpecs.get(targetElementSelector)
-
- if(!currentItems){
- currentItems = []
- menuSpecs.set(targetElementSelector,currentItems);
- }
- let newItem = {'id':targetElementSelector+'_'+uid(),
- 'name':entryName,
- 'func':entryFunction,
- 'isNew':true}
-
- currentItems.push(newItem)
- return newItem['id']
- }
-
- function removeContextMenuOption(uid){
- menuSpecs.forEach(function(v) {
- let index = -1
- v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
- if(index>=0){
- v.splice(index, 1);
- }
- })
- }
-
- function addContextMenuEventListener(){
- if(eventListenerApplied){
- return;
- }
- gradioApp().addEventListener("click", function(e) {
- if(! e.isTrusted){
- return
- }
-
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
- });
- gradioApp().addEventListener("contextmenu", function(e) {
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
- menuSpecs.forEach(function(v,k) {
- if(e.composedPath()[0].matches(k)){
- showContextMenu(e,e.composedPath()[0],v)
- e.preventDefault()
- }
- })
- });
- eventListenerApplied=true
-
- }
-
- return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
-}
-
-initResponse = contextMenuInit();
-appendContextMenuOption = initResponse[0];
-removeContextMenuOption = initResponse[1];
-addContextMenuEventListener = initResponse[2];
-
-(function(){
- //Start example Context Menu Items
- let generateOnRepeat = function(genbuttonid,interruptbuttonid){
- let genbutton = gradioApp().querySelector(genbuttonid);
- let interruptbutton = gradioApp().querySelector(interruptbuttonid);
- if(!interruptbutton.offsetParent){
- genbutton.click();
- }
- clearInterval(window.generateOnRepeatInterval)
- window.generateOnRepeatInterval = setInterval(function(){
- if(!interruptbutton.offsetParent){
- genbutton.click();
- }
- },
- 500)
- }
-
- appendContextMenuOption('#txt2img_generate','Generate forever',function(){
- generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
- })
- appendContextMenuOption('#img2img_generate','Generate forever',function(){
- generateOnRepeat('#img2img_generate','#img2img_interrupt');
- })
-
- let cancelGenerateForever = function(){
- clearInterval(window.generateOnRepeatInterval)
- }
-
- appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
-
-})();
-//End example Context Menu Items
-
-onUiUpdate(function(){
- addContextMenuEventListener()
-});
+
+var contextMenuInit = function() {
+ let eventListenerApplied = false;
+ let menuSpecs = new Map();
+
+ const uid = function() {
+ return Date.now().toString(36) + Math.random().toString(36).substring(2);
+ };
+
+ function showContextMenu(event, element, menuEntries) {
+ let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
+ let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
+
+ let oldMenu = gradioApp().querySelector('#context-menu');
+ if (oldMenu) {
+ oldMenu.remove();
+ }
+
+ let baseStyle = window.getComputedStyle(uiCurrentTab);
+
+ const contextMenu = document.createElement('nav');
+ contextMenu.id = "context-menu";
+ contextMenu.style.background = baseStyle.background;
+ contextMenu.style.color = baseStyle.color;
+ contextMenu.style.fontFamily = baseStyle.fontFamily;
+ contextMenu.style.top = posy + 'px';
+ contextMenu.style.left = posx + 'px';
+
+
+
+ const contextMenuList = document.createElement('ul');
+ contextMenuList.className = 'context-menu-items';
+ contextMenu.append(contextMenuList);
+
+ menuEntries.forEach(function(entry) {
+ let contextMenuEntry = document.createElement('a');
+ contextMenuEntry.innerHTML = entry['name'];
+ contextMenuEntry.addEventListener("click", function() {
+ entry['func']();
+ });
+ contextMenuList.append(contextMenuEntry);
+
+ });
+
+ gradioApp().appendChild(contextMenu);
+
+ let menuWidth = contextMenu.offsetWidth + 4;
+ let menuHeight = contextMenu.offsetHeight + 4;
+
+ let windowWidth = window.innerWidth;
+ let windowHeight = window.innerHeight;
+
+ if ((windowWidth - posx) < menuWidth) {
+ contextMenu.style.left = windowWidth - menuWidth + "px";
+ }
+
+ if ((windowHeight - posy) < menuHeight) {
+ contextMenu.style.top = windowHeight - menuHeight + "px";
+ }
+
+ }
+
+ function appendContextMenuOption(targetElementSelector, entryName, entryFunction) {
+
+ var currentItems = menuSpecs.get(targetElementSelector);
+
+ if (!currentItems) {
+ currentItems = [];
+ menuSpecs.set(targetElementSelector, currentItems);
+ }
+ let newItem = {
+ id: targetElementSelector + '_' + uid(),
+ name: entryName,
+ func: entryFunction,
+ isNew: true
+ };
+
+ currentItems.push(newItem);
+ return newItem['id'];
+ }
+
+ function removeContextMenuOption(uid) {
+ menuSpecs.forEach(function(v) {
+ let index = -1;
+ v.forEach(function(e, ei) {
+ if (e['id'] == uid) {
+ index = ei;
+ }
+ });
+ if (index >= 0) {
+ v.splice(index, 1);
+ }
+ });
+ }
+
+ function addContextMenuEventListener() {
+ if (eventListenerApplied) {
+ return;
+ }
+ gradioApp().addEventListener("click", function(e) {
+ if (!e.isTrusted) {
+ return;
+ }
+
+ let oldMenu = gradioApp().querySelector('#context-menu');
+ if (oldMenu) {
+ oldMenu.remove();
+ }
+ });
+ gradioApp().addEventListener("contextmenu", function(e) {
+ let oldMenu = gradioApp().querySelector('#context-menu');
+ if (oldMenu) {
+ oldMenu.remove();
+ }
+ menuSpecs.forEach(function(v, k) {
+ if (e.composedPath()[0].matches(k)) {
+ showContextMenu(e, e.composedPath()[0], v);
+ e.preventDefault();
+ }
+ });
+ });
+ eventListenerApplied = true;
+
+ }
+
+ return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener];
+};
+
+var initResponse = contextMenuInit();
+var appendContextMenuOption = initResponse[0];
+var removeContextMenuOption = initResponse[1];
+var addContextMenuEventListener = initResponse[2];
+
+(function() {
+ //Start example Context Menu Items
+ let generateOnRepeat = function(genbuttonid, interruptbuttonid) {
+ let genbutton = gradioApp().querySelector(genbuttonid);
+ let interruptbutton = gradioApp().querySelector(interruptbuttonid);
+ if (!interruptbutton.offsetParent) {
+ genbutton.click();
+ }
+ clearInterval(window.generateOnRepeatInterval);
+ window.generateOnRepeatInterval = setInterval(function() {
+ if (!interruptbutton.offsetParent) {
+ genbutton.click();
+ }
+ },
+ 500);
+ };
+
+ appendContextMenuOption('#txt2img_generate', 'Generate forever', function() {
+ generateOnRepeat('#txt2img_generate', '#txt2img_interrupt');
+ });
+ appendContextMenuOption('#img2img_generate', 'Generate forever', function() {
+ generateOnRepeat('#img2img_generate', '#img2img_interrupt');
+ });
+
+ let cancelGenerateForever = function() {
+ clearInterval(window.generateOnRepeatInterval);
+ };
+
+ appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever);
+ appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever);
+ appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever);
+ appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever);
+
+})();
+//End example Context Menu Items
+
+onUiUpdate(function() {
+ addContextMenuEventListener();
+});
diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js
index fe008924..e316a365 100644
--- a/javascript/dragdrop.js
+++ b/javascript/dragdrop.js
@@ -1,11 +1,11 @@
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard
-function isValidImageList( files ) {
+function isValidImageList(files) {
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
}
-function dropReplaceImage( imgWrap, files ) {
- if ( ! isValidImageList( files ) ) {
+function dropReplaceImage(imgWrap, files) {
+ if (!isValidImageList(files)) {
return;
}
@@ -14,44 +14,44 @@ function dropReplaceImage( imgWrap, files ) {
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]');
- if ( fileInput ) {
- if ( files.length === 0 ) {
+ if (fileInput) {
+ if (files.length === 0) {
files = new DataTransfer();
files.items.add(tmpFile);
fileInput.files = files.files;
} else {
fileInput.files = files;
}
- fileInput.dispatchEvent(new Event('change'));
+ fileInput.dispatchEvent(new Event('change'));
}
};
-
- if ( imgWrap.closest('#pnginfo_image') ) {
+
+ if (imgWrap.closest('#pnginfo_image')) {
// special treatment for PNG Info tab, wait for fetch request to finish
const oldFetch = window.fetch;
- window.fetch = async (input, options) => {
+ window.fetch = async(input, options) => {
const response = await oldFetch(input, options);
- if ( 'api/predict/' === input ) {
+ if ('api/predict/' === input) {
const content = await response.text();
window.fetch = oldFetch;
- window.requestAnimationFrame( () => callback() );
+ window.requestAnimationFrame(() => callback());
return new Response(content, {
status: response.status,
statusText: response.statusText,
headers: response.headers
- })
+ });
}
return response;
- };
+ };
} else {
- window.requestAnimationFrame( () => callback() );
+ window.requestAnimationFrame(() => callback());
}
}
window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]');
- if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
+ if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return;
}
e.stopPropagation();
@@ -65,33 +65,34 @@ window.document.addEventListener('drop', e => {
return;
}
const imgWrap = target.closest('[data-testid="image"]');
- if ( !imgWrap ) {
+ if (!imgWrap) {
return;
}
e.stopPropagation();
e.preventDefault();
const files = e.dataTransfer.files;
- dropReplaceImage( imgWrap, files );
+ dropReplaceImage(imgWrap, files);
});
window.addEventListener('paste', e => {
const files = e.clipboardData.files;
- if ( ! isValidImageList( files ) ) {
+ if (!isValidImageList(files)) {
return;
}
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
.filter(el => uiElementIsVisible(el));
- if ( ! visibleImageFields.length ) {
+ if (!visibleImageFields.length) {
return;
}
-
+
const firstFreeImageField = visibleImageFields
.filter(el => el.querySelector('input[type=file]'))?.[0];
dropReplaceImage(
firstFreeImageField ?
- firstFreeImageField :
- visibleImageFields[visibleImageFields.length - 1]
- , files );
+ firstFreeImageField :
+ visibleImageFields[visibleImageFields.length - 1]
+ , files
+ );
});
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index d2c2f190..fdf00b4d 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -1,120 +1,120 @@
-function keyupEditAttention(event){
- let target = event.originalTarget || event.composedPath()[0];
- if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
- if (! (event.metaKey || event.ctrlKey)) return;
-
- let isPlus = event.key == "ArrowUp"
- let isMinus = event.key == "ArrowDown"
- if (!isPlus && !isMinus) return;
-
- let selectionStart = target.selectionStart;
- let selectionEnd = target.selectionEnd;
- let text = target.value;
-
- function selectCurrentParenthesisBlock(OPEN, CLOSE){
- if (selectionStart !== selectionEnd) return false;
-
- // Find opening parenthesis around current cursor
- const before = text.substring(0, selectionStart);
- let beforeParen = before.lastIndexOf(OPEN);
- if (beforeParen == -1) return false;
- let beforeParenClose = before.lastIndexOf(CLOSE);
- while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
- beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
- beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
- }
-
- // Find closing parenthesis around current cursor
- const after = text.substring(selectionStart);
- let afterParen = after.indexOf(CLOSE);
- if (afterParen == -1) return false;
- let afterParenOpen = after.indexOf(OPEN);
- while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
- afterParen = after.indexOf(CLOSE, afterParen + 1);
- afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
- }
- if (beforeParen === -1 || afterParen === -1) return false;
-
- // Set the selection to the text between the parenthesis
- const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
- const lastColon = parenContent.lastIndexOf(":");
- selectionStart = beforeParen + 1;
- selectionEnd = selectionStart + lastColon;
- target.setSelectionRange(selectionStart, selectionEnd);
- return true;
- }
-
- function selectCurrentWord(){
- if (selectionStart !== selectionEnd) return false;
- const delimiters = opts.keyedit_delimiters + " \r\n\t";
-
- // seek backward until to find beggining
- while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
- selectionStart--;
- }
-
- // seek forward to find end
- while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
- selectionEnd++;
- }
-
- target.setSelectionRange(selectionStart, selectionEnd);
- return true;
- }
-
- // If the user hasn't selected anything, let's select their current parenthesis block or word
- if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
- selectCurrentWord();
- }
-
- event.preventDefault();
-
- var closeCharacter = ')'
- var delta = opts.keyedit_precision_attention
-
- if (selectionStart > 0 && text[selectionStart - 1] == '<'){
- closeCharacter = '>'
- delta = opts.keyedit_precision_extra
- } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
-
- // do not include spaces at the end
- while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
- selectionEnd -= 1;
- }
- if(selectionStart == selectionEnd){
- return
- }
-
- text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
-
- selectionStart += 1;
- selectionEnd += 1;
- }
-
- var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
- var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
- if (isNaN(weight)) return;
-
- weight += isPlus ? delta : -delta;
- weight = parseFloat(weight.toPrecision(12));
- if(String(weight).length == 1) weight += ".0"
-
- if (closeCharacter == ')' && weight == 1) {
- text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
- selectionStart--;
- selectionEnd--;
- } else {
- text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
- }
-
- target.focus();
- target.value = text;
- target.selectionStart = selectionStart;
- target.selectionEnd = selectionEnd;
-
- updateInput(target)
-}
-
-addEventListener('keydown', (event) => {
- keyupEditAttention(event);
-});
+function keyupEditAttention(event) {
+ let target = event.originalTarget || event.composedPath()[0];
+ if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
+ if (!(event.metaKey || event.ctrlKey)) return;
+
+ let isPlus = event.key == "ArrowUp";
+ let isMinus = event.key == "ArrowDown";
+ if (!isPlus && !isMinus) return;
+
+ let selectionStart = target.selectionStart;
+ let selectionEnd = target.selectionEnd;
+ let text = target.value;
+
+ function selectCurrentParenthesisBlock(OPEN, CLOSE) {
+ if (selectionStart !== selectionEnd) return false;
+
+ // Find opening parenthesis around current cursor
+ const before = text.substring(0, selectionStart);
+ let beforeParen = before.lastIndexOf(OPEN);
+ if (beforeParen == -1) return false;
+ let beforeParenClose = before.lastIndexOf(CLOSE);
+ while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
+ beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
+ beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
+ }
+
+ // Find closing parenthesis around current cursor
+ const after = text.substring(selectionStart);
+ let afterParen = after.indexOf(CLOSE);
+ if (afterParen == -1) return false;
+ let afterParenOpen = after.indexOf(OPEN);
+ while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
+ afterParen = after.indexOf(CLOSE, afterParen + 1);
+ afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
+ }
+ if (beforeParen === -1 || afterParen === -1) return false;
+
+ // Set the selection to the text between the parenthesis
+ const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
+ const lastColon = parenContent.lastIndexOf(":");
+ selectionStart = beforeParen + 1;
+ selectionEnd = selectionStart + lastColon;
+ target.setSelectionRange(selectionStart, selectionEnd);
+ return true;
+ }
+
+ function selectCurrentWord() {
+ if (selectionStart !== selectionEnd) return false;
+ const delimiters = opts.keyedit_delimiters + " \r\n\t";
+
+ // seek backward until to find beggining
+ while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
+ selectionStart--;
+ }
+
+ // seek forward to find end
+ while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
+ selectionEnd++;
+ }
+
+ target.setSelectionRange(selectionStart, selectionEnd);
+ return true;
+ }
+
+ // If the user hasn't selected anything, let's select their current parenthesis block or word
+ if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
+ selectCurrentWord();
+ }
+
+ event.preventDefault();
+
+ var closeCharacter = ')';
+ var delta = opts.keyedit_precision_attention;
+
+ if (selectionStart > 0 && text[selectionStart - 1] == '<') {
+ closeCharacter = '>';
+ delta = opts.keyedit_precision_extra;
+ } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
+
+ // do not include spaces at the end
+ while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
+ selectionEnd -= 1;
+ }
+ if (selectionStart == selectionEnd) {
+ return;
+ }
+
+ text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
+
+ selectionStart += 1;
+ selectionEnd += 1;
+ }
+
+ var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
+ var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
+ if (isNaN(weight)) return;
+
+ weight += isPlus ? delta : -delta;
+ weight = parseFloat(weight.toPrecision(12));
+ if (String(weight).length == 1) weight += ".0";
+
+ if (closeCharacter == ')' && weight == 1) {
+ text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
+ selectionStart--;
+ selectionEnd--;
+ } else {
+ text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+ }
+
+ target.focus();
+ target.value = text;
+ target.selectionStart = selectionStart;
+ target.selectionEnd = selectionEnd;
+
+ updateInput(target);
+}
+
+addEventListener('keydown', (event) => {
+ keyupEditAttention(event);
+});
diff --git a/javascript/extensions.js b/javascript/extensions.js
index 2a2d2f8e..efeaf3a5 100644
--- a/javascript/extensions.js
+++ b/javascript/extensions.js
@@ -1,71 +1,74 @@
-
-function extensions_apply(_disabled_list, _update_list, disable_all){
- var disable = []
- var update = []
-
- gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
- if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substring(7))
-
- if(x.name.startsWith("update_") && x.checked)
- update.push(x.name.substring(7))
- })
-
- restart_reload()
-
- return [JSON.stringify(disable), JSON.stringify(update), disable_all]
-}
-
-function extensions_check(){
- var disable = []
-
- gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
- if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substring(7))
- })
-
- gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
- x.innerHTML = "Loading..."
- })
-
-
- var id = randomId()
- requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
-
- })
-
- return [id, JSON.stringify(disable)]
-}
-
-function install_extension_from_index(button, url){
- button.disabled = "disabled"
- button.value = "Installing..."
-
- var textarea = gradioApp().querySelector('#extension_to_install textarea')
- textarea.value = url
- updateInput(textarea)
-
- gradioApp().querySelector('#install_extension_button').click()
-}
-
-function config_state_confirm_restore(_, config_state_name, config_restore_type) {
- if (config_state_name == "Current") {
- return [false, config_state_name, config_restore_type];
- }
- let restored = "";
- if (config_restore_type == "extensions") {
- restored = "all saved extension versions";
- } else if (config_restore_type == "webui") {
- restored = "the webui version";
- } else {
- restored = "the webui version and all saved extension versions";
- }
- let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
- if (confirmed) {
- restart_reload();
- gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
- x.innerHTML = "Loading..."
- })
- }
- return [confirmed, config_state_name, config_restore_type];
-}
+
+function extensions_apply(_disabled_list, _update_list, disable_all) {
+ var disable = [];
+ var update = [];
+
+ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
+ if (x.name.startsWith("enable_") && !x.checked) {
+ disable.push(x.name.substring(7));
+ }
+
+ if (x.name.startsWith("update_") && x.checked) {
+ update.push(x.name.substring(7));
+ }
+ });
+
+ restart_reload();
+
+ return [JSON.stringify(disable), JSON.stringify(update), disable_all];
+}
+
+function extensions_check() {
+ var disable = [];
+
+ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
+ if (x.name.startsWith("enable_") && !x.checked) {
+ disable.push(x.name.substring(7));
+ }
+ });
+
+ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
+ x.innerHTML = "Loading...";
+ });
+
+
+ var id = randomId();
+ requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
+
+ });
+
+ return [id, JSON.stringify(disable)];
+}
+
+function install_extension_from_index(button, url) {
+ button.disabled = "disabled";
+ button.value = "Installing...";
+
+ var textarea = gradioApp().querySelector('#extension_to_install textarea');
+ textarea.value = url;
+ updateInput(textarea);
+
+ gradioApp().querySelector('#install_extension_button').click();
+}
+
+function config_state_confirm_restore(_, config_state_name, config_restore_type) {
+ if (config_state_name == "Current") {
+ return [false, config_state_name, config_restore_type];
+ }
+ let restored = "";
+ if (config_restore_type == "extensions") {
+ restored = "all saved extension versions";
+ } else if (config_restore_type == "webui") {
+ restored = "the webui version";
+ } else {
+ restored = "the webui version and all saved extension versions";
+ }
+ let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
+ if (confirmed) {
+ restart_reload();
+ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
+ x.innerHTML = "Loading...";
+ });
+ }
+ return [confirmed, config_state_name, config_restore_type];
+}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index c85bc79a..aafe0a00 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -1,196 +1,215 @@
-function setupExtraNetworksForTab(tabname){
- gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
-
- var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
- var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
- var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
-
- search.classList.add('search')
- tabs.appendChild(search)
- tabs.appendChild(refresh)
-
- var applyFilter = function(){
- var searchTerm = search.value.toLowerCase()
-
- gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
- var searchOnly = elem.querySelector('.search_only')
- var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
-
- var visible = text.indexOf(searchTerm) != -1
-
- if(searchOnly && searchTerm.length < 4){
- visible = false
- }
-
- elem.style.display = visible ? "" : "none"
- })
- }
-
- search.addEventListener("input", applyFilter);
- applyFilter();
-
- extraNetworksApplyFilter[tabname] = applyFilter;
-}
-
-function applyExtraNetworkFilter(tabname){
- setTimeout(extraNetworksApplyFilter[tabname], 1);
-}
-
-var extraNetworksApplyFilter = {}
-var activePromptTextarea = {};
-
-function setupExtraNetworks(){
- setupExtraNetworksForTab('txt2img')
- setupExtraNetworksForTab('img2img')
-
- function registerPrompt(tabname, id){
- var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
-
- if (! activePromptTextarea[tabname]){
- activePromptTextarea[tabname] = textarea
- }
-
- textarea.addEventListener("focus", function(){
- activePromptTextarea[tabname] = textarea;
- });
- }
-
- registerPrompt('txt2img', 'txt2img_prompt')
- registerPrompt('txt2img', 'txt2img_neg_prompt')
- registerPrompt('img2img', 'img2img_prompt')
- registerPrompt('img2img', 'img2img_neg_prompt')
-}
-
-onUiLoaded(setupExtraNetworks)
-
-var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
-var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
-
-function tryToRemoveExtraNetworkFromPrompt(textarea, text){
- var m = text.match(re_extranet)
- if(! m) return false
-
- var partToSearch = m[1]
- var replaced = false
- var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
- m = found.match(re_extranet);
- if(m[1] == partToSearch){
- replaced = true;
- return ""
- }
- return found;
- })
-
- if(replaced){
- textarea.value = newTextareaText
- return true;
- }
-
- return false
-}
-
-function cardClicked(tabname, textToAdd, allowNegativePrompt){
- var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
-
- if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
- textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
- }
-
- updateInput(textarea)
-}
-
-function saveCardPreview(event, tabname, filename){
- var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
- var button = gradioApp().getElementById(tabname + '_save_preview')
-
- textarea.value = filename
- updateInput(textarea)
-
- button.click()
-
- event.stopPropagation()
- event.preventDefault()
-}
-
-function extraNetworksSearchButton(tabs_id, event){
- var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
- var button = event.target
- var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
-
- searchTextarea.value = text
- updateInput(searchTextarea)
-}
-
-var globalPopup = null;
-var globalPopupInner = null;
-function popup(contents){
- if(! globalPopup){
- globalPopup = document.createElement('div')
- globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
- globalPopup.classList.add('global-popup');
-
- var close = document.createElement('div')
- close.classList.add('global-popup-close');
- close.onclick = function(){ globalPopup.style.display = "none"; };
- close.title = "Close";
- globalPopup.appendChild(close)
-
- globalPopupInner = document.createElement('div')
- globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
- globalPopupInner.classList.add('global-popup-inner');
- globalPopup.appendChild(globalPopupInner)
-
- gradioApp().appendChild(globalPopup);
- }
-
- globalPopupInner.innerHTML = '';
- globalPopupInner.appendChild(contents);
-
- globalPopup.style.display = "flex";
-}
-
-function extraNetworksShowMetadata(text){
- var elem = document.createElement('pre')
- elem.classList.add('popup-metadata');
- elem.textContent = text;
-
- popup(elem);
-}
-
-function requestGet(url, data, handler, errorHandler){
- var xhr = new XMLHttpRequest();
- var args = Object.keys(data).map(function(k){ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]) }).join('&')
- xhr.open("GET", url + "?" + args, true);
-
- xhr.onreadystatechange = function () {
- if (xhr.readyState === 4) {
- if (xhr.status === 200) {
- try {
- var js = JSON.parse(xhr.responseText);
- handler(js)
- } catch (error) {
- console.error(error);
- errorHandler()
- }
- } else{
- errorHandler()
- }
- }
- };
- var js = JSON.stringify(data);
- xhr.send(js);
-}
-
-function extraNetworksRequestMetadata(event, extraPage, cardName){
- var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
-
- requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
- if(data && data.metadata){
- extraNetworksShowMetadata(data.metadata)
- } else{
- showError()
- }
- }, showError)
-
- event.stopPropagation()
-}
+function setupExtraNetworksForTab(tabname) {
+ gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
+
+ var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
+ var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
+ var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
+
+ search.classList.add('search');
+ tabs.appendChild(search);
+ tabs.appendChild(refresh);
+
+ var applyFilter = function() {
+ var searchTerm = search.value.toLowerCase();
+
+ gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
+ var searchOnly = elem.querySelector('.search_only');
+ var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase();
+
+ var visible = text.indexOf(searchTerm) != -1;
+
+ if (searchOnly && searchTerm.length < 4) {
+ visible = false;
+ }
+
+ elem.style.display = visible ? "" : "none";
+ });
+ };
+
+ search.addEventListener("input", applyFilter);
+ applyFilter();
+
+ extraNetworksApplyFilter[tabname] = applyFilter;
+}
+
+function applyExtraNetworkFilter(tabname) {
+ setTimeout(extraNetworksApplyFilter[tabname], 1);
+}
+
+var extraNetworksApplyFilter = {};
+var activePromptTextarea = {};
+
+function setupExtraNetworks() {
+ setupExtraNetworksForTab('txt2img');
+ setupExtraNetworksForTab('img2img');
+
+ function registerPrompt(tabname, id) {
+ var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
+
+ if (!activePromptTextarea[tabname]) {
+ activePromptTextarea[tabname] = textarea;
+ }
+
+ textarea.addEventListener("focus", function() {
+ activePromptTextarea[tabname] = textarea;
+ });
+ }
+
+ registerPrompt('txt2img', 'txt2img_prompt');
+ registerPrompt('txt2img', 'txt2img_neg_prompt');
+ registerPrompt('img2img', 'img2img_prompt');
+ registerPrompt('img2img', 'img2img_neg_prompt');
+}
+
+onUiLoaded(setupExtraNetworks);
+
+var re_extranet = /<([^:]+:[^:]+):[\d.]+>/;
+var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
+
+function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
+ var m = text.match(re_extranet);
+ var replaced = false;
+ var newTextareaText;
+ if (m) {
+ var partToSearch = m[1];
+ newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found) {
+ m = found.match(re_extranet);
+ if (m[1] == partToSearch) {
+ replaced = true;
+ return "";
+ }
+ return found;
+ });
+ } else {
+ newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
+ if (found == text) {
+ replaced = true;
+ return "";
+ }
+ return found;
+ });
+ }
+
+ if (replaced) {
+ textarea.value = newTextareaText;
+ return true;
+ }
+
+ return false;
+}
+
+function cardClicked(tabname, textToAdd, allowNegativePrompt) {
+ var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
+
+ if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) {
+ textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd;
+ }
+
+ updateInput(textarea);
+}
+
+function saveCardPreview(event, tabname, filename) {
+ var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea');
+ var button = gradioApp().getElementById(tabname + '_save_preview');
+
+ textarea.value = filename;
+ updateInput(textarea);
+
+ button.click();
+
+ event.stopPropagation();
+ event.preventDefault();
+}
+
+function extraNetworksSearchButton(tabs_id, event) {
+ var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
+ var button = event.target;
+ var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
+
+ searchTextarea.value = text;
+ updateInput(searchTextarea);
+}
+
+var globalPopup = null;
+var globalPopupInner = null;
+function popup(contents) {
+ if (!globalPopup) {
+ globalPopup = document.createElement('div');
+ globalPopup.onclick = function() {
+ globalPopup.style.display = "none";
+ };
+ globalPopup.classList.add('global-popup');
+
+ var close = document.createElement('div');
+ close.classList.add('global-popup-close');
+ close.onclick = function() {
+ globalPopup.style.display = "none";
+ };
+ close.title = "Close";
+ globalPopup.appendChild(close);
+
+ globalPopupInner = document.createElement('div');
+ globalPopupInner.onclick = function(event) {
+ event.stopPropagation(); return false;
+ };
+ globalPopupInner.classList.add('global-popup-inner');
+ globalPopup.appendChild(globalPopupInner);
+
+ gradioApp().appendChild(globalPopup);
+ }
+
+ globalPopupInner.innerHTML = '';
+ globalPopupInner.appendChild(contents);
+
+ globalPopup.style.display = "flex";
+}
+
+function extraNetworksShowMetadata(text) {
+ var elem = document.createElement('pre');
+ elem.classList.add('popup-metadata');
+ elem.textContent = text;
+
+ popup(elem);
+}
+
+function requestGet(url, data, handler, errorHandler) {
+ var xhr = new XMLHttpRequest();
+ var args = Object.keys(data).map(function(k) {
+ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]);
+ }).join('&');
+ xhr.open("GET", url + "?" + args, true);
+
+ xhr.onreadystatechange = function() {
+ if (xhr.readyState === 4) {
+ if (xhr.status === 200) {
+ try {
+ var js = JSON.parse(xhr.responseText);
+ handler(js);
+ } catch (error) {
+ console.error(error);
+ errorHandler();
+ }
+ } else {
+ errorHandler();
+ }
+ }
+ };
+ var js = JSON.stringify(data);
+ xhr.send(js);
+}
+
+function extraNetworksRequestMetadata(event, extraPage, cardName) {
+ var showError = function() {
+ extraNetworksShowMetadata("there was an error getting metadata");
+ };
+
+ requestGet("./sd_extra_networks/metadata", {page: extraPage, item: cardName}, function(data) {
+ if (data && data.metadata) {
+ extraNetworksShowMetadata(data.metadata);
+ } else {
+ showError();
+ }
+ }, showError);
+
+ event.stopPropagation();
+}
diff --git a/javascript/generationParams.js b/javascript/generationParams.js
index ef64ee2e..a877f8a5 100644
--- a/javascript/generationParams.js
+++ b/javascript/generationParams.js
@@ -1,33 +1,35 @@
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined;
-onUiUpdate(function(){
- if (!txt2img_gallery) {
- txt2img_gallery = attachGalleryListeners("txt2img")
- }
- if (!img2img_gallery) {
- img2img_gallery = attachGalleryListeners("img2img")
- }
- if (!modal) {
- modal = gradioApp().getElementById('lightboxModal')
- modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
- }
+onUiUpdate(function() {
+ if (!txt2img_gallery) {
+ txt2img_gallery = attachGalleryListeners("txt2img");
+ }
+ if (!img2img_gallery) {
+ img2img_gallery = attachGalleryListeners("img2img");
+ }
+ if (!modal) {
+ modal = gradioApp().getElementById('lightboxModal');
+ modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});
+ }
});
let modalObserver = new MutationObserver(function(mutations) {
- mutations.forEach(function(mutationRecord) {
- let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText
- if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img'))
- gradioApp().getElementById(selectedTab+"_generation_info_button")?.click()
- });
+ mutations.forEach(function(mutationRecord) {
+ let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;
+ if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {
+ gradioApp().getElementById(selectedTab + "_generation_info_button")?.click();
+ }
+ });
});
function attachGalleryListeners(tab_name) {
- var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
- gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
- gallery?.addEventListener('keydown', (e) => {
- if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
- gradioApp().getElementById(tab_name+"_generation_info_button").click()
- });
- return gallery;
+ var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');
+ gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click());
+ gallery?.addEventListener('keydown', (e) => {
+ if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow
+ gradioApp().getElementById(tab_name + "_generation_info_button").click();
+ }
+ });
+ return gallery;
}
diff --git a/javascript/hints.js b/javascript/hints.js
index 3746df99..46f342cb 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -1,16 +1,17 @@
// mouseover tooltips for various UI elements
-titles = {
+var titles = {
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
"Sampling method": "Which algorithm to use to produce the image",
- "GFPGAN": "Restore low quality faces using GFPGAN neural network",
- "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
- "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
- "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
- "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
+ "GFPGAN": "Restore low quality faces using GFPGAN neural network",
+ "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
+ "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
+ "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
+ "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
- "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
- "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
+ "\u{1F4D0}": "Auto detect size from img2img",
+ "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
+ "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
@@ -40,7 +41,7 @@ titles = {
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
-
+
"Skip": "Stop processing current image and continue processing.",
"Interrupt": "Stop processing images and return any results accumulated so far.",
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
@@ -66,8 +67,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
- "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime], [datetime], [job_timestamp], [hasprompt..]; leave empty for default.",
- "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime], [datetime], [job_timestamp], [hasprompt..]; leave empty for default.",
+ "Images filename pattern": "Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.",
+ "Directory name pattern": "Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
@@ -96,7 +97,7 @@ titles = {
"Add difference": "Result = A + (B - C) * M",
"No interpolation": "Result = A",
- "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
+ "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
@@ -113,38 +114,55 @@ titles = {
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
+};
+
+function updateTooltipForSpan(span) {
+ if (span.title) return; // already has a title
+
+ let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
+
+ if (!tooltip) {
+ tooltip = localization[titles[span.value]] || titles[span.value];
+ }
+
+ if (!tooltip) {
+ for (const c of span.classList) {
+ if (c in titles) {
+ tooltip = localization[titles[c]] || titles[c];
+ break;
+ }
+ }
+ }
+
+ if (tooltip) {
+ span.title = tooltip;
+ }
}
+function updateTooltipForSelect(select) {
+ if (select.onchange != null) return;
-onUiUpdate(function(){
- gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
- if (span.title) return; // already has a title
+ select.onchange = function() {
+ select.title = localization[titles[select.value]] || titles[select.value] || "";
+ };
+}
- let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
+var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
- if(!tooltip){
- tooltip = localization[titles[span.value]] || titles[span.value];
- }
+onUiUpdate(function(m) {
+ m.forEach(function(record) {
+ record.addedNodes.forEach(function(node) {
+ if (observedTooltipElements[node.tagName]) {
+ updateTooltipForSpan(node);
+ }
+ if (node.tagName == "SELECT") {
+ updateTooltipForSelect(node);
+ }
- if(!tooltip){
- for (const c of span.classList) {
- if (c in titles) {
- tooltip = localization[titles[c]] || titles[c];
- break;
- }
- }
- }
-
- if(tooltip){
- span.title = tooltip;
- }
- })
-
- gradioApp().querySelectorAll('select').forEach(function(select){
- if (select.onchange != null) return;
-
- select.onchange = function(){
- select.title = localization[titles[select.value]] || titles[select.value] || "";
- }
- })
-})
+ if (node.querySelectorAll) {
+ node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
+ node.querySelectorAll('select').forEach(updateTooltipForSelect);
+ }
+ });
+ });
+});
diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js
index 48196be4..0d04ab3b 100644
--- a/javascript/hires_fix.js
+++ b/javascript/hires_fix.js
@@ -1,18 +1,18 @@
-
-function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
- function setInactive(elem, inactive){
- elem.classList.toggle('inactive', !!inactive)
- }
-
- var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
- var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
- var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
-
- gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
-
- setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
- setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
- setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
-
- return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
-}
+
+function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {
+ function setInactive(elem, inactive) {
+ elem.classList.toggle('inactive', !!inactive);
+ }
+
+ var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');
+ var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');
+ var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');
+
+ gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "";
+
+ setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);
+ setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);
+ setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);
+
+ return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];
+}
diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js
index a612705d..3c9b8a6f 100644
--- a/javascript/imageMaskFix.js
+++ b/javascript/imageMaskFix.js
@@ -4,17 +4,16 @@
*/
function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
- if ( ! canvases.length ) {
- canvases_fixed = false; // TODO: this is unused..?
- window.removeEventListener( 'resize', imageMaskResize );
- return;
+ if (!canvases.length) {
+ window.removeEventListener('resize', imageMaskResize);
+ return;
}
const wrapper = canvases[0].closest('.touch-none');
const previewImage = wrapper.previousElementSibling;
- if ( ! previewImage.complete ) {
- previewImage.addEventListener( 'load', imageMaskResize);
+ if (!previewImage.complete) {
+ previewImage.addEventListener('load', imageMaskResize);
return;
}
@@ -24,15 +23,15 @@ function imageMaskResize() {
const nh = previewImage.naturalHeight;
const portrait = nh > nw;
- const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
- const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
+ const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
+ const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
wrapper.style.width = `${wW}px`;
wrapper.style.height = `${wH}px`;
wrapper.style.left = `0px`;
wrapper.style.top = `0px`;
- canvases.forEach( c => {
+ canvases.forEach(c => {
c.style.width = c.style.height = '';
c.style.maxWidth = '100%';
c.style.maxHeight = '100%';
@@ -41,4 +40,4 @@ function imageMaskResize() {
}
onUiUpdate(imageMaskResize);
-window.addEventListener( 'resize', imageMaskResize);
+window.addEventListener('resize', imageMaskResize);
diff --git a/javascript/imageParams.js b/javascript/imageParams.js
index 64aee93b..057e2d39 100644
--- a/javascript/imageParams.js
+++ b/javascript/imageParams.js
@@ -1,4 +1,4 @@
-window.onload = (function(){
+window.onload = (function() {
window.addEventListener('drop', e => {
const target = e.composedPath()[0];
if (target.placeholder.indexOf("Prompt") == -1) return;
@@ -10,7 +10,7 @@ window.onload = (function(){
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
- if ( fileInput ) {
+ if (fileInput) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index 32066ab8..78e24eb9 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -5,24 +5,24 @@ function closeModal() {
function showModal(event) {
const source = event.target || event.srcElement;
- const modalImage = gradioApp().getElementById("modalImage")
- const lb = gradioApp().getElementById("lightboxModal")
- modalImage.src = source.src
+ const modalImage = gradioApp().getElementById("modalImage");
+ const lb = gradioApp().getElementById("lightboxModal");
+ modalImage.src = source.src;
if (modalImage.style.display === 'none') {
lb.style.setProperty('background-image', 'url(' + source.src + ')');
}
lb.style.display = "flex";
- lb.focus()
+ lb.focus();
- const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
- const tabImg2Img = gradioApp().getElementById("tab_img2img")
+ const tabTxt2Img = gradioApp().getElementById("tab_txt2img");
+ const tabImg2Img = gradioApp().getElementById("tab_img2img");
// show the save button in modal only on txt2img or img2img tabs
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
- gradioApp().getElementById("modal_save").style.display = "inline"
+ gradioApp().getElementById("modal_save").style.display = "inline";
} else {
- gradioApp().getElementById("modal_save").style.display = "none"
+ gradioApp().getElementById("modal_save").style.display = "none";
}
- event.stopPropagation()
+ event.stopPropagation();
}
function negmod(n, m) {
@@ -30,14 +30,15 @@ function negmod(n, m) {
}
function updateOnBackgroundChange() {
- const modalImage = gradioApp().getElementById("modalImage")
+ const modalImage = gradioApp().getElementById("modalImage");
if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button();
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') {
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
+ const modal = gradioApp().getElementById("lightboxModal");
+ modal.style.setProperty('background-image', `url(${modalImage.src})`);
}
}
}
@@ -49,108 +50,109 @@ function modalImageSwitch(offset) {
if (galleryButtons.length > 1) {
var currentButton = selected_gallery_button();
- var result = -1
+ var result = -1;
galleryButtons.forEach(function(v, i) {
if (v == currentButton) {
- result = i
+ result = i;
}
- })
+ });
if (result != -1) {
- var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
- nextButton.click()
+ var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)];
+ nextButton.click();
const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal");
modalImage.src = nextButton.children[0].src;
if (modalImage.style.display === 'none') {
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
+ modal.style.setProperty('background-image', `url(${modalImage.src})`);
}
setTimeout(function() {
- modal.focus()
- }, 10)
+ modal.focus();
+ }, 10);
}
}
}
-function saveImage(){
- const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
- const tabImg2Img = gradioApp().getElementById("tab_img2img")
- const saveTxt2Img = "save_txt2img"
- const saveImg2Img = "save_img2img"
+function saveImage() {
+ const tabTxt2Img = gradioApp().getElementById("tab_txt2img");
+ const tabImg2Img = gradioApp().getElementById("tab_img2img");
+ const saveTxt2Img = "save_txt2img";
+ const saveImg2Img = "save_img2img";
if (tabTxt2Img.style.display != "none") {
- gradioApp().getElementById(saveTxt2Img).click()
+ gradioApp().getElementById(saveTxt2Img).click();
} else if (tabImg2Img.style.display != "none") {
- gradioApp().getElementById(saveImg2Img).click()
+ gradioApp().getElementById(saveImg2Img).click();
} else {
- console.error("missing implementation for saving modal of this type")
+ console.error("missing implementation for saving modal of this type");
}
}
function modalSaveImage(event) {
- saveImage()
- event.stopPropagation()
+ saveImage();
+ event.stopPropagation();
}
function modalNextImage(event) {
- modalImageSwitch(1)
- event.stopPropagation()
+ modalImageSwitch(1);
+ event.stopPropagation();
}
function modalPrevImage(event) {
- modalImageSwitch(-1)
- event.stopPropagation()
+ modalImageSwitch(-1);
+ event.stopPropagation();
}
function modalKeyHandler(event) {
switch (event.key) {
- case "s":
- saveImage()
- break;
- case "ArrowLeft":
- modalPrevImage(event)
- break;
- case "ArrowRight":
- modalNextImage(event)
- break;
- case "Escape":
- closeModal();
- break;
+ case "s":
+ saveImage();
+ break;
+ case "ArrowLeft":
+ modalPrevImage(event);
+ break;
+ case "ArrowRight":
+ modalNextImage(event);
+ break;
+ case "Escape":
+ closeModal();
+ break;
}
}
function setupImageForLightbox(e) {
- if (e.dataset.modded)
- return;
+ if (e.dataset.modded) {
+ return;
+ }
- e.dataset.modded = true;
- e.style.cursor='pointer'
- e.style.userSelect='none'
+ e.dataset.modded = true;
+ e.style.cursor = 'pointer';
+ e.style.userSelect = 'none';
- var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
+ var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1;
- // For Firefox, listening on click first switched to next image then shows the lightbox.
- // If you know how to fix this without switching to mousedown event, please.
- // For other browsers the event is click to make it possiblr to drag picture.
- var event = isFirefox ? 'mousedown' : 'click'
+ // For Firefox, listening on click first switched to next image then shows the lightbox.
+ // If you know how to fix this without switching to mousedown event, please.
+ // For other browsers the event is click to make it possiblr to drag picture.
+ var event = isFirefox ? 'mousedown' : 'click';
- e.addEventListener(event, function (evt) {
- if(!opts.js_modal_lightbox || evt.button != 0) return;
+ e.addEventListener(event, function(evt) {
+ if (!opts.js_modal_lightbox || evt.button != 0) return;
- modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
- evt.preventDefault()
- showModal(evt)
- }, true);
+ modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
+ evt.preventDefault();
+ showModal(evt);
+ }, true);
}
function modalZoomSet(modalImage, enable) {
- if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
+ if (modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
}
function modalZoomToggle(event) {
var modalImage = gradioApp().getElementById("modalImage");
- modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
- event.stopPropagation()
+ modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'));
+ event.stopPropagation();
}
function modalTileImageToggle(event) {
@@ -159,99 +161,93 @@ function modalTileImageToggle(event) {
const isTiling = modalImage.style.display === 'none';
if (isTiling) {
modalImage.style.display = 'block';
- modal.style.setProperty('background-image', 'none')
+ modal.style.setProperty('background-image', 'none');
} else {
modalImage.style.display = 'none';
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
+ modal.style.setProperty('background-image', `url(${modalImage.src})`);
}
- event.stopPropagation()
-}
-
-function galleryImageHandler(e) {
- //if (e && e.parentElement.tagName == 'BUTTON') {
- e.onclick = showGalleryImage;
- //}
+ event.stopPropagation();
}
onUiUpdate(function() {
- var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
+ var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox);
}
updateOnBackgroundChange();
-})
+});
document.addEventListener("DOMContentLoaded", function() {
//const modalFragment = document.createDocumentFragment();
- const modal = document.createElement('div')
+ const modal = document.createElement('div');
modal.onclick = closeModal;
modal.id = "lightboxModal";
- modal.tabIndex = 0
- modal.addEventListener('keydown', modalKeyHandler, true)
+ modal.tabIndex = 0;
+ modal.addEventListener('keydown', modalKeyHandler, true);
- const modalControls = document.createElement('div')
+ const modalControls = document.createElement('div');
modalControls.className = 'modalControls gradio-container';
modal.append(modalControls);
- const modalZoom = document.createElement('span')
+ const modalZoom = document.createElement('span');
modalZoom.className = 'modalZoom cursor';
- modalZoom.innerHTML = '⤡'
- modalZoom.addEventListener('click', modalZoomToggle, true)
+ modalZoom.innerHTML = '⤡';
+ modalZoom.addEventListener('click', modalZoomToggle, true);
modalZoom.title = "Toggle zoomed view";
- modalControls.appendChild(modalZoom)
+ modalControls.appendChild(modalZoom);
- const modalTileImage = document.createElement('span')
+ const modalTileImage = document.createElement('span');
modalTileImage.className = 'modalTileImage cursor';
- modalTileImage.innerHTML = '⊞'
- modalTileImage.addEventListener('click', modalTileImageToggle, true)
+ modalTileImage.innerHTML = '⊞';
+ modalTileImage.addEventListener('click', modalTileImageToggle, true);
modalTileImage.title = "Preview tiling";
- modalControls.appendChild(modalTileImage)
+ modalControls.appendChild(modalTileImage);
- const modalSave = document.createElement("span")
- modalSave.className = "modalSave cursor"
- modalSave.id = "modal_save"
- modalSave.innerHTML = "🖫"
- modalSave.addEventListener("click", modalSaveImage, true)
- modalSave.title = "Save Image(s)"
- modalControls.appendChild(modalSave)
+ const modalSave = document.createElement("span");
+ modalSave.className = "modalSave cursor";
+ modalSave.id = "modal_save";
+ modalSave.innerHTML = "🖫";
+ modalSave.addEventListener("click", modalSaveImage, true);
+ modalSave.title = "Save Image(s)";
+ modalControls.appendChild(modalSave);
- const modalClose = document.createElement('span')
+ const modalClose = document.createElement('span');
modalClose.className = 'modalClose cursor';
- modalClose.innerHTML = '×'
+ modalClose.innerHTML = '×';
modalClose.onclick = closeModal;
modalClose.title = "Close image viewer";
- modalControls.appendChild(modalClose)
+ modalControls.appendChild(modalClose);
- const modalImage = document.createElement('img')
+ const modalImage = document.createElement('img');
modalImage.id = 'modalImage';
modalImage.onclick = closeModal;
- modalImage.tabIndex = 0
- modalImage.addEventListener('keydown', modalKeyHandler, true)
- modal.appendChild(modalImage)
+ modalImage.tabIndex = 0;
+ modalImage.addEventListener('keydown', modalKeyHandler, true);
+ modal.appendChild(modalImage);
- const modalPrev = document.createElement('a')
+ const modalPrev = document.createElement('a');
modalPrev.className = 'modalPrev';
- modalPrev.innerHTML = '❮'
- modalPrev.tabIndex = 0
+ modalPrev.innerHTML = '❮';
+ modalPrev.tabIndex = 0;
modalPrev.addEventListener('click', modalPrevImage, true);
- modalPrev.addEventListener('keydown', modalKeyHandler, true)
- modal.appendChild(modalPrev)
+ modalPrev.addEventListener('keydown', modalKeyHandler, true);
+ modal.appendChild(modalPrev);
- const modalNext = document.createElement('a')
+ const modalNext = document.createElement('a');
modalNext.className = 'modalNext';
- modalNext.innerHTML = '❯'
- modalNext.tabIndex = 0
+ modalNext.innerHTML = '❯';
+ modalNext.tabIndex = 0;
modalNext.addEventListener('click', modalNextImage, true);
- modalNext.addEventListener('keydown', modalKeyHandler, true)
+ modalNext.addEventListener('keydown', modalKeyHandler, true);
- modal.appendChild(modalNext)
+ modal.appendChild(modalNext);
try {
- gradioApp().appendChild(modal);
- } catch (e) {
- gradioApp().body.appendChild(modal);
- }
+ gradioApp().appendChild(modal);
+ } catch (e) {
+ gradioApp().body.appendChild(modal);
+ }
document.body.appendChild(modal);
diff --git a/javascript/imageviewerGamepad.js b/javascript/imageviewerGamepad.js
index 6297a12b..31d226de 100644
--- a/javascript/imageviewerGamepad.js
+++ b/javascript/imageviewerGamepad.js
@@ -1,7 +1,7 @@
window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index;
let isWaiting = false;
- setInterval(async () => {
+ setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0];
@@ -14,7 +14,7 @@ window.addEventListener('gamepadconnected', (e) => {
}
if (isWaiting) {
await sleepUntil(() => {
- const xValue = navigator.getGamepads()[index].axes[0]
+ const xValue = navigator.getGamepads()[index].axes[0];
if (xValue < 0.3 && xValue > -0.3) {
return true;
}
diff --git a/javascript/localization.js b/javascript/localization.js
index 86e5ca67..eb22b8a7 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -1,177 +1,176 @@
-
-// localization = {} -- the dict with translations is created by the backend
-
-ignore_ids_for_localization={
- setting_sd_hypernetwork: 'OPTION',
- setting_sd_model_checkpoint: 'OPTION',
- setting_realesrgan_enabled_models: 'OPTION',
- modelmerger_primary_model_name: 'OPTION',
- modelmerger_secondary_model_name: 'OPTION',
- modelmerger_tertiary_model_name: 'OPTION',
- train_embedding: 'OPTION',
- train_hypernetwork: 'OPTION',
- txt2img_styles: 'OPTION',
- img2img_styles: 'OPTION',
- setting_random_artist_categories: 'SPAN',
- setting_face_restoration_model: 'SPAN',
- setting_realesrgan_enabled_models: 'SPAN',
- extras_upscaler_1: 'SPAN',
- extras_upscaler_2: 'SPAN',
-}
-
-re_num = /^[\.\d]+$/
-re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
-
-original_lines = {}
-translated_lines = {}
-
-function hasLocalization() {
- return window.localization && Object.keys(window.localization).length > 0;
-}
-
-function textNodesUnder(el){
- var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
- while(n=walk.nextNode()) a.push(n);
- return a;
-}
-
-function canBeTranslated(node, text){
- if(! text) return false;
- if(! node.parentElement) return false;
-
- var parentType = node.parentElement.nodeName
- if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
-
- if (parentType=='OPTION' || parentType=='SPAN'){
- var pnode = node
- for(var level=0; level<4; level++){
- pnode = pnode.parentElement
- if(! pnode) break;
-
- if(ignore_ids_for_localization[pnode.id] == parentType) return false;
- }
- }
-
- if(re_num.test(text)) return false;
- if(re_emoji.test(text)) return false;
- return true
-}
-
-function getTranslation(text){
- if(! text) return undefined
-
- if(translated_lines[text] === undefined){
- original_lines[text] = 1
- }
-
- tl = localization[text]
- if(tl !== undefined){
- translated_lines[tl] = 1
- }
-
- return tl
-}
-
-function processTextNode(node){
- var text = node.textContent.trim()
-
- if(! canBeTranslated(node, text)) return
-
- tl = getTranslation(text)
- if(tl !== undefined){
- node.textContent = tl
- }
-}
-
-function processNode(node){
- if(node.nodeType == 3){
- processTextNode(node)
- return
- }
-
- if(node.title){
- tl = getTranslation(node.title)
- if(tl !== undefined){
- node.title = tl
- }
- }
-
- if(node.placeholder){
- tl = getTranslation(node.placeholder)
- if(tl !== undefined){
- node.placeholder = tl
- }
- }
-
- textNodesUnder(node).forEach(function(node){
- processTextNode(node)
- })
-}
-
-function dumpTranslations(){
- if(!hasLocalization()) {
- // If we don't have any localization,
- // we will not have traversed the app to find
- // original_lines, so do that now.
- processNode(gradioApp());
- }
- var dumped = {}
- if (localization.rtl) {
- dumped.rtl = true;
- }
-
- for (const text in original_lines) {
- if(dumped[text] !== undefined) continue;
- dumped[text] = localization[text] || text;
- }
-
- return dumped;
-}
-
-function download_localization() {
- var text = JSON.stringify(dumpTranslations(), null, 4)
-
- var element = document.createElement('a');
- element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
- element.setAttribute('download', "localization.json");
- element.style.display = 'none';
- document.body.appendChild(element);
-
- element.click();
-
- document.body.removeChild(element);
-}
-
-document.addEventListener("DOMContentLoaded", function () {
- if (!hasLocalization()) {
- return;
- }
-
- onUiUpdate(function (m) {
- m.forEach(function (mutation) {
- mutation.addedNodes.forEach(function (node) {
- processNode(node)
- })
- });
- })
-
- processNode(gradioApp())
-
- if (localization.rtl) { // if the language is from right to left,
- (new MutationObserver((mutations, observer) => { // wait for the style to load
- mutations.forEach(mutation => {
- mutation.addedNodes.forEach(node => {
- if (node.tagName === 'STYLE') {
- observer.disconnect();
-
- for (const x of node.sheet.rules) { // find all rtl media rules
- if (Array.from(x.media || []).includes('rtl')) {
- x.media.appendMedium('all'); // enable them
- }
- }
- }
- })
- });
- })).observe(gradioApp(), { childList: true });
- }
-})
+
+// localization = {} -- the dict with translations is created by the backend
+
+var ignore_ids_for_localization = {
+ setting_sd_hypernetwork: 'OPTION',
+ setting_sd_model_checkpoint: 'OPTION',
+ modelmerger_primary_model_name: 'OPTION',
+ modelmerger_secondary_model_name: 'OPTION',
+ modelmerger_tertiary_model_name: 'OPTION',
+ train_embedding: 'OPTION',
+ train_hypernetwork: 'OPTION',
+ txt2img_styles: 'OPTION',
+ img2img_styles: 'OPTION',
+ setting_random_artist_categories: 'SPAN',
+ setting_face_restoration_model: 'SPAN',
+ setting_realesrgan_enabled_models: 'SPAN',
+ extras_upscaler_1: 'SPAN',
+ extras_upscaler_2: 'SPAN',
+};
+
+var re_num = /^[.\d]+$/;
+var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u;
+
+var original_lines = {};
+var translated_lines = {};
+
+function hasLocalization() {
+ return window.localization && Object.keys(window.localization).length > 0;
+}
+
+function textNodesUnder(el) {
+ var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false);
+ while ((n = walk.nextNode())) a.push(n);
+ return a;
+}
+
+function canBeTranslated(node, text) {
+ if (!text) return false;
+ if (!node.parentElement) return false;
+
+ var parentType = node.parentElement.nodeName;
+ if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false;
+
+ if (parentType == 'OPTION' || parentType == 'SPAN') {
+ var pnode = node;
+ for (var level = 0; level < 4; level++) {
+ pnode = pnode.parentElement;
+ if (!pnode) break;
+
+ if (ignore_ids_for_localization[pnode.id] == parentType) return false;
+ }
+ }
+
+ if (re_num.test(text)) return false;
+ if (re_emoji.test(text)) return false;
+ return true;
+}
+
+function getTranslation(text) {
+ if (!text) return undefined;
+
+ if (translated_lines[text] === undefined) {
+ original_lines[text] = 1;
+ }
+
+ var tl = localization[text];
+ if (tl !== undefined) {
+ translated_lines[tl] = 1;
+ }
+
+ return tl;
+}
+
+function processTextNode(node) {
+ var text = node.textContent.trim();
+
+ if (!canBeTranslated(node, text)) return;
+
+ var tl = getTranslation(text);
+ if (tl !== undefined) {
+ node.textContent = tl;
+ }
+}
+
+function processNode(node) {
+ if (node.nodeType == 3) {
+ processTextNode(node);
+ return;
+ }
+
+ if (node.title) {
+ let tl = getTranslation(node.title);
+ if (tl !== undefined) {
+ node.title = tl;
+ }
+ }
+
+ if (node.placeholder) {
+ let tl = getTranslation(node.placeholder);
+ if (tl !== undefined) {
+ node.placeholder = tl;
+ }
+ }
+
+ textNodesUnder(node).forEach(function(node) {
+ processTextNode(node);
+ });
+}
+
+function dumpTranslations() {
+ if (!hasLocalization()) {
+ // If we don't have any localization,
+ // we will not have traversed the app to find
+ // original_lines, so do that now.
+ processNode(gradioApp());
+ }
+ var dumped = {};
+ if (localization.rtl) {
+ dumped.rtl = true;
+ }
+
+ for (const text in original_lines) {
+ if (dumped[text] !== undefined) continue;
+ dumped[text] = localization[text] || text;
+ }
+
+ return dumped;
+}
+
+function download_localization() {
+ var text = JSON.stringify(dumpTranslations(), null, 4);
+
+ var element = document.createElement('a');
+ element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
+ element.setAttribute('download', "localization.json");
+ element.style.display = 'none';
+ document.body.appendChild(element);
+
+ element.click();
+
+ document.body.removeChild(element);
+}
+
+document.addEventListener("DOMContentLoaded", function() {
+ if (!hasLocalization()) {
+ return;
+ }
+
+ onUiUpdate(function(m) {
+ m.forEach(function(mutation) {
+ mutation.addedNodes.forEach(function(node) {
+ processNode(node);
+ });
+ });
+ });
+
+ processNode(gradioApp());
+
+ if (localization.rtl) { // if the language is from right to left,
+ (new MutationObserver((mutations, observer) => { // wait for the style to load
+ mutations.forEach(mutation => {
+ mutation.addedNodes.forEach(node => {
+ if (node.tagName === 'STYLE') {
+ observer.disconnect();
+
+ for (const x of node.sheet.rules) { // find all rtl media rules
+ if (Array.from(x.media || []).includes('rtl')) {
+ x.media.appendMedium('all'); // enable them
+ }
+ }
+ }
+ });
+ });
+ })).observe(gradioApp(), {childList: true});
+ }
+});
diff --git a/javascript/notification.js b/javascript/notification.js
index 83fce1f8..a68a76f2 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -4,14 +4,14 @@ let lastHeadImg = null;
let notificationButton = null;
-onUiUpdate(function(){
- if(notificationButton == null){
- notificationButton = gradioApp().getElementById('request_notifications')
+onUiUpdate(function() {
+ if (notificationButton == null) {
+ notificationButton = gradioApp().getElementById('request_notifications');
- if(notificationButton != null){
+ if (notificationButton != null) {
notificationButton.addEventListener('click', () => {
void Notification.requestPermission();
- },true);
+ }, true);
}
}
@@ -42,7 +42,7 @@ onUiUpdate(function(){
}
);
- notification.onclick = function(_){
+ notification.onclick = function(_) {
parent.focus();
this.close();
};
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 8d2c3492..29299787 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -1,29 +1,29 @@
// code related to showing and updating progressbar shown as the image is being made
-function rememberGallerySelection(){
+function rememberGallerySelection() {
}
-function getGallerySelectedIndex(){
+function getGallerySelectedIndex() {
}
-function request(url, data, handler, errorHandler){
+function request(url, data, handler, errorHandler) {
var xhr = new XMLHttpRequest();
xhr.open("POST", url, true);
xhr.setRequestHeader("Content-Type", "application/json");
- xhr.onreadystatechange = function () {
+ xhr.onreadystatechange = function() {
if (xhr.readyState === 4) {
if (xhr.status === 200) {
try {
var js = JSON.parse(xhr.responseText);
- handler(js)
+ handler(js);
} catch (error) {
console.error(error);
- errorHandler()
+ errorHandler();
}
- } else{
- errorHandler()
+ } else {
+ errorHandler();
}
}
};
@@ -31,147 +31,147 @@ function request(url, data, handler, errorHandler){
xhr.send(js);
}
-function pad2(x){
- return x<10 ? '0'+x : x
+function pad2(x) {
+ return x < 10 ? '0' + x : x;
}
-function formatTime(secs){
- if(secs > 3600){
- return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
- } else if(secs > 60){
- return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
- } else{
- return Math.floor(secs) + "s"
+function formatTime(secs) {
+ if (secs > 3600) {
+ return pad2(Math.floor(secs / 60 / 60)) + ":" + pad2(Math.floor(secs / 60) % 60) + ":" + pad2(Math.floor(secs) % 60);
+ } else if (secs > 60) {
+ return pad2(Math.floor(secs / 60)) + ":" + pad2(Math.floor(secs) % 60);
+ } else {
+ return Math.floor(secs) + "s";
}
}
-function setTitle(progress){
- var title = 'Stable Diffusion'
+function setTitle(progress) {
+ var title = 'Stable Diffusion';
- if(opts.show_progress_in_title && progress){
+ if (opts.show_progress_in_title && progress) {
title = '[' + progress.trim() + '] ' + title;
}
- if(document.title != title){
- document.title = title;
+ if (document.title != title) {
+ document.title = title;
}
}
-function randomId(){
- return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
+function randomId() {
+ return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + ")";
}
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
// calls onProgress every time there is a progress update
-function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){
- var dateStart = new Date()
- var wasEverActive = false
- var parentProgressbar = progressbarContainer.parentNode
- var parentGallery = gallery ? gallery.parentNode : null
+function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout = 40) {
+ var dateStart = new Date();
+ var wasEverActive = false;
+ var parentProgressbar = progressbarContainer.parentNode;
+ var parentGallery = gallery ? gallery.parentNode : null;
- var divProgress = document.createElement('div')
- divProgress.className='progressDiv'
- divProgress.style.display = opts.show_progressbar ? "block" : "none"
- var divInner = document.createElement('div')
- divInner.className='progress'
+ var divProgress = document.createElement('div');
+ divProgress.className = 'progressDiv';
+ divProgress.style.display = opts.show_progressbar ? "block" : "none";
+ var divInner = document.createElement('div');
+ divInner.className = 'progress';
- divProgress.appendChild(divInner)
- parentProgressbar.insertBefore(divProgress, progressbarContainer)
+ divProgress.appendChild(divInner);
+ parentProgressbar.insertBefore(divProgress, progressbarContainer);
- if(parentGallery){
- var livePreview = document.createElement('div')
- livePreview.className='livePreview'
- parentGallery.insertBefore(livePreview, gallery)
+ if (parentGallery) {
+ var livePreview = document.createElement('div');
+ livePreview.className = 'livePreview';
+ parentGallery.insertBefore(livePreview, gallery);
}
- var removeProgressBar = function(){
- setTitle("")
- parentProgressbar.removeChild(divProgress)
- if(parentGallery) parentGallery.removeChild(livePreview)
- atEnd()
- }
+ var removeProgressBar = function() {
+ setTitle("");
+ parentProgressbar.removeChild(divProgress);
+ if (parentGallery) parentGallery.removeChild(livePreview);
+ atEnd();
+ };
- var fun = function(id_task, id_live_preview){
- request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
- if(res.completed){
- removeProgressBar()
- return
+ var fun = function(id_task, id_live_preview) {
+ request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
+ if (res.completed) {
+ removeProgressBar();
+ return;
}
- var rect = progressbarContainer.getBoundingClientRect()
+ var rect = progressbarContainer.getBoundingClientRect();
- if(rect.width){
+ if (rect.width) {
divProgress.style.width = rect.width + "px";
}
- let progressText = ""
+ let progressText = "";
- divInner.style.width = ((res.progress || 0) * 100.0) + '%'
- divInner.style.background = res.progress ? "" : "transparent"
+ divInner.style.width = ((res.progress || 0) * 100.0) + '%';
+ divInner.style.background = res.progress ? "" : "transparent";
- if(res.progress > 0){
- progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
+ if (res.progress > 0) {
+ progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%';
}
- if(res.eta){
- progressText += " ETA: " + formatTime(res.eta)
+ if (res.eta) {
+ progressText += " ETA: " + formatTime(res.eta);
}
- setTitle(progressText)
+ setTitle(progressText);
- if(res.textinfo && res.textinfo.indexOf("\n") == -1){
- progressText = res.textinfo + " " + progressText
+ if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
+ progressText = res.textinfo + " " + progressText;
}
- divInner.textContent = progressText
+ divInner.textContent = progressText;
- var elapsedFromStart = (new Date() - dateStart) / 1000
+ var elapsedFromStart = (new Date() - dateStart) / 1000;
- if(res.active) wasEverActive = true;
+ if (res.active) wasEverActive = true;
- if(! res.active && wasEverActive){
- removeProgressBar()
- return
+ if (!res.active && wasEverActive) {
+ removeProgressBar();
+ return;
}
- if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){
- removeProgressBar()
- return
+ if (elapsedFromStart > inactivityTimeout && !res.queued && !res.active) {
+ removeProgressBar();
+ return;
}
- if(res.live_preview && gallery){
- var rect = gallery.getBoundingClientRect()
- if(rect.width){
- livePreview.style.width = rect.width + "px"
- livePreview.style.height = rect.height + "px"
+ if (res.live_preview && gallery) {
+ rect = gallery.getBoundingClientRect();
+ if (rect.width) {
+ livePreview.style.width = rect.width + "px";
+ livePreview.style.height = rect.height + "px";
}
var img = new Image();
img.onload = function() {
- livePreview.appendChild(img)
- if(livePreview.childElementCount > 2){
- livePreview.removeChild(livePreview.firstElementChild)
+ livePreview.appendChild(img);
+ if (livePreview.childElementCount > 2) {
+ livePreview.removeChild(livePreview.firstElementChild);
}
- }
+ };
img.src = res.live_preview;
}
- if(onProgress){
- onProgress(res)
+ if (onProgress) {
+ onProgress(res);
}
setTimeout(() => {
fun(id_task, res.id_live_preview);
- }, opts.live_preview_refresh_period || 500)
- }, function(){
- removeProgressBar()
- })
- }
+ }, opts.live_preview_refresh_period || 500);
+ }, function() {
+ removeProgressBar();
+ });
+ };
- fun(id_task, 0)
+ fun(id_task, 0);
}
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js
index 0354b860..37e3d075 100644
--- a/javascript/textualInversion.js
+++ b/javascript/textualInversion.js
@@ -1,17 +1,17 @@
-
-
-
-function start_training_textual_inversion(){
- gradioApp().querySelector('#ti_error').innerHTML=''
-
- var id = randomId()
- requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
- gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
- })
-
- var res = args_to_array(arguments)
-
- res[0] = id
-
- return res
-}
+
+
+
+function start_training_textual_inversion() {
+ gradioApp().querySelector('#ti_error').innerHTML = '';
+
+ var id = randomId();
+ requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) {
+ gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo;
+ });
+
+ var res = args_to_array(arguments);
+
+ res[0] = id;
+
+ return res;
+}
diff --git a/javascript/ui.js b/javascript/ui.js
index ed9673d6..c7316ddb 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -1,9 +1,9 @@
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
-function set_theme(theme){
- var gradioURL = window.location.href
+function set_theme(theme) {
+ var gradioURL = window.location.href;
if (!gradioURL.includes('?__theme=')) {
- window.location.replace(gradioURL + '?__theme=' + theme);
+ window.location.replace(gradioURL + '?__theme=' + theme);
}
}
@@ -14,7 +14,7 @@ function all_gallery_buttons() {
if (elem.parentElement.offsetParent) {
visibleGalleryButtons.push(elem);
}
- })
+ });
return visibleGalleryButtons;
}
@@ -25,31 +25,35 @@ function selected_gallery_button() {
if (elem.parentElement.offsetParent) {
visibleCurrentButton = elem;
}
- })
+ });
return visibleCurrentButton;
}
-function selected_gallery_index(){
+function selected_gallery_index() {
var buttons = all_gallery_buttons();
var button = selected_gallery_button();
- var result = -1
- buttons.forEach(function(v, i){ if(v==button) { result = i } })
+ var result = -1;
+ buttons.forEach(function(v, i) {
+ if (v == button) {
+ result = i;
+ }
+ });
- return result
+ return result;
}
-function extract_image_from_gallery(gallery){
- if (gallery.length == 0){
+function extract_image_from_gallery(gallery) {
+ if (gallery.length == 0) {
return [null];
}
- if (gallery.length == 1){
+ if (gallery.length == 1) {
return [gallery[0]];
}
- var index = selected_gallery_index()
+ var index = selected_gallery_index();
- if (index < 0 || index >= gallery.length){
+ if (index < 0 || index >= gallery.length) {
// Use the first image in the gallery as the default
index = 0;
}
@@ -57,249 +61,242 @@ function extract_image_from_gallery(gallery){
return [gallery[index]];
}
-function args_to_array(args){
- var res = []
- for(var i=0;i label > textarea");
- if(counter.parentElement == prompt.parentElement){
- return
+ if (counter.parentElement == prompt.parentElement) {
+ return;
}
- prompt.parentElement.insertBefore(counter, prompt)
- prompt.parentElement.style.position = "relative"
+ prompt.parentElement.insertBefore(counter, prompt);
+ prompt.parentElement.style.position = "relative";
- promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
- textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
+ promptTokecountUpdateFuncs[id] = function() {
+ update_token_counter(id_button);
+ };
+ textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
}
- registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
- registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button')
- registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
- registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
+ registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
+ registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
+ registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
+ registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
- var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
- var settings_tabs = gradioApp().querySelector('#settings div')
- if(show_all_pages && settings_tabs){
- settings_tabs.appendChild(show_all_pages)
- show_all_pages.onclick = function(){
- gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
- if(elem.id == "settings_tab_licenses")
+ var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
+ var settings_tabs = gradioApp().querySelector('#settings div');
+ if (show_all_pages && settings_tabs) {
+ settings_tabs.appendChild(show_all_pages);
+ show_all_pages.onclick = function() {
+ gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
+ if (elem.id == "settings_tab_licenses") {
return;
+ }
elem.style.display = "block";
- })
- }
+ });
+ };
}
-})
+});
-onOptionsChanged(function(){
- var elem = gradioApp().getElementById('sd_checkpoint_hash')
- var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
- var shorthash = sd_checkpoint_hash.substring(0,10)
+onOptionsChanged(function() {
+ var elem = gradioApp().getElementById('sd_checkpoint_hash');
+ var sd_checkpoint_hash = opts.sd_checkpoint_hash || "";
+ var shorthash = sd_checkpoint_hash.substring(0, 10);
- if(elem && elem.textContent != shorthash){
- elem.textContent = shorthash
- elem.title = sd_checkpoint_hash
- elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
- }
-})
+ if (elem && elem.textContent != shorthash) {
+ elem.textContent = shorthash;
+ elem.title = sd_checkpoint_hash;
+ elem.href = "https://google.com/search?q=" + sd_checkpoint_hash;
+ }
+});
let txt2img_textarea, img2img_textarea = undefined;
-let wait_time = 800
+let wait_time = 800;
let token_timeouts = {};
function update_txt2img_tokens(...args) {
- update_token_counter("txt2img_token_button")
- if (args.length == 2)
- return args[0]
- return args;
+ update_token_counter("txt2img_token_button");
+ if (args.length == 2) {
+ return args[0];
+ }
+ return args;
}
function update_img2img_tokens(...args) {
- update_token_counter("img2img_token_button")
- if (args.length == 2)
- return args[0]
- return args;
+ update_token_counter(
+ "img2img_token_button"
+ );
+ if (args.length == 2) {
+ return args[0];
+ }
+ return args;
}
function update_token_counter(button_id) {
- if (token_timeouts[button_id])
- clearTimeout(token_timeouts[button_id]);
- token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
+ if (token_timeouts[button_id]) {
+ clearTimeout(token_timeouts[button_id]);
+ }
+ token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
-function restart_reload(){
- document.body.innerHTML='Reloading... ';
+function restart_reload() {
+ document.body.innerHTML = 'Reloading... ';
- var requestPing = function(){
- requestGet("./internal/ping", {}, function(data){
+ var requestPing = function() {
+ requestGet("./internal/ping", {}, function(data) {
location.reload();
- }, function(){
+ }, function() {
setTimeout(requestPing, 500);
- })
- }
+ });
+ };
setTimeout(requestPing, 2000);
- return []
+ return [];
}
// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
// will only visible on web page and not sent to python.
-function updateInput(target){
- let e = new Event("input", { bubbles: true })
- Object.defineProperty(e, "target", {value: target})
- target.dispatchEvent(e);
+function updateInput(target) {
+ let e = new Event("input", {bubbles: true});
+ Object.defineProperty(e, "target", {value: target});
+ target.dispatchEvent(e);
}
var desiredCheckpointName = null;
-function selectCheckpoint(name){
+function selectCheckpoint(name) {
desiredCheckpointName = name;
- gradioApp().getElementById('change_checkpoint').click()
+ gradioApp().getElementById('change_checkpoint').click();
}
-function currentImg2imgSourceResolution(_, _, scaleBy){
- var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img')
- return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]
+function currentImg2imgSourceResolution(w, h, scaleBy) {
+ var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img');
+ return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy];
}
-function updateImg2imgResizeToTextAfterChangingImage(){
+function updateImg2imgResizeToTextAfterChangingImage() {
// At the time this is called from gradio, the image has no yet been replaced.
// There may be a better solution, but this is simple and straightforward so I'm going with it.
setTimeout(function() {
- gradioApp().getElementById('img2img_update_resize_to').click()
+ gradioApp().getElementById('img2img_update_resize_to').click();
}, 500);
- return []
+ return [];
+
+}
+
+
+
+function setRandomSeed(elem_id) {
+ var input = gradioApp().querySelector("#" + elem_id + " input");
+ if (!input) return [];
+
+ input.value = "-1";
+ updateInput(input);
+ return [];
+}
+
+function switchWidthHeight(tabname) {
+ var width = gradioApp().querySelector("#" + tabname + "_width input[type=number]");
+ var height = gradioApp().querySelector("#" + tabname + "_height input[type=number]");
+ if (!width || !height) return [];
+
+ var tmp = width.value;
+ width.value = height.value;
+ height.value = tmp;
+
+ updateInput(width);
+ updateInput(height);
+ return [];
}
diff --git a/javascript/ui_settings_hints.js b/javascript/ui_settings_hints.js
index 87a289d3..e216852b 100644
--- a/javascript/ui_settings_hints.js
+++ b/javascript/ui_settings_hints.js
@@ -1,41 +1,62 @@
-// various hints and extra info for the settings tab
-
-onUiLoaded(function(){
- createLink = function(elem_id, text, href){
- var a = document.createElement('A')
- a.textContent = text
- a.target = '_blank';
-
- elem = gradioApp().querySelector('#'+elem_id)
- elem.insertBefore(a, elem.querySelector('label'))
-
- return a
- }
-
- createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
- createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
-
- createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){
- requestGet("./internal/quicksettings-hint", {}, function(data){
- var table = document.createElement('table')
- table.className = 'settings-value-table'
-
- data.forEach(function(obj){
- var tr = document.createElement('tr')
- var td = document.createElement('td')
- td.textContent = obj.name
- tr.appendChild(td)
-
- var td = document.createElement('td')
- td.textContent = obj.label
- tr.appendChild(td)
-
- table.appendChild(tr)
- })
-
- popup(table);
- })
- });
-})
-
-
+// various hints and extra info for the settings tab
+
+var settingsHintsSetup = false;
+
+onOptionsChanged(function() {
+ if (settingsHintsSetup) return;
+ settingsHintsSetup = true;
+
+ gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) {
+ var name = div.id.substr(8);
+ var commentBefore = opts._comments_before[name];
+ var commentAfter = opts._comments_after[name];
+
+ if (!commentBefore && !commentAfter) return;
+
+ var span = null;
+ if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span');
+ else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild;
+ else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild;
+ else span = div.querySelector('label span').firstChild;
+
+ if (!span) return;
+
+ if (commentBefore) {
+ var comment = document.createElement('DIV');
+ comment.className = 'settings-comment';
+ comment.innerHTML = commentBefore;
+ span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
+ span.parentElement.insertBefore(comment, span);
+ span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
+ }
+ if (commentAfter) {
+ comment = document.createElement('DIV');
+ comment.className = 'settings-comment';
+ comment.innerHTML = commentAfter;
+ span.parentElement.insertBefore(comment, span.nextSibling);
+ span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling);
+ }
+ });
+});
+
+function settingsHintsShowQuicksettings() {
+ requestGet("./internal/quicksettings-hint", {}, function(data) {
+ var table = document.createElement('table');
+ table.className = 'settings-value-table';
+
+ data.forEach(function(obj) {
+ var tr = document.createElement('tr');
+ var td = document.createElement('td');
+ td.textContent = obj.name;
+ tr.appendChild(td);
+
+ td = document.createElement('td');
+ td.textContent = obj.label;
+ tr.appendChild(td);
+
+ table.appendChild(tr);
+ });
+
+ popup(table);
+ });
+}
diff --git a/launch.py b/launch.py
index cfc0cffa..a3261c0c 100644
--- a/launch.py
+++ b/launch.py
@@ -3,25 +3,23 @@ import subprocess
import os
import sys
import importlib.util
-import shlex
import platform
import json
+from functools import lru_cache
from modules import cmd_args
from modules.paths_internal import script_path, extensions_dir
-commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
-sys.argv += shlex.split(commandline_args)
-
args, _ = cmd_args.parser.parse_known_args()
python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
-stored_commit_hash = None
-stored_git_tag = None
dir_repos = "repositories"
+# Whether to default to printing command output
+default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
+
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
@@ -57,65 +55,52 @@ Use --skip-python-version-check to suppress this warning.
""")
+@lru_cache()
def commit_hash():
- global stored_commit_hash
-
- if stored_commit_hash is not None:
- return stored_commit_hash
-
try:
- stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
+ return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
except Exception:
- stored_commit_hash = ""
-
- return stored_commit_hash
+ return ""
+@lru_cache()
def git_tag():
- global stored_git_tag
-
- if stored_git_tag is not None:
- return stored_git_tag
-
try:
- stored_git_tag = run(f"{git} describe --tags").strip()
+ return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
- stored_git_tag = ""
-
- return stored_git_tag
+ return ""
-def run(command, desc=None, errdesc=None, custom_env=None, live=False):
+def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
if desc is not None:
print(desc)
- if live:
- result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
- if result.returncode != 0:
- raise RuntimeError(f"""{errdesc or 'Error running command'}.
-Command: {command}
-Error code: {result.returncode}""")
+ run_kwargs = {
+ "args": command,
+ "shell": True,
+ "env": os.environ if custom_env is None else custom_env,
+ "encoding": 'utf8',
+ "errors": 'ignore',
+ }
- return ""
+ if not live:
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
+ result = subprocess.run(**run_kwargs)
if result.returncode != 0:
+ error_bits = [
+ f"{errdesc or 'Error running command'}.",
+ f"Command: {command}",
+ f"Error code: {result.returncode}",
+ ]
+ if result.stdout:
+ error_bits.append(f"stdout: {result.stdout}")
+ if result.stderr:
+ error_bits.append(f"stderr: {result.stderr}")
+ raise RuntimeError("\n".join(error_bits))
- message = f"""{errdesc or 'Error running command'}.
-Command: {command}
-Error code: {result.returncode}
-stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''}
-stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''}
-"""
- raise RuntimeError(message)
-
- return result.stdout.decode(encoding="utf8", errors="ignore")
-
-
-def check_run(command):
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
- return result.returncode == 0
+ return (result.stdout or "")
def is_installed(package):
@@ -131,11 +116,7 @@ def repo_dir(name):
return os.path.join(script_path, dir_repos, name)
-def run_python(code, desc=None, errdesc=None):
- return run(f'"{python}" -c "{code}"', desc, errdesc)
-
-
-def run_pip(command, desc=None, live=False):
+def run_pip(command, desc=None, live=default_command_live):
if args.skip_install:
return
@@ -143,8 +124,9 @@ def run_pip(command, desc=None, live=False):
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
-def check_run_python(code):
- return check_run(f'"{python}" -c "{code}"')
+def check_run_python(code: str) -> bool:
+ result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
+ return result.returncode == 0
def git_clone(url, dir, name, commithash=None):
@@ -237,13 +219,14 @@ def run_extensions_installers(settings_file):
def prepare_environment():
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118")
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
- gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
- clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
- openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
+ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
+ clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
+ openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
@@ -270,8 +253,11 @@ def prepare_environment():
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
- if not args.skip_torch_cuda_test:
- run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
+ if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
+ raise RuntimeError(
+ 'Torch is not able to use GPU; '
+ 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
+ )
if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan")
@@ -294,8 +280,8 @@ def prepare_environment():
elif platform.system() == "Linux":
run_pip(f"install {xformers_package}", "xformers")
- if not is_installed("pyngrok") and args.ngrok:
- run_pip("install pyngrok", "ngrok")
+ if not is_installed("ngrok") and args.ngrok:
+ run_pip("install ngrok", "ngrok")
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
@@ -319,7 +305,7 @@ def prepare_environment():
if args.update_all_extensions:
git_pull_recursive(extensions_dir)
-
+
if "--exit" in sys.argv:
print("Exiting because of --exit argument")
exit(0)
diff --git a/modules/Roboto-Regular.ttf b/modules/Roboto-Regular.ttf
new file mode 100644
index 00000000..500b1045
Binary files /dev/null and b/modules/Roboto-Regular.ttf differ
diff --git a/modules/api/api.py b/modules/api/api.py
index 9bb95dfd..eee99bbb 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -15,7 +15,8 @@ from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
-from modules.api.models import *
+from modules.api import models
+from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
@@ -25,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
-from typing import List
+from typing import Dict, List, Any
import piexif
import piexif.helper
+
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
- except:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
+
def script_name_to_index(name, scripts):
try:
return [script.title().lower() for script in scripts].index(name.lower())
- except:
- raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
+ except Exception as e:
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
+
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
@@ -48,20 +52,23 @@ def validate_sampler_name(name):
return name
+
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict
+
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
- except Exception as err:
- raise HTTPException(status_code=500, detail="Invalid encoded image")
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Invalid encoded image") from e
+
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
@@ -92,6 +99,7 @@ def encode_pil_to_base64(image):
return base64.b64encode(bytes_data)
+
def api_middleware(app: FastAPI):
rich_available = True
try:
@@ -99,7 +107,7 @@ def api_middleware(app: FastAPI):
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
- except:
+ except Exception:
import traceback
rich_available = False
@@ -157,7 +165,7 @@ def api_middleware(app: FastAPI):
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
if shared.cmd_opts.api_auth:
- self.credentials = dict()
+ self.credentials = {}
for auth in shared.cmd_opts.api_auth.split(","):
user, password = auth.split(":")
self.credentials[user] = password
@@ -166,36 +174,37 @@ class Api:
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
- self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
- self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
- self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
- self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
- self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
- self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
+ self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
+ self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
+ self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
+ self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
+ self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
- self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+ self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
- self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
- self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
- self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
- self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
- self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
- self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
- self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
- self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
- self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
+ self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
- self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
- self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
- self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
- self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
- self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
- self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+ self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
+ self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
+ self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
+ self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
+ self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
+ self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
- self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
+ self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -219,17 +228,25 @@ class Api:
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx]
return script, script_idx
-
- def get_scripts_list(self):
- t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
- i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
- return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
+ def get_scripts_list(self):
+ t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
+ i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
+
+ return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
+
+ def get_script_info(self):
+ res = []
+
+ for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
+ res += [script.api_info for script in script_list if script.api_info is not None]
+
+ return res
def get_script(self, script_name, script_runner):
if script_name is None or script_name == "":
return None, None
-
+
script_idx = script_name_to_index(script_name, script_runner.scripts)
return script_runner.scripts[script_idx]
@@ -264,11 +281,11 @@ class Api:
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
for alwayson_script_name in request.alwayson_scripts.keys():
alwayson_script = self.get_script(alwayson_script_name, script_runner)
- if alwayson_script == None:
+ if alwayson_script is None:
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
# Selectable script in always on script param check
- if alwayson_script.alwayson == False:
- raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
+ if alwayson_script.alwayson is False:
+ raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
# always on script with no arg should always run so you don't really need to add them to the requests
if "args" in request.alwayson_scripts[alwayson_script_name]:
# min between arg length in scriptrunner and arg length in the request
@@ -276,7 +293,7 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
- def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False)
@@ -310,7 +327,7 @@ class Api:
p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin()
- if selectable_scripts != None:
+ if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
@@ -320,9 +337,9 @@ class Api:
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
- return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
+ return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
- def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -367,7 +384,7 @@ class Api:
p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin()
- if selectable_scripts != None:
+ if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
@@ -381,9 +398,9 @@ class Api:
img2imgreq.init_images = None
img2imgreq.mask = None
- return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
+ return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
- def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+ def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
reqDict = setUpscalers(req)
reqDict['image'] = decode_base64_to_image(reqDict['image'])
@@ -391,9 +408,9 @@ class Api:
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
- return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
+ return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
- def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
image_list = reqDict.pop('imageList', [])
@@ -402,15 +419,15 @@ class Api:
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
- return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
+ return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
- def pnginfoapi(self, req: PNGInfoRequest):
+ def pnginfoapi(self, req: models.PNGInfoRequest):
if(not req.image.strip()):
- return PNGInfoResponse(info="")
+ return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip())
if image is None:
- return PNGInfoResponse(info="")
+ return models.PNGInfoResponse(info="")
geninfo, items = images.read_info_from_image(image)
if geninfo is None:
@@ -418,13 +435,13 @@ class Api:
items = {**{'parameters': geninfo}, **items}
- return PNGInfoResponse(info=geninfo, items=items)
+ return models.PNGInfoResponse(info=geninfo, items=items)
- def progressapi(self, req: ProgressRequest = Depends()):
+ def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
- return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
+ return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
# avoid dividing zero
progress = 0.01
@@ -446,9 +463,9 @@ class Api:
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
+ return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
- def interrogateapi(self, interrogatereq: InterrogateRequest):
+ def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
@@ -465,7 +482,7 @@ class Api:
else:
raise HTTPException(status_code=404, detail="Model not found")
- return InterrogateResponse(caption=processed)
+ return models.InterrogateResponse(caption=processed)
def interruptapi(self):
shared.state.interrupt()
@@ -570,36 +587,36 @@ class Api:
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
- return CreateResponse(info=f"create embedding filename: {filename}")
+ return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
shared.state.end()
- return TrainResponse(info=f"create embedding error: {e}")
+ return models.TrainResponse(info=f"create embedding error: {e}")
def create_hypernetwork(self, args: dict):
try:
shared.state.begin()
filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
- return CreateResponse(info=f"create hypernetwork filename: {filename}")
+ return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e:
shared.state.end()
- return TrainResponse(info=f"create hypernetwork error: {e}")
+ return models.TrainResponse(info=f"create hypernetwork error: {e}")
def preprocess(self, args: dict):
try:
shared.state.begin()
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
- return PreprocessResponse(info = 'preprocess complete')
+ return models.PreprocessResponse(info = 'preprocess complete')
except KeyError as e:
shared.state.end()
- return PreprocessResponse(info=f"preprocess error: invalid token: {e}")
+ return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except AssertionError as e:
shared.state.end()
- return PreprocessResponse(info=f"preprocess error: {e}")
+ return models.PreprocessResponse(info=f"preprocess error: {e}")
except FileNotFoundError as e:
shared.state.end()
- return PreprocessResponse(info=f'preprocess error: {e}')
+ return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict):
try:
@@ -617,10 +634,10 @@ class Api:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
- return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
+ return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg:
shared.state.end()
- return TrainResponse(info=f"train embedding error: {msg}")
+ return models.TrainResponse(info=f"train embedding error: {msg}")
def train_hypernetwork(self, args: dict):
try:
@@ -641,14 +658,15 @@ class Api:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
- return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError as msg:
+ return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
+ except AssertionError:
shared.state.end()
- return TrainResponse(info=f"train embedding error: {error}")
+ return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self):
try:
- import os, psutil
+ import os
+ import psutil
process = psutil.Process(os.getpid())
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
@@ -675,10 +693,10 @@ class Api:
'events': warnings,
}
else:
- cuda = { 'error': 'unavailable' }
+ cuda = {'error': 'unavailable'}
except Exception as err:
- cuda = { 'error': f'{err}' }
- return MemoryResponse(ram = ram, cuda = cuda)
+ cuda = {'error': f'{err}'}
+ return models.MemoryResponse(ram=ram, cuda=cuda)
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index 4a70f440..1ff2fb33 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -223,8 +223,9 @@ for key in _options:
if(_options[key].dest != 'help'):
flag = _options[key]
_type = str
- if _options[key].default is not None: _type = type(_options[key].default)
- flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+ if _options[key].default is not None:
+ _type = type(_options[key].default)
+ flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags)
@@ -286,6 +287,23 @@ class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
+
class ScriptsList(BaseModel):
- txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
- img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
\ No newline at end of file
+ txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
+ img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
+
+
+class ScriptArg(BaseModel):
+ label: str = Field(default=None, title="Label", description="Name of the argument in UI")
+ value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
+ minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
+ maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
+ step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
+ choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+
+
+class ScriptInfo(BaseModel):
+ name: str = Field(default=None, title="Name", description="Script name")
+ is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
+ is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
+ args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index d906a571..a533a454 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -1,6 +1,7 @@
import argparse
+import json
import os
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
parser = argparse.ArgumentParser()
@@ -39,7 +40,8 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
-parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
+parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
+parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
@@ -102,4 +104,5 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
-parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
\ No newline at end of file
+parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
+parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
index 11dcc3ee..12db6814 100644
--- a/modules/codeformer/codeformer_arch.py
+++ b/modules/codeformer/codeformer_arch.py
@@ -1,14 +1,12 @@
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
import math
-import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
-from typing import Optional, List
+from typing import Optional
-from modules.codeformer.vqgan_arch import *
-from basicsr.utils import get_root_logger
+from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5):
@@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
-
+
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
@@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
- def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
- connect_list=['32', '64', '128', '256'],
- fix_modules=['quantize','generator']):
+ connect_list=('32', '64', '128', '256'),
+ fix_modules=('quantize', 'generator')):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None:
@@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False))
-
+
self.channels = {
'16': 512,
'32': 256,
@@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
- x = block(x)
+ x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
@@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
- x = block(x)
+ x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
# logits doesn't need softmax before cross_entropy loss
- return out, logits, lq_feat
\ No newline at end of file
+ return out, logits, lq_feat
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
index e7293683..09ee6660 100644
--- a/modules/codeformer/vqgan_arch.py
+++ b/modules/codeformer/vqgan_arch.py
@@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-
+
@torch.jit.script
def swish(x):
@@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
- q = q.permute(0, 2, 1)
+ q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
- w_ = torch.bmm(q, k)
+ w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
- w_ = w_.permute(0, 2, 1)
+ w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
@@ -272,18 +270,18 @@ class Encoder(nn.Module):
def forward(self, x):
for block in self.blocks:
x = block(x)
-
+
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
- self.nf = nf
- self.ch_mult = ch_mult
+ self.nf = nf
+ self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
- self.resolution = img_size
+ self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
@@ -317,29 +315,29 @@ class Generator(nn.Module):
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
-
+
def forward(self, x):
for block in self.blocks:
x = block(x)
-
+
return x
-
+
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
- def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
- self.in_channels = 3
- self.nf = nf
- self.n_blocks = res_blocks
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
- self.attn_resolutions = attn_resolutions
+ self.attn_resolutions = attn_resolutions or [16]
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
@@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
self.kl_weight
)
self.generator = Generator(
- self.nf,
+ self.nf,
self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
self.attn_resolutions
)
@@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
raise ValueError('Wrong params!')
def forward(self, x):
- return self.main(x)
\ No newline at end of file
+ return self.main(x)
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 8d84bbc9..ececdbae 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -33,11 +33,9 @@ def setup_model(dirname):
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
- from basicsr.utils.download_util import load_file_from_url
- from basicsr.utils import imwrite, img2tensor, tensor2img
+ from basicsr.utils import img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
- from modules.shared import cmd_opts
net_class = CodeFormer
@@ -96,7 +94,7 @@ def setup_model(dirname):
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
- for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+ for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
diff --git a/modules/config_states.py b/modules/config_states.py
index 2ea00929..db65bcdb 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -14,7 +14,7 @@ from collections import OrderedDict
import git
from modules import shared, extensions
-from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
+from modules.paths_internal import script_path, config_states_dir
all_config_states = OrderedDict()
@@ -35,7 +35,7 @@ def list_config_states():
j["filepath"] = path
config_states.append(j)
- config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
+ config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states:
timestamp = time.asctime(time.gmtime(cs["created_at"]))
@@ -83,6 +83,8 @@ def get_extension_config():
ext_config = {}
for ext in extensions.extensions:
+ ext.read_info_from_repo()
+
entry = {
"name": ext.name,
"path": ext.path,
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 122fce7f..547e1b4c 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -2,7 +2,6 @@ import os
import re
import torch
-from PIL import Image
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
@@ -79,7 +78,7 @@ class DeepDanbooru:
res = []
- filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
+ filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag]
diff --git a/modules/devices.py b/modules/devices.py
index c705a3cb..d8a34a0f 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -65,7 +65,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+ if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index f4369257..a009eb42 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -6,7 +6,7 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
-from modules import shared, modelloader, images, devices
+from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
@@ -16,9 +16,7 @@ def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
+ items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
@@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
+ items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
index 6071fea7..2b9888ba 100644
--- a/modules/esrgan_model_arch.py
+++ b/modules/esrgan_model_arch.py
@@ -2,7 +2,6 @@
from collections import OrderedDict
import math
-import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
"""
@@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
x = x + sampled_noise
- return x
+ return x
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
@@ -438,9 +437,11 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
padding = padding if pad_type == 'zero' else 0
if convtype=='PartialConv2D':
+ from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='DeformConv2D':
+ from torchvision.ops import DeformConv2d # not tested
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='Conv3D':
diff --git a/modules/extensions.py b/modules/extensions.py
index 34d9d654..359a7aa5 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,13 +1,12 @@
import os
import sys
+import threading
import traceback
-import time
-from datetime import datetime
import git
from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
@@ -25,6 +24,8 @@ def active():
class Extension:
+ lock = threading.Lock()
+
def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
self.path = path
@@ -43,8 +44,13 @@ class Extension:
if self.is_builtin or self.have_info_from_repo:
return
- self.have_info_from_repo = True
+ with self.lock:
+ if self.have_info_from_repo:
+ return
+ self.do_read_info_from_repo()
+
+ def do_read_info_from_repo(self):
repo = None
try:
if os.path.exists(os.path.join(self.path, ".git")):
@@ -59,18 +65,18 @@ class Extension:
try:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
- head = repo.head.commit
self.commit_date = repo.head.commit.committed_date
- ts = time.asctime(time.gmtime(self.commit_date))
if repo.active_branch:
self.branch = repo.active_branch.name
- self.commit_hash = head.hexsha
- self.version = f'{self.commit_hash[:8]} ({ts})'
+ self.commit_hash = repo.head.commit.hexsha
+ self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
except Exception as ex:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None
+ self.have_info_from_repo = True
+
def list_files(self, subdir, extension):
from modules import scripts
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 1978673d..f9db41bc 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""
- for extra_network_name, extra_network_args in extra_network_data.items():
+ for extra_network_name in extra_network_data:
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
continue
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
index 04f27c9f..aa2a14ef 100644
--- a/modules/extra_networks_hypernet.py
+++ b/modules/extra_networks_hypernet.py
@@ -1,4 +1,4 @@
-from modules import extra_networks, shared, extra_networks
+from modules import extra_networks, shared
from modules.hypernetworks import hypernetwork
diff --git a/modules/extras.py b/modules/extras.py
index ff4e9c4e..830b53aa 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -136,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
result_is_instruct_pix2pix_model = False
if theta_func2:
- shared.state.textinfo = f"Loading B"
+ shared.state.textinfo = "Loading B"
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
else:
theta_1 = None
if theta_func1:
- shared.state.textinfo = f"Loading C"
+ shared.state.textinfo = "Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
@@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
result_is_inpainting_model = True
else:
theta_0[key] = theta_func2(a, b, multiplier)
-
+
theta_0[key] = to_half(theta_0[key], save_as_half)
shared.state.sampling_step += 1
@@ -242,9 +242,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
- metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
+ metadata = None
if save_metadata:
+ metadata = {"format": "pt"}
+
merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256,
@@ -262,15 +264,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
}
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+ sd_merge_models = {}
+
def add_model_metadata(checkpoint_info):
checkpoint_info.calculate_shorthash()
- metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ sd_merge_models[checkpoint_info.sha256] = {
"name": checkpoint_info.name,
"legacy_hash": checkpoint_info.hash,
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
}
- metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+ sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
add_model_metadata(primary_model_info)
if secondary_model_info:
@@ -278,7 +282,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info:
add_model_metadata(tertiary_model_info)
- metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+ metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index fe8b18b2..f1a2204c 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,15 +1,11 @@
import base64
-import html
import io
-import math
import os
import re
-from pathlib import Path
import gradio as gr
from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks
-import tempfile
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
@@ -23,14 +19,14 @@ registered_param_bindings = []
class ParamBinding:
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
self.paste_button = paste_button
self.tabname = tabname
self.source_text_component = source_text_component
self.source_image_component = source_image_component
self.source_tabname = source_tabname
self.override_settings_component = override_settings_component
- self.paste_field_names = paste_field_names
+ self.paste_field_names = paste_field_names or []
def reset():
@@ -251,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
lines.append(lastline)
lastline = ''
- for i, line in enumerate(lines):
+ for line in lines:
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
@@ -312,6 +308,8 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ('Token merging ratio', 'token_merging_ratio'),
+ ('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'),
('NGMS', 's_min_uncond'),
]
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index fbe6215a..0131dea4 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -78,7 +78,7 @@ def setup_model(dirname):
try:
from gfpgan import GFPGANer
- from facexlib import detection, parsing
+ from facexlib import detection, parsing # noqa: F401
global user_path
global have_gfpgan
global gfpgan_constructor
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 1fc49537..570b5603 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,4 +1,3 @@
-import csv
import datetime
import glob
import html
@@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
-from collections import defaultdict, deque
+from collections import deque
from statistics import stdev, mean
@@ -178,34 +177,34 @@ class Hypernetwork:
def weights(self):
res = []
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
res += layer.parameters()
return res
def train(self, mode=True):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.train(mode=mode)
for param in layer.parameters():
param.requires_grad = mode
def to(self, device):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.to(device)
return self
def set_multiplier(self, multiplier):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.multiplier = multiplier
return self
def eval(self):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.eval()
for param in layer.parameters():
@@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
k = self.to_k(context_k)
v = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@@ -541,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
+
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
@@ -594,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
print(e)
scaler = torch.cuda.amp.GradScaler()
-
+
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
@@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
@@ -637,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if clip_grad:
clip_grad_sched.step(hypernetwork.step)
-
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight:
@@ -658,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step += loss.item()
scaler.scale(loss).backward()
-
+
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
loss_logging.append(_loss_step)
if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate)
-
+
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
@@ -675,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step = 0
steps_done = hypernetwork.step + 1
-
+
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 76599f5a..8b6255e2 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -1,19 +1,17 @@
import html
-import os
-import re
import gradio as gr
import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
+ return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
def train_hypernetwork(*args):
diff --git a/modules/images.py b/modules/images.py
index a41965ab..4e8cd993 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -13,17 +13,24 @@ import numpy as np
import piexif
import piexif.helper
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
-from fonts.ttf import Roboto
import string
import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
-from modules.shared import opts, cmd_opts
+from modules.paths_internal import roboto_ttf_file
+from modules.shared import opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+def get_font(fontsize: int):
+ try:
+ return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
+ except Exception:
+ return ImageFont.truetype(roboto_ttf_file, fontsize)
+
+
def image_grid(imgs, batch_size=1, rows=None):
if rows is None:
if opts.n_rows > 0:
@@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
lines.append(word)
return lines
- def get_font(fontsize):
- try:
- return ImageFont.truetype(opts.font or Roboto, fontsize)
- except Exception:
- return ImageFont.truetype(Roboto, fontsize)
-
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
- for i, line in enumerate(lines):
+ for line in lines:
fnt = initial_fnt
fontsize = initial_fontsize
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
@@ -366,7 +367,7 @@ class FilenameGenerator:
self.seed = seed
self.prompt = prompt
self.image = image
-
+
def hasprompt(self, *args):
lower = self.prompt.lower()
if self.p is None or self.prompt is None:
@@ -409,13 +410,13 @@ class FilenameGenerator:
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
- except pytz.exceptions.UnknownTimeZoneError as _:
+ except pytz.exceptions.UnknownTimeZoneError:
time_zone = None
time_zone_time = time_datetime.astimezone(time_zone)
try:
formatted_time = time_zone_time.strftime(time_format)
- except (ValueError, TypeError) as _:
+ except (ValueError, TypeError):
formatted_time = time_zone_time.strftime(self.default_time_format)
return sanitize_filename_part(formatted_time, replace_spaces=False)
@@ -472,15 +473,52 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename)
for p in os.listdir(path):
if p.startswith(basename):
- l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+ parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try:
- result = max(int(l[0]), result)
+ result = max(int(parts[0]), result)
except ValueError:
pass
return result + 1
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+ if extension is None:
+ extension = os.path.splitext(filename)[1]
+
+ image_format = Image.registered_extensions()[extension]
+
+ existing_pnginfo = existing_pnginfo or {}
+ if opts.enable_pnginfo:
+ existing_pnginfo['parameters'] = geninfo
+
+ if extension.lower() == '.png':
+ pnginfo_data = PngImagePlugin.PngInfo()
+ for k, v in (existing_pnginfo or {}).items():
+ pnginfo_data.add_text(k, str(v))
+
+ image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
+
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+ if image.mode == 'RGBA':
+ image = image.convert("RGB")
+ elif image.mode == 'I;16':
+ image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
+
+ image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
+
+ if opts.enable_pnginfo and geninfo is not None:
+ exif_bytes = piexif.dump({
+ "Exif": {
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
+ },
+ })
+
+ piexif.insert(exif_bytes, filename)
+ else:
+ image.save(filename, format=image_format, quality=opts.jpeg_quality)
+
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
"""Save an image.
@@ -565,38 +603,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
info = params.pnginfo.get(pnginfo_section_name, None)
def _atomically_save_image(image_to_save, filename_without_extension, extension):
- # save image with .tmp extension to avoid race condition when another process detects new image in the directory
+ """
+ save image with .tmp extension to avoid race condition when another process detects new image in the directory
+ """
temp_file_path = f"{filename_without_extension}.tmp"
- image_format = Image.registered_extensions()[extension]
- if extension.lower() == '.png':
- pnginfo_data = PngImagePlugin.PngInfo()
- if opts.enable_pnginfo:
- for k, v in params.pnginfo.items():
- pnginfo_data.add_text(k, str(v))
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
-
- elif extension.lower() in (".jpg", ".jpeg", ".webp"):
- if image_to_save.mode == 'RGBA':
- image_to_save = image_to_save.convert("RGB")
- elif image_to_save.mode == 'I;16':
- image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
-
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
-
- if opts.enable_pnginfo and info is not None:
- exif_bytes = piexif.dump({
- "Exif": {
- piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
- },
- })
-
- piexif.insert(exif_bytes, temp_file_path)
- else:
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
-
- # atomically rename the file with correct extension
os.replace(temp_file_path, filename_without_extension + extension)
fullfn_without_extension, extension = os.path.splitext(params.filename)
diff --git a/modules/img2img.py b/modules/img2img.py
index 9fc3a698..d704bf90 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,19 +1,15 @@
-import math
import os
-import sys
-import traceback
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
-from modules import devices, sd_samplers
+from modules import sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
-import modules.images as images
import modules.scripts
@@ -59,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
# try to find corresponding mask for an image using simple filename matching
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# if not found use first one ("same mask for all images" use-case)
- if not mask_image_path in inpaint_masks:
+ if mask_image_path not in inpaint_masks:
mask_image_path = inpaint_masks[0]
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9f7d657f..111b1322 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -11,7 +11,6 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-import modules.shared as shared
from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384
@@ -160,7 +159,7 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
+ text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
@@ -208,8 +207,8 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True)
- for name, topn, items in self.categories():
- matches = self.rank(image_features, items, top_count=topn)
+ for cat in self.categories():
+ matches = self.rank(image_features, cat.items, top_count=cat.topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:
res += f", ({match}:{score/100:.3f})"
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 40ce2101..d74c6b95 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,6 +1,5 @@
import torch
import platform
-from modules import paths
from modules.sd_hijack_utils import CondFunc
from packaging import version
@@ -43,7 +42,7 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
- # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
@@ -61,4 +60,4 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
- CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
\ No newline at end of file
+ CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
diff --git a/modules/masking.py b/modules/masking.py
index a5c4d2da..be9f84c7 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
def get_crop_region(mask, pad=0):
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
-
+
h, w = mask.shape
crop_left = 0
diff --git a/modules/modelloader.py b/modules/modelloader.py
index a70aa0e3..2a479bcb 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,4 +1,3 @@
-import glob
import os
import shutil
import importlib
@@ -40,7 +39,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if os.path.islink(full_path) and not os.path.exists(full_path):
print(f"Skipping broken symlink: {full_path}")
continue
- if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
+ if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
continue
if full_path not in output:
output.append(full_path)
@@ -108,12 +107,12 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
print(f"Moving {file} from {src_path} to {dest_path}.")
try:
shutil.move(fullpath, dest_path)
- except:
+ except Exception:
pass
if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True)
- except:
+ except Exception:
pass
@@ -127,7 +126,7 @@ def load_upscalers():
full_model = f"modules.{model_name}_model"
try:
importlib.import_module(full_model)
- except:
+ except Exception:
pass
datas = []
diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py
index f880bc3c..3fb76b65 100644
--- a/modules/models/diffusion/ddpm_edit.py
+++ b/modules/models/diffusion/ddpm_edit.py
@@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
# If initialing from EMA-only checkpoint, create EMA model after loading.
if self.use_ema and not load_ema:
@@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
+ ignore_keys = ignore_keys or []
+
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
@@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
+ log = {}
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
@@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
log["inputs"] = x
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
x_start = x[:n_row]
for t in range(self.num_timesteps):
@@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
+ super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
+ except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
@@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
- def rescale_bbox(bbox):
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
- return x0, y0, w, h
-
- return [rescale_bbox(b) for b in bboxes]
-
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
@@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
return img, intermediates
@torch.no_grad()
@@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
if return_intermediates:
return img, intermediates
@@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
@@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
use_ddim = False
- log = dict()
+ log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
@@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
if plot_diffusion_rows:
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
if inpaint:
# make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+ super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+ logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py
index e1265e3f..dbb35964 100644
--- a/modules/models/diffusion/uni_pc/__init__.py
+++ b/modules/models/diffusion/uni_pc/__init__.py
@@ -1 +1 @@
-from .sampler import UniPCSampler
+from .sampler import UniPCSampler # noqa: F401
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
index a241c8a7..0a9defa1 100644
--- a/modules/models/diffusion/uni_pc/sampler.py
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -54,7 +54,8 @@ class UniPCSampler(object):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list): ctmp = ctmp[0]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
index 11b330bc..d257a728 100644
--- a/modules/models/diffusion/uni_pc/uni_pc.py
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -1,7 +1,6 @@
import torch
-import torch.nn.functional as F
import math
-from tqdm.auto import trange
+import tqdm
class NoiseScheduleVP:
@@ -179,13 +178,13 @@ def model_wrapper(
model,
noise_schedule,
model_type="noise",
- model_kwargs={},
+ model_kwargs=None,
guidance_type="uncond",
#condition=None,
#unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
- classifier_kwargs={},
+ classifier_kwargs=None,
):
"""Create a wrapper function for the noise prediction model.
@@ -276,6 +275,9 @@ def model_wrapper(
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
+ model_kwargs = model_kwargs or {}
+ classifier_kwargs = classifier_kwargs or {}
+
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
@@ -342,7 +344,7 @@ def model_wrapper(
t_in = torch.cat([t_continuous] * 2)
if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict)
- c_in = dict()
+ c_in = {}
for k in condition:
if isinstance(condition[k], list):
c_in[k] = [torch.cat([
@@ -353,7 +355,7 @@ def model_wrapper(
unconditional_condition[k],
condition[k]])
elif isinstance(condition, list):
- c_in = list()
+ c_in = []
assert isinstance(unconditional_condition, list)
for i in range(len(condition)):
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
@@ -757,40 +759,44 @@ class UniPC:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
- # Init the first `order` values by lower order multistep DPM-Solver.
- for init_order in range(1, order):
- vec_t = timesteps[init_order].expand(x.shape[0])
- x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
- if model_x is None:
- model_x = self.model_fn(x, vec_t)
- if self.after_update is not None:
- self.after_update(x, model_x)
- model_prev_list.append(model_x)
- t_prev_list.append(vec_t)
- for step in trange(order, steps + 1):
- vec_t = timesteps[step].expand(x.shape[0])
- if lower_order_final:
- step_order = min(order, steps + 1 - step)
- else:
- step_order = order
- #print('this step order:', step_order)
- if step == steps:
- #print('do not run corrector at the last step')
- use_corrector = False
- else:
- use_corrector = True
- x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
- if self.after_update is not None:
- self.after_update(x, model_x)
- for i in range(order - 1):
- t_prev_list[i] = t_prev_list[i + 1]
- model_prev_list[i] = model_prev_list[i + 1]
- t_prev_list[-1] = vec_t
- # We do not need to evaluate the final model value.
- if step < steps:
+ with tqdm.tqdm(total=steps) as pbar:
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in range(1, order):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
if model_x is None:
model_x = self.model_fn(x, vec_t)
- model_prev_list[-1] = model_x
+ if self.after_update is not None:
+ self.after_update(x, model_x)
+ model_prev_list.append(model_x)
+ t_prev_list.append(vec_t)
+ pbar.update()
+
+ for step in range(order, steps + 1):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ #print('this step order:', step_order)
+ if step == steps:
+ #print('do not run corrector at the last step')
+ use_corrector = False
+ else:
+ use_corrector = True
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
+ if self.after_update is not None:
+ self.after_update(x, model_x)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ if model_x is None:
+ model_x = self.model_fn(x, vec_t)
+ model_prev_list[-1] = model_x
+ pbar.update()
else:
raise NotImplementedError()
if denoise_to_zero:
diff --git a/modules/ngrok.py b/modules/ngrok.py
index 7a7b4b26..0c713e27 100644
--- a/modules/ngrok.py
+++ b/modules/ngrok.py
@@ -1,6 +1,7 @@
-from pyngrok import ngrok, conf, exception
+import ngrok
-def connect(token, port, region):
+# Connect to ngrok for ingress
+def connect(token, port, options):
account = None
if token is None:
token = 'None'
@@ -10,28 +11,19 @@ def connect(token, port, region):
token, username, password = token.split(':', 2)
account = f"{username}:{password}"
- config = conf.PyngrokConfig(
- auth_token=token, region=region
- )
-
- # Guard for existing tunnels
- existing = ngrok.get_tunnels(pyngrok_config=config)
- if existing:
- for established in existing:
- # Extra configuration in the case that the user is also using ngrok for other tunnels
- if established.config['addr'][-4:] == str(port):
- public_url = existing[0].public_url
- print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
- 'You can use this link after the launch is complete.')
- return
-
+ # For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
+ if not options.get('authtoken_from_env'):
+ options['authtoken'] = token
+ if account:
+ options['basic_auth'] = account
+ if not options.get('session_metadata'):
+ options['session_metadata'] = 'stable-diffusion-webui'
+
+
try:
- if account is None:
- public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
- else:
- public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
- except exception.PyngrokNgrokError:
- print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
+ public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
+ except Exception as e:
+ print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
else:
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
diff --git a/modules/paths.py b/modules/paths.py
index acf1894b..5f6474c0 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,8 +1,8 @@
import os
import sys
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
-import modules.safe
+import modules.safe # noqa: F401
# data_path = cmd_opts_pre.data
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 6765bafe..005a9b0a 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -2,8 +2,14 @@
import argparse
import os
+import sys
+import shlex
-script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
+sys.argv += shlex.split(commandline_args)
+
+modules_path = os.path.dirname(os.path.realpath(__file__))
+script_path = os.path.dirname(modules_path)
sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
@@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser_pre = argparse.ArgumentParser(add_help=False)
-parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
+parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
cmd_opts_pre = parser_pre.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
@@ -21,3 +27,5 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
+
+roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
diff --git a/modules/processing.py b/modules/processing.py
index 1a76e552..2b8dd361 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -2,7 +2,6 @@ import json
import math
import os
import sys
-import warnings
import hashlib
import torch
@@ -11,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
import random
import cv2
from skimage import exposure
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -31,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@@ -150,6 +150,8 @@ class StableDiffusionProcessing:
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
self.disable_extra_networks = False
+ self.token_merging_ratio = 0
+ self.token_merging_ratio_hr = 0
if not seed_enable_extras:
self.subseed = -1
@@ -165,7 +167,8 @@ class StableDiffusionProcessing:
self.all_subseeds = None
self.iteration = 0
self.is_hr_pass = False
-
+ self.sampler = None
+
@property
def sd_model(self):
@@ -274,6 +277,12 @@ class StableDiffusionProcessing:
def close(self):
self.sampler = None
+ def get_token_merging_ratio(self, for_hr=False):
+ if for_hr:
+ return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
+
+ return self.token_merging_ratio or opts.token_merging_ratio
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -303,6 +312,8 @@ class Processed:
self.styles = p.styles
self.job_timestamp = state.job_timestamp
self.clip_skip = opts.CLIP_stop_at_last_layers
+ self.token_merging_ratio = p.token_merging_ratio
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
self.eta = p.eta
self.ddim_discretize = p.ddim_discretize
@@ -310,6 +321,7 @@ class Processed:
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
+ self.s_min_uncond = p.s_min_uncond
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
@@ -360,6 +372,9 @@ class Processed:
def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
@@ -472,6 +487,13 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
+ enable_hr = getattr(p, 'enable_hr', False)
+ token_merging_ratio = p.get_token_merging_ratio()
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
+
+ uses_ensd = opts.eta_noise_seed_delta != 0
+ if uses_ensd:
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
generation_params = {
"Steps": p.steps,
@@ -489,15 +511,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
- "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
+ "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
+ "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
+ **p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
}
- generation_params.update(p.extra_generation_params)
-
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
@@ -523,9 +546,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
+ sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
+
res = process_images_inner(p)
finally:
+ sd_models.apply_token_merging(p.sd_model, 0)
+
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@@ -660,12 +687,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- step_multiplier = 1
- if not shared.opts.dont_fix_second_order_samplers_schedule:
- try:
- step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
- except:
- pass
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+ step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
@@ -978,8 +1001,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+
self.is_hr_pass = False
return samples
@@ -1141,3 +1168,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
devices.torch_gc()
return samples
+
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
diff --git a/modules/progress.py b/modules/progress.py
index 948e6f00..f405f07f 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -95,9 +95,20 @@ def progressapi(req: ProgressRequest):
image = shared.state.current_image
if image is not None:
buffered = io.BytesIO()
- image.save(buffered, format="png")
+
+ if opts.live_previews_image_format == "png":
+ # using optimize for large images takes an enormous amount of time
+ if max(*image.size) <= 256:
+ save_kwargs = {"optimize": True}
+ else:
+ save_kwargs = {"optimize": False, "compress_level": 1}
+
+ else:
+ save_kwargs = {}
+
+ image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
- live_preview = f"data:image/png;base64,{base64_image}"
+ live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
id_live_preview = shared.state.id_live_preview
else:
live_preview = None
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 69665372..b4aff704 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
"""
def collect_steps(steps, tree):
- l = [steps]
+ res = [steps]
+
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
- l.append(tree.children[-1])
+ res.append(tree.children[-1])
+
def alternate(self, tree):
- l.extend(range(1, steps+1))
+ res.extend(range(1, steps+1))
+
CollectSteps().visit(tree)
- return sorted(set(l))
+ return sorted(set(res))
def at_step(step, tree):
class AtStep(lark.Transformer):
@@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
def get_schedule(prompt):
try:
tree = schedule_parser.parse(prompt)
- except lark.exceptions.LarkError as e:
+ except lark.exceptions.LarkError:
if 0:
import traceback
traceback.print_exc()
@@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
conds = model.get_learned_conditioning(texts)
cond_schedule = []
- for i, (end_at_step, text) in enumerate(prompt_schedule):
+ for i, (end_at_step, _) in enumerate(prompt_schedule):
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
cache[prompt] = cond_schedule
@@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c):
target_index = 0
- for current, (end_at, cond) in enumerate(cond_schedule):
- if current_step <= end_at:
+ for current, entry in enumerate(cond_schedule):
+ if current_step <= entry.end_at_step:
target_index = current
break
res[i] = cond_schedule[target_index].cond
@@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
tensors = []
conds_list = []
- for batch_no, composable_prompts in enumerate(c.batch):
+ for composable_prompts in c.batch:
conds_for_batch = []
- for cond_index, composable_prompt in enumerate(composable_prompts):
+ for composable_prompt in composable_prompts:
target_index = 0
- for current, (end_at, cond) in enumerate(composable_prompt.schedules):
- if current_step <= end_at:
+ for current, entry in enumerate(composable_prompt.schedules):
+ if current_step <= entry.end_at_step:
target_index = current
break
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index efd7fca5..c24d8dbb 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -17,9 +17,9 @@ class UpscalerRealESRGAN(Upscaler):
self.user_path = path
super().__init__()
try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+ from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
+ from realesrgan import RealESRGANer # noqa: F401
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True
self.scalers = []
scalers = self.load_models(path)
@@ -134,6 +134,6 @@ def get_realesrgan_models(scaler):
),
]
return models
- except Exception as e:
+ except Exception:
print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/safe.py b/modules/safe.py
index e1a67f73..e8f50774 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -95,16 +95,16 @@ def check_pt(filename, extra_handler):
except zipfile.BadZipfile:
- # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+ # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
- for i in range(5):
+ for _ in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
- return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
+ return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 17109732..3c21a362 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -32,27 +32,42 @@ class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x
"""Latent image representation in the process of being denoised"""
-
+
self.image_cond = image_cond
"""Conditioning image"""
-
+
self.sigma = sigma
"""Current sigma noise step value"""
-
+
self.sampling_step = sampling_step
"""Current Sampling step number"""
-
+
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
-
+
self.text_cond = text_cond
""" Encoder hidden states of text conditioning from prompt"""
-
+
self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt"""
class CFGDenoisedParams:
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.inner_model = inner_model
+ """Inner model reference used for denoising"""
+
+
+class AfterCFGCallbackParams:
def __init__(self, x, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -87,6 +102,7 @@ callback_map = dict(
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
+ callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -186,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
report_exception(c, 'cfg_denoised_callback')
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+ for c in callback_map['callbacks_cfg_after_cfg']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_after_cfg_callback')
+
+
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
@@ -240,7 +264,7 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
-
+
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -332,6 +356,14 @@ def on_cfg_denoised(callback):
add_callback(callback_map['callbacks_cfg_denoised'], callback)
+def on_cfg_after_cfg(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
+ The callback is called with one argument:
+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
+ """
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+
+
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
diff --git a/modules/script_loading.py b/modules/script_loading.py
index a7d2203f..57b15862 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -2,7 +2,6 @@ import os
import sys
import traceback
import importlib.util
-from types import ModuleType
def load_module(path):
diff --git a/modules/scripts.py b/modules/scripts.py
index d945b89f..e33d8c81 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -17,6 +17,9 @@ class PostprocessImageArgs:
class Script:
+ name = None
+ """script's internal name derived from title"""
+
filename = None
args_from = None
args_to = None
@@ -25,8 +28,8 @@ class Script:
is_txt2img = False
is_img2img = False
- """A gr.Group component that has all script's UI inside it"""
group = None
+ """A gr.Group component that has all script's UI inside it"""
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@@ -38,6 +41,9 @@ class Script:
various "Send to " buttons when clicked
"""
+ api_info = None
+ """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -231,7 +237,7 @@ def load_scripts():
syspath = sys.path
def register_scripts_from_module(module):
- for key, script_class in module.__dict__.items():
+ for script_class in module.__dict__.values():
if type(script_class) != type:
continue
@@ -295,9 +301,9 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
- for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
- script = script_class()
- script.filename = path
+ for script_data in auto_processing_scripts + scripts_data:
+ script = script_data.script_class()
+ script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
@@ -313,6 +319,8 @@ class ScriptRunner:
self.selectable_scripts.append(script)
def setup_ui(self):
+ import modules.api.models as api_models
+
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
inputs = [None]
@@ -327,9 +335,28 @@ class ScriptRunner:
if controls is None:
return
+ script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+ api_args = []
+
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
+ arg_info = api_models.ScriptArg(label=control.label or "")
+
+ for field in ("value", "minimum", "maximum", "step", "choices"):
+ v = getattr(control, field, None)
+ if v is not None:
+ setattr(arg_info, field, v)
+
+ api_args.append(arg_info)
+
+ script.api_info = api_models.ScriptInfo(
+ name=script.name,
+ is_img2img=script.is_img2img,
+ is_alwayson=script.alwayson,
+ args=api_args,
+ )
+
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
@@ -492,7 +519,7 @@ class ScriptRunner:
module = script_loading.load_module(script.filename)
cache[filename] = module
- for key, script_class in module.__dict__.items():
+ for script_class in module.__dict__.values():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py
index 30d6d658..d63078de 100644
--- a/modules/scripts_auto_postprocessing.py
+++ b/modules/scripts_auto_postprocessing.py
@@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args):
- args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
+ args_dict = dict(zip(self.postprocessing_controls, args))
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {}
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index b11568c0..bac1335d 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
def initialize_scripts(self, scripts_data):
self.scripts = []
- for script_class, path, basedir, script_module in scripts_data:
- script: ScriptPostprocessing = script_class()
- script.filename = path
+ for script_data in scripts_data:
+ script: ScriptPostprocessing = script_data.script_class()
+ script.filename = script_data.path
if script.name == "Simple Upscale":
continue
@@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
script_args = args[script.args_from:script.args_to]
process_args = {}
- for (name, component), value in zip(script.controls.items(), script_args):
+ for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value
script.process(pp, **process_args)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index c4a09d15..9fc89dc6 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -61,7 +61,7 @@ class DisableInitialization:
if res is None:
res = original(url, *args, local_files_only=False, **kwargs)
return res
- except Exception as e:
+ except Exception:
return original(url, *args, local_files_only=False, **kwargs)
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f4bb0266..14e7f799 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules import devices, sd_hijack_optimizations, shared
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -34,10 +34,10 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
-
+
optimization_method = None
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
+ can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
@@ -92,12 +92,12 @@ def fix_checkpoint():
def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False)
-
+
#Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None:
loss *= weight
-
+
#Return the loss, as mean if specified
return loss.mean() if mean else loss
@@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w
-
+
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'):
@@ -118,9 +118,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Delete temporary weights if appended
del sd_model._custom_loss_weight
- except AttributeError as e:
+ except AttributeError:
pass
-
+
#If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss
@@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model):
def undo_weighted_forward(sd_model):
try:
del sd_model.weighted_forward
- except AttributeError as e:
+ except AttributeError:
pass
@@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- m.cond_stage_model = m.cond_stage_model.wrapped
+ m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
@@ -216,6 +216,9 @@ class StableDiffusionModelHijack:
self.comments = []
def get_prompt_lengths(self, text):
+ if self.clip is None:
+ return "-", "-"
+
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count)
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 9fa5c5c5..cc6e8c21 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack.fixes = [x.fixes for x in batch_chunk]
for fixes in self.hijack.fixes:
- for position, embedding in fixes:
+ for _position, embedding in fixes:
used_embeddings[embedding.name] = embedding
z = self.process_tokens(tokens, multipliers)
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 55a2ce4d..c1977b19 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,16 +1,10 @@
-import os
import torch
-from einops import repeat
-from omegaconf import ListConfig
-
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from ldm.models.diffusion.ddpm import LatentDiffusion
-from ldm.models.diffusion.plms import PLMSSampler
-from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+from ldm.models.diffusion.ddim import noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
@@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
- c_in = dict()
+ c_in = {}
for k in c:
if isinstance(c[k], list):
c_in[k] = [
diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py
index 3c727d3b..6fe6b6ff 100644
--- a/modules/sd_hijack_ip2p.py
+++ b/modules/sd_hijack_ip2p.py
@@ -1,8 +1,5 @@
-import collections
import os.path
-import sys
-import gc
-import time
+
def should_hijack_ip2p(checkpoint_info):
from modules import sd_models_config
@@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
- return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
+ return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index f10865cd..f00fe55c 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
v_in = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
-
+
s2 = s1.softmax(dim=-1)
del s1
-
+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
del q, k, v
@@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale
-
+
del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
-
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
+
mem_free_total = get_available_vram()
-
+
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
-
+
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-
+
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
-
+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
-
+
del q, k, v
r1 = r1.to(dtype)
@@ -228,8 +228,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
r = einsum_op(q, k, v)
r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
@@ -296,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
- query_chunk_size = q_tokens
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
@@ -335,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -370,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
-
+
del q_in, k_in, v_in
dtype = q.dtype
@@ -452,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
h3 += x
return h3
-
+
def xformers_attnblock_forward(self, x):
try:
h_ = x
@@ -461,7 +460,7 @@ def xformers_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -483,7 +482,7 @@ def sdp_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -507,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
index 4ac51c38..28528329 100644
--- a/modules/sd_hijack_xlmr.py
+++ b/modules/sd_hijack_xlmr.py
@@ -1,8 +1,6 @@
-import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
-from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 36f643e1..4bd8783e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -15,9 +15,9 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -87,8 +87,7 @@ class CheckpointInfo:
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
+ from transformers import logging, CLIPModel # noqa: F401
logging.set_verbosity_error()
except Exception:
@@ -167,7 +166,7 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
-
+
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -239,7 +238,7 @@ def read_metadata_from_safetensors(filename):
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
- except Exception as e:
+ except Exception:
pass
return res
@@ -374,7 +373,7 @@ def enable_midas_autodownload():
if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)
-
+
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
@@ -415,6 +414,9 @@ class SdModelData:
def get_sd_model(self):
if self.sd_model is None:
with self.lock:
+ if self.sd_model is not None:
+ return self.sd_model
+
try:
load_model()
except Exception as e:
@@ -467,7 +469,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
+ except Exception:
pass
if sd_model is None:
@@ -538,13 +540,12 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
- checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception as e:
+ except Exception:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
@@ -565,7 +566,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model:
@@ -580,3 +581,29 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, token_merging_ratio):
+ """
+ Applies speed and memory optimizations from tomesd.
+ """
+
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
+
+ if current_token_merging_ratio == token_merging_ratio:
+ return
+
+ if current_token_merging_ratio > 0:
+ tomesd.remove_patch(sd_model)
+
+ if token_merging_ratio > 0:
+ tomesd.apply_patch(
+ sd_model,
+ ratio=token_merging_ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
+
+ sd_model.applied_token_merged_ratio = token_merging_ratio
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 7a79925a..9bfe1237 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -1,4 +1,3 @@
-import re
import os
import torch
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index ff361f22..f22aad8f 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,7 +1,7 @@
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
# imports for functions that previously were here and are used by other modules
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
+from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
@@ -14,12 +14,18 @@ samplers_for_img2img = []
samplers_map = {}
-def create_sampler(name, model):
+def find_sampler_config(name):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
+ return config
+
+
+def create_sampler(name, model):
+ config = find_sampler_config(name)
+
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index bc074238..763829f1 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx
+from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
from modules.shared import opts, state
import modules.shared as shared
@@ -22,7 +22,7 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
def single_sample_to_image(sample, approximation=None):
@@ -30,15 +30,19 @@ def single_sample_to_image(sample, approximation=None):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
+ x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
+ elif approximation == 3:
+ x_sample = sample * 1.5
+ x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
+
return Image.fromarray(x_sample)
@@ -58,6 +62,25 @@ def store_latent(decoded):
shared.state.assign_current_image(sample_to_image(decoded))
+def is_sampler_using_eta_noise_seed_delta(p):
+ """returns whether sampler from config will use eta noise seed delta for image creation"""
+
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+
+ eta = p.eta
+
+ if eta is None and p.sampler is not None:
+ eta = p.sampler.eta
+
+ if eta is None and sampler_config is not None:
+ eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
+
+ if eta == 0:
+ return False
+
+ return sampler_config.options.get("uses_ensd", False)
+
+
class InterruptedException(BaseException):
pass
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index bfcc5574..bdae8b40 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -11,7 +11,7 @@ import modules.models.diffusion.uni_pc
samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
]
@@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
- res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
@@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
- assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
+ assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
@@ -134,7 +134,11 @@ class VanillaStableDiffusionSampler:
self.update_step(x)
def initialize(self, p):
- self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ if self.is_ddim:
+ self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ else:
+ self.eta = 0.0
+
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 0fc9f456..552c6c64 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,7 +1,6 @@
from collections import deque
import torch
import inspect
-import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
@@ -9,25 +8,26 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
]
samplers_data_k_diffusion = [
@@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module):
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else:
image_uncond = image_cond
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -161,7 +161,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -181,6 +181,10 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+
self.step += 1
return denoised
@@ -317,7 +321,7 @@ class KDiffusionSampler:
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
-
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
@@ -340,9 +344,9 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}
@@ -375,9 +379,9 @@ class KDiffusionSampler:
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 521e485a..e4ff2994 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,8 +1,5 @@
-import torch
-import safetensors.torch
import os
import collections
-from collections import namedtuple
from modules import paths, shared, devices, script_callbacks, sd_models
import glob
from copy import deepcopy
@@ -88,10 +85,10 @@ def refresh_vae_list():
def find_vae_near_checkpoint(checkpoint_file):
- checkpoint_path = os.path.splitext(checkpoint_file)[0]
- for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]:
- if os.path.isfile(vae_location):
- return vae_location
+ checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
+ for vae_file in vae_dict.values():
+ if os.path.basename(vae_file).startswith(checkpoint_path):
+ return vae_file
return None
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
new file mode 100644
index 00000000..5e8496e8
--- /dev/null
+++ b/modules/sd_vae_taesd.py
@@ -0,0 +1,88 @@
+"""
+Tiny AutoEncoder for Stable Diffusion
+(DNN for encoding / decoding SD's latent space)
+
+https://github.com/madebyollin/taesd
+"""
+import os
+import torch
+import torch.nn as nn
+
+from modules import devices, paths_internal
+
+sd_vae_taesd = None
+
+
+def conv(n_in, n_out, **kwargs):
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+
+
+class Clamp(nn.Module):
+ @staticmethod
+ def forward(x):
+ return torch.tanh(x / 3) * 3
+
+
+class Block(nn.Module):
+ def __init__(self, n_in, n_out):
+ super().__init__()
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ self.fuse = nn.ReLU()
+
+ def forward(self, x):
+ return self.fuse(self.conv(x) + self.skip(x))
+
+
+def decoder():
+ return nn.Sequential(
+ Clamp(), conv(4, 64), nn.ReLU(),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), conv(64, 3),
+ )
+
+
+class TAESD(nn.Module):
+ latent_magnitude = 3
+ latent_shift = 0.5
+
+ def __init__(self, decoder_path="taesd_decoder.pth"):
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
+ super().__init__()
+ self.decoder = decoder()
+ self.decoder.load_state_dict(
+ torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+
+ @staticmethod
+ def unscale_latents(x):
+ """[0, 1] -> raw latents"""
+ return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
+
+
+def download_model(model_path):
+ model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
+
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+
+ print(f'Downloading TAESD decoder to: {model_path}')
+ torch.hub.download_url_to_file(model_url, model_path)
+
+
+def model():
+ global sd_vae_taesd
+
+ if sd_vae_taesd is None:
+ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
+ download_model(model_path)
+
+ if os.path.exists(model_path):
+ sd_vae_taesd = TAESD(model_path)
+ sd_vae_taesd.eval()
+ sd_vae_taesd.to(devices.device, devices.dtype)
+ else:
+ raise FileNotFoundError('TAESD model not found')
+
+ return sd_vae_taesd.decoder
diff --git a/modules/shared.py b/modules/shared.py
index b3508883..9e9e8cd4 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,12 +1,10 @@
-import argparse
import datetime
import json
import os
import sys
+import threading
import time
-import requests
-from PIL import Image
import gradio as gr
import tqdm
@@ -15,7 +13,7 @@ import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
-from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion
demo = None
@@ -113,8 +111,47 @@ class State:
id_live_preview = 0
textinfo = None
time_start = None
- need_restart = False
server_start = None
+ _server_command_signal = threading.Event()
+ _server_command: str | None = None
+
+ @property
+ def need_restart(self) -> bool:
+ # Compatibility getter for need_restart.
+ return self.server_command == "restart"
+
+ @need_restart.setter
+ def need_restart(self, value: bool) -> None:
+ # Compatibility setter for need_restart.
+ if value:
+ self.server_command = "restart"
+
+ @property
+ def server_command(self):
+ return self._server_command
+
+ @server_command.setter
+ def server_command(self, value: str | None) -> None:
+ """
+ Set the server command to `value` and signal that it's been set.
+ """
+ self._server_command = value
+ self._server_command_signal.set()
+
+ def wait_for_server_command(self, timeout: float | None = None) -> str | None:
+ """
+ Wait for server command to get set; return and clear the value and signal.
+ """
+ if self._server_command_signal.wait(timeout):
+ self._server_command_signal.clear()
+ req = self._server_command
+ self._server_command = None
+ return req
+ return None
+
+ def request_restart(self) -> None:
+ self.interrupt()
+ self.server_command = "restart"
def skip(self):
self.skipped = True
@@ -202,8 +239,9 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
+
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
self.default = default
self.label = label
self.component = component
@@ -212,9 +250,33 @@ class OptionInfo:
self.section = section
self.refresh = refresh
+ self.comment_before = comment_before
+ """HTML text that will be added after label in UI"""
+
+ self.comment_after = comment_after
+ """HTML text that will be added before label in UI"""
+
+ def link(self, label, url):
+ self.comment_before += f"[{label} ]"
+ return self
+
+ def js(self, label, js_func):
+ self.comment_before += f"[{label} ]"
+ return self
+
+ def info(self, info):
+ self.comment_after += f"({info}) "
+ return self
+
+ def needs_restart(self):
+ self.comment_after += " (requires restart) "
+ return self
+
+
+
def options_section(section_identifier, options_dict):
- for k, v in options_dict.items():
+ for v in options_dict.values():
v.section = section_identifier
return options_dict
@@ -243,7 +305,7 @@ options_templates = {}
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
- "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
+ "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
@@ -262,10 +324,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
- "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
+ "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
- "img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
+ "img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
@@ -293,31 +355,30 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
+ "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
options_templates.update(options_section(('upscaling', "Upscaling"), {
- "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
- "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
+ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
+ "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
- "SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
- "SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
- "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
}))
options_templates.update(options_section(('system', "System"), {
"show_warnings": OptionInfo(False, "Show warnings in console."),
- "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
+ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
+ "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -339,20 +400,27 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
+ "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
- "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
+ "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
- "randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
+ "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
+}))
+
+options_templates.update(options_section(('optimizations', "Optimizations"), {
+ "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+ "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
+ "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -364,30 +432,35 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
- "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
- "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
- "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
- "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
- "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
- "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
+ "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
+ "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
+ "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
+ "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
- "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
- "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
- "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
- "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
- "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
+ "deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
+ "deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
+ "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
}))
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
+ "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
+ "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
- "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
- "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
- "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
+ "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
+ "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
+ "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
+ "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
+ "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
@@ -400,17 +473,16 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
- "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
- "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
+ "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
- "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
- "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
+ "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
+ "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
- "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
- "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
- "gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
+ "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
}))
options_templates.update(options_section(('infotext', "Infotext"), {
@@ -423,27 +495,27 @@ options_templates.update(options_section(('infotext', "Infotext"), {
options_templates.update(options_section(('ui', "Live previews"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
- "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
+ "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
- "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
+ "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
- "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
- "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
+ "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
+ "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
- 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
+ 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
+ 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
- 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
+ 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
}))
@@ -460,6 +532,7 @@ options_templates.update(options_section((None, "Hidden options"), {
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
+
options_templates.update()
@@ -571,7 +644,9 @@ class Options:
func()
def dumpjson(self):
- d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
+ d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
+ d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
+ d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
return json.dumps(d)
def add_option(self, key, info):
@@ -582,11 +657,11 @@ class Options:
section_ids = {}
settings_items = self.data_labels.items()
- for k, item in settings_items:
+ for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
- self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+ self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
@@ -748,11 +823,14 @@ def walk_files(path, allowed_extensions=None):
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
- for root, dirs, files in os.walk(path):
+ for root, _, files in os.walk(path, followlinks=True):
for filename in files:
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
+ if not opts.list_hidden_files and ("/." in root or "\\." in root):
+ continue
+
yield os.path.join(root, filename)
diff --git a/modules/styles.py b/modules/styles.py
index 11642075..34e1b5e1 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,18 +1,9 @@
-# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
-from __future__ import annotations
-
import csv
import os
import os.path
import typing
-import collections.abc as abc
-import tempfile
import shutil
-if typing.TYPE_CHECKING:
- # Only import this when code is being type-checked, it doesn't have any effect at runtime
- from .processing import StableDiffusionProcessing
-
class PromptStyle(typing.NamedTuple):
name: str
@@ -52,7 +43,7 @@ class StyleDatabase:
return
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
- reader = csv.DictReader(file)
+ reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
# Support loading old CSV format with "name, text"-columns
prompt = row["prompt"] if "prompt" in row else row["text"]
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 05595323..497568eb 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -179,7 +179,7 @@ def efficient_dot_product_attention(
chunk_idx,
min(query_chunk_size, q_tokens)
)
-
+
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@@ -201,14 +201,15 @@ def efficient_dot_product_attention(
key=key,
value=value,
)
-
- # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
- # and pass slices to be mutated, instead of torch.cat()ing the returned slices
- res = torch.cat([
- compute_query_chunk_attn(
+
+ res = torch.zeros_like(query)
+ for i in range(math.ceil(q_tokens / query_chunk_size)):
+ attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
- ) for i in range(math.ceil(q_tokens / query_chunk_size))
- ], dim=1)
+ )
+
+ res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
+
return res
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index ba1bdcd4..8e667a4d 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,10 +1,8 @@
import cv2
import requests
import os
-from collections import defaultdict
-from math import log, sqrt
import numpy as np
-from PIL import Image, ImageDraw
+from PIL import ImageDraw
GREEN = "#0F0"
BLUE = "#00F"
@@ -12,63 +10,64 @@ RED = "#F00"
def crop_image(im, settings):
- """ Intelligently crop an image to the subject matter """
+ """ Intelligently crop an image to the subject matter """
- scale_by = 1
- if is_landscape(im.width, im.height):
- scale_by = settings.crop_height / im.height
- elif is_portrait(im.width, im.height):
- scale_by = settings.crop_width / im.width
- elif is_square(im.width, im.height):
- if is_square(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_landscape(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_portrait(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_height / im.height
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
- im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
- im_debug = im.copy()
- focus = focal_point(im_debug, settings)
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(settings.crop_height / 2)
- x_half = int(settings.crop_width / 2)
+ focus = focal_point(im_debug, settings)
- x1 = focus.x - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + settings.crop_width > im.width:
- x1 = im.width - settings.crop_width
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
- y1 = focus.y - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + settings.crop_height > im.height:
- y1 = im.height - settings.crop_height
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
- x2 = x1 + settings.crop_width
- y2 = y1 + settings.crop_height
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
- crop = [x1, y1, x2, y2]
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
- results = []
+ crop = [x1, y1, x2, y2]
- results.append(im.crop(tuple(crop)))
+ results = []
- if settings.annotate_image:
- d = ImageDraw.Draw(im_debug)
- rect = list(crop)
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- results.append(im_debug)
- if settings.destop_view_image:
- im_debug.show()
+ results.append(im.crop(tuple(crop)))
- return results
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im_debug)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
+ if settings.destop_view_image:
+ im_debug.show()
+
+ return results
def focal_point(im, settings):
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
@@ -88,7 +87,7 @@ def focal_point(im, settings):
corner_centroid = None
if len(corner_points) > 0:
corner_centroid = centroid(corner_points)
- corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid)
entropy_centroid = None
@@ -100,7 +99,7 @@ def focal_point(im, settings):
face_centroid = None
if len(face_points) > 0:
face_centroid = centroid(face_points)
- face_centroid.weight = settings.face_points_weight / weight_pref_total
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid)
average_point = poi_average(pois, settings)
@@ -134,7 +133,7 @@ def focal_point(im, settings):
d.rectangle(f.bounding(4), outline=color)
d.ellipse(average_point.bounding(max_size), outline=GREEN)
-
+
return average_point
@@ -185,7 +184,7 @@ def image_face_points(im, settings):
try:
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except:
+ except Exception:
continue
if len(faces) > 0:
@@ -262,10 +261,11 @@ def image_entropy(im):
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
+
def centroid(pois):
- x = [poi.x for poi in pois]
- y = [poi.y for poi in pois]
- return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
def poi_average(pois, settings):
@@ -283,59 +283,59 @@ def poi_average(pois, settings):
def is_landscape(w, h):
- return w > h
+ return w > h
def is_portrait(w, h):
- return h > w
+ return h > w
def is_square(w, h):
- return w == h
+ return w == h
def download_and_cache_models(dirname):
- download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- model_file_name = 'face_detection_yunet.onnx'
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
- cache_file = os.path.join(dirname, model_file_name)
- if not os.path.exists(cache_file):
- print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
- response = requests.get(download_url)
- with open(cache_file, "wb") as f:
- f.write(response.content)
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
- if os.path.exists(cache_file):
- return cache_file
- return None
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
class PointOfInterest:
- def __init__(self, x, y, weight=1.0, size=10):
- self.x = x
- self.y = y
- self.weight = weight
- self.size = size
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
- def bounding(self, size):
- return [
- self.x - size//2,
- self.y - size//2,
- self.x + size//2,
- self.y + size//2
- ]
+ def bounding(self, size):
+ return [
+ self.x - size // 2,
+ self.y - size // 2,
+ self.x + size // 2,
+ self.y + size // 2
+ ]
class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
- self.crop_width = crop_width
- self.crop_height = crop_height
- self.corner_points_weight = corner_points_weight
- self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = face_points_weight
- self.annotate_image = annotate_image
- self.destop_view_image = False
- self.dnn_model_path = dnn_model_path
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 41610e03..b9621fc9 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -118,7 +118,7 @@ class PersonalizedBase(Dataset):
weight = torch.ones(latent_sample.shape)
else:
weight = None
-
+
if latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
else:
@@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader):
return self
def collate_wrapper_random(batch):
- return BatchLoaderRandom(batch)
\ No newline at end of file
+ return BatchLoaderRandom(batch)
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 5593f88c..5858a55f 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -2,10 +2,8 @@ import base64
import json
import numpy as np
import zlib
-from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
-from fonts.ttf import Roboto
+from PIL import Image, ImageDraw, ImageFont
import torch
-from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
@@ -17,7 +15,7 @@ class EmbeddingEncoder(json.JSONEncoder):
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
- json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+ json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:
@@ -136,11 +134,8 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = srcimage.copy()
fontsize = 32
if textfont is None:
- try:
- textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
- textfont = opts.font or Roboto
- except Exception:
- textfont = Roboto
+ from modules.images import get_font
+ textfont = get_font(fontsize)
factor = 1.5
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index f63fc72f..c56bea45 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -12,7 +12,7 @@ class LearnScheduleIterator:
self.it = 0
self.maxit = 0
try:
- for i, pair in enumerate(pairs):
+ for pair in pairs:
if not pair.strip():
continue
tmp = pair.split(':')
@@ -32,8 +32,8 @@ class LearnScheduleIterator:
self.maxit += 1
return
assert self.rates
- except (ValueError, AssertionError):
- raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+ except (ValueError, AssertionError) as e:
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
def __iter__(self):
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index da0bcb26..a009d8e8 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,13 +1,9 @@
import os
from PIL import Image, ImageOps
import math
-import platform
-import sys
import tqdm
-import time
from modules import paths, shared, images, deepbooru
-from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
@@ -129,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
default=None
)
return wh and center_crop(image, *wh)
-
+
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4368eb63..d489ed1e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,7 +1,6 @@
import os
import sys
import traceback
-import inspect
from collections import namedtuple
import torch
@@ -30,7 +29,7 @@ textual_inversion_templates = {}
def list_textual_inversion_templates():
textual_inversion_templates.clear()
- for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
for fn in fns:
path = os.path.join(root, fn)
@@ -167,8 +166,7 @@ class EmbeddingDatabase:
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
@@ -199,7 +197,7 @@ class EmbeddingDatabase:
if not os.path.isdir(embdir.path):
return
- for root, dirs, fns in os.walk(embdir.path, followlinks=True):
+ for root, _, fns in os.walk(embdir.path, followlinks=True):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -216,7 +214,7 @@ class EmbeddingDatabase:
def load_textual_inversion_embeddings(self, force_reload=False):
if not force_reload:
need_reload = False
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
if embdir.has_changed():
need_reload = True
break
@@ -229,7 +227,7 @@ class EmbeddingDatabase:
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
self.load_from_dir(embdir)
embdir.update()
@@ -325,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
- tensorboard_writer.add_scalar(tag=tag,
+ tensorboard_writer.add_scalar(tag=tag,
scalar_value=value, global_step=step)
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
# Convert a pil image to a torch tensor
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
- img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
len(pil_image.getbands()))
img_tensor = img_tensor.permute((2, 0, 1))
-
+
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
@@ -404,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
-
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
@@ -414,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
-
+
if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory)
@@ -441,7 +439,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
-
+
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
@@ -470,7 +468,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
@@ -487,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if clip_grad:
clip_grad_sched.step(embedding.step)
-
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight:
@@ -515,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
-
+
if clip_grad:
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
@@ -603,7 +601,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
+ except Exception:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint()
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 16841d0f..f022381c 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,18 +1,15 @@
import modules.scripts
-from modules import sd_samplers
+from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
- StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
-import modules.processing as processing
from modules.ui import plaintext_to_html
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts)
- p = StableDiffusionProcessingTxt2Img(
+ p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
@@ -53,7 +50,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p)
+ processed = processing.process_images(p)
p.close()
diff --git a/modules/ui.py b/modules/ui.py
index f07bcc41..3be5257a 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,29 +1,23 @@
-import html
import json
-import math
import mimetypes
import os
-import platform
-import random
import sys
-import tempfile
-import time
import traceback
-from functools import partial, reduce
+from functools import reduce
import warnings
import gradio as gr
import gradio.routes
import gradio.utils
import numpy as np
-from PIL import Image, PngImagePlugin
+from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing, progress
-from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
-from modules.shared import opts, cmd_opts, restricted_opts
+from modules.shared import opts, cmd_opts
import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
@@ -34,7 +28,6 @@ import modules.shared as shared
import modules.styles
import modules.textual_inversion.ui
from modules import prompt_parser
-from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.textual_inversion import textual_inversion
@@ -59,7 +52,7 @@ if cmd_opts.ngrok is not None:
ngrok.connect(
cmd_opts.ngrok,
cmd_opts.port if cmd_opts.port is not None else 7860,
- cmd_opts.ngrok_region
+ cmd_opts.ngrok_options
)
@@ -82,6 +75,7 @@ clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
restore_progress_symbol = '\U0001F300' # 🌀
+detect_image_size_symbol = '\U0001F4D0' # 📐
def plaintext_to_html(text):
@@ -93,16 +87,6 @@ def send_gradio_gallery_to_image(x):
return None
return image_from_url_text(x[0])
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- if isinstance(x, gr.Tabs) and x.elem_id is not None:
- # Tabs element can't have a label, have to use elem_id instead
- func(f"{path}/Tabs@{x.elem_id}", x)
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(f"{path}/{x.label}", x)
-
def add_style(name: str, prompt: str, negative_prompt: str):
if name is None:
@@ -206,8 +190,8 @@ def create_seed_inputs(target_interface):
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
- random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
- random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
+ random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[])
+ random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[])
def change_visibility(show):
return {comp: gr_show(show) for comp in seed_extras}
@@ -246,7 +230,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
- except json.decoder.JSONDecodeError as e:
+ except json.decoder.JSONDecodeError:
if gen_info_string != '':
print("Error parsing JSON generation info:", file=sys.stderr)
print(gen_info_string, file=sys.stderr)
@@ -423,7 +407,7 @@ def create_sampler_and_steps_selection(choices, tabname):
def ordered_ui_categories():
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
+ for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
@@ -591,7 +575,7 @@ def create_ui():
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
restore_progress_button.click(
fn=progress.restore_progress,
@@ -704,19 +688,19 @@ def create_ui():
img2img_selected_tab = gr.State(0)
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
@@ -736,8 +720,8 @@ def create_ui():
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = ' Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(
- f"Process images in a directory on the same machine where the server is running." +
- f" Use an empty output directory to save pictures normally instead of writing to the output directory." +
+ "
Process images in a directory on the same machine where the server is running." +
+ " Use an empty output directory to save pictures normally instead of writing to the output directory." +
f" Add inpaint batch mask directory to enable inpaint batch processing."
f"{hidden}
"
)
@@ -746,7 +730,6 @@ def create_ui():
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
- img2img_image_inputs = [init_img, sketch, init_img_with_mask, inpaint_color_sketch]
for i, tab in enumerate(img2img_tabs):
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
@@ -790,6 +773,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+ detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
with gr.Tab(label="Resize by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
@@ -967,7 +951,16 @@ def create_ui():
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
+
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
+
+ detect_image_size_btn.click(
+ fn=lambda w, h, _: (w or gr.update(), h or gr.update()),
+ _js="currentImg2imgSourceResolution",
+ inputs=[dummy_component, dummy_component, dummy_component],
+ outputs=[width, height],
+ show_progress=False,
+ )
restore_progress_button.click(
fn=progress.restore_progress,
@@ -1189,7 +1182,7 @@ def create_ui():
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
+
with gr.Column(visible=False) as process_multicrop_col:
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
with gr.Row():
@@ -1201,7 +1194,7 @@ def create_ui():
with gr.Row():
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
-
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1230,7 +1223,7 @@ def create_ui():
)
def get_textual_inversion_template_names():
- return sorted([x for x in textual_inversion.textual_inversion_templates])
+ return sorted(textual_inversion.textual_inversion_templates)
with gr.Tab(label="Train", id="train"):
gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
")
@@ -1238,13 +1231,13 @@ def create_ui():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
- train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
- create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
with FormRow():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
-
+
with FormRow():
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
@@ -1290,8 +1283,8 @@ def create_ui():
with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
- ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
- ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
+ gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
create_embedding.click(
@@ -1479,6 +1472,8 @@ def create_ui():
return res
+ loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
+
components = []
component_dict = {}
shared.settings_components = component_dict
@@ -1566,6 +1561,9 @@ def create_ui():
current_row.__exit__()
current_tab.__exit__()
+ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
+ loadsave.create_ui()
+
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
@@ -1578,7 +1576,7 @@ def create_ui():
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
-
+
def unload_sd_weights():
modules.sd_models.unload_model_weights()
@@ -1622,12 +1620,8 @@ def create_ui():
outputs=[]
)
- def request_restart():
- shared.state.interrupt()
- shared.state.need_restart = True
-
restart_gradio.click(
- fn=request_restart,
+ fn=shared.state.request_restart,
_js='restart_reload',
inputs=[],
outputs=[],
@@ -1639,7 +1633,7 @@ def create_ui():
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
- (train_interface, "Train", "ti"),
+ (train_interface, "Train", "train"),
]
interfaces += script_callbacks.ui_tabs_callback()
@@ -1654,21 +1648,34 @@ def create_ui():
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"):
- for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
+ for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
parameters_copypaste.connect_paste_params_buttons()
with gr.Tabs(elem_id="tabs") as tabs:
- for interface, label, ifid in interfaces:
+ tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
+ sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))
+
+ for interface, label, ifid in sorted_interfaces:
if label in shared.opts.hidden_tabs:
continue
with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
interface.render()
+ for interface, _label, ifid in interfaces:
+ if ifid in ["extensions", "settings"]:
+ continue
+
+ loadsave.add_block(interface, ifid)
+
+ loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
+
+ loadsave.setup_ui()
+
if os.path.exists(os.path.join(script_path, "notification.mp3")):
- audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+ gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
footer = shared.html("footer.html")
footer = footer.format(versions=versions_html())
@@ -1681,7 +1688,7 @@ def create_ui():
outputs=[text_settings, result],
)
- for i, k, item in quicksettings_list:
+ for _i, k, _item in quicksettings_list:
component = component_dict[k]
info = opts.data_labels[k]
@@ -1755,97 +1762,8 @@ def create_ui():
]
)
- ui_config_file = cmd_opts.ui_config_file
- ui_settings = {}
- settings_count = len(ui_settings)
- error_loading = False
-
- try:
- if os.path.exists(ui_config_file):
- with open(ui_config_file, "r", encoding="utf8") as file:
- ui_settings = json.load(file)
- except Exception:
- error_loading = True
- print("Error loading settings:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def loadsave(path, x):
- def apply_field(obj, field, condition=None, init_field=None):
- key = f"{path}/{field}"
-
- if getattr(obj, 'custom_script_source', None) is not None:
- key = f"customscript/{obj.custom_script_source}/{key}"
-
- if getattr(obj, 'do_not_save_to_config', False):
- return
-
- saved_value = ui_settings.get(key, None)
- if saved_value is None:
- ui_settings[key] = getattr(obj, field)
- elif condition and not condition(saved_value):
- pass
-
- # this warning is generally not useful;
- # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
- else:
- setattr(obj, field, saved_value)
- if init_field is not None:
- init_field(saved_value)
-
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
- apply_field(x, 'visible')
-
- if type(x) == gr.Slider:
- apply_field(x, 'value')
- apply_field(x, 'minimum')
- apply_field(x, 'maximum')
- apply_field(x, 'step')
-
- if type(x) == gr.Radio:
- apply_field(x, 'value', lambda val: val in x.choices)
-
- if type(x) == gr.Checkbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Textbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Number:
- apply_field(x, 'value')
-
- if type(x) == gr.Dropdown:
- def check_dropdown(val):
- if getattr(x, 'multiselect', False):
- return all([value in x.choices for value in val])
- else:
- return val in x.choices
-
- apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
-
- def check_tab_id(tab_id):
- tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
- if type(tab_id) == str:
- tab_ids = [t.id for t in tab_items]
- return tab_id in tab_ids
- elif type(tab_id) == int:
- return tab_id >= 0 and tab_id < len(tab_items)
- else:
- return False
-
- if type(x) == gr.Tabs:
- apply_field(x, 'selected', check_tab_id)
-
- visit(txt2img_interface, loadsave, "txt2img")
- visit(img2img_interface, loadsave, "img2img")
- visit(extras_interface, loadsave, "extras")
- visit(modelmerger_interface, loadsave, "modelmerger")
- visit(train_interface, loadsave, "train")
-
- loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)
-
- if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
- with open(ui_config_file, "w", encoding="utf8") as file:
- json.dump(ui_settings, file, indent=4)
+ loadsave.dump_defaults()
+ demo.ui_loadsave = loadsave
# Required as a workaround for change() event not triggering when loading values from ui-config.json
interp_description.value = update_interp_description(interp_method.value)
@@ -1933,15 +1851,15 @@ def versions_html():
return f"""
version: {tag}
- •
+ •
python: {python_version}
- •
+ •
torch: {getattr(torch, '__long_version__',torch.__version__)}
- •
+ •
xformers: {xformers_version}
- •
+ •
gradio: {gr.__version__}
- •
+ •
checkpoint: N/A
"""
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index d9faf85a..4ba3bdd7 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,6 +1,7 @@
import json
import os.path
import sys
+import threading
import time
from datetime import datetime
import traceback
@@ -51,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename)
-
- shared.state.interrupt()
- shared.state.need_restart = True
+ shared.state.request_restart()
def save_config_state(name):
@@ -91,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
if restore_type == "webui" or restore_type == "both":
config_states.restore_webui_config(config_state)
- shared.state.interrupt()
- shared.state.need_restart = True
+ shared.state.request_restart()
return ""
@@ -140,7 +138,9 @@ def extension_table():
Extension
URL
- Version
+ Branch
+ Version
+ Date
Update
@@ -148,6 +148,7 @@ def extension_table():
"""
for ext in extensions.extensions:
+ ext: extensions.Extension
ext.read_info_from_repo()
remote = f"""{html.escape("built-in" if ext.is_builtin else ext.remote or '')} """
@@ -169,7 +170,9 @@ def extension_table():
{html.escape(ext.name)}
{remote}
+ {ext.branch}
{version_link}
+ {time.asctime(time.gmtime(ext.commit_date))}
{ext_status}
"""
@@ -467,7 +470,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
{html.escape(description)}Added: {html.escape(added)}
{install_code}
-
+
"""
for tag in [x for x in extension_tags if x not in tags]:
@@ -484,13 +487,20 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
return code, list(tags)
+def preload_extensions_git_metadata():
+ for extension in extensions.extensions:
+ extension.read_info_from_repo()
+
+
def create_ui():
import modules.ui
config_states.list_config_states()
+ threading.Thread(target=preload_extensions_git_metadata).start()
+
with gr.Blocks(analytics_enabled=False) as ui:
- with gr.Tabs(elem_id="tabs_extensions") as tabs:
+ with gr.Tabs(elem_id="tabs_extensions"):
with gr.TabItem("Installed", id="installed"):
with gr.Row(elem_id="extensions_installed_top"):
@@ -508,7 +518,8 @@ def create_ui():
"""
info = gr.HTML(html)
- extensions_table = gr.HTML(lambda: extension_table())
+ extensions_table = gr.HTML('Loading...')
+ ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
apply.click(
fn=apply_and_restart,
@@ -535,9 +546,9 @@ def create_ui():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
- with gr.Row():
+ with gr.Row():
search_extensions_text = gr.Text(label="Search").style(container=False)
-
+
install_result = gr.HTML()
available_extensions_table = gr.HTML()
@@ -579,9 +590,9 @@ def create_ui():
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
- fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
inputs=[install_dirname, install_url, install_branch],
- outputs=[extensions_table, install_result],
+ outputs=[install_url, extensions_table, install_result],
)
with gr.TabItem("Backup/Restore"):
@@ -595,7 +606,8 @@ def create_ui():
config_save_button = gr.Button(value="Save Current Config")
config_states_info = gr.HTML("")
- config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
+ config_states_table = gr.HTML("Loading...")
+ ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
@@ -608,4 +620,5 @@ def create_ui():
outputs=[config_states_table],
)
+
return ui
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 8c3dea56..8bd0722e 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -1,11 +1,9 @@
-import glob
import os.path
import urllib.parse
from pathlib import Path
-from PIL import PngImagePlugin
from modules import shared
-from modules.images import read_info_from_image
+from modules.images import read_info_from_image, save_image_with_geninfo
import gradio as gr
import json
import html
@@ -27,11 +25,11 @@ def register_page(page):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
- if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
+ if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg", ".webp"):
+ if ext not in (".png", ".jpg", ".jpeg", ".webp"):
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304
@@ -91,7 +89,7 @@ class ExtraNetworksPage:
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for root, dirs, files in os.walk(parentdir):
+ for root, dirs, _ in os.walk(parentdir, followlinks=True):
for dirname in dirs:
x = os.path.join(root, dirname)
@@ -106,6 +104,9 @@ class ExtraNetworksPage:
if not is_empty and not subdir.endswith("/"):
subdir = subdir + "/"
+ if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
+ continue
+
subdirs[subdir] = 1
if subdirs:
@@ -148,6 +149,10 @@ class ExtraNetworksPage:
return []
def create_html_for_item(self, item, tabname):
+ """
+ Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
+ """
+
preview = item.get("preview", None)
onclick = item.get("onclick", None)
@@ -170,9 +175,15 @@ class ExtraNetworksPage:
if filename.startswith(absdir):
local_path = filename[len(absdir):]
- # if this is true, the item must not be show in the default view, and must instead only be
+ # if this is true, the item must not be shown in the default view, and must instead only be
# shown when searching for it
- serach_only = "/." in local_path or "\\." in local_path
+ if shared.opts.extra_networks_hidden_models == "Always":
+ search_only = False
+ else:
+ search_only = "/." in local_path or "\\." in local_path
+
+ if search_only and shared.opts.extra_networks_hidden_models == "Never":
+ return ""
args = {
"style": f"'display: none; {height}{width}{background_image}'",
@@ -185,7 +196,7 @@ class ExtraNetworksPage:
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
- "serach_only": " search_only" if serach_only else "",
+ "search_only": " search_only" if search_only else "",
}
return self.card_page.format(**args)
@@ -195,7 +206,7 @@ class ExtraNetworksPage:
Find a preview PNG for a given path (without extension) and call link_preview on it.
"""
- preview_extensions = ["png", "jpg", "webp"]
+ preview_extensions = ["png", "jpg", "jpeg", "webp"]
if shared.opts.samples_format not in preview_extensions:
preview_extensions.append(shared.opts.samples_format)
@@ -263,13 +274,13 @@ def create_ui(container, button, tabname):
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
- with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+ with gr.Tabs(elem_id=tabname+"_extra_tabs"):
for page in ui.stored_extra_pages:
page_id = page.title.lower().replace(" ", "_")
with gr.Tab(page.title, id=page_id):
elem_id = f"{tabname}_{page_id}_cards_html"
- page_elem = gr.HTML('', elem_id=elem_id)
+ page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
@@ -283,13 +294,24 @@ def create_ui(container, button, tabname):
def toggle_visibility(is_visible):
is_visible = not is_visible
- if is_visible and not ui.pages_contents:
+ return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
+
+ def fill_tabs(is_empty):
+ """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
+
+ if not ui.pages_contents:
refresh()
- return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
+ if is_empty:
+ return True, *ui.pages_contents
+
+ return True, *[gr.update() for _ in ui.pages_contents]
state_visible = gr.State(value=False)
- button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
+
+ state_empty = gr.State(value=True)
+ button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
def refresh():
for pg in ui.stored_extra_pages:
@@ -327,18 +349,13 @@ def setup_ui(ui, gallery):
is_allowed = False
for extra_page in ui.stored_extra_pages:
- if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
is_allowed = True
break
assert is_allowed, f'writing to {filename} is not allowed'
- if geninfo:
- pnginfo_data = PngImagePlugin.PngInfo()
- pnginfo_data.add_text('parameters', geninfo)
- image.save(filename, pnginfo=pnginfo_data)
- else:
- image.save(filename)
+ save_image_with_geninfo(image, geninfo, filename)
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
new file mode 100644
index 00000000..728fec9e
--- /dev/null
+++ b/modules/ui_loadsave.py
@@ -0,0 +1,208 @@
+import json
+import os
+
+import gradio as gr
+
+from modules import errors
+from modules.ui_components import ToolButton
+
+
+class UiLoadsave:
+ """allows saving and restorig default values for gradio components"""
+
+ def __init__(self, filename):
+ self.filename = filename
+ self.ui_settings = {}
+ self.component_mapping = {}
+ self.error_loading = False
+ self.finalized_ui = False
+
+ self.ui_defaults_view = None
+ self.ui_defaults_apply = None
+ self.ui_defaults_review = None
+
+ try:
+ if os.path.exists(self.filename):
+ self.ui_settings = self.read_from_file()
+ except Exception as e:
+ self.error_loading = True
+ errors.display(e, "loading settings")
+
+ def add_component(self, path, x):
+ """adds component to the registry of tracked components"""
+
+ assert not self.finalized_ui
+
+ def apply_field(obj, field, condition=None, init_field=None):
+ key = f"{path}/{field}"
+
+ if getattr(obj, 'custom_script_source', None) is not None:
+ key = f"customscript/{obj.custom_script_source}/{key}"
+
+ if getattr(obj, 'do_not_save_to_config', False):
+ return
+
+ saved_value = self.ui_settings.get(key, None)
+ if saved_value is None:
+ self.ui_settings[key] = getattr(obj, field)
+ elif condition and not condition(saved_value):
+ pass
+ else:
+ setattr(obj, field, saved_value)
+ if init_field is not None:
+ init_field(saved_value)
+
+ if field == 'value' and key not in self.component_mapping:
+ self.component_mapping[key] = x
+
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
+ apply_field(x, 'visible')
+
+ if type(x) == gr.Slider:
+ apply_field(x, 'value')
+ apply_field(x, 'minimum')
+ apply_field(x, 'maximum')
+ apply_field(x, 'step')
+
+ if type(x) == gr.Radio:
+ apply_field(x, 'value', lambda val: val in x.choices)
+
+ if type(x) == gr.Checkbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Textbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Number:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Dropdown:
+ def check_dropdown(val):
+ if getattr(x, 'multiselect', False):
+ return all(value in x.choices for value in val)
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+
+ def check_tab_id(tab_id):
+ tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
+ if type(tab_id) == str:
+ tab_ids = [t.id for t in tab_items]
+ return tab_id in tab_ids
+ elif type(tab_id) == int:
+ return 0 <= tab_id < len(tab_items)
+ else:
+ return False
+
+ if type(x) == gr.Tabs:
+ apply_field(x, 'selected', check_tab_id)
+
+ def add_block(self, x, path=""):
+ """adds all components inside a gradio block x to the registry of tracked components"""
+
+ if hasattr(x, 'children'):
+ if isinstance(x, gr.Tabs) and x.elem_id is not None:
+ # Tabs element can't have a label, have to use elem_id instead
+ self.add_component(f"{path}/Tabs@{x.elem_id}", x)
+ for c in x.children:
+ self.add_block(c, path)
+ elif x.label is not None:
+ self.add_component(f"{path}/{x.label}", x)
+
+ def read_from_file(self):
+ with open(self.filename, "r", encoding="utf8") as file:
+ return json.load(file)
+
+ def write_to_file(self, current_ui_settings):
+ with open(self.filename, "w", encoding="utf8") as file:
+ json.dump(current_ui_settings, file, indent=4)
+
+ def dump_defaults(self):
+ """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
+
+ if self.error_loading and os.path.exists(self.filename):
+ return
+
+ self.write_to_file(self.ui_settings)
+
+ def iter_changes(self, current_ui_settings, values):
+ """
+ given a dictionary with defaults from a file and current values from gradio elements, returns
+ an iterator over tuples of values that are not the same between the file and the current;
+ tuple contents are: path, old value, new value
+ """
+
+ for (path, component), new_value in zip(self.component_mapping.items(), values):
+ old_value = current_ui_settings.get(path)
+
+ choices = getattr(component, 'choices', None)
+ if isinstance(new_value, int) and choices:
+ if new_value >= len(choices):
+ continue
+
+ new_value = choices[new_value]
+
+ if new_value == old_value:
+ continue
+
+ if old_value is None and new_value == '' or new_value == []:
+ continue
+
+ yield path, old_value, new_value
+
+ def ui_view(self, *values):
+ text = ["Path Old value New value "]
+
+ for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
+ if old_value is None:
+ old_value = "None "
+
+ text.append(f"{path} {old_value} {new_value} ")
+
+ if len(text) == 1:
+ text.append("No changes ")
+
+ text.append(" ")
+ return "".join(text)
+
+ def ui_apply(self, *values):
+ num_changed = 0
+
+ current_ui_settings = self.read_from_file()
+
+ for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
+ num_changed += 1
+ current_ui_settings[path] = new_value
+
+ if num_changed == 0:
+ return "No changes."
+
+ self.write_to_file(current_ui_settings)
+
+ return f"Wrote {num_changed} changes."
+
+ def create_ui(self):
+ """creates ui elements for editing defaults UI, without adding any logic to them"""
+
+ gr.HTML(
+ f"This page allows you to change default values in UI elements on other tabs. "
+ f"Make your changes, press 'View changes' to review the changed default values, "
+ f"then press 'Apply' to write them to {self.filename}. "
+ f"New defaults will apply after you restart the UI. "
+ )
+
+ with gr.Row():
+ self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
+ self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
+
+ self.ui_defaults_review = gr.HTML("")
+
+ def setup_ui(self):
+ """adds logic to elements created with create_ui; all add_block class must be made before this"""
+
+ assert not self.finalized_ui
+ self.finalized_ui = True
+
+ self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
+ self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index f25639e5..c7dc1154 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -1,5 +1,5 @@
import gradio as gr
-from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+from modules import scripts, shared, ui_common, postprocessing, call_queue
import modules.generation_parameters_copypaste as parameters_copypaste
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 46fa9cb0..f05049e1 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -23,7 +23,7 @@ def register_tmp_file(gradio, filename):
def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
- return any([filename in fileset for fileset in gradio.temp_file_sets])
+ return any(filename in fileset for fileset in gradio.temp_file_sets)
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
@@ -72,7 +72,7 @@ def cleanup_tmpdr():
if temp_dir == "" or not os.path.isdir(temp_dir):
return
- for root, dirs, files in os.walk(temp_dir, topdown=False):
+ for root, _, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
diff --git a/modules/upscaler.py b/modules/upscaler.py
index e2eaa730..8acb6e96 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -2,8 +2,6 @@ import os
from abc import abstractmethod
import PIL
-import numpy as np
-import torch
from PIL import Image
import modules.shared
@@ -43,9 +41,9 @@ class Upscaler:
os.makedirs(self.model_path, exist_ok=True)
try:
- import cv2
+ import cv2 # noqa: F401
self.can_tile = True
- except:
+ except Exception:
pass
@abstractmethod
@@ -57,7 +55,7 @@ class Upscaler:
dest_w = int(img.width * scale)
dest_h = int(img.height * scale)
- for i in range(3):
+ for _ in range(3):
shape = (img.width, img.height)
img = self.do_upscale(img, selected_model)
diff --git a/modules/xlmr.py b/modules/xlmr.py
index beab3fdf..a407a3ca 100644
--- a/modules/xlmr.py
+++ b/modules/xlmr.py
@@ -1,4 +1,4 @@
-from transformers import BertPreTrainedModel,BertModel,BertConfig
+from transformers import BertPreTrainedModel, BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
- # modify initialization for autoloading
+ # modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
@@ -74,7 +74,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
- return features['projection_state']
+ return features['projection_state']
def forward(
self,
@@ -134,4 +134,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
- config_class= RobertaSeriesConfig
\ No newline at end of file
+ config_class= RobertaSeriesConfig
diff --git a/package.json b/package.json
new file mode 100644
index 00000000..c0ba4067
--- /dev/null
+++ b/package.json
@@ -0,0 +1,11 @@
+{
+ "name": "stable-diffusion-webui",
+ "version": "0.0.0",
+ "devDependencies": {
+ "eslint": "^8.40.0"
+ },
+ "scripts": {
+ "lint": "eslint .",
+ "fix": "eslint --fix ."
+ }
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..d4a1bbf4
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,32 @@
+[tool.ruff]
+
+target-version = "py39"
+
+extend-select = [
+ "B",
+ "C",
+ "I",
+ "W",
+]
+
+exclude = [
+ "extensions",
+ "extensions-disabled",
+]
+
+ignore = [
+ "E501", # Line too long
+ "E731", # Do not assign a `lambda` expression, use a `def`
+
+ "I001", # Import block is un-sorted or un-formatted
+ "C901", # Function is too complex
+ "C408", # Rewrite as a literal
+ "W605", # invalid escape sequence, messes with some docstrings
+]
+
+[tool.ruff.per-file-ignores]
+"webui.py" = ["E402"] # Module level import not at top of file
+
+[tool.ruff.flake8-bugbear]
+# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
+extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
diff --git a/requirements.txt b/requirements.txt
index 35987a13..302b3dab 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,8 +2,6 @@ astunparse
blendmodes
accelerate
basicsr
-fonts
-font-roboto
gfpgan
gradio==3.29.0
numpy
@@ -31,3 +29,4 @@ torchsde
safetensors
psutil
rich
+tomesd
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 7bce02e5..17ae9484 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -5,19 +5,17 @@ basicsr==1.4.2
gfpgan==1.3.8
gradio==3.29.0
numpy==1.23.5
-Pillow==9.4.0
+Pillow==9.5.0
realesrgan==0.3.0
torch
omegaconf==2.2.3
pytorch_lightning==1.9.4
-scikit-image==0.19.2
-fonts
-font-roboto
+scikit-image==0.20.0
timm==0.6.7
piexif==1.1.3
einops==0.4.1
jsonmerge==1.8.0
-clean-fid==0.1.29
+clean-fid==0.1.35
resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7
@@ -28,3 +26,4 @@ torchsde==0.2.5
safetensors==0.3.1
httpcore<=0.15
fastapi==0.94.0
+tomesd==0.1.2
diff --git a/script.js b/script.js
index 03afe844..db4d9157 100644
--- a/script.js
+++ b/script.js
@@ -1,66 +1,72 @@
function gradioApp() {
- const elems = document.getElementsByTagName('gradio-app')
- const elem = elems.length == 0 ? document : elems[0]
+ const elems = document.getElementsByTagName('gradio-app');
+ const elem = elems.length == 0 ? document : elems[0];
- if (elem !== document) elem.getElementById = function(id){ return document.getElementById(id) }
- return elem.shadowRoot ? elem.shadowRoot : elem
+ if (elem !== document) {
+ elem.getElementById = function(id) {
+ return document.getElementById(id);
+ };
+ }
+ return elem.shadowRoot ? elem.shadowRoot : elem;
}
function get_uiCurrentTab() {
- return gradioApp().querySelector('#tabs button.selected')
+ return gradioApp().querySelector('#tabs button.selected');
}
function get_uiCurrentTabContent() {
- return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
+ return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])');
}
-uiUpdateCallbacks = []
-uiLoadedCallbacks = []
-uiTabChangeCallbacks = []
-optionsChangedCallbacks = []
-let uiCurrentTab = null
+var uiUpdateCallbacks = [];
+var uiLoadedCallbacks = [];
+var uiTabChangeCallbacks = [];
+var optionsChangedCallbacks = [];
+var uiCurrentTab = null;
-function onUiUpdate(callback){
- uiUpdateCallbacks.push(callback)
+function onUiUpdate(callback) {
+ uiUpdateCallbacks.push(callback);
}
-function onUiLoaded(callback){
- uiLoadedCallbacks.push(callback)
+function onUiLoaded(callback) {
+ uiLoadedCallbacks.push(callback);
}
-function onUiTabChange(callback){
- uiTabChangeCallbacks.push(callback)
+function onUiTabChange(callback) {
+ uiTabChangeCallbacks.push(callback);
}
-function onOptionsChanged(callback){
- optionsChangedCallbacks.push(callback)
+function onOptionsChanged(callback) {
+ optionsChangedCallbacks.push(callback);
}
-function runCallback(x, m){
+function runCallback(x, m) {
try {
- x(m)
+ x(m);
} catch (e) {
(console.error || console.log).call(console, e.message, e);
}
}
function executeCallbacks(queue, m) {
- queue.forEach(function(x){runCallback(x, m)})
+ queue.forEach(function(x) {
+ runCallback(x, m);
+ });
}
var executedOnLoaded = false;
document.addEventListener("DOMContentLoaded", function() {
- var mutationObserver = new MutationObserver(function(m){
- if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
+ var mutationObserver = new MutationObserver(function(m) {
+ if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) {
executedOnLoaded = true;
executeCallbacks(uiLoadedCallbacks);
}
executeCallbacks(uiUpdateCallbacks, m);
const newTab = get_uiCurrentTab();
- if ( newTab && ( newTab !== uiCurrentTab ) ) {
+ if (newTab && (newTab !== uiCurrentTab)) {
uiCurrentTab = newTab;
executeCallbacks(uiTabChangeCallbacks);
}
});
- mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
+ mutationObserver.observe(gradioApp(), {childList: true, subtree: true});
});
/**
@@ -69,33 +75,33 @@ document.addEventListener("DOMContentLoaded", function() {
document.addEventListener('keydown', function(e) {
var handled = false;
if (e.key !== undefined) {
- if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
+ if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} else if (e.keyCode !== undefined) {
- if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
+ if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
}
if (handled) {
- button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
+ var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
if (button) {
button.click();
}
e.preventDefault();
}
-})
+});
/**
* checks that a UI element is not in another hidden element or tab content
*/
function uiElementIsVisible(el) {
let isVisible = !el.closest('.\\!hidden');
- if ( ! isVisible ) {
+ if (!isVisible) {
return false;
}
- while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) {
- if ( ! isVisible ) {
+ while ((isVisible = el.closest('.tabitem')?.style.display) !== 'none') {
+ if (!isVisible) {
return false;
- } else if ( el.parentElement ) {
- el = el.parentElement
+ } else if (el.parentElement) {
+ el = el.parentElement;
} else {
break;
}
diff --git a/scripts/custom_code.py b/scripts/custom_code.py
index f36a3675..cc6f0d49 100644
--- a/scripts/custom_code.py
+++ b/scripts/custom_code.py
@@ -4,7 +4,7 @@ import ast
import copy
from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
+from modules.shared import cmd_opts
def convertExpr2Expression(expr):
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
index bb00fb3f..1e833fa8 100644
--- a/scripts/img2imgalt.py
+++ b/scripts/img2imgalt.py
@@ -149,9 +149,9 @@ class Script(scripts.Script):
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
return [
- info,
+ info,
override_sampler,
- override_prompt, original_prompt, original_negative_prompt,
+ override_prompt, original_prompt, original_negative_prompt,
override_steps, st,
override_strength,
cfg, randomness, sigma_adjustment,
@@ -191,17 +191,17 @@ class Script(scripts.Script):
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
-
+
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
-
+
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps)
-
+
noise_dt = combined_noise - (p.init_latent / sigmas[0])
-
+
p.seed = p.seed + 1
-
+
return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
p.sample = sample_extra
diff --git a/scripts/loopback.py b/scripts/loopback.py
index ad6609be..2d5feaf9 100644
--- a/scripts/loopback.py
+++ b/scripts/loopback.py
@@ -14,7 +14,7 @@ class Script(scripts.Script):
def show(self, is_img2img):
return is_img2img
- def ui(self, is_img2img):
+ def ui(self, is_img2img):
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
@@ -104,7 +104,7 @@ class Script(scripts.Script):
p.seed = processed.seed + 1
p.denoising_strength = calculate_denoising_strength(i + 1)
-
+
if state.skipped:
break
@@ -121,7 +121,7 @@ class Script(scripts.Script):
all_images.append(last_image)
p.inpainting_fill = original_inpainting_fill
-
+
if state.interrupted:
break
@@ -132,7 +132,7 @@ class Script(scripts.Script):
if opts.return_grid:
grids.append(grid)
-
+
all_images = grids + all_images
processed = Processed(p, all_images, initial_seed, initial_info)
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
index 670bb8ac..665dbe89 100644
--- a/scripts/outpainting_mk_2.py
+++ b/scripts/outpainting_mk_2.py
@@ -7,9 +7,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing, devices
+from modules import images
from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
@@ -72,7 +72,7 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
- np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
+ _np_src_image[:] * (1. - np_mask_rgb)
np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
index ddcbd2d3..ea0632b6 100644
--- a/scripts/poor_mans_outpainting.py
+++ b/scripts/poor_mans_outpainting.py
@@ -4,9 +4,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing, devices
+from modules import images, devices
from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
class Script(scripts.Script):
@@ -19,7 +19,7 @@ class Script(scripts.Script):
def ui(self, is_img2img):
if not is_img2img:
return None
-
+
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index ef1186ac..edb70ac0 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -98,13 +98,13 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'
upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
- pp.info[f"Postprocess upscaler"] = upscaler1.name
+ pp.info["Postprocess upscaler"] = upscaler1.name
if upscaler2 and upscaler_2_visibility > 0:
second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)
- pp.info[f"Postprocess upscaler 2"] = upscaler2.name
+ pp.info["Postprocess upscaler 2"] = upscaler2.name
pp.image = upscaled_image
@@ -134,4 +134,4 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
assert upscaler1, f'could not find upscaler named {upscaler_name}'
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
- pp.info[f"Postprocess upscaler"] = upscaler1.name
+ pp.info["Postprocess upscaler"] = upscaler1.name
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index e9b11517..88324fe6 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -1,14 +1,11 @@
import math
-from collections import namedtuple
-from copy import copy
-import random
import modules.scripts as scripts
import gradio as gr
from modules import images
-from modules.processing import process_images, Processed
-from modules.shared import opts, cmd_opts, state
+from modules.processing import process_images
+from modules.shared import opts, state
import modules.sd_samplers
@@ -99,7 +96,7 @@ class Script(scripts.Script):
p.prompt_for_display = positive_prompt
processed = process_images(p)
- grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
+ grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
processed.images.insert(0, grid)
processed.index_of_first_image = 1
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index f168389c..b918a764 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -1,6 +1,4 @@
import copy
-import math
-import os
import random
import sys
import traceback
@@ -11,8 +9,7 @@ import gradio as gr
from modules import sd_samplers
from modules.processing import Processed, process_images
-from PIL import Image
-from modules.shared import opts, cmd_opts, state
+from modules.shared import state
def process_string_tag(tag):
@@ -158,7 +155,7 @@ class Script(scripts.Script):
images = []
all_prompts = []
infotexts = []
- for n, args in enumerate(jobs):
+ for args in jobs:
state.job = f"{state.job_no + 1} out of {state.job_count}"
copy_p = copy.copy(p)
@@ -167,7 +164,7 @@ class Script(scripts.Script):
proc = process_images(copy_p)
images += proc.images
-
+
if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter)
all_prompts += proc.all_prompts
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index 332d76d9..e614c23b 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -4,9 +4,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image
-from modules import processing, shared, sd_samplers, images, devices
+from modules import processing, shared, images, devices
from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
class Script(scripts.Script):
@@ -16,7 +16,7 @@ class Script(scripts.Script):
def show(self, is_img2img):
return is_img2img
- def ui(self, is_img2img):
+ def ui(self, is_img2img):
info = gr.HTML("Will upscale the image by the selected scale factor; use width and height sliders to set tile size
")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
@@ -56,7 +56,7 @@ class Script(scripts.Script):
work = []
- for y, h, row in grid.tiles:
+ for _y, _h, row in grid.tiles:
for tiledata in row:
work.append(tiledata[2])
@@ -85,7 +85,7 @@ class Script(scripts.Script):
work_results += processed.images
image_index = 0
- for y, h, row in grid.tiles:
+ for _y, _h, row in grid.tiles:
for tiledata in row:
tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
image_index += 1
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index a725d74a..da820b39 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -10,15 +10,13 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
+from modules import images, sd_samplers, processing, sd_models, sd_vae
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
import modules.shared as shared
import modules.sd_samplers
import modules.sd_models
import modules.sd_vae
-import glob
-import os
import re
from modules.ui_components import ToolButton
@@ -86,7 +84,7 @@ def apply_checkpoint(p, x, xs):
info = modules.sd_models.get_closet_checkpoint_match(x)
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
- p.override_settings['sd_model_checkpoint'] = info.hash
+ p.override_settings['sd_model_checkpoint'] = info.name
def confirm_checkpoints(p, xs):
@@ -146,6 +144,11 @@ def apply_face_restore(p, opt, x):
p.restore_faces = is_active
+def apply_override(field):
+ def fun(p, x, xs):
+ p.override_settings[field] = x
+ return fun
+
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@@ -226,6 +229,8 @@ axis_options = [
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
+ AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
+ AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
]
@@ -316,7 +321,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
return Processed(p, [])
z_count = len(zs)
- sub_grids = [None] * z_count
+
for i in range(z_count):
start_index = (i * len(xs) * len(ys)) + i
end_index = start_index + len(xs) * len(ys)
@@ -706,7 +711,7 @@ class Script(scripts.Script):
if not include_sub_grids:
# Done with sub-grids, drop all related information:
- for sg in range(z_count):
+ for _ in range(z_count):
del processed.images[1]
del processed.all_prompts[1]
del processed.all_seeds[1]
diff --git a/style.css b/style.css
index 31b2ed5a..b300dfa1 100644
--- a/style.css
+++ b/style.css
@@ -320,20 +320,14 @@ button.custom-button{
div.dimensions-tools{
min-width: 0 !important;
max-width: fit-content;
- flex-direction: row;
- align-content: center;
+ flex-direction: column;
+ place-content: center;
}
div#extras_scale_to_tab div.form{
flex-direction: row;
}
-#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{
- height: 480px !important;
- max-height: 480px !important;
- min-height: 480px !important;
-}
-
#img2img_sketch, #img2maskimg, #inpaint_sketch {
overflow: overlay !important;
resize: auto;
@@ -363,12 +357,10 @@ div#extras_scale_to_tab div.form{
/* settings */
#quicksettings {
- width: fit-content;
align-items: end;
}
#quicksettings > div, #quicksettings > fieldset{
- max-width: 24em;
min-width: 24em;
padding: 0;
border: none;
@@ -417,6 +409,30 @@ table.settings-value-table td{
max-width: 36em;
}
+.ui-defaults-none{
+ color: #aaa !important;
+}
+
+#settings span{
+ color: var(--body-text-color);
+}
+
+#settings .gradio-textbox, #settings .gradio-slider, #settings .gradio-number, #settings .gradio-dropdown, #settings .gradio-checkboxgroup, #settings .gradio-radio{
+ margin-top: 0.75em;
+}
+
+#settings span .settings-comment {
+ display: inline
+}
+
+.settings-comment a{
+ text-decoration: underline;
+}
+
+.settings-comment .info{
+ opacity: 0.75;
+}
+
/* live preview */
.progressDiv{
position: relative;
diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py
index 0bfc28a0..d9e46b5e 100644
--- a/test/basic_features/utils_test.py
+++ b/test/basic_features/utils_test.py
@@ -1,62 +1,64 @@
import unittest
import requests
+
class UtilsTests(unittest.TestCase):
- def setUp(self):
- self.url_options = "http://localhost:7860/sdapi/v1/options"
- self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
- self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
- self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
- self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
- self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
- self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
- self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
- self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
- self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
+ def setUp(self):
+ self.url_options = "http://localhost:7860/sdapi/v1/options"
+ self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
+ self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
+ self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
+ self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
+ self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
+ self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
+ self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
+ self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
+ self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
- def test_options_get(self):
- self.assertEqual(requests.get(self.url_options).status_code, 200)
+ def test_options_get(self):
+ self.assertEqual(requests.get(self.url_options).status_code, 200)
- def test_options_write(self):
- response = requests.get(self.url_options)
- self.assertEqual(response.status_code, 200)
+ def test_options_write(self):
+ response = requests.get(self.url_options)
+ self.assertEqual(response.status_code, 200)
- pre_value = response.json()["send_seed"]
+ pre_value = response.json()["send_seed"]
- self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
+ self.assertEqual(requests.post(self.url_options, json={"send_seed": not pre_value}).status_code, 200)
- response = requests.get(self.url_options)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.json()["send_seed"], not pre_value)
+ response = requests.get(self.url_options)
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.json()["send_seed"], not pre_value)
- requests.post(self.url_options, json={"send_seed": pre_value})
+ requests.post(self.url_options, json={"send_seed": pre_value})
- def test_cmd_flags(self):
- self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
+ def test_cmd_flags(self):
+ self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
- def test_samplers(self):
- self.assertEqual(requests.get(self.url_samplers).status_code, 200)
+ def test_samplers(self):
+ self.assertEqual(requests.get(self.url_samplers).status_code, 200)
- def test_upscalers(self):
- self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
+ def test_upscalers(self):
+ self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
- def test_sd_models(self):
- self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
+ def test_sd_models(self):
+ self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
- def test_hypernetworks(self):
- self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
+ def test_hypernetworks(self):
+ self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
- def test_face_restorers(self):
- self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
-
- def test_realesrgan_models(self):
- self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
-
- def test_prompt_styles(self):
- self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
+ def test_face_restorers(self):
+ self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
+
+ def test_realesrgan_models(self):
+ self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
+
+ def test_prompt_styles(self):
+ self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
+
+ def test_embeddings(self):
+ self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
- def test_embeddings(self):
- self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
if __name__ == "__main__":
unittest.main()
diff --git a/webui-macos-env.sh b/webui-macos-env.sh
index 10ab81c9..6354e73b 100644
--- a/webui-macos-env.sh
+++ b/webui-macos-env.sh
@@ -11,7 +11,7 @@ fi
export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
-export TORCH_COMMAND="pip install torch torchvision"
+export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1
diff --git a/webui.py b/webui.py
index e8f0a63d..cebfba96 100644
--- a/webui.py
+++ b/webui.py
@@ -8,7 +8,7 @@ import warnings
import json
from threading import Thread
-from fastapi import FastAPI
+from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
@@ -16,12 +16,12 @@ from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors
+from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer = timer.Timer()
import torch
-import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
+import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
@@ -31,19 +31,19 @@ startup_timer.record("import torch")
import gradio
startup_timer.record("import gradio")
-import ldm.modules.encoders.modules
+import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
@@ -144,16 +144,11 @@ Use --skip-version-check commandline argument to disable this check.
""".strip())
-def initialize():
- fix_asyncio_event_loop_policy()
-
- check_versions()
-
- extensions.list_extensions()
- localization.list_localizations(cmd_opts.localizations_dir)
- startup_timer.record("list extensions")
-
+def restore_config_state_file():
config_state_file = shared.opts.restore_config_state_file
+ if config_state_file == "":
+ return
+
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
@@ -166,6 +161,18 @@ def initialize():
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
+
+def initialize():
+ fix_asyncio_event_loop_policy()
+
+ check_versions()
+
+ extensions.list_extensions()
+ localization.list_localizations(cmd_opts.localizations_dir)
+ startup_timer.record("list extensions")
+
+ restore_config_state_file()
+
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
@@ -185,7 +192,7 @@ def initialize():
startup_timer.record("load scripts")
modelloader.load_upscalers()
- startup_timer.record("load upscalers") #Is this necessary? I don't know.
+ startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
@@ -234,7 +241,10 @@ def initialize():
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
- signal.signal(signal.SIGINT, sigint_handler)
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
def setup_middleware(app):
@@ -255,19 +265,6 @@ def create_api(app):
return api
-def wait_on_server(demo=None):
- while 1:
- time.sleep(0.5)
- if shared.state.need_restart:
- shared.state.need_restart = False
- time.sleep(0.5)
- demo.close()
- time.sleep(0.5)
-
- modules.script_callbacks.app_reload_callback()
- break
-
-
def api_only():
initialize()
@@ -280,6 +277,12 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
+
+def stop_route(request):
+ shared.state.server_command = "stop"
+ return Response("Stopping.")
+
+
def webui():
launch_api = cmd_opts.api
initialize()
@@ -328,6 +331,9 @@ def webui():
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
+ if cmd_opts.add_stop_route:
+ app.add_route("/_stop", stop_route, methods=["POST"])
+
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
@@ -357,10 +363,29 @@ def webui():
if cmd_opts.subpath:
redirector = FastAPI()
redirector.get("/")
- mounted_app = gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
+ gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
- wait_on_server(shared.demo)
+ try:
+ while True:
+ server_command = shared.state.wait_for_server_command(timeout=5)
+ if server_command:
+ if server_command in ("stop", "restart"):
+ break
+ else:
+ print(f"Unknown server command: {server_command}")
+ except KeyboardInterrupt:
+ print('Caught KeyboardInterrupt, stopping...')
+ server_command = "stop"
+
+ if server_command == "stop":
+ print("Stopping server...")
+ # If we catch a keyboard interrupt, we want to stop the server and exit.
+ shared.demo.close()
+ break
print('Restarting UI...')
+ shared.demo.close()
+ time.sleep(0.5)
+ modules.script_callbacks.app_reload_callback()
startup_timer.reset()
@@ -370,18 +395,7 @@ def webui():
extensions.list_extensions()
startup_timer.record("list extensions")
- config_state_file = shared.opts.restore_config_state_file
- shared.opts.restore_config_state_file = ""
- shared.opts.save(shared.config_filename)
-
- if os.path.isfile(config_state_file):
- print(f"*** About to restore extension state from file: {config_state_file}")
- with open(config_state_file, "r", encoding="utf-8") as f:
- config_state = json.load(f)
- config_states.restore_extension_config(config_state)
- startup_timer.record("restore extension config")
- elif config_state_file:
- print(f"!!! Config state backup not found: {config_state_file}")
+ restore_config_state_file()
localization.list_localizations(cmd_opts.localizations_dir)
diff --git a/webui.sh b/webui.sh
index 113a8c1a..ab52ac3b 100755
--- a/webui.sh
+++ b/webui.sh
@@ -94,6 +94,14 @@ else
printf "\n%s\n" "${delimiter}"
fi
+if [[ $(getconf LONG_BIT) = 32 ]]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: Unsupported Running on a 32bit OS\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
if [[ -d .git ]]
then
printf "\n%s\n" "${delimiter}"
@@ -118,9 +126,8 @@ case "$gpu_info" in
esac
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then
- # Apparently now this works
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
-fi
+fi
for preq in "${GIT}" "${python_cmd}"
do