Skip to content

Commit 2fe53f7

Browse files
committed
update
1 parent 6a55390 commit 2fe53f7

9 files changed

Lines changed: 161 additions & 84 deletions

File tree

cal4od/cal4od_helper.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def intersect(boxes1, boxes2):
268268
import matplotlib.pyplot as plt
269269

270270

271-
def draw_PIL_image(image, boxes, labels, scores, name, no=None, color='greenyellow'):
271+
def draw_PIL_image(image, boxes, labels, scores, name):
272272
if type(image) != PIL.Image.Image:
273273
image = F.to_pil_image(image)
274274
plt.imshow(image)
@@ -279,20 +279,15 @@ def draw_PIL_image(image, boxes, labels, scores, name, no=None, color='greenyell
279279
plt.margins(0, 0)
280280
plt.gca().xaxis.set_major_locator(plt.NullLocator())
281281
plt.gca().yaxis.set_major_locator(plt.NullLocator())
282-
i = 0
283-
if no is not None:
284-
for n in no:
285-
if i >= 1:
286-
color = 'greenyellow'
287-
else:
288-
color = 'red'
289-
i += 1
290-
x, y = boxes[n][0], boxes[n][1]
291-
w, h = boxes[n][2] - boxes[n][0], boxes[n][3] - boxes[n][1]
292-
plt.gca().add_patch(plt.Rectangle((x, y), w, h, fill=False, edgecolor=color, linewidth=2.5))
293-
plt.text(x, y, '{}={}'.format(voc_labels[labels[n]], scores[n]), color='color', verticalalignment='bottom',
294-
fontsize=4)
295-
plt.savefig('vis/{}.jpg'.format(name), dpi=128, bbox_inches='tight', pad_inches=0)
282+
# for i in range(len(boxes)):
283+
# x, y = boxes[i][0], boxes[i][1]
284+
# w, h = boxes[i][2] - boxes[i][0], boxes[i][3] - boxes[i][1]
285+
# plt.gca().add_patch(
286+
# plt.Rectangle((x, y), w, h, fill=False, edgecolor=label_color_map[rev_label_map[labels[i].item()]],
287+
# linewidth=2.5))
288+
# # plt.text(x, y, '{}={}'.format(voc_labels[labels[n]], scores[n]), color='color', verticalalignment='bottom',
289+
# # fontsize=4)
290+
plt.savefig('vis/{}.png'.format(name), dpi=256, bbox_inches='tight', pad_inches=0)
296291
# plt.show()
297292
plt.cla()
298293

cal4od_train.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from detection.frcnn_la import fasterrcnn_resnet50_fpn_feature
3737
from detection.retinanet_cal import retinanet_mobilenet, retinanet_resnet50_fpn_cal
3838

39+
A = 1
40+
3941

4042
def train_one_epoch(task_model, task_optimizer, data_loader, device, cycle, epoch, print_freq):
4143
task_model.train()
@@ -53,6 +55,10 @@ def train_one_epoch(task_model, task_optimizer, data_loader, device, cycle, epoc
5355
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
5456
images = list(image.to(device) for image in images)
5557
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
58+
# global A
59+
# for image, target in zip(images, targets):
60+
# draw_PIL_image(image, target['boxes'], target['labels'], None, "unlabeled_{}".format(A))
61+
# A += 1
5662
task_loss_dict = task_model(images, targets)
5763
task_losses = sum(loss for loss in task_loss_dict.values())
5864
# reduce losses over all GPUs for logging purposes
@@ -203,8 +209,8 @@ def get_uncertainty(task_model, unlabeled_loader, augs, num_cls):
203209
for output, aug_box, aug_image in zip(outputs, aug_boxes, aug_images):
204210
consistency_img = 1.0
205211
mean_img = []
206-
boxes, scores_cls, pm, labels = output['boxes'], output['scores_cls'], output['prob_max'], output[
207-
'labels']
212+
boxes, scores_cls, pm, labels, scores = output['boxes'], output['scores_cls'], output['prob_max'], \
213+
output['labels'], output['scores']
208214
cls_corr = [0] * (num_cls - 1)
209215
for p, l in zip(pm, labels):
210216
# if p.item() > 0.4:
@@ -217,7 +223,7 @@ def get_uncertainty(task_model, unlabeled_loader, augs, num_cls):
217223
continue
218224
j = 0
219225
no = []
220-
for ab, ref_score_cls, ref_pm in zip(aug_box, ref_scores_cls, prob_max):
226+
for ab, ref_score_cls, ref_pm, ref_score in zip(aug_box, ref_scores_cls, prob_max, ref_scores):
221227
width = torch.min(ab[2], boxes[:, 2]) - torch.max(ab[0], boxes[:, 0])
222228
height = torch.min(ab[3], boxes[:, 3]) - torch.max(ab[1], boxes[:, 1])
223229
Aarea = (ab[2] - ab[0]) * (ab[3] - ab[1])
@@ -232,13 +238,13 @@ def get_uncertainty(task_model, unlabeled_loader, augs, num_cls):
232238
js = 0.5 * scipy.stats.entropy(p, m) + 0.5 * scipy.stats.entropy(q, m)
233239
if js < 0:
234240
js = 0
235-
# if 0.6 > torch.max(iou) > 0.4 and 0.4 < 0.5 * (1 - js) * (
236-
# ref_pm + pm[torch.argmax(iou)]) < 0.6 and (ab[2] - ab[0]) > 200 and \
237-
# torch.max(ref_score_cls) < 0.6 and pm[torch.argmax(iou)] < 0.6:
241+
# if 0.7 < torch.max(iou) and 0.4 > 0.5 * (1 - js) * (
242+
# ref_pm + pm[torch.argmax(iou)]) and (ab[2] - ab[0]) > 100 and scores[
243+
# torch.argmax(iou)] > 0.4 and ref_score > 0.4:
238244
# # draw_PIL_image(image, boxes, ref_labels, i, no=[ab], color='greenyellow')
239245
# no = [j, torch.argmax(iou).item()]
240246
# draw_PIL_image_2(aug_image.cpu(), aug_box, boxes, ref_labels, labels, ref_scores, pm,
241-
# 'bad', no=no, color='red')
247+
# 'bad_2', no=no, color='red')
242248
# print(1 / 0)
243249
j += 1
244250
consistency_img = min(consistency_img, torch.abs(
@@ -379,7 +385,7 @@ def main(args):
379385
init_num = 500
380386
budget_num = 500
381387
if 'retina' in args.model:
382-
init_num = 500
388+
init_num = 1000
383389
budget_num = 500
384390
else:
385391
init_num = 5000
@@ -441,14 +447,23 @@ def main(args):
441447
print("Getting stability")
442448
random.shuffle(unlabeled_set)
443449
if 'coco' in args.dataset:
444-
subset = unlabeled_set[:10000]
450+
subset = unlabeled_set[:5000]
445451
else:
446452
subset = unlabeled_set
447453
if args.mutual:
448454
unlabeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(subset),
449455
num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
450456
uncertainty, _cls_corrs = get_uncertainty(task_model, unlabeled_loader, augs, num_classes)
457+
# labeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(labeled_set),
458+
# num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
459+
# u, _ = get_uncertainty(task_model, labeled_loader, augs, num_classes)
460+
# with open("vis/cal_labeled_metric_{}_{}_{}.pkl".format(args.model, args.dataset, cycle),
461+
# "wb") as fp: # Pickling
462+
# pickle.dump(u, fp)
451463
arg = np.argsort(np.array(uncertainty))
464+
# with open("vis/cal_unlabeled_metric_{}_{}_{}.pkl".format(args.model, args.dataset, cycle),
465+
# "wb") as fp: # Pickling
466+
# pickle.dump(torch.tensor(uncertainty)[arg][:budget_num].numpy(), fp)
452467
cls_corrs_set = arg[:int(args.mr * budget_num)]
453468
cls_corrs = [_cls_corrs[i] for i in cls_corrs_set]
454469
labeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(labeled_set),
@@ -457,8 +472,6 @@ def main(args):
457472
# Update the labeled dataset and the unlabeled dataset, respectively
458473
tobe_labeled_set = list(torch.tensor(subset)[arg][tobe_labeled_set].numpy())
459474
labeled_set += tobe_labeled_set
460-
with open("vis/cal_{}_{}_{}.txt".format(args.model, args.dataset, cycle), "wb") as fp: # Pickling
461-
pickle.dump(labeled_set, fp)
462475
unlabeled_set = list(set(indices) - set(labeled_set))
463476
else:
464477
unlabeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(subset),
@@ -467,6 +480,7 @@ def main(args):
467480
arg = np.argsort(np.array(uncertainty))
468481
# Update the labeled dataset and the unlabeled dataset, respectively
469482
labeled_set += list(torch.tensor(subset)[arg][:budget_num].numpy())
483+
labeled_set = list(set(labeled_set))
470484
unlabeled_set = list(set(indices) - set(labeled_set))
471485

472486
# Create a new dataloader for the updated labeled dataset
@@ -494,26 +508,35 @@ def main(args):
494508
coco_evaluate(task_model, data_loader_test)
495509
elif 'voc' in args.dataset:
496510
voc_evaluate(task_model, data_loader_test, args.dataset, False, path=args.results_path)
497-
if not args.skip and cycle == 0:
498-
if 'faster' in args.model:
499-
utils.save_on_master({
500-
'model': task_model.state_dict(), 'args': args},
501-
os.path.join(args.first_checkpoint_path, '{}_frcnn_1st.pth'.format(args.dataset)))
502-
elif 'retina' in args.model:
503-
utils.save_on_master({
504-
'model': task_model.state_dict(), 'args': args},
505-
os.path.join(args.first_checkpoint_path, '{}_retinanet_1st.pth'.format(args.dataset)))
511+
# if not args.skip and cycle == 0:
512+
# if 'faster' in args.model:
513+
# utils.save_on_master({
514+
# 'model': task_model.state_dict(), 'args': args},
515+
# os.path.join(args.first_checkpoint_path, '{}_frcnn_1st.pth'.format(args.dataset)))
516+
# elif 'retina' in args.model:
517+
# utils.save_on_master({
518+
# 'model': task_model.state_dict(), 'args': args},
519+
# os.path.join(args.first_checkpoint_path, '{}_retinanet_1st.pth'.format(args.dataset)))
506520
random.shuffle(unlabeled_set)
507521
if 'coco' in args.dataset:
508-
subset = unlabeled_set[:10000]
522+
subset = unlabeled_set[:5000]
509523
else:
510524
subset = unlabeled_set
511525
print("Getting stability")
512526
if args.mutual:
513527
unlabeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(subset),
514528
num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
515529
uncertainty, _cls_corrs = get_uncertainty(task_model, unlabeled_loader, augs, num_classes)
530+
labeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(labeled_set),
531+
num_workers=args.workers, pin_memory=True, collate_fn=utils.collate_fn)
532+
# u, _ = get_uncertainty(task_model, labeled_loader, augs, num_classes)
533+
# with open("vis/cal_labeled_metric_{}_{}_{}.pkl".format(args.model, args.dataset, cycle),
534+
# "wb") as fp: # Pickling
535+
# pickle.dump(u, fp)
516536
arg = np.argsort(np.array(uncertainty))
537+
# with open("vis/cal_unlabeled_metric_{}_{}_{}.pkl".format(args.model, args.dataset, cycle),
538+
# "wb") as fp: # Pickling
539+
# pickle.dump(torch.tensor(uncertainty)[arg][:budget_num].numpy(), fp)
517540
cls_corrs_set = arg[:int(args.mr * budget_num)]
518541
cls_corrs = [_cls_corrs[i] for i in cls_corrs_set]
519542
labeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(labeled_set),
@@ -522,8 +545,8 @@ def main(args):
522545
# Update the labeled dataset and the unlabeled dataset, respectively
523546
tobe_labeled_set = list(torch.tensor(subset)[arg][tobe_labeled_set].numpy())
524547
labeled_set += tobe_labeled_set
525-
with open("vis/cal_{}_{}_{}.txt".format(args.model, args.dataset, cycle), "wb") as fp: # Pickling
526-
pickle.dump(labeled_set, fp)
548+
# with open("vis/cal_{}_{}_{}.txt".format(args.model, args.dataset, cycle), "wb") as fp: # Pickling
549+
# pickle.dump(labeled_set, fp)
527550
unlabeled_set = list(set(indices) - set(labeled_set))
528551
else:
529552
unlabeled_loader = DataLoader(dataset_aug, batch_size=1, sampler=SubsetSequentialSampler(subset),
@@ -532,6 +555,7 @@ def main(args):
532555
arg = np.argsort(np.array(uncertainty))
533556
# Update the labeled dataset and the unlabeled dataset, respectively
534557
labeled_set += list(torch.tensor(subset)[arg][:budget_num].numpy())
558+
labeled_set = list(set(labeled_set))
535559
unlabeled_set = list(set(indices) - set(labeled_set))
536560
# Create a new dataloader for the updated labeled dataset
537561
train_sampler = SubsetRandomSampler(labeled_set)

detection/retina_ssm.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -522,27 +522,34 @@ def ssm_postprocess_detections(self, head_outputs, anchors, image_shapes):
522522
if torch.max(scores_per_image) < CONF_THRESH:
523523
# print(scores)
524524
al_idx = 1
525+
detections.append({
526+
"boxes": all_boxes,
527+
"labels": all_labels,
528+
"scores": all_scores,
529+
'al': al_idx,
530+
})
525531
continue
526532
for class_index in range(num_classes):
527533
# remove low scoring boxes
528534
inds = torch.gt(scores_per_image[:, class_index], self.score_thresh)
529-
boxes_per_class, scores_per_class, labels_per_class = \
530-
boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index]
535+
boxes_per_class, scores_per_class, scores_all_class, labels_per_class = \
536+
boxes_per_image[inds], scores_per_image[inds, class_index], scores_per_image[inds], \
537+
labels_per_image[inds, class_index]
531538
other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image]
532539

533540
# remove empty boxes
534541
keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2)
535-
boxes_per_class, scores_per_class, labels_per_class = \
536-
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
542+
boxes_per_class, scores_per_class, scores_all_class, labels_per_class = \
543+
boxes_per_class[keep], scores_per_class[keep], scores_all_class[keep], labels_per_class[keep]
537544
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
538545

539546
# non-maximum suppression, independently done per class
540547
keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh)
541548

