Skip to content
Snippets Groups Projects
Commit 04e5e1b8 authored by Ross Girshick's avatar Ross Girshick
Browse files

change where nets and results are output; add EXP_DIR to config

parent c4199cb7
No related branches found
No related tags found
No related merge requests found
......@@ -122,7 +122,17 @@ __C.EPS = 1e-14
# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..'))
def merge_a_into_b(a, b):
# Place outputs under an experiments directory
__C.EXP_DIR = 'default'
def get_output_path(imdb, net):
path = os.path.join(__C.ROOT_DIR, 'output', __C.EXP_DIR, imdb.name)
if net is None:
return path
else:
return os.path.join(path, net.name)
def _merge_a_into_b(a, b):
"""
Merge config dictionary a into config dictionary b, clobbering the options
in b whenever they are also specified in a.
......@@ -142,7 +152,7 @@ def merge_a_into_b(a, b):
# recursively merge dicts
if type(v) is edict:
try:
merge_a_into_b(a[k], b[k])
_merge_a_into_b(a[k], b[k])
except:
print('Error under config key: {}'.format(k))
raise
......@@ -159,4 +169,4 @@ def cfg_from_file(filename):
with open(filename, 'r') as f:
yaml_cfg = edict(yaml.load(f))
merge_a_into_b(yaml_cfg, __C)
_merge_a_into_b(yaml_cfg, __C)
......@@ -5,7 +5,7 @@
# Written by Ross Girshick
# --------------------------------------------------------
from fast_rcnn_config import cfg
from fast_rcnn_config import cfg, get_output_path
import argparse
from utils.timer import Timer
import numpy as np
......@@ -213,9 +213,7 @@ def test_net(net, imdb):
all_boxes = [[[] for _ in xrange(num_images)]
for _ in xrange(imdb.num_classes)]
# Output directory will be something like:
# output/vgg16_fast_rcnn_iter_40000/voc_2007_test/
output_dir = os.path.join(cfg.ROOT_DIR, 'output', net.name, imdb.name)
output_dir = get_output_path(imdb, net)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
......
......@@ -5,20 +5,22 @@
# Written by Ross Girshick
# --------------------------------------------------------
from fast_rcnn_config import cfg
from fast_rcnn_config import cfg, get_output_path
import numpy as np
import cv2
import caffe
import finetuning
import bbox_regression_targets
import os
from caffe.proto import caffe_pb2
import google.protobuf as pb2
class SolverWrapper(object):
def __init__(self, solver_prototxt, pretrained_model=None):
def __init__(self, solver_prototxt, imdb, pretrained_model=None):
self.bbox_means = None
self.bbox_stds = None
self.imdb = imdb
self.solver = caffe.SGDSolver(solver_prototxt)
if pretrained_model is not None:
......@@ -47,10 +49,16 @@ class SolverWrapper(object):
self.solver.net.params['bbox_pred'][1].data[...] = \
self.solver.net.params['bbox_pred'][1].data + means
output_dir = get_output_path(self.imdb, None)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
filename = self.solver_param.snapshot_prefix + infix + \
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel'
'_iter_{:d}'.format(self.solver.iter) + '.caffemodel'
filename = os.path.join(output_dir, filename)
self.solver.net.save(str(filename))
print 'Wrote snapshot to: {:s}'.format(filename)
......@@ -169,7 +177,7 @@ def train_net(solver_prototxt, imdb, pretrained_model=None, max_iters=40000):
bbox_regression_targets.append_bbox_regression_targets(roidb)
print 'done'
sw = SolverWrapper(solver_prototxt, pretrained_model=pretrained_model)
sw = SolverWrapper(solver_prototxt, imdb, pretrained_model=pretrained_model)
sw.bbox_means = means
sw.bbox_stds = stds
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment