Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
F
fast-rcnn
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
maxmzkr
fast-rcnn
Commits
2572d763
Commit
2572d763
authored
9 years ago
by
Ross Girshick
Browse files
Options
Downloads
Patches
Plain Diff
reduce prefetch process overhead
parent
1b12f840
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
lib/roi_data_layer/layer.py
+55
-25
55 additions, 25 deletions
lib/roi_data_layer/layer.py
with
55 additions
and
25 deletions
lib/roi_data_layer/layer.py
+
55
−
25
View file @
2572d763
...
...
@@ -15,7 +15,7 @@ from fast_rcnn.config import cfg
from
roi_data_layer.minibatch
import
get_minibatch
import
numpy
as
np
import
yaml
from
multiprocessing
import
Process
,
q
ueue
s
from
multiprocessing
import
Process
,
Q
ueue
class
RoIDataLayer
(
caffe
.
Layer
):
"""
Fast R-CNN data layer used for training.
"""
...
...
@@ -34,28 +34,17 @@ class RoIDataLayer(caffe.Layer):
self
.
_cur
+=
cfg
.
TRAIN
.
IMS_PER_BATCH
return
db_inds
@staticmethod
def
_prefetch
(
minibatch_db
,
num_classes
,
output_queue
):
"""
Prefetch minibatch blobs (if enabled cfg.TRAIN.USE_PREFETCH).
"""
blobs
=
get_minibatch
(
minibatch_db
,
num_classes
)
output_queue
.
put
(
blobs
)
def
_get_next_minibatch
(
self
):
"""
Return the blobs to be used for the next minibatch.
If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a
separate process and made available through the self._prefetch_queue
queue.
separate process and made available through self._blob_queue.
"""
db_inds
=
self
.
_get_next_minibatch_inds
()
minibatch_db
=
[
self
.
_roidb
[
i
]
for
i
in
db_inds
]
if
cfg
.
TRAIN
.
USE_PREFETCH
:
self
.
_prefetch_process
=
Process
(
target
=
RoIDataLayer
.
_prefetch
,
args
=
(
minibatch_db
,
self
.
_num_classes
,
self
.
_prefetch_queue
))
self
.
_prefetch_process
.
start
()
return
self
.
_blob_queue
.
get
()
else
:
db_inds
=
self
.
_get_next_minibatch_inds
()
minibatch_db
=
[
self
.
_roidb
[
i
]
for
i
in
db_inds
]
return
get_minibatch
(
minibatch_db
,
self
.
_num_classes
)
def
set_roidb
(
self
,
roidb
):
...
...
@@ -63,13 +52,21 @@ class RoIDataLayer(caffe.Layer):
self
.
_roidb
=
roidb
self
.
_shuffle_roidb_inds
()
if
cfg
.
TRAIN
.
USE_PREFETCH
:
self
.
_get_next_minibatch
()
self
.
_blob_queue
=
Queue
(
10
)
self
.
_prefetch_process
=
BlobFetcher
(
self
.
_blob_queue
,
self
.
_roidb
,
self
.
_num_classes
)
self
.
_prefetch_process
.
start
()
# Terminate the child process when the parent exists
def
cleanup
():
print
'
Terminating BlobFetcher
'
self
.
_prefetch_process
.
terminate
()
self
.
_prefetch_process
.
join
()
import
atexit
atexit
.
register
(
cleanup
)
def
setup
(
self
,
bottom
,
top
):
"""
Setup the RoIDataLayer.
"""
if
cfg
.
TRAIN
.
USE_PREFETCH
:
self
.
_prefetch_process
=
None
self
.
_prefetch_queue
=
queues
.
SimpleQueue
()
# parse the layer parameter string, which must be valid YAML
layer_params
=
yaml
.
load
(
self
.
param_str_
)
...
...
@@ -106,11 +103,7 @@ class RoIDataLayer(caffe.Layer):
def
forward
(
self
,
bottom
,
top
):
"""
Get blobs and copy them into this layer
'
s top blob vector.
"""
if
cfg
.
TRAIN
.
USE_PREFETCH
:
blobs
=
self
.
_prefetch_queue
.
get
()
self
.
_get_next_minibatch
()
else
:
blobs
=
self
.
_get_next_minibatch
()
blobs
=
self
.
_get_next_minibatch
()
for
blob_name
,
blob
in
blobs
.
iteritems
():
top_ind
=
self
.
_name_to_top_map
[
blob_name
]
...
...
@@ -126,3 +119,40 @@ class RoIDataLayer(caffe.Layer):
def
reshape
(
self
,
bottom
,
top
):
"""
Reshaping happens during the call to forward.
"""
pass
class
BlobFetcher
(
Process
):
"""
Experimental class for prefetching blobs in a separate process.
"""
def
__init__
(
self
,
queue
,
roidb
,
num_classes
):
super
(
BlobFetcher
,
self
).
__init__
()
self
.
_queue
=
queue
self
.
_roidb
=
roidb
self
.
_num_classes
=
num_classes
self
.
_perm
=
None
self
.
_cur
=
0
self
.
_shuffle_roidb_inds
()
# fix the random seed for reproducibility
np
.
random
.
seed
(
cfg
.
RNG_SEED
)
def
_shuffle_roidb_inds
(
self
):
"""
Randomly permute the training roidb.
"""
# TODO(rbg): remove duplicated code
self
.
_perm
=
np
.
random
.
permutation
(
np
.
arange
(
len
(
self
.
_roidb
)))
self
.
_cur
=
0
def
_get_next_minibatch_inds
(
self
):
"""
Return the roidb indices for the next minibatch.
"""
# TODO(rbg): remove duplicated code
if
self
.
_cur
+
cfg
.
TRAIN
.
IMS_PER_BATCH
>=
len
(
self
.
_roidb
):
self
.
_shuffle_roidb_inds
()
db_inds
=
self
.
_perm
[
self
.
_cur
:
self
.
_cur
+
cfg
.
TRAIN
.
IMS_PER_BATCH
]
self
.
_cur
+=
cfg
.
TRAIN
.
IMS_PER_BATCH
return
db_inds
def
run
(
self
):
print
'
BlobFetcher started
'
while
True
:
db_inds
=
self
.
_get_next_minibatch_inds
()
minibatch_db
=
[
self
.
_roidb
[
i
]
for
i
in
db_inds
]
blobs
=
get_minibatch
(
minibatch_db
,
self
.
_num_classes
)
self
.
_queue
.
put
(
blobs
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment