Generate tfrecords format data and use the dataset API to use tfrecords data

TFRecords is a built-in file format designed in TensorFlow. It is a binary file with the following advantages:

  • A framework for unifying different input files
  • It is better memory utilization and easier to copy and move (TFRecord compressed binary files, protocal buffer serialization)
  • is used to store binary data and label (trained category labels) data in the same file

1. When storing other data as TFRecords files, you need to go through two steps:

Create TFRecord storage

Use the following statement in tensorflow to resume tfrecord memory:

tf.python_io.TFRecordWriter(path)

path: path to the created TFRecords file

method:

  • write(record): Write a string record (i.e. a sample) to the file
  • close() : Close the file writer after all files have been written.

Note: The string here is a serialized Example, which is implemented through Example.SerializeToString(). Its function is to compress the map in the Example into binary, saving a lot of space.

Construct the Example module of each sample

The Example module is defined as follows:

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature> feature = 1;
};

message Feature {
  one of kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

As you can see, Example can include data in three formats: tf.int64, tf.float32 and binary types.

Features are saved in the form of key-value pairs. The sample code is as follows:

example = tf.train.Example(
            features=tf.train.Features(feature={
                "label": tf.train.Feature(float_list=tf.train.FloatList(value=[string[1]])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                'x1_offset':tf.train.Feature(float_list=tf.train.FloatList(value=[string[2]])),
                'y1_offset': tf.train.Feature(float_list=tf.train.FloatList(value=[string[3]])),
                'x2_offset': tf.train.Feature(float_list=tf.train.FloatList(value=[string[4]])),
                'y2_offset': tf.train.Feature(float_list=tf.train.FloatList(value=[string[5]])),
                'beta_det':tf.train.Feature(float_list=tf.train.FloatList(value=[string[6]])),
                'beta_bbox':tf.train.Feature(float_list=tf.train.FloatList(value=[string[7]]))
            }))

After constructing the Example module, we can write the sample to the file:

writer.write(example.SerializeToString())

Don’t forget to close the file writer after all files have been written.

2. After creating our own tfrecords file, we can use it during training. tensorflow provides us with the Dataset API to conveniently use tfrecords files.

First, we need to define a function that parses tfrecords, which is used to parse binary files into tensors. The sample code is as follows:

def pares_tf(example_proto):
    #Define the parsed dictionary
    dics = {
        'label': tf.FixedLenFeature([], tf.float32),
        'img_raw': tf.FixedLenFeature([], tf.string),
        'x1_offset': tf.FixedLenFeature([], tf.float32),
        'y1_offset': tf.FixedLenFeature([], tf.float32),
        'x2_offset': tf.FixedLenFeature([], tf.float32),
        'y2_offset': tf.FixedLenFeature([], tf.float32),
        'beta_det': tf.FixedLenFeature([], tf.float32),
        'beta_bbox': tf.FixedLenFeature([], tf.float32)}
    #Call the interface to parse a line of samples
    parsed_example = tf.parse_single_example(serialized=example_proto,features=dics)
    image = tf.decode_raw(parsed_example['img_raw'],out_type=tf.uint8)
    image = tf.reshape(image,shape=[12,12,3])
    #Normalize the image data here
    image = (tf.cast(image,tf.float32)/255.0)
    label = parsed_example['label']
    label=tf.reshape(label,shape=[1])
    label = tf.cast(label,tf.float32)
    x1_offset=parsed_example['x1_offset']
    x1_offset = tf.reshape(x1_offset, shape=[1])
    y1_offset=parsed_example['y1_offset']
    y1_offset = tf.reshape(y1_offset, shape=[1])
    x2_offset=parsed_example['x2_offset']
    x2_offset = tf.reshape(x2_offset, shape=[1])
    y2_offset=parsed_example['y2_offset']
    y2_offset = tf.reshape(y2_offset, shape=[1])
    beta_det=parsed_example['beta_det']
    beta_det=tf.reshape(beta_det,shape=[1])
    beta_bbox=parsed_example['beta_bbox']
    beta_bbox=tf.reshape(beta_bbox,shape=[1])

    return image,label,x1_offset,y1_offset,x2_offset,y2_offset,beta_det,beta_bbox

Next, we need to use tf.data.TFRecordDataset(filenames) to read in the tfrecords file.

A Dataset becomes a new Dataset through Transformation. Usually we can use Transformation to complete a series of operations such as data transformation, shuffling, forming batches, and generating epochs.

Commonly used Transformations include map, batch, shuffle, and repeat.

map:

 map receives a function, each element in the Dataset will be used as the input of this function, and the return value of the function will be used as the new Dataset

batch:

Batch is to combine multiple elements into a batch

repeat:

The function of repeat is to repeat the entire sequence multiple times. It is mainly used to process epochs in machine learning. Assuming that the original data is one epoch, use repeat(5) to turn it into 5 epochs.

shuffle:

The function of shuffle is to shuffle the elements in the dataset. It has a parameter buffersize, which indicates the size used when shuffling.

Sample code:

dataset = tf.data.TFRecordDataset(filenames=[filename])
dataset = dataset.map(pares_tf)
dataset = dataset.batch(16).repeat(1)#The entire sequence is only used once, and 16 samples are used to form a batch each time

Now that this batch of samples is ready, how to take it out for training? The answer is to use iterators. The statement in tensorflow is as follows:

iterator = dataset.make_one_shot_iterator()

The so-called one_shot means that it can only be read once from beginning to end. So how to take out different samples in each training round? The iterator’s get_netxt() method can achieve this. It should be noted that what is obtained by using get_next() here is only a tensor, not a specific value. If we want to use this value during training, we need to obtain it in the session.

The complete code for reading tfrecords files using dataset is as follows:

def pares_tf(example_proto):
    #Define the parsed dictionary
    dics = {
        'label': tf.FixedLenFeature([], tf.float32),
        'img_raw': tf.FixedLenFeature([], tf.string),
        'x1_offset': tf.FixedLenFeature([], tf.float32),
        'y1_offset': tf.FixedLenFeature([], tf.float32),
        'x2_offset': tf.FixedLenFeature([], tf.float32),
        'y2_offset': tf.FixedLenFeature([], tf.float32),
        'beta_det': tf.FixedLenFeature([], tf.float32),
        'beta_bbox': tf.FixedLenFeature([], tf.float32)}
    #Call the interface to parse a line of samples
    parsed_example = tf.parse_single_example(serialized=example_proto,features=dics)
    image = tf.decode_raw(parsed_example['img_raw'],out_type=tf.uint8)
    image = tf.reshape(image,shape=[12,12,3])
    #Normalize the image data here
    image = (tf.cast(image,tf.float32)/255.0)
    label = parsed_example['label']
    label=tf.reshape(label,shape=[1])
    label = tf.cast(label,tf.float32)
    x1_offset=parsed_example['x1_offset']
    x1_offset = tf.reshape(x1_offset, shape=[1])
    y1_offset=parsed_example['y1_offset']
    y1_offset = tf.reshape(y1_offset, shape=[1])
    x2_offset=parsed_example['x2_offset']
    x2_offset = tf.reshape(x2_offset, shape=[1])
    y2_offset=parsed_example['y2_offset']
    y2_offset = tf.reshape(y2_offset, shape=[1])
    beta_det=parsed_example['beta_det']
    beta_det=tf.reshape(beta_det,shape=[1])
    beta_bbox=parsed_example['beta_bbox']
    beta_bbox=tf.reshape(beta_bbox,shape=[1])

    return image,label,x1_offset,y1_offset,x2_offset,y2_offset,beta_det,beta_bbox

dataset = tf.data.TFRecordDataset(filenames=[filename])
dataset = dataset.map(pares_tf)
dataset = dataset.batch(16).repeat(1)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
        
    img, label, x1_offset, y1_offset, x2_offset, y2_offset, beta_det, beta_bbox = sess.run(fetches=next_element)