Skip to content
Snippets Groups Projects
Commit 2572d763 authored by Ross Girshick's avatar Ross Girshick
Browse files

reduce prefetch process overhead

parent 1b12f840
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,7 @@ from fast_rcnn.config import cfg
from roi_data_layer.minibatch import get_minibatch
import numpy as np
import yaml
from multiprocessing import Process, queues
from multiprocessing import Process, Queue
class RoIDataLayer(caffe.Layer):
"""Fast R-CNN data layer used for training."""
......@@ -34,28 +34,17 @@ class RoIDataLayer(caffe.Layer):
self._cur += cfg.TRAIN.IMS_PER_BATCH
return db_inds
@staticmethod
def _prefetch(minibatch_db, num_classes, output_queue):
"""Prefetch minibatch blobs (if enabled cfg.TRAIN.USE_PREFETCH)."""
blobs = get_minibatch(minibatch_db, num_classes)
output_queue.put(blobs)
def _get_next_minibatch(self):
"""Return the blobs to be used for the next minibatch.
If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a
separate process and made available through the self._prefetch_queue
queue.
separate process and made available through self._blob_queue.
"""
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
if cfg.TRAIN.USE_PREFETCH:
self._prefetch_process = Process(target=RoIDataLayer._prefetch,
args=(minibatch_db,
self._num_classes,
self._prefetch_queue))
self._prefetch_process.start()
return self._blob_queue.get()
else:
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
return get_minibatch(minibatch_db, self._num_classes)
def set_roidb(self, roidb):
......@@ -63,13 +52,21 @@ class RoIDataLayer(caffe.Layer):
self._roidb = roidb
self._shuffle_roidb_inds()
if cfg.TRAIN.USE_PREFETCH:
self._get_next_minibatch()
self._blob_queue = Queue(10)
self._prefetch_process = BlobFetcher(self._blob_queue,
self._roidb,
self._num_classes)
self._prefetch_process.start()
# Terminate the child process when the parent exists
def cleanup():
print 'Terminating BlobFetcher'
self._prefetch_process.terminate()
self._prefetch_process.join()
import atexit
atexit.register(cleanup)
def setup(self, bottom, top):
"""Setup the RoIDataLayer."""
if cfg.TRAIN.USE_PREFETCH:
self._prefetch_process = None
self._prefetch_queue = queues.SimpleQueue()
# parse the layer parameter string, which must be valid YAML
layer_params = yaml.load(self.param_str_)
......@@ -106,11 +103,7 @@ class RoIDataLayer(caffe.Layer):
def forward(self, bottom, top):
"""Get blobs and copy them into this layer's top blob vector."""
if cfg.TRAIN.USE_PREFETCH:
blobs = self._prefetch_queue.get()
self._get_next_minibatch()
else:
blobs = self._get_next_minibatch()
blobs = self._get_next_minibatch()
for blob_name, blob in blobs.iteritems():
top_ind = self._name_to_top_map[blob_name]
......@@ -126,3 +119,40 @@ class RoIDataLayer(caffe.Layer):
def reshape(self, bottom, top):
"""Reshaping happens during the call to forward."""
pass
class BlobFetcher(Process):
"""Experimental class for prefetching blobs in a separate process."""
def __init__(self, queue, roidb, num_classes):
super(BlobFetcher, self).__init__()
self._queue = queue
self._roidb = roidb
self._num_classes = num_classes
self._perm = None
self._cur = 0
self._shuffle_roidb_inds()
# fix the random seed for reproducibility
np.random.seed(cfg.RNG_SEED)
def _shuffle_roidb_inds(self):
"""Randomly permute the training roidb."""
# TODO(rbg): remove duplicated code
self._perm = np.random.permutation(np.arange(len(self._roidb)))
self._cur = 0
def _get_next_minibatch_inds(self):
"""Return the roidb indices for the next minibatch."""
# TODO(rbg): remove duplicated code
if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):
self._shuffle_roidb_inds()
db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]
self._cur += cfg.TRAIN.IMS_PER_BATCH
return db_inds
def run(self):
print 'BlobFetcher started'
while True:
db_inds = self._get_next_minibatch_inds()
minibatch_db = [self._roidb[i] for i in db_inds]
blobs = get_minibatch(minibatch_db, self._num_classes)
self._queue.put(blobs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment