Skip to content
Snippets Groups Projects
Commit 0946565d authored by Ross Girshick's avatar Ross Girshick
Browse files

allow for training and testing without bbox regression (bonus: code cleanup)

parent f422e466
No related branches found
No related tags found
No related merge requests found
......@@ -10,8 +10,6 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
layer {
name: "conv1"
......
......@@ -10,8 +10,6 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
layer {
name: "conv1"
......
......@@ -10,29 +10,20 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
input: "labels"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 1
dim: 1
dim: 1
}
input: "bbox_targets"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 84 # 4 * K (=21) classes
dim: 1
dim: 1
}
input: "bbox_loss_weights"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 84 # 4 * K (=21) classes
dim: 1
dim: 1
}
layer {
name: "conv1"
......
......@@ -12,8 +12,6 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
layer {
......
......@@ -12,8 +12,6 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
layer {
......
......@@ -12,32 +12,23 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
input: "labels"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 1
dim: 1
dim: 1
}
input: "bbox_targets"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 84 # 4 * K (=21) classes
dim: 1
dim: 1
}
input: "bbox_loss_weights"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 84 # 4 * K (=21) classes
dim: 1
dim: 1
}
layer {
......
......@@ -10,8 +10,6 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
layer {
name: "conv1"
......
......@@ -10,8 +10,6 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
layer {
name: "conv1"
......
......@@ -10,29 +10,20 @@ input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
input: "labels"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 1
dim: 1
dim: 1
}
input: "bbox_targets"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 84 # 4 * K (=21) classes
dim: 1
dim: 1
}
input: "bbox_loss_weights"
input_shape {
dim: 1 # to be changed on-the-fly to match num ROIs
dim: 84 # 4 * K (=21) classes
dim: 1
dim: 1
}
layer {
name: "conv1"
......
......@@ -66,6 +66,9 @@ __C.TRAIN.BG_THRESH_LO = 0.1
# Use horizontally-flipped images during training?
__C.TRAIN.USE_FLIPPED = True
# Train bounding-box regressors
__C.TRAIN.BBOX_REG = True
# Overlap required between a ROI and ground-truth box in order for that ROI to
# be used as a bounding-box regression training example
__C.TRAIN.BBOX_THRESH = 0.5
......@@ -98,6 +101,9 @@ __C.TEST.NMS = 0.3
# scores when testing
__C.TEST.BINARY = False
# Test using bounding-box regressors
__C.TEST.BBOX_REG = True
#
# MISC
#
......
......@@ -132,11 +132,8 @@ def im_detect(net, im, boxes):
boxes = boxes[index, :]
# reshape network inputs
base_shape = blobs['data'].shape
num_rois = blobs['rois'].shape[0]
net.blobs['data'].reshape(base_shape[0], base_shape[1],
base_shape[2], base_shape[3])
net.blobs['rois'].reshape(num_rois, 5, 1, 1)
net.blobs['data'].reshape(*(blobs['data'].shape))
net.blobs['rois'].reshape(*(blobs['rois'].shape))
blobs_out = net.forward(data=blobs['data'].astype(np.float32, copy=False),
rois=blobs['rois'].astype(np.float32, copy=False))
if cfg.TEST.BINARY:
......@@ -148,10 +145,14 @@ def im_detect(net, im, boxes):
# use softmax estimated probabilities
scores = blobs_out['cls_prob']
# Apply bounding-box regression deltas
box_deltas = blobs_out['bbox_pred']
pred_boxes = _bbox_pred(boxes, box_deltas)
pred_boxes = _clip_boxes(pred_boxes, im.shape)
if cfg.TEST.BBOX_REG:
# Apply bounding-box regression deltas
box_deltas = blobs_out['bbox_pred']
pred_boxes = _bbox_pred(boxes, box_deltas)
pred_boxes = _clip_boxes(pred_boxes, im.shape)
else:
# Simply repeat the boxes, once for each class
pred_boxes = np.tile(boxes, (1, scores.shape[1]))
if cfg.DEDUP_BOXES > 0:
# Map scores and predictions back to the original set of boxes
......
......@@ -33,20 +33,21 @@ class SolverWrapper(object):
pb2.text_format.Merge(f.read(), self.solver_param)
def snapshot(self):
assert self.bbox_stds is not None
assert self.bbox_means is not None
# save original values
orig_0 = self.solver.net.params['bbox_pred'][0].data.copy()
orig_1 = self.solver.net.params['bbox_pred'][1].data.copy()
# scale and shift with bbox reg unnormalization; then save snapshot
self.solver.net.params['bbox_pred'][0].data[...] = \
(self.solver.net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
self.solver.net.params['bbox_pred'][1].data[...] = \
(self.solver.net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
if cfg.TRAIN.BBOX_REG:
assert self.bbox_stds is not None
assert self.bbox_means is not None
# save original values
orig_0 = self.solver.net.params['bbox_pred'][0].data.copy()
orig_1 = self.solver.net.params['bbox_pred'][1].data.copy()
# scale and shift with bbox reg unnormalization; then save snapshot
self.solver.net.params['bbox_pred'][0].data[...] = \
(self.solver.net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
self.solver.net.params['bbox_pred'][1].data[...] = \
(self.solver.net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
output_dir = get_output_path(self.imdb, None)
if not os.path.exists(output_dir):
......@@ -61,9 +62,10 @@ class SolverWrapper(object):
self.solver.net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename)
# restore net to original state
self.solver.net.params['bbox_pred'][0].data[...] = orig_0
self.solver.net.params['bbox_pred'][1].data[...] = orig_1
if cfg.TRAIN.BBOX_REG:
# restore net to original state
self.solver.net.params['bbox_pred'][0].data[...] = orig_0
self.solver.net.params['bbox_pred'][1].data[...] = orig_1
def train_model(self, roidb, max_iters):
last_snapshot_iter = -1
......@@ -77,44 +79,16 @@ class SolverWrapper(object):
db_inds = shuffled_inds[shuffled_i:shuffled_i +
cfg.TRAIN.IMS_PER_BATCH]
minibatch_db = [roidb[i] for i in db_inds]
im_blob, rois_blob, labels_blob, \
bbox_targets_blob, bbox_loss_weights_blob = \
finetuning.get_minibatch(minibatch_db)
blobs = finetuning.get_minibatch(minibatch_db)
# Reshape net's input blobs
net = self.solver.net
base_shape = im_blob.shape
num_rois = rois_blob.shape[0]
bbox_shape = bbox_targets_blob.shape[1]
net.blobs['data'].reshape(base_shape[0], base_shape[1],
base_shape[2], base_shape[3])
net.blobs['rois'].reshape(num_rois, 5, 1, 1)
net.blobs['labels'].reshape(num_rois, 1, 1, 1)
net.blobs['bbox_targets'].reshape(num_rois, bbox_shape, 1, 1)
net.blobs['bbox_loss_weights'].reshape(num_rois, bbox_shape,
1, 1)
# Copy data into net's input blobs
net.blobs['data'].data[...] = \
im_blob.astype(np.float32, copy=False)
net.blobs['rois'].data[...] = \
rois_blob[:, :, np.newaxis, np.newaxis] \
.astype(np.float32, copy=False)
net.blobs['labels'].data[...] = \
labels_blob[:, np.newaxis, np.newaxis, np.newaxis] \
.astype(np.float32, copy=False)
net.blobs['bbox_targets'].data[...] = \
bbox_targets_blob[:, :, np.newaxis, np.newaxis] \
.astype(np.float32, copy=False)
net.blobs['bbox_loss_weights'].data[...] = \
bbox_loss_weights_blob[:, :, np.newaxis, np.newaxis] \
.astype(np.float32, copy=False)
for blob_name, blob in blobs.iteritems():
# Reshape net's input blobs
net.blobs[blob_name].reshape(*(blob.shape))
# Copy data into net's input blobs
net.blobs[blob_name].data[...] = blob.astype(np.float32,
copy=False)
# Make one SGD update
self.solver.step(1)
......
......@@ -55,7 +55,15 @@ def get_minibatch(roidb):
# For debug visualizations
# _vis_minibatch(im_blob, rois_blob, labels_blob, all_overlaps)
return im_blob, rois_blob, labels_blob, bbox_targets_blob, bbox_loss_blob
blobs = {'data': im_blob,
'rois': rois_blob,
'labels': labels_blob}
if cfg.TRAIN.BBOX_REG:
blobs['bbox_targets'] = bbox_targets_blob
blobs['bbox_loss_weights'] = bbox_loss_blob
return blobs
def _sample_rois(roidb, fg_rois_per_image, rois_per_image):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment