From 04e5e1b8c18d07ea8646205284402d6c27a780ff Mon Sep 17 00:00:00 2001 From: Ross Girshick <ross.girshick@gmail.com> Date: Thu, 2 Apr 2015 11:12:49 -0700 Subject: [PATCH] change where nets and results are output; add EXP_DIR to config --- src/fast_rcnn_config.py | 16 +++++++++++++--- src/fast_rcnn_test.py | 6 ++---- src/fast_rcnn_train.py | 16 ++++++++++++---- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/fast_rcnn_config.py b/src/fast_rcnn_config.py index c4e03ad..4a21c02 100644 --- a/src/fast_rcnn_config.py +++ b/src/fast_rcnn_config.py @@ -122,7 +122,17 @@ __C.EPS = 1e-14 # Root directory of project __C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..')) -def merge_a_into_b(a, b): +# Place outputs under an experiments directory +__C.EXP_DIR = 'default' + +def get_output_path(imdb, net): + path = os.path.join(__C.ROOT_DIR, 'output', __C.EXP_DIR, imdb.name) + if net is None: + return path + else: + return os.path.join(path, net.name) + +def _merge_a_into_b(a, b): """ Merge config dictionary a into config dictionary b, clobbering the options in b whenever they are also specified in a. @@ -142,7 +152,7 @@ def merge_a_into_b(a, b): # recursively merge dicts if type(v) is edict: try: - merge_a_into_b(a[k], b[k]) + _merge_a_into_b(a[k], b[k]) except: print('Error under config key: {}'.format(k)) raise @@ -159,4 +169,4 @@ def cfg_from_file(filename): with open(filename, 'r') as f: yaml_cfg = edict(yaml.load(f)) - merge_a_into_b(yaml_cfg, __C) + _merge_a_into_b(yaml_cfg, __C) diff --git a/src/fast_rcnn_test.py b/src/fast_rcnn_test.py index 85ddd7c..f995b1c 100644 --- a/src/fast_rcnn_test.py +++ b/src/fast_rcnn_test.py @@ -5,7 +5,7 @@ # Written by Ross Girshick # -------------------------------------------------------- -from fast_rcnn_config import cfg +from fast_rcnn_config import cfg, get_output_path import argparse from utils.timer import Timer import numpy as np @@ -213,9 +213,7 @@ def test_net(net, imdb): all_boxes = [[[] for _ in xrange(num_images)] for _ in xrange(imdb.num_classes)] - # Output directory will be something like: - # output/vgg16_fast_rcnn_iter_40000/voc_2007_test/ - output_dir = os.path.join(cfg.ROOT_DIR, 'output', net.name, imdb.name) + output_dir = get_output_path(imdb, net) if not os.path.exists(output_dir): os.makedirs(output_dir) diff --git a/src/fast_rcnn_train.py b/src/fast_rcnn_train.py index 5dc494d..622657a 100644 --- a/src/fast_rcnn_train.py +++ b/src/fast_rcnn_train.py @@ -5,20 +5,22 @@ # Written by Ross Girshick # -------------------------------------------------------- -from fast_rcnn_config import cfg +from fast_rcnn_config import cfg, get_output_path import numpy as np import cv2 import caffe import finetuning import bbox_regression_targets +import os from caffe.proto import caffe_pb2 import google.protobuf as pb2 class SolverWrapper(object): - def __init__(self, solver_prototxt, pretrained_model=None): + def __init__(self, solver_prototxt, imdb, pretrained_model=None): self.bbox_means = None self.bbox_stds = None + self.imdb = imdb self.solver = caffe.SGDSolver(solver_prototxt) if pretrained_model is not None: @@ -47,10 +49,16 @@ class SolverWrapper(object): self.solver.net.params['bbox_pred'][1].data[...] = \ self.solver.net.params['bbox_pred'][1].data + means + output_dir = get_output_path(self.imdb, None) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX if cfg.TRAIN.SNAPSHOT_INFIX != '' else '') filename = self.solver_param.snapshot_prefix + infix + \ - '_iter_{:d}'.format(self.solver.iter) + '.caffemodel' + '_iter_{:d}'.format(self.solver.iter) + '.caffemodel' + filename = os.path.join(output_dir, filename) + self.solver.net.save(str(filename)) print 'Wrote snapshot to: {:s}'.format(filename) @@ -169,7 +177,7 @@ def train_net(solver_prototxt, imdb, pretrained_model=None, max_iters=40000): bbox_regression_targets.append_bbox_regression_targets(roidb) print 'done' - sw = SolverWrapper(solver_prototxt, pretrained_model=pretrained_model) + sw = SolverWrapper(solver_prototxt, imdb, pretrained_model=pretrained_model) sw.bbox_means = means sw.bbox_stds = stds -- GitLab