Skip to content

Commit e2b567d

Browse files
authored
Fix undefined sem_masks error and incorrect proto unwrap (ultralytics#23197)
1 parent eee8259 commit e2b567d

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

ultralytics/utils/loss.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -
481481
"""Calculate and return the combined loss for detection and segmentation."""
482482
pred_masks, proto = preds["mask_coefficient"].permute(0, 2, 1).contiguous(), preds["proto"]
483483
loss = torch.zeros(5, device=self.device) # box, seg, cls, dfl
484-
if len(proto) == 2:
484+
if isinstance(proto, tuple) and len(proto) == 2:
485485
proto, pred_semseg = proto
486486
else:
487487
pred_semseg = None
@@ -490,6 +490,7 @@ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -
490490
loss[0], loss[2], loss[3] = det_loss[0], det_loss[1], det_loss[2]
491491

492492
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
493+
sem_masks = batch["sem_masks"].to(self.device) # NxHxW
493494
if fg_mask.sum():
494495
# Masks loss
495496
masks = batch["masks"].to(self.device).float()
@@ -511,7 +512,6 @@ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -
511512
imgsz,
512513
)
513514
if pred_semseg is not None:
514-
sem_masks = batch["sem_masks"].to(self.device) # NxHxW
515515
mask_zero = sem_masks == 0 # NxHxW
516516
sem_masks = F.one_hot(sem_masks.long(), num_classes=self.nc).permute(0, 3, 1, 2).float() # NxCxHxW
517517
sem_masks[mask_zero.unsqueeze(1).expand_as(sem_masks)] = 0
@@ -522,7 +522,6 @@ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -
522522
else:
523523
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
524524
loss[4] += (pred_semseg * 0).sum() + (sem_masks * 0).sum()
525-
526525
loss[1] *= self.hyp.box # seg gain
527526
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
528527

0 commit comments

Comments
 (0)