mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 06:28:12 +03:00
add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards.
This commit is contained in:
parent
ccb9233934
commit
84b6fcd02c
@ -3,7 +3,7 @@ import contextlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors
|
||||
from modules import errors, rng_philox
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
@ -90,23 +90,58 @@ def cond_cast_float(input):
|
||||
return input.float() if unet_needs_upcast else input
|
||||
|
||||
|
||||
nv_rng = None
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
from modules.shared import opts
|
||||
|
||||
torch.manual_seed(seed)
|
||||
manual_seed(seed)
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
return torch.asarray(nv_rng.randn(shape), device=device)
|
||||
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_like(x):
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
||||
|
||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=cpu).to(x.device)
|
||||
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
return torch.asarray(nv_rng.randn(shape), device=device)
|
||||
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def manual_seed(seed):
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "NV":
|
||||
global nv_rng
|
||||
nv_rng = rng_philox.Generator(seed)
|
||||
return
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
from modules import shared
|
||||
|
||||
|
@ -492,7 +492,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
||||
|
||||
subnoise = None
|
||||
if subseeds is not None:
|
||||
if subseeds is not None and subseed_strength != 0:
|
||||
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
||||
|
||||
subnoise = devices.randn(subseed, noise_shape)
|
||||
@ -524,7 +524,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
cnt = p.sampler.number_of_needed_noises(p)
|
||||
|
||||
if eta_noise_seed_delta > 0:
|
||||
torch.manual_seed(seed + eta_noise_seed_delta)
|
||||
devices.manual_seed(seed + eta_noise_seed_delta)
|
||||
|
||||
for j in range(cnt):
|
||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
||||
@ -636,7 +636,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
||||
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||
"RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
|
||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||
**p.extra_generation_params,
|
||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||
|
100
modules/rng_philox.py
Normal file
100
modules/rng_philox.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""RNG imitiating torch cuda randn on CPU. You are welcome.
|
||||
|
||||
Usage:
|
||||
|
||||
```
|
||||
g = Generator(seed=0)
|
||||
print(g.randn(shape=(3, 4)))
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
|
||||
[-0.12086647 -0.57972564 -0.62285122 -0.32838709]
|
||||
[-1.07454231 -0.36314407 -1.67105067 2.26550497]]
|
||||
```
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
philox_m = [0xD2511F53, 0xCD9E8D57]
|
||||
philox_w = [0x9E3779B9, 0xBB67AE85]
|
||||
|
||||
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
|
||||
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
|
||||
|
||||
|
||||
def uint32(x):
|
||||
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
|
||||
return np.moveaxis(x.view(np.uint32).reshape(-1, 2), 0, 1)
|
||||
|
||||
|
||||
def philox4_round(counter, key):
|
||||
"""A single round of the Philox 4x32 random number generator."""
|
||||
|
||||
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
|
||||
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
|
||||
|
||||
counter[0] = v2[1] ^ counter[1] ^ key[0]
|
||||
counter[1] = v2[0]
|
||||
counter[2] = v1[1] ^ counter[3] ^ key[1]
|
||||
counter[3] = v1[0]
|
||||
|
||||
|
||||
def philox4_32(counter, key, rounds=10):
|
||||
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
|
||||
|
||||
Parameters:
|
||||
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
|
||||
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
|
||||
rounds (int): The number of rounds to perform.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
|
||||
"""
|
||||
|
||||
for _ in range(rounds - 1):
|
||||
philox4_round(counter, key)
|
||||
|
||||
key[0] = key[0] + philox_w[0]
|
||||
key[1] = key[1] + philox_w[1]
|
||||
|
||||
philox4_round(counter, key)
|
||||
return counter
|
||||
|
||||
|
||||
def box_muller(x, y):
|
||||
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
|
||||
u = x.astype(np.float32) * two_pow32_inv + two_pow32_inv / 2
|
||||
v = y.astype(np.float32) * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
|
||||
|
||||
s = np.sqrt(-2.0 * np.log(u))
|
||||
|
||||
r1 = s * np.sin(v)
|
||||
return r1.astype(np.float32)
|
||||
|
||||
|
||||
class Generator:
|
||||
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
|
||||
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
self.offset = 0
|
||||
|
||||
def randn(self, shape):
|
||||
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
|
||||
|
||||
n = 1
|
||||
for x in shape:
|
||||
n *= x
|
||||
|
||||
counter = np.zeros((4, n), dtype=np.uint32)
|
||||
counter[0] = self.offset
|
||||
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
|
||||
self.offset += 1
|
||||
|
||||
key = uint32(np.array([[self.seed] * n], dtype=np.uint64))
|
||||
|
||||
g = philox4_32(counter, key)
|
||||
|
||||
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
|
@ -260,10 +260,7 @@ class TorchHijack:
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
|
@ -428,7 +428,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
|
Loading…
Reference in New Issue
Block a user