1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
| from __future__ import absolute_import from __future__ import division from __future__ import print_function
import hashlib import io import logging import os
from lxml import etree import PIL.Image import tensorflow as tf
from object_detection.utils import dataset_util from object_detection.utils import label_map_util
flags = tf.app.flags flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.') flags.DEFINE_string('set', 'train', 'Convert training set, validation set or ' 'merged set.') flags.DEFINE_string('annotations_dir', 'Annotations', '(Relative) path to annotations directory.') flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.') flags.DEFINE_string('output_path', '', 'Path to output TFRecord') flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt', 'Path to label map proto') flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore ' 'difficult instances') FLAGS = flags.FLAGS
SETS = ['train', 'val', 'trainval', 'test'] YEARS = ['VOC2007', 'VOC2012', 'merged']
def dict_to_tf_example(data, dataset_directory, label_map_dict, ignore_difficult_instances=False, image_subdirectory='JPEGImages'): img_path = os.path.join(data['folder'], image_subdirectory, data['filename']) full_path = os.path.join(dataset_directory, img_path) with tf.gfile.GFile(full_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = PIL.Image.open(encoded_jpg_io) if image.format != 'JPEG': raise ValueError('Image format not JPEG') key = hashlib.sha256(encoded_jpg).hexdigest()
width = int(data['size']['width']) height = int(data['size']['height'])
xmin = [] ymin = [] xmax = [] ymax = [] classes = [] classes_text = [] truncated = [] poses = [] difficult_obj = [] for obj in data['object']: difficult = bool(int(obj['difficult'])) if ignore_difficult_instances and difficult: continue
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width) ymin.append(float(obj['bndbox']['ymin']) / height) xmax.append(float(obj['bndbox']['xmax']) / width) ymax.append(float(obj['bndbox']['ymax']) / height) classes_text.append(obj['name'].encode('utf8')) classes.append(label_map_dict[obj['name']]) truncated.append(int(obj['truncated'])) poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature( data['filename'].encode('utf8')), 'image/source_id': dataset_util.bytes_feature( data['filename'].encode('utf8')), 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmin), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmax), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymin), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymax), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), 'image/object/difficult': dataset_util.int64_list_feature(difficult_obj), 'image/object/truncated': dataset_util.int64_list_feature(truncated), 'image/object/view': dataset_util.bytes_list_feature(poses), })) return example
def main(_): if FLAGS.set not in SETS: raise ValueError('set must be in : {}'.format(SETS)) if FLAGS.year not in YEARS: raise ValueError('year must be in : {}'.format(YEARS))
data_dir = FLAGS.data_dir years = ['VOC2007', 'VOC2012'] if FLAGS.year != 'merged': years = [FLAGS.year]
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
for year in years: logging.info('Reading from PASCAL %s dataset.', year) examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main', 'aeroplane_' + FLAGS.set + '.txt') annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir) examples_list = dataset_util.read_examples_list(examples_path) for idx, example in enumerate(examples_list): if idx % 100 == 0: logging.info('On image %d of %d', idx, len(examples_list)) path = os.path.join(annotations_dir, example + '.xml') with tf.gfile.GFile(path, 'r') as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict, FLAGS.ignore_difficult_instances) writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__': tf.app.run()
|