Skip to content
Snippets Groups Projects
Commit c68656df authored by Ross Girshick's avatar Ross Girshick
Browse files

switch to pure python nms

parent bd90015c
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ import numpy as np
import matplotlib.pyplot as plt
import cv2
import caffe
import utils.cython_nms
import utils.nms
import cPickle
import heapq
import utils.blob
......@@ -180,12 +180,13 @@ def _apply_nms(all_boxes, thresh):
dets = all_boxes[cls_ind][im_ind]
if dets == []:
continue
keep = utils.cython_nms.nms(dets, thresh)
keep = utils.nms.nms(dets, thresh)
if len(keep) == 0:
continue
nms_boxes[cls_ind][im_ind] = dets[keep, :].copy()
return nms_boxes
def test_net(net, imdb):
num_images = len(imdb.image_index)
# heuristic: keep an average of 40 detections per class per images prior
......
......@@ -8,10 +8,6 @@ ext_modules = [
Extension(
"utils.cython_bbox",
["utils/bbox.pyx"]
),
Extension(
"utils.cython_nms",
["utils/nms.pyx"]
)
]
cmdclass.update({'build_ext': build_ext})
......
"""
Fast code for dealing with image windows.
Written by Sergey Karayev. See bbox.pyx.license.txt.
Written by Sergey Karayev and used with permission.
"""
cimport cython
import numpy as np
......
Copyright (c) 2014, Sergey Karayev
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import numpy as np
def nms(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
scores = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
"""
Fast Cython code for NMS of detection bounding boxes.
Originally written by Kai Wang for the plex project:
https://github.com/shiaokai/plex
See nms.pyx.license.txt.
"""
import numpy as np
cimport numpy as np
cimport cython
DTYPE = np.float32
ctypedef np.float32_t DTYPE_t
cdef inline int int_max(int a, int b): return a if a >= b else b
cdef inline int int_min(int a, int b): return a if a <= b else b
cdef inline float float_max(float a, float b): return a if a >= b else b
cdef inline float float_min(float a, float b): return a if a <= b else b
def nms(np.ndarray[DTYPE_t, ndim=2] bbs, overlap_thr = 0.5):
"""
NMS detection boxes.
Parameters
----------
bbs: (N, 5) ndarray of float32
[xmin, ymin, xmax, ymax, score]
overlap_thr: float32
Returns
-------
keep_idxs: list of int
Indices to keep in the original bbs.
"""
if bbs.shape[0] == 0:
return np.zeros((0,5), dtype=DTYPE)
# sort bbs by score
cdef np.ndarray[np.int_t, ndim=1] sidx = np.argsort(bbs[:,4])
sidx = sidx[::-1]
bbs = bbs[sidx,:]
keep = [True] * bbs.shape[0]
cdef np.ndarray[DTYPE_t, ndim=1] bbs_start_x = bbs[:,0]
cdef np.ndarray[DTYPE_t, ndim=1] bbs_start_y = bbs[:,1]
cdef np.ndarray[DTYPE_t, ndim=1] bbs_end_x = bbs[:,2]
cdef np.ndarray[DTYPE_t, ndim=1] bbs_end_y = bbs[:,3]
cdef np.ndarray[DTYPE_t, ndim=1] bbs_areas = \
(bbs_end_y - bbs_start_y + 1) * (bbs_end_x - bbs_start_x + 1)
cdef int i, jj
cdef DTYPE_t intersect_width, intersect_height
cdef DTYPE_t intersect_area, union_area
cdef DTYPE_t overlap
cdef DTYPE_t bbs_end_x_i, bbs_end_y_i, bbs_areas_i
cdef DTYPE_t bbs_start_x_i, bbs_start_y_i
# start at highest scoring bb
for i in range(bbs.shape[0]):
if not(keep[i]):
continue
bbs_end_x_i = bbs_end_x[i]
bbs_end_y_i = bbs_end_y[i]
bbs_areas_i = bbs_areas[i]
bbs_start_x_i = bbs_start_x[i]
bbs_start_y_i = bbs_start_y[i]
for jj in range(i+1, bbs.shape[0]):
if not(keep[jj]):
continue
# mask out all worst scoring overlapping
intersect_width = float_min(bbs_end_x_i, bbs_end_x[jj]) - \
float_max(bbs_start_x_i, bbs_start_x[jj]) + 1
if intersect_width <= 0:
continue
intersect_height = float_min(bbs_end_y_i, bbs_end_y[jj]) - \
float_max(bbs_start_y_i, bbs_start_y[jj]) + 1
if intersect_width <= 0:
continue
intersect_area = intersect_width * intersect_height
union_area = bbs_areas_i + bbs_areas[jj] - intersect_area
overlap = intersect_area / union_area
# threshold and reject
if overlap > overlap_thr:
keep[jj] = False
# Return original detection indices
keep_idxs=[]
for i in range(len(keep)):
if keep[i]:
keep_idxs.append(sidx[i])
return keep_idxs
This software is Copyright (c) 2012 The Regents of the University of California. All
Rights Reserved.
Permission to use, copy, modify, and distribute this software and its documentation
for educational, research and non-profit purposes, without fee, and without a written
agreement is hereby granted, provided that the above copyright notice, this
paragraph and the following three paragraphs appear in all copies.
Permission to incorporate this software into commercial products may be obtained
by contacting:
Technology Transfer Office
9500 Gilman Drive, Mail Code 0910
University of California
La Jolla, CA 92093-0910
(858) 534-5815
invent@ucsd.edu
This software program and documentation are copyrighted by The Regents of the
University of California. The software program and documentation are supplied "as
is", without any accompanying services from The Regents. The Regents does not
warrant that the operation of the program will be uninterrupted or error-free. The
end-user understands that the program was developed for research purposes and is
advised not to rely exclusively on the program for any reason.
IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO
ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR
CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, ARISING
OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION,
EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF
THE POSSIBILITY OF SUCH DAMAGE. THE UNIVERSITY OF
CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND
THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO
PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
MODIFICATIONS.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment