Option to use CPU for random number generation.

Makes a given manual seed generate the same images across different
platforms, independently of the GPU architecture in use.

Fixes #9613.
This commit is contained in:
Deciare 2023-04-18 23:18:58 -04:00 committed by Deciare
parent 22bcc7be42
commit d40e44ade4
4 changed files with 17 additions and 3 deletions

View File

@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape):
from modules.shared import opts
torch.manual_seed(seed)
if device.type == 'mps':
if opts.use_cpu_randn or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
if device.type == 'mps':
from modules.shared import opts
if opts.use_cpu_randn or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)

View File

@ -60,3 +60,12 @@ def store_latent(decoded):
class InterruptedException(BaseException):
pass
if opts.use_cpu_randn:
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn

View File

@ -190,7 +190,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
if x.device.type == 'mps':
if opts.use_cpu_randn or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)

View File

@ -331,6 +331,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"use_cpu_randn": OptionInfo(False, "Use CPU for random number generation to make manual seeds generate the same image across platforms. This may change existing seeds."),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {