Use fixed size for sub-quadratic chunking on MPS

Even if this causes chunks to be much smaller, performance isn't significantly impacted. This will usually reduce memory usage but should also help with poor performance when free memory is low.
This commit is contained in:
brkirch 2023-05-08 18:16:01 -04:00
parent 3163d1269a
commit abfa4ad8bc

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import math import math
import psutil import psutil
import platform
import torch import torch
from torch import einsum from torch import einsum
@ -427,7 +428,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None: if chunk_threshold is None:
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) if q.device.type == 'mps':
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
else:
chunk_threshold_bytes = int(get_available_vram() * 0.7)
elif chunk_threshold == 0: elif chunk_threshold == 0:
chunk_threshold_bytes = None chunk_threshold_bytes = None
else: else: