Skip to content

Commit 52c5bd3

Browse files
committed
ssm
1 parent a4d24dc commit 52c5bd3

8 files changed

Lines changed: 567 additions & 125 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ data/*
2222
*/**/**/*.pyc
2323
*/**/**/**/*.pyc
2424
*/**/**/**/**/*.pyc
25+
vis/
2526
aten/build/
2627
aten/src/ATen/Config.h
2728
aten/src/ATen/cuda/CUDAConfig.h

detection/frcnn_feature.py

Lines changed: 106 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch.jit.annotations import Optional, List, Dict, Tuple
1919
from torchvision.ops import boxes as box_ops
2020
from torchvision.models.utils import load_state_dict_from_url
21+
import torch
2122

2223
model_urls = {
2324
'fasterrcnn_resnet50_fpn_coco':
@@ -63,8 +64,81 @@ def _fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
6364
return classification_loss, box_loss
6465

6566

67+
def judge_y(score):
68+
'''return :
69+
y:np.array len(score)
70+
'''
71+
y = []
72+
for s in score:
73+
if s == 1 or torch.log(s) > torch.log(1 - s):
74+
y.append(1)
75+
else:
76+
y.append(-1)
77+
return y
78+
79+
6680
class RoIHeads(_RoIHeads):
67-
def forward(self, features, proposals, image_shapes, targets=None):
81+
82+
def ssm_postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
83+
device = class_logits.device
84+
num_classes = class_logits.shape[-1]
85+
86+
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
87+
pred_boxes = self.box_coder.decode(box_regression, proposals)
88+
89+
pred_scores = F.softmax(class_logits, -1)
90+
91+
# split boxes and scores per image
92+
pred_boxes = pred_boxes.split(boxes_per_image, 0)
93+
pred_scores = pred_scores.split(boxes_per_image, 0)
94+
al_idx = 0
95+
all_boxes = torch.empty([0, 4]).cuda()
96+
all_scores = torch.tensor([]).cuda()
97+
all_labels = []
98+
CONF_THRESH = 0.1 # bigger leads more active learning samples
99+
for boxes, scores, image_shape in zip(pred_boxes, pred_scores, image_shapes):
100+
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
101+
# create labels for each prediction
102+
labels = torch.arange(num_classes, device=device)
103+
labels = labels.view(1, -1).expand_as(scores)
104+
105+
# remove predictions with the background label
106+
boxes = boxes[:, 1:]
107+
scores = scores[:, 1:]
108+
labels = labels[:, 1:]
109+
if torch.max(scores) < CONF_THRESH:
110+
al_idx = 1
111+
continue
112+
for cls_ind in range(num_classes - 1):
113+
cls_boxes = boxes[:, cls_ind]
114+
cls_scores = scores[:, cls_ind]
115+
cls_labels = labels[:, cls_ind]
116+
# batch everything, by making every class prediction be a separate instance
117+
cls_boxes = cls_boxes.reshape(-1, 4)
118+
cls_scores = cls_scores.flatten()
119+
cls_labels = cls_labels.flatten()
120+
121+
# remove low scoring boxes
122+
123+
# non-maximum suppression, independently done per class
124+
keep = box_ops.batched_nms(cls_boxes, cls_scores, cls_labels, self.nms_thresh)
125+
# keep only topk scoring predictions
126+
keep = keep[:self.detections_per_img]
127+
cls_boxes, cls_scores, cls_labels = cls_boxes[keep], cls_scores[keep], cls_labels[keep]
128+
inds = torch.nonzero(cls_scores > self.score_thresh).squeeze(1)
129+
if len(inds) == 0:
130+
continue
131+
for j in inds:
132+
# boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
133+
134+
all_boxes = torch.cat((all_boxes, cls_boxes[j].unsqueeze(0)), 0)
135+
k = keep[j]
136+
all_scores = torch.cat((all_scores, scores[k].unsqueeze(0)), 0)
137+
all_labels.append(judge_y(scores[k]))
138+
# all_scores = [torch.cat(all_scores, 1)]
139+
return [all_boxes], [all_scores], [all_labels], al_idx
140+
141+
def forward(self, features, proposals, image_shapes, ssm, targets=None):
68142
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
69143
"""
70144
Arguments:
@@ -101,106 +175,31 @@ def forward(self, features, proposals, image_shapes, targets=None):
101175
"loss_box_reg": loss_box_reg
102176
}
103177
else:
104-
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
105-
num_images = len(boxes)
106-
for i in range(num_images):
107-
result.append(
108-
{
109-
"boxes": boxes[i],
110-
"labels": labels[i],
111-
"scores": scores[i],
112-
}
113-
)
114-
115-
if self.has_mask():
116-
mask_proposals = [p["boxes"] for p in result]
117-
if self.training:
118-
assert matched_idxs is not None
119-
# during training, only focus on positive boxes
120-
num_images = len(proposals)
121-
mask_proposals = []
122-
pos_matched_idxs = []
123-
for img_id in range(num_images):
124-
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
125-
mask_proposals.append(proposals[img_id][pos])
126-
pos_matched_idxs.append(matched_idxs[img_id][pos])
127-
else:
128-
pos_matched_idxs = None
129-
130-
if self.mask_roi_pool is not None:
131-
mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
132-
mask_features = self.mask_head(mask_features)
133-
mask_logits = self.mask_predictor(mask_features)
134-
else:
135-
mask_logits = torch.tensor(0)
136-
raise Exception("Expected mask_roi_pool to be not None")
137-
138-
loss_mask = {}
139-
if self.training:
140-
assert targets is not None
141-
assert pos_matched_idxs is not None
142-
assert mask_logits is not None
143-
144-
gt_masks = [t["masks"] for t in targets]
145-
gt_labels = [t["labels"] for t in targets]
146-
rcnn_loss_mask = maskrcnn_loss(
147-
mask_logits, mask_proposals,
148-
gt_masks, gt_labels, pos_matched_idxs)
149-
loss_mask = {
150-
"loss_mask": rcnn_loss_mask
151-
}
178+
if ssm:
179+
boxes, scores, labels, al = self.ssm_postprocess_detections(class_logits, box_regression,
180+
proposals, image_shapes)
181+
num_images = len(boxes)
182+
for i in range(num_images):
183+
result.append(
184+
{
185+
"boxes": boxes[i],
186+
"labels": labels[i],
187+
"scores": scores[i],
188+
'al': al,
189+
}
190+
)
152191
else:
153-
labels = [r["labels"] for r in result]
154-
masks_probs = maskrcnn_inference(mask_logits, labels)
155-
for mask_prob, r in zip(masks_probs, result):
156-
r["masks"] = mask_prob
157-
158-
losses.update(loss_mask)
159-
160-
# keep none checks in if conditional so torchscript will conditionally
161-
# compile each branch
162-
if self.keypoint_roi_pool is not None and self.keypoint_head is not None \
163-
and self.keypoint_predictor is not None:
164-
keypoint_proposals = [p["boxes"] for p in result]
165-
if self.training:
166-
# during training, only focus on positive boxes
167-
num_images = len(proposals)
168-
keypoint_proposals = []
169-
pos_matched_idxs = []
170-
assert matched_idxs is not None
171-
for img_id in range(num_images):
172-
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
173-
keypoint_proposals.append(proposals[img_id][pos])
174-
pos_matched_idxs.append(matched_idxs[img_id][pos])
175-
else:
176-
pos_matched_idxs = None
177-
178-
keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
179-
keypoint_features = self.keypoint_head(keypoint_features)
180-
keypoint_logits = self.keypoint_predictor(keypoint_features)
181-
182-
loss_keypoint = {}
183-
if self.training:
184-
assert targets is not None
185-
assert pos_matched_idxs is not None
186-
187-
gt_keypoints = [t["keypoints"] for t in targets]
188-
rcnn_loss_keypoint = keypointrcnn_loss(
189-
keypoint_logits, keypoint_proposals,
190-
gt_keypoints, pos_matched_idxs)
191-
loss_keypoint = {
192-
"loss_keypoint": rcnn_loss_keypoint
193-
}
194-
else:
195-
assert keypoint_logits is not None
196-
assert keypoint_proposals is not None
197-
198-
keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
199-
for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
200-
r["keypoints"] = keypoint_prob
201-
r["keypoints_scores"] = kps
202-
203-
losses.update(loss_keypoint)
192+
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression,
193+
proposals, image_shapes)
194+
num_images = len(boxes)
195+
for i in range(num_images):
196+
result.append(
197+
{
198+
"boxes": boxes[i],
199+
"labels": labels[i],
200+
"scores": scores[i],
201+
}
202+
)
204203

205204
return result, losses
206205

@@ -561,9 +560,7 @@ def __init__(self, backbone, num_classes=None,
561560
box_predictor = FastRCNNPredictor(
562561
representation_size,
563562
num_classes)
564-
565-
roi_heads = RoIHeads(
566-
# Box
563+
roi_heads = RoIHeads( # Box
567564
box_roi_pool, box_head, box_predictor,
568565
box_fg_iou_thresh, box_bg_iou_thresh,
569566
box_batch_size_per_image, box_positive_fraction,
@@ -575,9 +572,12 @@ def __init__(self, backbone, num_classes=None,
575572
if image_std is None:
576573
image_std = [0.229, 0.224, 0.225]
577574
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
578-
575+
self.ssm = False
579576
super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform)
580577

578+
def ssm_mode(self, ssm):
579+
self.ssm = ssm
580+
581581
def forward(self, images, targets=None):
582582
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
583583
"""
@@ -603,7 +603,8 @@ def forward(self, images, targets=None):
603603
if isinstance(features, torch.Tensor):
604604
features = OrderedDict([('0', features)])
605605
proposals, proposal_losses = self.rpn(images, features, targets)
606-
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
606+
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, self.ssm, targets)
607+
# if not len(detections) == 0:
607608
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
608609

609610
losses = {}
@@ -627,7 +628,6 @@ def fasterrcnn_resnet50_fpn_feature(pretrained=False, progress=True,
627628
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
628629
model = FRCNN_Feature(backbone, num_classes, **kwargs)
629630
if pretrained:
630-
print(model_urls.keys())
631631
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
632632
progress=progress)
633633
model.load_state_dict(state_dict)

detection/group_by_aspect_ratio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,6 @@ def create_aspect_ratio_groups(dataset, k=0):
190190
# count number of elements per group
191191
counts = np.unique(groups, return_counts=True)[1]
192192
fbins = [0] + bins + [np.inf]
193-
print("Using {} as bins for aspect ratio quantization".format(fbins))
194-
print("Count of instances per bin: {}".format(counts))
193+
# print("Using {} as bins for aspect ratio quantization".format(fbins))
194+
# print("Count of instances per bin: {}".format(counts))
195195
return groups

ll4al/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@
5252
T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
5353
# T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) # CIFAR-100
5454
])
55-
56-
cifar10_train = CIFAR10('/home/omnisky/ywp/data', train=True, download=True, transform=train_transform)
57-
cifar10_unlabeled = CIFAR10('/home/omnisky/ywp/data', train=True, download=True, transform=test_transform)
58-
cifar10_test = CIFAR10('/home/omnisky/ywp/data', train=False, download=True, transform=test_transform)
55+
#
56+
# cifar10_train = CIFAR10('/home/omnisky/ywp/data', train=True, download=True, transform=train_transform)
57+
# cifar10_unlabeled = CIFAR10('/home/omnisky/ywp/data', train=True, download=True, transform=test_transform)
58+
# cifar10_test = CIFAR10('/home/omnisky/ywp/data', train=False, download=True, transform=test_transform)
5959

6060

6161
##

ll_train.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,14 @@ def main(args):
216216

217217
# Update the labeled dataset and the unlabeled dataset, respectively
218218
labeled_set += list(torch.tensor(subset)[arg][int(-0.05 * num_images):].numpy())
219-
unlabeled_set = list(torch.tensor(subset)[arg][:int(-0.05 * num_images)].numpy()) + \
220-
unlabeled_set[int(0.2 * num_images):]
219+
unlabeled_set = list(torch.tensor(subset)[arg][:int(-0.05 * num_images)].numpy()) + unlabeled_set
221220

222221
# Create a new dataloader for the updated labeled dataset
223222
train_sampler = SubsetRandomSampler(labeled_set)
224223

225-
total_time = time.time() - start_time
226-
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
227-
print('Training time {}'.format(total_time_str))
224+
total_time = time.time() - start_time
225+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
226+
print('Training time {}'.format(total_time_str))
228227

229228

230229
if __name__ == "__main__":

random_train.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def main(args):
119119
indices = list(range(num_images))
120120
random.shuffle(indices)
121121
labeled_set = indices[:int(num_images * 0.1)]
122-
print(labeled_set)
123122
unlabeled_set = indices[int(num_images * 0.1):]
124123
train_sampler = SubsetRandomSampler(labeled_set)
125124
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
@@ -158,15 +157,15 @@ def main(args):
158157
train_one_epoch(task_model, task_optimizer, data_loader, device, cycle, epoch, args.print_freq)
159158
task_lr_scheduler.step()
160159
# evaluate after pre-set epoch
161-
if (epoch + 1) == args.task_epochs or (epoch + 1) == args.total_epochs:
160+
if (epoch + 1) == args.total_epochs:
162161
if 'coco' in args.dataset:
163162
coco_evaluate(task_model, data_loader_test)
164163
elif 'voc' in args.dataset:
165164
voc_evaluate(task_model, data_loader_test)
166165
random.shuffle(unlabeled_set)
167166
# Update the labeled dataset and the unlabeled dataset, respectively
168-
labeled_set += unlabeled_set[int(0.05 * num_images):]
169-
unlabeled_set = unlabeled_set[:int(0.05 * num_images)]
167+
labeled_set += unlabeled_set[:int(0.05 * num_images)]
168+
unlabeled_set = unlabeled_set[int(0.05 * num_images):]
170169

171170
# Create a new dataloader for the updated labeled dataset
172171
train_sampler = SubsetRandomSampler(labeled_set)
@@ -188,15 +187,15 @@ def main(args):
188187
parser.add_argument('--device', default='cuda', help='device')
189188
parser.add_argument('-b', '--batch-size', default=2, type=int,
190189
help='images per gpu, the total batch size is $NGPU x batch_size')
191-
parser.add_argument('--task_epochs', default=15, type=int, metavar='N',
190+
parser.add_argument('--task_epochs', default=20, type=int, metavar='N',
192191
help='number of total epochs to run')
193-
parser.add_argument('--total_epochs', default=15, type=int, metavar='N',
192+
parser.add_argument('--total_epochs', default=20, type=int, metavar='N',
194193
help='number of total epochs to run')
195194
parser.add_argument('--cycles', default=7, type=int, metavar='N',
196195
help='number of cycles epochs to run')
197196
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
198197
help='number of data loading workers (default: 4)')
199-
parser.add_argument('--lr', default=0.005, type=float,
198+
parser.add_argument('--lr', default=0.0025, type=float,
200199
help='initial learning rate, 0.02 is the default value for training '
201200
'on 8 gpus and 2 images_per_gpu')
202201
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
@@ -205,7 +204,7 @@ def main(args):
205204
metavar='W', help='weight decay (default: 1e-4)',
206205
dest='weight_decay')
207206
parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
208-
parser.add_argument('--lr-steps', default=[10], nargs='+', type=int, help='decrease lr every step-size epochs')
207+
parser.add_argument('--lr-steps', default=[16, 19], nargs='+', type=int, help='decrease lr every step-size epochs')
209208
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
210209
parser.add_argument('--print-freq', default=1000, type=int, help='print frequency')
211210
parser.add_argument('--output-dir', default=None, help='path where to save')

0 commit comments

Comments
 (0)