# Create input data
with tf.device("/cpu:0"):
train_glob = os.path.join(RECORDS_ROOT, "train/*/*.JPEG")
train_files = glob.glob(train_glob)
if not train_files:
raise RuntimeError(
"No training images found with glob '{}'.".format(train_glob))
train_dataset = tf.data.Dataset.from_tensor_slices(train_files)
train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat()
train_dataset = train_dataset.map(
read_png, num_parallel_calls=configs.dataset.NUM_PREPROCESS_THREADS)
train_dataset = train_dataset.map(
lambda x: tf.random_crop(x, (configs.dataset.crop_size[0], configs.dataset.crop_size[1], 3)))
train_dataset = train_dataset.batch(configs.dataset.batch_size)
train_dataset = train_dataset.prefetch(32)
num_pixels = configs.dataset.batch_size * configs.dataset.crop_size[0] * configs.dataset.crop_size[1]
# Get training patch from dataset
x = train_dataset.make_one_shot_iterator().get_next()
Tensorflow: 1.14.0