How to get batch size back from a TensorFlow dataset
Hotshot TensorFlow is here! In this article, we learn how to get the batch size back from the input dataset or the iterator.
Getting the batch size back
Let’s brush up on a few concepts of TensorFlow before starting the tutorial :
- Batch : A dataset is sequentially divided into smaller parts/sets called batches and then fed into the model for easy computation
- Iterator: Gives access to individual elements of a dataset by iterating through it. There are 4 types of Iterators in TensorFlow. We will be using the Initializable Iterator which lets you feed data dynamically whenever its called.
import tensorflow as tf import numpy as np
Step 1: Import Tensorflow and the numpy libraries. I am using Tensorflow 1.0 version here.
feature = np.random.sample((100,2)) label = np.random.sample((100,1)) data = tf.data.Dataset.from_tensor_slices((feature,label))
Step 2: Initialized numpy random vectors of shape (100,2) (for features) and (100,1) (for labels) following which initialize a TensorFlow dataset object “data” by using Dataset.from_tensor_slices.(Note: If you are feeding multiple arrays/tensors to this method make sure that they have the same 0th dimension, here it’s 100.)
batch_size = 4 data = data.batch(batch_size)
Step 3: Define a random batch size which denotes the no of elements in each batch.”Data.batch” divides the dataset into a number of batches each containing 4 elements.
Iterator = data.make_initializable_iterator() batchdata = Iterator.get_next() with tf.Session() as sess: sess.run(Iterator.initializer) print(np.shape(sess.run(batchdata)))
Step 4: The make_initializable_iterator returns an uninitialized iterator which enumerates elements of your dataset. The “get_next” function creates an operation in the graph of TensorFlow which returns the values from the fed Dataset when running in a session.
After defining the iterator we proceed towards our session. By “Iterator.initializer” we initialize the iterator which is now ready for use. Then run the “batchdata”, get its first element and finally, display the 0th dimension of the element which is guess what? our Batch Size!
Please feel free to comment and give your feedback.