Skip to content
Snippets Groups Projects
Commit f59a7466 authored by Ross Girshick's avatar Ross Girshick
Browse files

start of big refactoring; python layer for roi data

parent 2ef01a9c
No related branches found
No related tags found
No related merge requests found
......@@ -85,7 +85,7 @@ class imdb(object):
"""
raise NotImplementedError
def append_flipped_roidb(self):
def append_flipped_images(self):
num_images = self.num_images
widths = [PIL.Image.open(self.image_path_at(i)).size[0]
for i in xrange(num_images)]
......
......@@ -13,7 +13,7 @@
# and use cfg_from_file(yaml_file) to load it and override the default options.
#
# - See tools/{train,test}_net.py for example code that uses cfg_from_file().
# - See examples/multiscale.yml for an example YAML config override file.
# - See experiments/cfgs/*.yml for example YAML config override files.
#
import os
......@@ -125,7 +125,7 @@ __C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))
# Place outputs under an experiments directory
__C.EXP_DIR = 'default'
def get_output_path(imdb, net):
def get_output_dir(imdb, net):
path = osp.abspath(osp.join(__C.ROOT_DIR, 'output', __C.EXP_DIR, imdb.name))
if net is None:
return path
......
......@@ -5,7 +5,7 @@
# Written by Ross Girshick
# --------------------------------------------------------
from fast_rcnn.config import cfg, get_output_path
from fast_rcnn.config import cfg, get_output_dir
import argparse
from utils.timer import Timer
import numpy as np
......@@ -212,7 +212,7 @@ def test_net(net, imdb):
all_boxes = [[[] for _ in xrange(num_images)]
for _ in xrange(imdb.num_classes)]
output_dir = get_output_path(imdb, net)
output_dir = get_output_dir(imdb, net)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
......
......@@ -5,154 +5,100 @@
# Written by Ross Girshick
# --------------------------------------------------------
from fast_rcnn.config import cfg, get_output_path
import numpy as np
import caffe
import fast_rcnn.finetuning as finetuning
import fast_rcnn.bbox_regression_targets as bbox_regression_targets
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
import numpy as np
import os
from caffe.proto import caffe_pb2
import google.protobuf as pb2
class SolverWrapper(object):
def __init__(self, solver_prototxt, imdb, pretrained_model=None):
self.bbox_means = None
self.bbox_stds = None
self.imdb = imdb
def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
self.output_dir = output_dir
print 'Computing bounding-box regression targets...'
self.bbox_means, self.bbox_stds = \
rdl_roidb.add_bbox_regression_targets(roidb)
print 'done'
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
print 'Loading pretrained model weights from {:s}' \
.format(pretrained_model)
print ('Loading pretrained model '
'weights from {:s}').format(pretrained_model)
self.solver.net.copy_from(pretrained_model)
self.solver_param = caffe_pb2.SolverParameter()
with open(solver_prototxt, 'rt') as f:
pb2.text_format.Merge(f.read(), self.solver_param)
self.solver.net.layers[0].set_roidb(roidb)
def snapshot(self):
if cfg.TRAIN.BBOX_REG:
assert self.bbox_stds is not None
assert self.bbox_means is not None
net = self.solver.net
if cfg.TRAIN.BBOX_REG:
# save original values
orig_0 = self.solver.net.params['bbox_pred'][0].data.copy()
orig_1 = self.solver.net.params['bbox_pred'][1].data.copy()
orig_0 = net.params['bbox_pred'][0].data.copy()
orig_1 = net.params['bbox_pred'][1].data.copy()
# scale and shift with bbox reg unnormalization; then save snapshot
self.solver.net.params['bbox_pred'][0].data[...] = \
(self.solver.net.params['bbox_pred'][0].data *
net.params['bbox_pred'][0].data[...] = \
(net.params['bbox_pred'][0].data *
self.bbox_stds[:, np.newaxis])
self.solver.net.params['bbox_pred'][1].data[...] = \
(self.solver.net.params['bbox_pred'][1].data *
net.params['bbox_pred'][1].data[...] = \
(net.params['bbox_pred'][1].data *
self.bbox_stds + self.bbox_means)
output_dir = get_output_path(self.imdb, None)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
filename = self.solver_param.snapshot_prefix + infix + \
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel'
filename = os.path.join(output_dir, filename)
filename = (self.solver_param.snapshot_prefix + infix +
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
filename = os.path.join(self.output_dir, filename)
self.solver.net.save(str(filename))
net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename)
if cfg.TRAIN.BBOX_REG:
# restore net to original state
self.solver.net.params['bbox_pred'][0].data[...] = orig_0
self.solver.net.params['bbox_pred'][1].data[...] = orig_1
net.params['bbox_pred'][0].data[...] = orig_0
net.params['bbox_pred'][1].data[...] = orig_1
def train_model(self, roidb, max_iters):
def train_model(self, max_iters):
last_snapshot_iter = -1
while self.solver.iter < max_iters:
shuffled_inds = np.random.permutation(np.arange(len(roidb)))
lim = (len(shuffled_inds) / cfg.TRAIN.IMS_PER_BATCH) * \
cfg.TRAIN.IMS_PER_BATCH
shuffled_inds = shuffled_inds[0:lim]
for shuffled_i in xrange(0, len(shuffled_inds),
cfg.TRAIN.IMS_PER_BATCH):
db_inds = shuffled_inds[shuffled_i:shuffled_i +
cfg.TRAIN.IMS_PER_BATCH]
minibatch_db = [roidb[i] for i in db_inds]
blobs = finetuning.get_minibatch(minibatch_db)
net = self.solver.net
for blob_name, blob in blobs.iteritems():
# Reshape net's input blobs
net.blobs[blob_name].reshape(*(blob.shape))
# Copy data into net's input blobs
net.blobs[blob_name].data[...] = blob.astype(np.float32,
copy=False)
# Make one SGD update
self.solver.step(1)
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
self.snapshot()
if self.solver.iter >= max_iters:
break
# Make one SGD update
self.solver.step(1)
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
self.snapshot()
if last_snapshot_iter != self.solver.iter:
self.snapshot()
def prepare_training_roidb(imdb):
"""
Enrich the imdb's roidb by adding some derived quantities that
are useful for training. This function precomputes the maximum
overlap, taken over ground-truth boxes, between each ROI and
each ground-truth box. The class with maximum overlap is also
recorded.
"""
roidb = imdb.roidb
for i in xrange(len(imdb.image_index)):
roidb[i]['image'] = imdb.image_path_at(i)
# need gt_overlaps as a dense array for argmax
gt_overlaps = roidb[i]['gt_overlaps'].toarray()
# max overlap with gt over classes (columns)
max_overlaps = gt_overlaps.max(axis=1)
# gt class that had the max overlap
max_classes = gt_overlaps.argmax(axis=1)
roidb[i]['max_classes'] = max_classes
roidb[i]['max_overlaps'] = max_overlaps
# sanity checks
# max overlap of 0 => class should be zero (background)
zero_inds = np.where(max_overlaps == 0)[0]
assert all(max_classes[zero_inds] == 0)
# max overlap > 0 => class should not be zero (must be a fg class)
nonzero_inds = np.where(max_overlaps > 0)[0]
assert all(max_classes[nonzero_inds] != 0)
return roidb
def train_net(solver_prototxt, imdb, pretrained_model=None, max_iters=40000):
# enhance roidb to contain flipped examples
def get_training_roidb(imdb):
if cfg.TRAIN.USE_FLIPPED:
print 'Appending horizontally-flipped training examples...'
imdb.append_flipped_roidb()
imdb.append_flipped_images()
print 'done'
# enhance roidb to contain some useful derived quanties
print 'Preparing training data...'
roidb = prepare_training_roidb(imdb)
rdl_roidb.prepare_roidb(imdb)
print 'done'
# enhance roidb to contain bounding-box regression targets
print 'Computing bounding-box regression targets...'
means, stds = \
bbox_regression_targets.append_bbox_regression_targets(roidb)
print 'done'
return imdb.roidb
sw = SolverWrapper(solver_prototxt, imdb, pretrained_model=pretrained_model)
sw.bbox_means = means
sw.bbox_stds = stds
def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=40000):
sw = SolverWrapper(solver_prototxt, roidb, output_dir,
pretrained_model=pretrained_model)
print 'Solving...'
sw.train_model(roidb, max_iters=max_iters)
sw.train_model(max_iters)
print 'done solving'
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
import caffe
from fast_rcnn.config import cfg
from roi_data_layer.minibatch import get_minibatch
import numpy as np
import yaml
class DataLayer(caffe.Layer):
"""Fast R-CNN data layer."""
def _shuffle_roidb_inds(self):
self._perm = np.random.permutation(np.arange(len(self._roidb)))
self._cur = 0
def _get_next_minibatch_inds(self):
if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):
self._shuffle_roidb_inds()
db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]
self._cur += cfg.TRAIN.IMS_PER_BATCH
return db_inds
def _set_next_minibatch(self):
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
self._blobs = get_minibatch(minibatch_db, self._num_classes)
def set_roidb(self, roidb):
self._roidb = roidb
self._shuffle_roidb_inds()
def setup(self, bottom, top):
layer_params = yaml.load(self.param_str_)
self._num_classes = layer_params['num_classes']
self._name_to_top_map = {
'data': 0,
'rois': 1,
'labels': 2,
'bbox_targets': 3,
'bbox_loss_weights': 4}
# data
top[0].reshape(1, 3, 1, 1)
# rois
top[1].reshape(1, 5)
# labels
top[2].reshape(1)
# bbox_targets
top[3].reshape(1, self._num_classes * 4)
# bbox_loss_weights
top[4].reshape(1, self._num_classes * 4)
# TODO(rbg):
# Start a prefetch thread that calls self._get_next_minibatch()
def forward(self, bottom, top):
# TODO(rbg):
# wait for prefetch thread to finish
self._set_next_minibatch()
for blob_name, blob in self._blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
# Reshape net's input blobs
top[top_ind].reshape(*(blob.shape))
# Copy data into net's input blobs
top[top_ind].data[...] = blob.astype(np.float32, copy=False)
# TODO(rbg):
# start next prefetch thread
def backward(self, top, propagate_down, bottom):
"""This layer does not propagate gradients."""
pass
def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass
......@@ -8,17 +8,14 @@
import numpy as np
import numpy.random as npr
import cv2
import matplotlib.pyplot as plt
from fast_rcnn.config import cfg
from utils.blob import prep_im_for_blob, im_list_to_blob
def get_minibatch(roidb):
def get_minibatch(roidb, num_classes):
"""
Given a roidb, construct a minibatch sampled from it.
"""
num_images = len(roidb)
# Infer number of classes from the number of columns in gt_overlaps
num_classes = roidb[0]['gt_overlaps'].shape[1]
# Sample random scales to use for each image in this batch
random_scale_inds = npr.randint(0, high=len(cfg.TRAIN.SCALES),
size=num_images)
......@@ -39,7 +36,8 @@ def get_minibatch(roidb):
# all_overlaps = []
for im_i in xrange(num_images):
labels, overlaps, im_rois, bbox_targets, bbox_loss \
= _sample_rois(roidb[im_i], fg_rois_per_image, rois_per_image)
= _sample_rois(roidb[im_i], fg_rois_per_image, rois_per_image,
num_classes)
# Add to RoIs blob
rois = _scale_im_rois(im_rois, im_scales[im_i])
......@@ -66,7 +64,7 @@ def get_minibatch(roidb):
return blobs
def _sample_rois(roidb, fg_rois_per_image, rois_per_image):
def _sample_rois(roidb, fg_rois_per_image, rois_per_image, num_classes):
"""
Generate a random sample of RoIs comprising foreground and background
examples.
......@@ -108,8 +106,6 @@ def _sample_rois(roidb, fg_rois_per_image, rois_per_image):
overlaps = overlaps[keep_inds]
rois = rois[keep_inds]
# Infer number of classes from the number of columns in gt_overlaps
num_classes = roidb['gt_overlaps'].shape[1]
bbox_targets, bbox_loss_weights = \
_get_bbox_regression_labels(roidb['bbox_targets'][keep_inds, :],
num_classes)
......@@ -167,6 +163,7 @@ def _get_bbox_regression_labels(bbox_target_data, num_classes):
def _vis_minibatch(im_blob, rois_blob, labels_blob, overlaps):
"""Visualize a mini-batch for debugging."""
import matplotlib.pyplot as plt
for i in xrange(rois_blob.shape[0]):
rois = rois_blob[i, :]
im_ind = rois[0]
......
......@@ -9,6 +9,76 @@ import numpy as np
from fast_rcnn.config import cfg
import utils.cython_bbox
def prepare_roidb(imdb):
"""
Enrich the imdb's roidb by adding some derived quantities that
are useful for training. This function precomputes the maximum
overlap, taken over ground-truth boxes, between each ROI and
each ground-truth box. The class with maximum overlap is also
recorded.
"""
roidb = imdb.roidb
for i in xrange(len(imdb.image_index)):
roidb[i]['image'] = imdb.image_path_at(i)
# need gt_overlaps as a dense array for argmax
gt_overlaps = roidb[i]['gt_overlaps'].toarray()
# max overlap with gt over classes (columns)
max_overlaps = gt_overlaps.max(axis=1)
# gt class that had the max overlap
max_classes = gt_overlaps.argmax(axis=1)
roidb[i]['max_classes'] = max_classes
roidb[i]['max_overlaps'] = max_overlaps
# sanity checks
# max overlap of 0 => class should be zero (background)
zero_inds = np.where(max_overlaps == 0)[0]
assert all(max_classes[zero_inds] == 0)
# max overlap > 0 => class should not be zero (must be a fg class)
nonzero_inds = np.where(max_overlaps > 0)[0]
assert all(max_classes[nonzero_inds] != 0)
def add_bbox_regression_targets(roidb):
assert len(roidb) > 0
assert 'max_classes' in roidb[0], 'Did you call prepare_roidb first?'
num_images = len(roidb)
# Infer number of classes from the number of columns in gt_overlaps
num_classes = roidb[0]['gt_overlaps'].shape[1]
for im_i in xrange(num_images):
rois = roidb[im_i]['boxes']
max_overlaps = roidb[im_i]['max_overlaps']
max_classes = roidb[im_i]['max_classes']
roidb[im_i]['bbox_targets'] = \
_compute_targets(rois, max_overlaps, max_classes)
# Compute values needed for means and stds
# var(x) = E(x^2) - E(x)^2
class_counts = np.zeros((num_classes, 1)) + cfg.EPS
sums = np.zeros((num_classes, 4))
squared_sums = np.zeros((num_classes, 4))
for im_i in xrange(num_images):
targets = roidb[im_i]['bbox_targets']
for cls in xrange(1, num_classes):
cls_inds = np.where(targets[:, 0] == cls)[0]
if cls_inds.size > 0:
class_counts[cls] += cls_inds.size
sums[cls, :] += targets[cls_inds, 1:].sum(axis=0)
squared_sums[cls, :] += (targets[cls_inds, 1:] ** 2).sum(axis=0)
means = sums / class_counts
stds = np.sqrt(squared_sums / class_counts - means ** 2)
# Normalize targets
for im_i in xrange(num_images):
targets = roidb[im_i]['bbox_targets']
for cls in xrange(1, num_classes):
cls_inds = np.where(targets[:, 0] == cls)[0]
roidb[im_i]['bbox_targets'][cls_inds, 1:] -= means[cls, :]
roidb[im_i]['bbox_targets'][cls_inds, 1:] /= stds[cls, :]
# These values will be needed for making predictions
# (the predicts will need to be unnormalized and uncentered)
return means.ravel(), stds.ravel()
def _compute_targets(rois, overlaps, labels):
# Ensure ROIs are floats
rois = rois.astype(np.float, copy=False)
......@@ -50,45 +120,3 @@ def _compute_targets(rois, overlaps, labels):
targets[ex_inds, 3] = targets_dw
targets[ex_inds, 4] = targets_dh
return targets
def append_bbox_regression_targets(roidb):
num_images = len(roidb)
# Infer number of classes from the number of columns in gt_overlaps
num_classes = roidb[0]['gt_overlaps'].shape[1]
for im_i in xrange(num_images):
rois = roidb[im_i]['boxes']
max_overlaps = roidb[im_i]['max_overlaps']
max_classes = roidb[im_i]['max_classes']
roidb[im_i]['bbox_targets'] = \
_compute_targets(rois, max_overlaps, max_classes)
# Compute values needed for means and stds
# var(x) = E(x^2) - E(x)^2
class_counts = np.zeros((num_classes, 1)) + cfg.EPS
sums = np.zeros((num_classes, 4))
squared_sums = np.zeros((num_classes, 4))
for im_i in xrange(num_images):
targets = roidb[im_i]['bbox_targets']
for cls in xrange(1, num_classes):
cls_inds = np.where(targets[:, 0] == cls)[0]
if cls_inds.size > 0:
class_counts[cls] += cls_inds.size
sums[cls, :] += targets[cls_inds, 1:].sum(axis=0)
squared_sums[cls, :] += (targets[cls_inds, 1:] ** 2).sum(axis=0)
means = sums / class_counts
stds = np.sqrt(squared_sums / class_counts - means ** 2)
# Normalize targets
for im_i in xrange(num_images):
targets = roidb[im_i]['bbox_targets']
for cls in xrange(1, num_classes):
cls_inds = np.where(targets[:, 0] == cls)[0]
roidb[im_i]['bbox_targets'][cls_inds, 1:] \
-= means[cls, :]
roidb[im_i]['bbox_targets'][cls_inds, 1:] \
/= stds[cls, :]
# These values will be needed for making predictions
# (the predicts will need to be unnormalized and uncentered)
return means.ravel(), stds.ravel()
......@@ -8,8 +8,8 @@
# --------------------------------------------------------
import _init_paths
import fast_rcnn as frc
from fast_rcnn.config import cfg, cfg_from_file
from fast_rcnn.train import get_training_roidb, train_net
from fast_rcnn.config import cfg, cfg_from_file, get_output_dir
from datasets.factory import get_imdb
import caffe
import argparse
......@@ -67,7 +67,11 @@ if __name__ == '__main__':
imdb_train = get_imdb(args.imdb_name)
print 'Loaded dataset `{:s}` for training'.format(imdb_train.name)
roidb = get_training_roidb(imdb_train)
frc.train.train_net(args.solver, imdb_train,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)
output_dir = get_output_dir(imdb_train, None)
print 'Output will be saved to `{:s}`'.format(output_dir)
train_net(args.solver, roidb, output_dir,
pretrained_model=args.pretrained_model,
max_iters=args.max_iters)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment