Skip to content

Commit 25d9ccc

Browse files
Any-Winter-4079lstein
authored andcommitted
Update model.py
1 parent 9cdf3ac commit 25d9ccc

1 file changed

Lines changed: 14 additions & 10 deletions

File tree

ldm/modules/diffusionmodules/model.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -210,25 +210,29 @@ def forward(self, x):
210210
h_ = torch.zeros_like(k, device=q.device)
211211

212212
device_type = 'mps' if q.device.type == 'mps' else 'cuda'
213-
214-
if device_type == 'mps':
215-
mem_free_total = psutil.virtual_memory().available
216-
else:
213+
if device_type == 'cuda':
217214
stats = torch.cuda.memory_stats(q.device)
218215
mem_active = stats['active_bytes.all.current']
219216
mem_reserved = stats['reserved_bytes.all.current']
220217
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
221218
mem_free_torch = mem_reserved - mem_active
222219
mem_free_total = mem_free_cuda + mem_free_torch
223220

224-
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
225-
mem_required = tensor_size * 2.5
226-
steps = 1
221+
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
222+
mem_required = tensor_size * 2.5
223+
steps = 1
227224

228-
if mem_required > mem_free_total:
229-
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
225+
if mem_required > mem_free_total:
226+
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
227+
228+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
230229

231-
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
230+
else:
231+
if psutil.virtual_memory().available / (1024**3) < 12:
232+
slice_size = 1
233+
else:
234+
slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1])))
235+
232236
for i in range(0, q.shape[1], slice_size):
233237
end = i + slice_size
234238

0 commit comments

Comments
 (0)