@@ -481,7 +481,7 @@ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -
481481 """Calculate and return the combined loss for detection and segmentation."""
482482 pred_masks , proto = preds ["mask_coefficient" ].permute (0 , 2 , 1 ).contiguous (), preds ["proto" ]
483483 loss = torch .zeros (5 , device = self .device ) # box, seg, cls, dfl
484- if len (proto ) == 2 :
484+ if isinstance ( proto , tuple ) and len (proto ) == 2 :
485485 proto , pred_semseg = proto
486486 else :
487487 pred_semseg = None
@@ -490,6 +490,7 @@ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -
490490 loss [0 ], loss [2 ], loss [3 ] = det_loss [0 ], det_loss [1 ], det_loss [2 ]
491491
492492 batch_size , _ , mask_h , mask_w = proto .shape # batch size, number of masks, mask height, mask width
493+ sem_masks = batch ["sem_masks" ].to (self .device ) # NxHxW
493494 if fg_mask .sum ():
494495 # Masks loss
495496 masks = batch ["masks" ].to (self .device ).float ()
@@ -511,7 +512,6 @@ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -
511512 imgsz ,
512513 )
513514 if pred_semseg is not None :
514- sem_masks = batch ["sem_masks" ].to (self .device ) # NxHxW
515515 mask_zero = sem_masks == 0 # NxHxW
516516 sem_masks = F .one_hot (sem_masks .long (), num_classes = self .nc ).permute (0 , 3 , 1 , 2 ).float () # NxCxHxW
517517 sem_masks [mask_zero .unsqueeze (1 ).expand_as (sem_masks )] = 0
@@ -522,7 +522,6 @@ def loss(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -
522522 else :
523523 loss [1 ] += (proto * 0 ).sum () + (pred_masks * 0 ).sum () # inf sums may lead to nan loss
524524 loss [4 ] += (pred_semseg * 0 ).sum () + (sem_masks * 0 ).sum ()
525-
526525 loss [1 ] *= self .hyp .box # seg gain
527526 return loss * batch_size , loss .detach () # loss(box, cls, dfl)
528527
0 commit comments