Skip to content

Commit d59d522

Browse files
committed
Add training and evaluation code.
1 parent 451b47f commit d59d522

11 files changed

+1255
-1
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
*.pyc
22
.ipynb_checkpoints
33
test.ipynb
4+
run.sh
45
feature/
56
data/
67
checkpoints/

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2018 Yin Cui
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+20-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Notice that we used a mini validation set (./inat_minival.txt) contains 9,697 im
2020

2121

2222
## Preparation:
23-
+ Clone the repo with recursive
23+
+ Clone the repo with recursive:
2424
```bash
2525
git clone --recursive https://github.com/richardaecn/cvpr18-inaturalist-transfer.git
2626
```
@@ -94,6 +94,25 @@ DomainSimilarityDemo.ipynb
9494
```
9595

9696

97+
## Training and Evaluation
98+
+ Convert dataset into '.tfrecord':
99+
```
100+
python convert_dataset.py --dataset_name=cub_200 --num_shards=10
101+
```
102+
+ Train (fine-tune) the model on 1 GPU:
103+
```
104+
CUDA_VISIBLE_DEVICES=0 ./train.sh
105+
```
106+
+ Evaluate the model on another GPU simultaneously:
107+
```
108+
CUDA_VISIBLE_DEVICES=1 ./eval.sh
109+
```
110+
+ Run Tensorboard for visualization:
111+
```
112+
tensorboard --logdir=./checkpoints/cub_200/ --port=6006
113+
```
114+
115+
97116
## Citation
98117
If you find our work helpful in your research, please cite it as:
99118
```latex

convert_dataset.py

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""Converts data to TFRecords of TF-Example protos.
2+
python convert_dataset.py --dataset_name=cub_200 --num_shards=10
3+
"""
4+
5+
from __future__ import absolute_import
6+
from __future__ import division
7+
from __future__ import print_function
8+
9+
import math
10+
import os
11+
import random
12+
import sys
13+
14+
import tensorflow as tf
15+
16+
sys.path.insert(0, './slim/')
17+
from datasets import dataset_utils
18+
19+
FLAGS = tf.app.flags.FLAGS
20+
21+
tf.app.flags.DEFINE_string(
22+
'dataset_name',
23+
None,
24+
'The name of the dataset to convert, one of "ILSVRC2012", "inat2017", '
25+
'"aircraft", "cub_200", "flower_102", "food_101", "nabirds", '
26+
'"stanford_cars", "stanford_dogs"')
27+
28+
tf.app.flags.DEFINE_integer(
29+
'num_shards', 10, 'The number of shards per dataset split.')
30+
31+
32+
class ImageReader(object):
33+
"""Helper class that provides TensorFlow image coding utilities."""
34+
35+
def __init__(self):
36+
# Initializes function that decodes RGB JPEG data.
37+
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
38+
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
39+
40+
def read_image_dims(self, sess, image_data):
41+
image = self.decode_jpeg(sess, image_data)
42+
return image.shape[0], image.shape[1]
43+
44+
def decode_jpeg(self, sess, image_data):
45+
image = sess.run(self._decode_jpeg,
46+
feed_dict={self._decode_jpeg_data: image_data})
47+
assert len(image.shape) == 3
48+
assert image.shape[2] == 3
49+
return image
50+
51+
52+
def _get_filenames_and_labels(dataset_dir):
53+
train_filenames = []
54+
val_filenames = []
55+
train_labels = []
56+
val_labels = []
57+
for line in open(os.path.join(dataset_dir, 'train.txt'), 'r'):
58+
line_list = line.strip().split(': ')
59+
train_filenames.append(os.path.join(dataset_dir, line_list[0]))
60+
train_labels.append(int(line_list[1]))
61+
for line in open(os.path.join(dataset_dir, 'val.txt'), 'r'):
62+
line_list = line.strip().split(': ')
63+
val_filenames.append(os.path.join(dataset_dir, line_list[0]))
64+
val_labels.append(int(line_list[1]))
65+
return train_filenames, val_filenames, train_labels, val_labels
66+
67+
68+
def _get_dataset_filename(dataset_dir, split_name, shard_id):
69+
output_filename = '%s_%05d-of-%05d.tfrecord' % (
70+
split_name, shard_id, FLAGS.num_shards)
71+
return os.path.join(dataset_dir, output_filename)
72+
73+
74+
def _convert_dataset(split_name, filenames, labels, dataset_dir):
75+
"""Converts the given filenames to a TFRecord dataset.
76+
77+
Args:
78+
split_name: The name of the dataset, either 'train' or 'validation'.
79+
filenames: A list of absolute paths to png or jpg images.
80+
labels: A list of class ids (integers start with 0).
81+
dataset_dir: The directory where the converted datasets are stored.
82+
"""
83+
assert split_name in ['train', 'validation']
84+
85+
num_per_shard = int(math.ceil(len(filenames) / float(FLAGS.num_shards)))
86+
87+
with tf.Graph().as_default():
88+
image_reader = ImageReader()
89+
90+
with tf.Session('') as sess:
91+
92+
for shard_id in range(FLAGS.num_shards):
93+
output_filename = _get_dataset_filename(
94+
dataset_dir, split_name, shard_id)
95+
96+
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
97+
start_ndx = shard_id * num_per_shard
98+
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
99+
for i in range(start_ndx, end_ndx):
100+
sys.stdout.write('\r>> Converting %s image %d/%d shard %d' % (
101+
split_name, i+1, len(filenames), shard_id))
102+
sys.stdout.flush()
103+
104+
# Read the filename and label:
105+
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
106+
height, width = image_reader.read_image_dims(sess, image_data)
107+
class_id = labels[i]
108+
109+
example = dataset_utils.image_to_tfexample(
110+
image_data, b'jpg', height, width, class_id)
111+
tfrecord_writer.write(example.SerializeToString())
112+
113+
sys.stdout.write('\n')
114+
sys.stdout.flush()
115+
116+
117+
def _dataset_exists(dataset_dir):
118+
for split_name in ['train', 'validation']:
119+
for shard_id in range(FLAGS.num_shards):
120+
output_filename = _get_dataset_filename(
121+
dataset_dir, split_name, shard_id)
122+
if not tf.gfile.Exists(output_filename):
123+
return False
124+
return True
125+
126+
127+
def run(dataset_dir):
128+
"""Runs the conversion operation.
129+
130+
Args:
131+
dataset_dir: The dataset directory where the dataset is stored.
132+
"""
133+
if not tf.gfile.Exists(dataset_dir):
134+
tf.gfile.MakeDirs(dataset_dir)
135+
136+
if _dataset_exists(dataset_dir):
137+
print('Dataset files already exist. Exiting without re-creating them.')
138+
return
139+
140+
train_filenames, val_filenames, train_labels, val_labels = \
141+
_get_filenames_and_labels(dataset_dir)
142+
143+
train_idx = list(zip(train_filenames, train_labels))
144+
random.shuffle(train_idx)
145+
train_filenames, train_labels = zip(*train_idx)
146+
147+
_convert_dataset('train', train_filenames, train_labels, dataset_dir)
148+
_convert_dataset('validation', val_filenames, val_labels, dataset_dir)
149+
150+
print('\nFinished converting the dataset!')
151+
152+
153+
def main(_):
154+
if not FLAGS.dataset_name:
155+
raise ValueError('You must supply the dataset name with --dataset_name')
156+
157+
run(os.path.join('./data', FLAGS.dataset_name))
158+
159+
160+
if __name__ == '__main__':
161+
tf.app.run()

0 commit comments

Comments
 (0)