Skip to content
Snippets Groups Projects
Commit 8f5c4e7f authored by Ross Girshick's avatar Ross Girshick
Browse files

basic support for a prefetch process

parent f59a7466
No related branches found
No related tags found
No related merge requests found
......@@ -74,6 +74,10 @@ __C.TRAIN.SNAPSHOT_ITERS = 10000
# infix to yield the path: <prefix>[_<infix>]_iters_XYZ.caffemodel
__C.TRAIN.SNAPSHOT_INFIX = ''
# Use a prefetch thread in roi_data_layer.layer
# So far I haven't found this useful; likely more engineering work is required
__C.TRAIN.USE_PREFETCH = False
#
# Testing options
#
......
......@@ -8,6 +8,7 @@
import caffe
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
from utils.timer import Timer
import numpy as np
import os
......@@ -71,9 +72,14 @@ class SolverWrapper(object):
def train_model(self, max_iters):
last_snapshot_iter = -1
timer = Timer()
while self.solver.iter < max_iters:
# Make one SGD update
timer.tic()
self.solver.step(1)
timer.toc()
if self.solver.iter % (10 * self.solver_param.display) == 0:
print 'speed: {:.3f}s / iter'.format(timer.average_time)
if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
last_snapshot_iter = self.solver.iter
......
......@@ -10,6 +10,7 @@ from fast_rcnn.config import cfg
from roi_data_layer.minibatch import get_minibatch
import numpy as np
import yaml
from multiprocessing import Process, queues
class DataLayer(caffe.Layer):
"""Fast R-CNN data layer."""
......@@ -26,16 +27,34 @@ class DataLayer(caffe.Layer):
self._cur += cfg.TRAIN.IMS_PER_BATCH
return db_inds
def _set_next_minibatch(self):
@staticmethod
def _prefetch(minibatch_db, num_classes, output_queue):
blobs = get_minibatch(minibatch_db, num_classes)
output_queue.put(blobs)
def _get_next_minibatch(self):
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
self._blobs = get_minibatch(minibatch_db, self._num_classes)
if cfg.TRAIN.USE_PREFETCH:
self._prefetch_process = Process(target=DataLayer._prefetch,
args=(minibatch_db,
self._num_classes,
self._prefetch_queue))
self._prefetch_process.start()
else:
return get_minibatch(minibatch_db, self._num_classes)
def set_roidb(self, roidb):
self._roidb = roidb
self._shuffle_roidb_inds()
if cfg.TRAIN.USE_PREFETCH:
self._get_next_minibatch()
def setup(self, bottom, top):
if cfg.TRAIN.USE_PREFETCH:
self._prefetch_process = None
self._prefetch_queue = queues.SimpleQueue()
layer_params = yaml.load(self.param_str_)
self._num_classes = layer_params['num_classes']
......@@ -58,24 +77,20 @@ class DataLayer(caffe.Layer):
# bbox_loss_weights
top[4].reshape(1, self._num_classes * 4)
# TODO(rbg):
# Start a prefetch thread that calls self._get_next_minibatch()
def forward(self, bottom, top):
# TODO(rbg):
# wait for prefetch thread to finish
self._set_next_minibatch()
if cfg.TRAIN.USE_PREFETCH:
blobs = self._prefetch_queue.get()
self._get_next_minibatch()
else:
blobs = self._get_next_minibatch()
for blob_name, blob in self._blobs.iteritems():
for blob_name, blob in blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
# Reshape net's input blobs
top[top_ind].reshape(*(blob.shape))
# Copy data into net's input blobs
top[top_ind].data[...] = blob.astype(np.float32, copy=False)
# TODO(rbg):
# start next prefetch thread
def backward(self, top, propagate_down, bottom):
"""This layer does not propagate gradients."""
pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment