mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-15 07:03:06 +03:00
torch.empty
can create issues; use torch.zeros
For MPS, using a tensor created with `torch.empty()` can cause `torch.baddbmm()` to include NaNs in the tensor it returns, even though `beta=0`. However, with a tensor of shape [1,1,1], there should be a negligible performance difference between `torch.empty()` and `torch.zeros()` anyway, so it's better to just use `torch.zeros()` for this and avoid unnecessarily creating issues.
This commit is contained in:
parent
87dd685224
commit
2489252099
@ -58,7 +58,7 @@ def _summarize_chunk(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> AttnChunk:
|
) -> AttnChunk:
|
||||||
attn_weights = torch.baddbmm(
|
attn_weights = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
attn_scores = torch.baddbmm(
|
attn_scores = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
|
Loading…
Reference in New Issue
Block a user