542549
# keep only topk scoring predictions
543550
keep = keep[:self.detections_per_img]
544-
boxes_per_class, scores_per_class, labels_per_class = \
545-
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
551+
boxes_per_class, scores_per_class, scores_all_class, labels_per_class = \
552+
boxes_per_class[keep], scores_per_class[keep], scores_all_class[keep], labels_per_class[keep]
546553
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
547554

548555
image_boxes.append(boxes_per_class)
@@ -553,19 +560,20 @@ def ssm_postprocess_detections(self, head_outputs, anchors, image_shapes):
553560
if k not in image_other_outputs:
554561
image_other_outputs[k] = []
555562
image_other_outputs[k].append(v)
556-
for i in len(boxes_per_class):
563+
564+
for i in range(len(boxes_per_class)):
557565
all_boxes = torch.cat((all_boxes, boxes_per_class[i].unsqueeze(0)), 0)
558566
all_scores = torch.cat((all_scores, scores_per_class[i].unsqueeze(0)), 0)
559-
all_labels.append(judge_y(scores_per_class[i]))
560-
567+
all_labels.append(judge_y(scores_all_class[i][1:]))
568+
detections.append({
569+
"boxes": all_boxes,
570+
"labels": all_labels,
571+
"scores": all_scores,
572+
'al': al_idx,
573+
})
561574
for k, v in image_other_outputs.items():
562575
detections[-1].update({k: torch.cat(v, dim=0)})
563-
detections.append({
564-
"boxes": all_boxes,
565-
"labels": all_labels,
566-
"scores": all_scores,
567-
'al': al_idx,
568-
})
576+
569577
return detections
570578

571579
def forward(self, images, targets=None):
@@ -646,7 +654,6 @@ def forward(self, images, targets=None):
646654
# print(self.ssm)
647655
if self.ssm:
648656
detections = self.ssm_postprocess_detections(head_outputs, anchors, images.image_sizes)
649-
print(detections)
650657
else:
651658
detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
652659
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

0 commit comments

Comments
 (0)