Skip to content
Snippets Groups Projects
Commit 2de94e3a authored by Ross Girshick's avatar Ross Girshick
Browse files

replace FEAT_STRIDE with spatial_scale

parent 3f383eb5
No related branches found
No related tags found
No related merge requests found
......@@ -194,6 +194,7 @@ layer {
roi_pooling_param {
pooled_w: 6
pooled_h: 6
spatial_scale: 0.0625 # 1/16
}
}
layer {
......
......@@ -194,6 +194,7 @@ layer {
roi_pooling_param {
pooled_w: 6
pooled_h: 6
spatial_scale: 0.0625 # 1/16
}
}
layer {
......
......@@ -215,6 +215,7 @@ layer {
roi_pooling_param {
pooled_w: 6
pooled_h: 6
spatial_scale: 0.0625 # 1/16
}
}
layer {
......
name: "VGG_ILSVRC_16_layers"
input: "data"
input_shape {
dim: 1
dim: 3
dim: 224
dim: 224
}
input: "rois"
input_shape {
dim: 1 # to be changed on-the-fly to num ROIs
dim: 5 # [batch ind, x1, y1, x2, y2] zero-based indexing
dim: 1
dim: 1
}
layer {
name: "conv1_1"
type: "Convolution"
bottom: "data"
top: "conv1_1"
param {
lr_mult: 0
decay_mult: 0
}
param {
lr_mult: 0
decay_mult: 0
}
convolution_param {
num_output: 64
pad: 1
kernel_size: 3
}
}
layer {
name: "relu1_1"
type: "ReLU"
bottom: "conv1_1"
top: "conv1_1"
}
layer {
name: "conv1_2"
type: "Convolution"
bottom: "conv1_1"
top: "conv1_2"
param {
lr_mult: 0
decay_mult: 0
}
param {
lr_mult: 0
decay_mult: 0
}
convolution_param {
num_output: 64
pad: 1
kernel_size: 3
}
}
layer {
name: "relu1_2"
type: "ReLU"
bottom: "conv1_2"
top: "conv1_2"
}
layer {
name: "pool1"
type: "Pooling"
bottom: "conv1_2"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv2_1"
type: "Convolution"
bottom: "pool1"
top: "conv2_1"
param {
lr_mult: 0
decay_mult: 0
}
param {
lr_mult: 0
decay_mult: 0
}
convolution_param {
num_output: 128
pad: 1
kernel_size: 3
}
}
layer {
name: "relu2_1"
type: "ReLU"
bottom: "conv2_1"
top: "conv2_1"
}
layer {
name: "conv2_2"
type: "Convolution"
bottom: "conv2_1"
top: "conv2_2"
param {
lr_mult: 0
decay_mult: 0
}
param {
lr_mult: 0
decay_mult: 0
}
convolution_param {
num_output: 128
pad: 1
kernel_size: 3
}
}
layer {
name: "relu2_2"
type: "ReLU"
bottom: "conv2_2"
top: "conv2_2"
}
layer {
name: "pool2"
type: "Pooling"
bottom: "conv2_2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv3_1"
type: "Convolution"
bottom: "pool2"
top: "conv3_1"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 256
pad: 1
kernel_size: 3
}
}
layer {
name: "relu3_1"
type: "ReLU"
bottom: "conv3_1"
top: "conv3_1"
}
layer {
name: "conv3_2"
type: "Convolution"
bottom: "conv3_1"
top: "conv3_2"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 256
pad: 1
kernel_size: 3
}
}
layer {
name: "relu3_2"
type: "ReLU"
bottom: "conv3_2"
top: "conv3_2"
}
layer {
name: "conv3_3"
type: "Convolution"
bottom: "conv3_2"
top: "conv3_3"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 256
pad: 1
kernel_size: 3
}
}
layer {
name: "relu3_3"
type: "ReLU"
bottom: "conv3_3"
top: "conv3_3"
}
layer {
name: "pool3"
type: "Pooling"
bottom: "conv3_3"
top: "pool3"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv4_1"
type: "Convolution"
bottom: "pool3"
top: "conv4_1"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
}
}
layer {
name: "relu4_1"
type: "ReLU"
bottom: "conv4_1"
top: "conv4_1"
}
layer {
name: "conv4_2"
type: "Convolution"
bottom: "conv4_1"
top: "conv4_2"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
}
}
layer {
name: "relu4_2"
type: "ReLU"
bottom: "conv4_2"
top: "conv4_2"
}
layer {
name: "conv4_3"
type: "Convolution"
bottom: "conv4_2"
top: "conv4_3"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
}
}
layer {
name: "relu4_3"
type: "ReLU"
bottom: "conv4_3"
top: "conv4_3"
}
layer {
name: "pool4"
type: "Pooling"
bottom: "conv4_3"
top: "pool4"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv5_1"
type: "Convolution"
bottom: "pool4"
top: "conv5_1"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
}
}
layer {
name: "relu5_1"
type: "ReLU"
bottom: "conv5_1"
top: "conv5_1"
}
layer {
name: "conv5_2"
type: "Convolution"
bottom: "conv5_1"
top: "conv5_2"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
}
}
layer {
name: "relu5_2"
type: "ReLU"
bottom: "conv5_2"
top: "conv5_2"
}
layer {
name: "conv5_3"
type: "Convolution"
bottom: "conv5_2"
top: "conv5_3"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
convolution_param {
num_output: 512
pad: 1
kernel_size: 3
}
}
layer {
name: "relu5_3"
type: "ReLU"
bottom: "conv5_3"
top: "conv5_3"
}
layer {
name: "roi_pool5"
type: "ROIPooling"
bottom: "conv5_3"
bottom: "rois"
top: "pool5"
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625 # 1/16
}
}
layer {
name: "fc6_L"
type: "InnerProduct"
bottom: "pool5"
top: "fc6_L"
param {
lr_mult: 1
decay_mult: 1
}
inner_product_param {
num_output: 1024
bias_term: false
}
}
layer {
name: "fc6_U"
type: "InnerProduct"
bottom: "fc6_L"
top: "fc6_U"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 4096
}
}
layer {
name: "relu6"
type: "ReLU"
bottom: "fc6_U"
top: "fc6_U"
}
layer {
name: "drop6"
type: "Dropout"
bottom: "fc6_U"
top: "fc6_U"
dropout_param {
dropout_ratio: 0.5
}
}
layer {
name: "fc7_L"
type: "InnerProduct"
bottom: "fc6_U"
top: "fc7_L"
param {
lr_mult: 1
decay_mult: 1
}
inner_product_param {
num_output: 256
bias_term: false
}
}
layer {
name: "fc7_U"
type: "InnerProduct"
bottom: "fc7_L"
top: "fc7_U"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 4096
}
}
layer {
name: "relu7"
type: "ReLU"
bottom: "fc7_U"
top: "fc7_U"
}
layer {
name: "drop7"
type: "Dropout"
bottom: "fc7_U"
top: "fc7_U"
dropout_param {
dropout_ratio: 0.5
}
}
layer {
name: "cls_score"
type: "InnerProduct"
bottom: "fc7_U"
top: "cls_score"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 21
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "bbox_pred"
type: "InnerProduct"
bottom: "fc7_U"
top: "bbox_pred"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: 84
weight_filler {
type: "gaussian"
std: 0.001
}
bias_filler {
type: "constant"
value: 0
}
}
}
layer {
name: "cls_prob"
type: "Softmax"
bottom: "cls_score"
top: "cls_prob"
}
......@@ -394,6 +394,7 @@ layer {
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625 # 1/16
}
}
layer {
......
......@@ -418,6 +418,7 @@ layer {
roi_pooling_param {
pooled_w: 7
pooled_h: 7
spatial_scale: 0.0625 # 1/16
}
}
layer {
......
......@@ -194,6 +194,7 @@ layer {
roi_pooling_param {
pooled_w: 6
pooled_h: 6
spatial_scale: 0.0625 # 1/16
}
}
layer {
......
......@@ -194,6 +194,7 @@ layer {
roi_pooling_param {
pooled_w: 6
pooled_h: 6
spatial_scale: 0.0625 # 1/16
}
}
layer {
......
......@@ -215,6 +215,7 @@ layer {
roi_pooling_param {
pooled_w: 6
pooled_h: 6
spatial_scale: 0.0625 # 1/16
}
}
layer {
......
......@@ -17,14 +17,15 @@
#
import os
import os.path as osp
import sys
import numpy as np
# `pip install easydict` if you don't have it
from easydict import EasyDict as edict
# Add caffe to PYTHONPATH
caffe_path = os.path.abspath(os.path.join(os.path.dirname(__file__),
'..', 'caffe-fast-rcnn', 'python'))
caffe_path = osp.abspath(osp.join(osp.dirname(__file__), '..',
'caffe-fast-rcnn', 'python'))
sys.path.insert(0, caffe_path)
__C = edict()
......@@ -101,22 +102,26 @@ __C.TEST.BINARY = False
# MISC
#
# The mapping from image coordinates to feature map coordinates might cause
# some boxes that are distinct in image space to become identical in feature
# coordinates. If DEDUP_BOXES > 0, then DEDUP_BOXES is used as the scale factor
# for identifying duplicate boxes.
# 1/16 is correct for {Alex,Caffe}Net, VGG_CNN_M_1024, and VGG_16
__C.DEDUP_BOXES = 1./16.
# Pixel mean values (BGR order) as a (1, 1, 3) array
# These are the values originally used for training VGG_16
__C.PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]])
# Stride in input image pixels at ROI pooling level (network specific)
# 16 is true for {Alex,Caffe}Net, VGG_CNN_M_1024, and VGG16
# If your network has a different stride (e.g., VGG_CNN_S has stride 12)
# make sure to override this in a config file!)
__C.FEAT_STRIDE = 16
# For reproducibility
__C.RNG_SEED = 3
# A small number that's used many times
__C.EPS = 1e-14
# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..'))
def merge_a_into_b(a, b):
"""
Merge config dictionary a into config dictionary b, clobbering the options
......
......@@ -45,11 +45,11 @@ def _get_image_blob(im):
return blob, np.array(im_scale_factors)
def _get_rois_blob(im_rois, im_scale_factors):
feat_rois, levels = _map_im_rois_to_feat_rois(im_rois, im_scale_factors)
rois_blob = np.hstack((levels, feat_rois))[:, :, np.newaxis, np.newaxis]
rois, levels = _scale_im_rois(im_rois, im_scale_factors)
rois_blob = np.hstack((levels, rois))[:, :, np.newaxis, np.newaxis]
return rois_blob.astype(np.float32, copy=False)
def _map_im_rois_to_feat_rois(im_rois, scales):
def _scale_im_rois(im_rois, scales):
im_rois = im_rois.astype(np.float, copy=False)
if len(scales) > 1:
......@@ -63,9 +63,9 @@ def _map_im_rois_to_feat_rois(im_rois, scales):
else:
levels = np.zeros((im_rois.shape[0], 1), dtype=np.int)
feat_rois = np.round(im_rois * scales[levels] / cfg.FEAT_STRIDE)
rois = im_rois * scales[levels]
return feat_rois, levels
return rois, levels
def _get_blobs(im, rois):
blobs = {'data' : None, 'rois' : None}
......@@ -123,12 +123,13 @@ def im_detect(net, im, boxes):
# (some distinct image ROIs get mapped to the same feature ROI).
# Here, we identify duplicate feature ROIs, so we only compute features
# on the unique subset.
v = np.array([1, 1e3, 1e6, 1e9, 1e12])
hashes = blobs['rois'][:, :, 0, 0].dot(v.T)
_, index, inv_index = np.unique(hashes, return_index=True,
return_inverse=True)
blobs['rois'] = blobs['rois'][index, :, :, :]
boxes = boxes[index, :]
if cfg.DEDUP_BOXES > 0:
v = np.array([1, 1e3, 1e6, 1e9, 1e12])
hashes = np.round(blobs['rois'][:, :, 0, 0] * cfg.DEDUP_BOXES).dot(v)
_, index, inv_index = np.unique(hashes, return_index=True,
return_inverse=True)
blobs['rois'] = blobs['rois'][index, :, :, :]
boxes = boxes[index, :]
# reshape network inputs
base_shape = blobs['data'].shape
......@@ -152,9 +153,10 @@ def im_detect(net, im, boxes):
pred_boxes = _bbox_pred(boxes, box_deltas)
pred_boxes = _clip_boxes(pred_boxes, im.shape)
# Map scores and predictions back to the original set of boxes
scores = scores[inv_index, :]
pred_boxes = pred_boxes[inv_index, :]
if cfg.DEDUP_BOXES > 0:
# Map scores and predictions back to the original set of boxes
scores = scores[inv_index, :]
pred_boxes = pred_boxes[inv_index, :]
return scores, pred_boxes
......@@ -213,8 +215,7 @@ def test_net(net, imdb):
# Output directory will be something like:
# output/vgg16_fast_rcnn_iter_40000/voc_2007_test/
output_dir = os.path.join(os.path.dirname(__file__), 'output',
net.name, imdb.name)
output_dir = os.path.join(cfg.ROOT_DIR, 'output', net.name, imdb.name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
......
......@@ -41,9 +41,9 @@ def get_minibatch(roidb):
= _sample_rois(roidb[im_i], fg_rois_per_image, rois_per_image)
# Add to ROIs blob
feat_rois = _map_im_rois_to_feat_rois(im_rois, im_scales[im_i])
batch_ind = im_i * np.ones((feat_rois.shape[0], 1))
rois_blob_this_image = np.hstack((batch_ind, feat_rois))
rois = _scale_im_rois(im_rois, im_scales[im_i])
batch_ind = im_i * np.ones((rois.shape[0], 1))
rois_blob_this_image = np.hstack((batch_ind, rois))
rois_blob = np.vstack((rois_blob, rois_blob_this_image))
# Add to labels, bbox targets, and bbox loss blobs
......@@ -131,12 +131,9 @@ def _get_image_blob(roidb, scale_inds):
return blob, im_scales
def _map_im_rois_to_feat_rois(im_rois, im_scale_factor):
"""
Map a ROI in image-pixel coordinates to a ROI in feature coordinates.
"""
feat_rois = np.round(im_rois * im_scale_factor / float(cfg.FEAT_STRIDE))
return feat_rois
def _scale_im_rois(im_rois, im_scale_factor):
rois = im_rois * im_scale_factor
return rois
def _get_bbox_regression_labels(bbox_target_data, num_classes):
"""
......@@ -164,7 +161,7 @@ def _vis_minibatch(im_blob, rois_blob, labels_blob, overlaps):
for i in xrange(rois_blob.shape[0]):
rois = rois_blob[i, :]
im_ind = rois[0]
roi = rois[1:] * cfg.FEAT_STRIDE
roi = rois[1:]
im = im_blob[im_ind, :, :, :].transpose((1, 2, 0)).copy()
im += cfg.PIXEL_MEANS
im = im[:, :, (2, 1, 0)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment