1818from torch .jit .annotations import Optional , List , Dict , Tuple
1919from torchvision .ops import boxes as box_ops
2020from torchvision .models .utils import load_state_dict_from_url
21+ import torch
2122
2223model_urls = {
2324 'fasterrcnn_resnet50_fpn_coco' :
@@ -63,8 +64,81 @@ def _fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
6364 return classification_loss , box_loss
6465
6566
67+ def judge_y (score ):
68+ '''return :
69+ y:np.array len(score)
70+ '''
71+ y = []
72+ for s in score :
73+ if s == 1 or torch .log (s ) > torch .log (1 - s ):
74+ y .append (1 )
75+ else :
76+ y .append (- 1 )
77+ return y
78+
79+
6680class RoIHeads (_RoIHeads ):
67- def forward (self , features , proposals , image_shapes , targets = None ):
81+
82+ def ssm_postprocess_detections (self , class_logits , box_regression , proposals , image_shapes ):
83+ device = class_logits .device
84+ num_classes = class_logits .shape [- 1 ]
85+
86+ boxes_per_image = [len (boxes_in_image ) for boxes_in_image in proposals ]
87+ pred_boxes = self .box_coder .decode (box_regression , proposals )
88+
89+ pred_scores = F .softmax (class_logits , - 1 )
90+
91+ # split boxes and scores per image
92+ pred_boxes = pred_boxes .split (boxes_per_image , 0 )
93+ pred_scores = pred_scores .split (boxes_per_image , 0 )
94+ al_idx = 0
95+ all_boxes = torch .empty ([0 , 4 ]).cuda ()
96+ all_scores = torch .tensor ([]).cuda ()
97+ all_labels = []
98+ CONF_THRESH = 0.1 # bigger leads more active learning samples
99+ for boxes , scores , image_shape in zip (pred_boxes , pred_scores , image_shapes ):
100+ boxes = box_ops .clip_boxes_to_image (boxes , image_shape )
101+ # create labels for each prediction
102+ labels = torch .arange (num_classes , device = device )
103+ labels = labels .view (1 , - 1 ).expand_as (scores )
104+
105+ # remove predictions with the background label
106+ boxes = boxes [:, 1 :]
107+ scores = scores [:, 1 :]
108+ labels = labels [:, 1 :]
109+ if torch .max (scores ) < CONF_THRESH :
110+ al_idx = 1
111+ continue
112+ for cls_ind in range (num_classes - 1 ):
113+ cls_boxes = boxes [:, cls_ind ]
114+ cls_scores = scores [:, cls_ind ]
115+ cls_labels = labels [:, cls_ind ]
116+ # batch everything, by making every class prediction be a separate instance
117+ cls_boxes = cls_boxes .reshape (- 1 , 4 )
118+ cls_scores = cls_scores .flatten ()
119+ cls_labels = cls_labels .flatten ()
120+
121+ # remove low scoring boxes
122+
123+ # non-maximum suppression, independently done per class
124+ keep = box_ops .batched_nms (cls_boxes , cls_scores , cls_labels , self .nms_thresh )
125+ # keep only topk scoring predictions
126+ keep = keep [:self .detections_per_img ]
127+ cls_boxes , cls_scores , cls_labels = cls_boxes [keep ], cls_scores [keep ], cls_labels [keep ]
128+ inds = torch .nonzero (cls_scores > self .score_thresh ).squeeze (1 )
129+ if len (inds ) == 0 :
130+ continue
131+ for j in inds :
132+ # boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
133+
134+ all_boxes = torch .cat ((all_boxes , cls_boxes [j ].unsqueeze (0 )), 0 )
135+ k = keep [j ]
136+ all_scores = torch .cat ((all_scores , scores [k ].unsqueeze (0 )), 0 )
137+ all_labels .append (judge_y (scores [k ]))
138+ # all_scores = [torch.cat(all_scores, 1)]
139+ return [all_boxes ], [all_scores ], [all_labels ], al_idx
140+
141+ def forward (self , features , proposals , image_shapes , ssm , targets = None ):
68142 # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
69143 """
70144 Arguments:
@@ -101,106 +175,31 @@ def forward(self, features, proposals, image_shapes, targets=None):
101175 "loss_box_reg" : loss_box_reg
102176 }
103177 else :
104- boxes , scores , labels = self .postprocess_detections (class_logits , box_regression , proposals , image_shapes )
105- num_images = len (boxes )
106- for i in range (num_images ):
107- result .append (
108- {
109- "boxes" : boxes [i ],
110- "labels" : labels [i ],
111- "scores" : scores [i ],
112- }
113- )
114-
115- if self .has_mask ():
116- mask_proposals = [p ["boxes" ] for p in result ]
117- if self .training :
118- assert matched_idxs is not None
119- # during training, only focus on positive boxes
120- num_images = len (proposals )
121- mask_proposals = []
122- pos_matched_idxs = []
123- for img_id in range (num_images ):
124- pos = torch .nonzero (labels [img_id ] > 0 ).squeeze (1 )
125- mask_proposals .append (proposals [img_id ][pos ])
126- pos_matched_idxs .append (matched_idxs [img_id ][pos ])
127- else :
128- pos_matched_idxs = None
129-
130- if self .mask_roi_pool is not None :
131- mask_features = self .mask_roi_pool (features , mask_proposals , image_shapes )
132- mask_features = self .mask_head (mask_features )
133- mask_logits = self .mask_predictor (mask_features )
134- else :
135- mask_logits = torch .tensor (0 )
136- raise Exception ("Expected mask_roi_pool to be not None" )
137-
138- loss_mask = {}
139- if self .training :
140- assert targets is not None
141- assert pos_matched_idxs is not None
142- assert mask_logits is not None
143-
144- gt_masks = [t ["masks" ] for t in targets ]
145- gt_labels = [t ["labels" ] for t in targets ]
146- rcnn_loss_mask = maskrcnn_loss (
147- mask_logits , mask_proposals ,
148- gt_masks , gt_labels , pos_matched_idxs )
149- loss_mask = {
150- "loss_mask" : rcnn_loss_mask
151- }
178+ if ssm :
179+ boxes , scores , labels , al = self .ssm_postprocess_detections (class_logits , box_regression ,
180+ proposals , image_shapes )
181+ num_images = len (boxes )
182+ for i in range (num_images ):
183+ result .append (
184+ {
185+ "boxes" : boxes [i ],
186+ "labels" : labels [i ],
187+ "scores" : scores [i ],
188+ 'al' : al ,
189+ }
190+ )
152191 else :
153- labels = [r ["labels" ] for r in result ]
154- masks_probs = maskrcnn_inference (mask_logits , labels )
155- for mask_prob , r in zip (masks_probs , result ):
156- r ["masks" ] = mask_prob
157-
158- losses .update (loss_mask )
159-
160- # keep none checks in if conditional so torchscript will conditionally
161- # compile each branch
162- if self .keypoint_roi_pool is not None and self .keypoint_head is not None \
163- and self .keypoint_predictor is not None :
164- keypoint_proposals = [p ["boxes" ] for p in result ]
165- if self .training :
166- # during training, only focus on positive boxes
167- num_images = len (proposals )
168- keypoint_proposals = []
169- pos_matched_idxs = []
170- assert matched_idxs is not None
171- for img_id in range (num_images ):
172- pos = torch .nonzero (labels [img_id ] > 0 ).squeeze (1 )
173- keypoint_proposals .append (proposals [img_id ][pos ])
174- pos_matched_idxs .append (matched_idxs [img_id ][pos ])
175- else :
176- pos_matched_idxs = None
177-
178- keypoint_features = self .keypoint_roi_pool (features , keypoint_proposals , image_shapes )
179- keypoint_features = self .keypoint_head (keypoint_features )
180- keypoint_logits = self .keypoint_predictor (keypoint_features )
181-
182- loss_keypoint = {}
183- if self .training :
184- assert targets is not None
185- assert pos_matched_idxs is not None
186-
187- gt_keypoints = [t ["keypoints" ] for t in targets ]
188- rcnn_loss_keypoint = keypointrcnn_loss (
189- keypoint_logits , keypoint_proposals ,
190- gt_keypoints , pos_matched_idxs )
191- loss_keypoint = {
192- "loss_keypoint" : rcnn_loss_keypoint
193- }
194- else :
195- assert keypoint_logits is not None
196- assert keypoint_proposals is not None
197-
198- keypoints_probs , kp_scores = keypointrcnn_inference (keypoint_logits , keypoint_proposals )
199- for keypoint_prob , kps , r in zip (keypoints_probs , kp_scores , result ):
200- r ["keypoints" ] = keypoint_prob
201- r ["keypoints_scores" ] = kps
202-
203- losses .update (loss_keypoint )
192+ boxes , scores , labels = self .postprocess_detections (class_logits , box_regression ,
193+ proposals , image_shapes )
194+ num_images = len (boxes )
195+ for i in range (num_images ):
196+ result .append (
197+ {
198+ "boxes" : boxes [i ],
199+ "labels" : labels [i ],
200+ "scores" : scores [i ],
201+ }
202+ )
204203
205204 return result , losses
206205
@@ -561,9 +560,7 @@ def __init__(self, backbone, num_classes=None,
561560 box_predictor = FastRCNNPredictor (
562561 representation_size ,
563562 num_classes )
564-
565- roi_heads = RoIHeads (
566- # Box
563+ roi_heads = RoIHeads ( # Box
567564 box_roi_pool , box_head , box_predictor ,
568565 box_fg_iou_thresh , box_bg_iou_thresh ,
569566 box_batch_size_per_image , box_positive_fraction ,
@@ -575,9 +572,12 @@ def __init__(self, backbone, num_classes=None,
575572 if image_std is None :
576573 image_std = [0.229 , 0.224 , 0.225 ]
577574 transform = GeneralizedRCNNTransform (min_size , max_size , image_mean , image_std )
578-
575+ self . ssm = False
579576 super (FasterRCNN , self ).__init__ (backbone , rpn , roi_heads , transform )
580577
578+ def ssm_mode (self , ssm ):
579+ self .ssm = ssm
580+
581581 def forward (self , images , targets = None ):
582582 # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
583583 """
@@ -603,7 +603,8 @@ def forward(self, images, targets=None):
603603 if isinstance (features , torch .Tensor ):
604604 features = OrderedDict ([('0' , features )])
605605 proposals , proposal_losses = self .rpn (images , features , targets )
606- detections , detector_losses = self .roi_heads (features , proposals , images .image_sizes , targets )
606+ detections , detector_losses = self .roi_heads (features , proposals , images .image_sizes , self .ssm , targets )
607+ # if not len(detections) == 0:
607608 detections = self .transform .postprocess (detections , images .image_sizes , original_image_sizes )
608609
609610 losses = {}
@@ -627,7 +628,6 @@ def fasterrcnn_resnet50_fpn_feature(pretrained=False, progress=True,
627628 backbone = resnet_fpn_backbone ('resnet50' , pretrained_backbone )
628629 model = FRCNN_Feature (backbone , num_classes , ** kwargs )
629630 if pretrained :
630- print (model_urls .keys ())
631631 state_dict = load_state_dict_from_url (model_urls ['fasterrcnn_resnet50_fpn_coco' ],
632632 progress = progress )
633633 model .load_state_dict (state_dict )
0 commit comments