Skip to content
Snippets Groups Projects
Commit 3471eab6 authored by Ross Girshick's avatar Ross Girshick
Browse files

good working state

parent eafc81ab
No related branches found
No related tags found
No related merge requests found
*.pyc
import numpy as np
# Scales used in the SPP-net paper
SCALES = (480, 576, 688, 864, 1200)
# Minibatch size
BATCH_SIZE = 128
# Fraction of minibatch that is foreground labeled (class > 0)
FG_FRACTION = 0.25
# Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH)
FG_THRESH = 0.5
# Overlap threshold for a ROI to be considered background (class = 0 if
# overlap in [0.1, 0.5))
BG_THRESH_HI = 0.5
BG_THRESH_LO = 0.1
# Pixel mean values (BGR order) as a (1, 1, 3) array
PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]])
# Stride in input image pixels at ROI pooling level
FEAT_STRIDE = 16
# Max pixel size of a scaled input image
MAX_SIZE = 2000
import numpy as np
import cv2
import matplotlib.pyplot as plt
import fast_rcnn_config as conf
from keyboard import keyboard
SCALES = (480, 576, 688, 864, 1200)
BATCH_SIZE = 128
FG_FRACTION = 0.25
FG_THRESH = 0.5
BG_THRESH_HI = 0.5
BG_THRESH_LO = 0.1
PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]])
FEAT_STRIDE = 16
MAX_SIZE = 2000
def sample_rois(labels, overlaps, rois, fg_rois_per_image, rois_per_image):
"""Generate a random sample of ROIs comprising foreground and background
examples.
......@@ -30,13 +21,13 @@ def sample_rois(labels, overlaps, rois, fg_rois_per_image, rois_per_image):
rois (2d np array)
"""
# Select foreground ROIs
fg_inds = np.where(overlaps >= FG_THRESH)[0]
fg_inds = np.where(overlaps >= conf.FG_THRESH)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_image, fg_inds.size)
fg_inds = np.random.choice(fg_inds, size=fg_rois_per_this_image,
replace=False)
# Select background ROIs
bg_inds = np.where((overlaps < BG_THRESH_HI) &
(overlaps >= BG_THRESH_LO))[0]
bg_inds = np.where((overlaps < conf.BG_THRESH_HI) &
(overlaps >= conf.BG_THRESH_LO))[0]
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.size)
......@@ -59,18 +50,18 @@ def get_image_blob(window_db, scale_inds, do_flip):
im_scale_factors = []
for i in xrange(num_images):
im = cv2.imread(window_db[i]['image'])
# if do_flip:
# im = im[:, ::-1, :]
if do_flip:
im = im[:, ::-1, :]
im = im.astype(np.float32, copy=False)
im -= PIXEL_MEANS
im -= conf.PIXEL_MEANS
im_shape = im.shape
im_size = np.min(im_shape[0:2])
im_size_big = np.max(im_shape[0:2])
target_size = SCALES[scale_inds[i]]
target_size = conf.SCALES[scale_inds[i]]
im_scale = float(target_size) / float(im_size)
# Prevent the biggest axis from being more than MAX_SIZE
if np.round(im_scale * im_size_big) > MAX_SIZE:
im_scale = float(MAX_SIZE) / float(im_size_big)
if np.round(im_scale * im_size_big) > conf.MAX_SIZE:
im_scale = float(conf.MAX_SIZE) / float(im_size_big)
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale,
interpolation=cv2.INTER_LINEAR)
im_scale_factors.append(im_scale)
......@@ -78,31 +69,34 @@ def get_image_blob(window_db, scale_inds, do_flip):
max_shape = np.maximum(max_shape, im.shape)
blob = np.zeros((num_images, max_shape[0],
max_shape[1], max_shape[2]))
max_shape[1], max_shape[2]), dtype=np.float32)
for i in xrange(num_images):
im = processed_ims[i]
blob[i, 0:im.shape[0], 0:im.shape[1], :] = im
# Move channels (axis 3) to axis 1
# Axis order will become: (batch elem, channel, height, width)
channel_swap = (0, 3, 1, 2)
blob = blob.transpose(channel_swap)
return blob, im_scale_factors
def map_im_rois_to_feat_rois(im_rois, im_scale_factor):
feat_rois = np.round(im_rois * im_scale_factor / FEAT_STRIDE)
feat_rois = np.round(im_rois * im_scale_factor / conf.FEAT_STRIDE)
return feat_rois
def get_minibatch(window_db, random_flip=True):
def get_minibatch(window_db, random_flip=False):
# Decide to flip the entire batch or not
# do_flip = False if not random_flip else bool(np.random.randint(0, high=2))
do_flip = False if not random_flip else bool(np.random.randint(0, high=2))
assert(not do_flip)
num_images = len(window_db)
# Sample random scales to use for each image in this batch
random_scale_inds = np.random.randint(0, high=len(SCALES), size=num_images)
assert(BATCH_SIZE % num_images == 0), 'num_images must divide BATCH_SIZE'
rois_per_image = BATCH_SIZE / num_images
fg_rois_per_image = np.round(FG_FRACTION * rois_per_image)
random_scale_inds = \
np.random.randint(0, high=len(conf.SCALES), size=num_images)
assert(conf.BATCH_SIZE % num_images == 0), \
'num_images must divide BATCH_SIZE'
rois_per_image = conf.BATCH_SIZE / num_images
fg_rois_per_image = np.round(conf.FG_FRACTION * rois_per_image)
# Get the input blob, formatted for caffe
# Takes care of random scaling and flipping
do_flip = False
im_blob, im_scale_factors = get_image_blob(window_db,
random_scale_inds, do_flip)
# Now, build the region of interest and label blobs
......@@ -110,16 +104,23 @@ def get_minibatch(window_db, random_flip=True):
labels_blob = np.zeros((0), dtype=np.float32)
all_overlaps = []
for im_i in xrange(num_images):
# (labels, overlaps, x1, y1, x2, y2)
labels = window_db[im_i]['windows'][:, 0]
overlaps = window_db[im_i]['windows'][:, 1]
im_rois = window_db[im_i]['windows'][:, 2:]
# if do_flip:
# im_rois[:, (0, 2)] = window_db[im_i]['width'] - \
# im_rois[:, (2, 0)] - 1
if do_flip:
im_rois[:, (0, 2)] = window_db[im_i]['width'] - \
im_rois[:, (2, 0)] - 1
labels, overlaps, im_rois = sample_rois(labels, overlaps, im_rois,
fg_rois_per_image,
rois_per_image)
feat_rois = map_im_rois_to_feat_rois(im_rois, im_scale_factors[im_i])
# Assert various bounds
assert((feat_rois[:, 2] >= feat_rois[:, 0]).all())
assert((feat_rois[:, 3] >= feat_rois[:, 1]).all())
assert((feat_rois >= 0).all())
assert((feat_rois < np.max(im_blob.shape[2:4]) *
im_scale_factors[im_i] / conf.FEAT_STRIDE).all())
rois_blob_this_image = \
np.append(im_i * np.ones((feat_rois.shape[0], 1)), feat_rois,
axis=1)
......@@ -134,9 +135,9 @@ def vis_minibatch(im_blob, rois_blob, labels_blob, overlaps):
for i in xrange(rois_blob.shape[0]):
rois = rois_blob[i, :]
im_ind = rois[0]
roi = rois[1:] * FEAT_STRIDE
roi = rois[1:] * conf.FEAT_STRIDE
im = im_blob[im_ind, :, :, :].transpose((1, 2, 0)).copy()
im += PIXEL_MEANS
im += conf.PIXEL_MEANS
im = im[:, :, (2, 1, 0)]
im = im.astype(np.uint8)
cls = labels_blob[i]
......
name: "CaffeNet"
input: "data"
input_dim: 7
input_dim: 3
input_dim: 1713
input_dim: 1713
input: "rois"
input_dim: 1 # to be changed on-the-fly
input_dim: 5 # [level, x1, y1, x2, y2] zero-based indexing
input_dim: 1
input_dim: 1
input: "labels"
input_dim: 1 # to be changed on-the-fly
input_dim: 1
input_dim: 1
input_dim: 1
layers {
name: "conv1"
type: CONVOLUTION
bottom: "data"
top: "conv1"
convolution_param {
num_output: 96
kernel_size: 11
stride: 4
pad: 5
}
# Learning parameters
blobs_lr: 0
blobs_lr: 0
weight_decay: 0
weight_decay: 0
}
layers {
name: "relu1"
type: RELU
bottom: "conv1"
top: "conv1"
}
layers {
name: "pool1"
type: POOLING
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 3
stride: 2
pad: 1
}
}
layers {
name: "norm1"
type: LRN
bottom: "pool1"
top: "norm1"
lrn_param {
local_size: 5
alpha: 0.0001
beta: 0.75
}
}
layers {
name: "conv2"
type: CONVOLUTION
bottom: "norm1"
top: "conv2"
convolution_param {
num_output: 256
kernel_size: 5
pad: 2
group: 2
}
# Learning parameters
blobs_lr: 1
blobs_lr: 2
weight_decay: 1
weight_decay: 0
}
layers {
name: "relu2"
type: RELU
bottom: "conv2"
top: "conv2"
}
layers {
name: "pool2"
type: POOLING
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 3
stride: 2
pad: 1
}
}
layers {
name: "norm2"
type: LRN
bottom: "pool2"
top: "norm2"
lrn_param {
local_size: 5
alpha: 0.0001
beta: 0.75
}
}
layers {
name: "conv3"
type: CONVOLUTION
bottom: "norm2"
top: "conv3"
convolution_param {
num_output: 384
kernel_size: 3
pad: 1
}
# Learning parameters
blobs_lr: 1
blobs_lr: 2
weight_decay: 1
weight_decay: 0
}
layers {
name: "relu3"
type: RELU
bottom: "conv3"
top: "conv3"
}
layers {
name: "conv4"
type: CONVOLUTION
bottom: "conv3"
top: "conv4"
convolution_param {
num_output: 384
kernel_size: 3
pad: 1
group: 2
}
# Learning parameters
blobs_lr: 1
blobs_lr: 2
weight_decay: 1
weight_decay: 0
}
layers {
name: "relu4"
type: RELU
bottom: "conv4"
top: "conv4"
}
layers {
name: "conv5"
type: CONVOLUTION
bottom: "conv4"
top: "conv5"
convolution_param {
num_output: 256
kernel_size: 3
pad: 1
group: 2
}
# Learning parameters
blobs_lr: 1
blobs_lr: 2
weight_decay: 1
weight_decay: 0
}
layers {
name: "relu5"
type: RELU
bottom: "conv5"
top: "conv5"
}
layers {
name: "roi_pool5"
type: ROI_POOLING
bottom: "conv5"
bottom: "rois"
top: "pool5"
roi_pooling_param {
pooled_w: 6
pooled_h: 6
}
}
layers {
name: "fc6"
type: INNER_PRODUCT
bottom: "pool5"
top: "fc6"
inner_product_param {
num_output: 4096
}
# Learning parameters
blobs_lr: 1
blobs_lr: 2
weight_decay: 1
weight_decay: 0
}
layers {
name: "relu6"
type: RELU
bottom: "fc6"
top: "fc6"
}
layers {
name: "drop6"
type: DROPOUT
bottom: "fc6"
top: "fc6"
dropout_param {
dropout_ratio: 0.5
}
}
layers {
name: "fc7"
type: INNER_PRODUCT
bottom: "fc6"
top: "fc7"
inner_product_param {
num_output: 4096
}
# Learning parameters
blobs_lr: 1
blobs_lr: 2
weight_decay: 1
weight_decay: 0
}
layers {
name: "relu7"
type: RELU
bottom: "fc7"
top: "fc7"
}
layers {
name: "drop7"
type: DROPOUT
bottom: "fc7"
top: "fc7"
dropout_param {
dropout_ratio: 0.5
}
}
layers {
name: "fc8_pascal"
type: INNER_PRODUCT
bottom: "fc7"
top: "fc8_pascal"
inner_product_param {
num_output: 21
weight_filler {
type: "gaussian"
std: 0.01
}
bias_filler {
type: "constant"
value: 0
}
}
# Learning parameters
blobs_lr: 1
blobs_lr: 2
weight_decay: 1
weight_decay: 0
}
layers {
name: "loss"
type: SOFTMAX_LOSS
bottom: "fc8_pascal"
bottom: "labels"
}
train_net: "model-defs/pyramid_fcs_only.prototxt"
train_net: "model-defs/pyramid.prototxt"
#test_iter: 100
#test_interval: 1000
base_lr: 0.001
......@@ -11,4 +11,4 @@ max_iter: 500000
momentum: 0.9
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "snapshots/pyramid_finetune_fcs_only"
snapshot_prefix: "snapshots/pyramid_finetune"
......@@ -133,6 +133,7 @@ def train_model(solver_def_path, window_db_path, pretrained_model=None,
def train_model_random_scales(solver_def_path, window_db_path,
pretrained_model=None, GPU_ID=None):
IMAGES_PER_BATCH = 4
solver, window_db = \
load_solver_and_window_db(solver_def_path,
window_db_path,
......@@ -142,17 +143,15 @@ def train_model_random_scales(solver_def_path, window_db_path,
if GPU_ID is not None:
caffe.set_device(GPU_ID)
# TODO(rbg): fix this temp hack
window_db = window_db[0:5008]
max_epochs = 100
dp = DeepPyramid(solver.net)
for epoch in xrange(max_epochs):
# TODO(rbg): shuffle window_db
for db_i in xrange(0, len(window_db), 4):
shuffled_inds = np.random.permutation(np.arange(len(window_db)))
lim = (len(shuffled_inds) / IMAGES_PER_BATCH) * IMAGES_PER_BATCH
shuffled_inds = shuffled_inds[0:lim]
for shuffled_i in xrange(0, len(shuffled_inds), 4):
start_t = time.time()
minibatch_db = window_db[db_i:db_i + 4]
db_inds = shuffled_inds[shuffled_i:shuffled_i + 4]
minibatch_db = [window_db[i] for i in db_inds]
im_blob, rois_blob, labels_blob = \
finetuning.get_minibatch(minibatch_db)
......@@ -164,11 +163,14 @@ def train_model_random_scales(solver_def_path, window_db_path,
solver.net.blobs['rois'].reshape(num_rois, 5, 1, 1)
solver.net.blobs['labels'].reshape(num_rois, 1, 1, 1)
# Copy data into net's input blobs
solver.net.blobs['data'].data[...] = im_blob
solver.net.blobs['data'].data[...] = \
im_blob.astype(np.float32, copy=False)
solver.net.blobs['rois'].data[...] = \
rois_blob[:, :, np.newaxis, np.newaxis]
rois_blob[:, :, np.newaxis, np.newaxis] \
.astype(np.float32, copy=False)
solver.net.blobs['labels'].data[...] = \
labels_blob[:, np.newaxis, np.newaxis, np.newaxis]
labels_blob[:, np.newaxis, np.newaxis, np.newaxis] \
.astype(np.float32, copy=False)
# print 'epoch {:d} image {:d}'.format(epoch, db_i)
# print_label_stats(labels_blob)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment