@@ -210,25 +210,29 @@ def forward(self, x):
210210 h_ = torch .zeros_like (k , device = q .device )
211211
212212 device_type = 'mps' if q .device .type == 'mps' else 'cuda'
213-
214- if device_type == 'mps' :
215- mem_free_total = psutil .virtual_memory ().available
216- else :
213+ if device_type == 'cuda' :
217214 stats = torch .cuda .memory_stats (q .device )
218215 mem_active = stats ['active_bytes.all.current' ]
219216 mem_reserved = stats ['reserved_bytes.all.current' ]
220217 mem_free_cuda , _ = torch .cuda .mem_get_info (torch .cuda .current_device ())
221218 mem_free_torch = mem_reserved - mem_active
222219 mem_free_total = mem_free_cuda + mem_free_torch
223220
224- tensor_size = q .shape [0 ] * q .shape [1 ] * k .shape [2 ] * 4
225- mem_required = tensor_size * 2.5
226- steps = 1
221+ tensor_size = q .shape [0 ] * q .shape [1 ] * k .shape [2 ] * 4
222+ mem_required = tensor_size * 2.5
223+ steps = 1
227224
228- if mem_required > mem_free_total :
229- steps = 2 ** (math .ceil (math .log (mem_required / mem_free_total , 2 )))
225+ if mem_required > mem_free_total :
226+ steps = 2 ** (math .ceil (math .log (mem_required / mem_free_total , 2 )))
227+
228+ slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
230229
231- slice_size = q .shape [1 ] // steps if (q .shape [1 ] % steps ) == 0 else q .shape [1 ]
230+ else :
231+ if psutil .virtual_memory ().available / (1024 ** 3 ) < 12 :
232+ slice_size = 1
233+ else :
234+ slice_size = min (q .shape [1 ], math .floor (2 ** 30 / (q .shape [0 ] * q .shape [1 ])))
235+
232236 for i in range (0 , q .shape [1 ], slice_size ):
233237 end = i + slice_size
234238
0 commit comments