diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 0e810eec..b3e71270 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ from __future__ import annotations import math import psutil +import platform import torch from torch import einsum @@ -427,7 +428,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens if chunk_threshold is None: - chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + if q.device.type == 'mps': + chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) + else: + chunk_threshold_bytes = int(get_available_vram() * 0.7) elif chunk_threshold == 0: chunk_threshold_bytes = None else: