1. 图片数据集的原生构建方法

(1)预处理图像的function

  1. import tensorflow as tf
  2. from tensorflow.keras import datasets,layers,models
  3. BATCH_SIZE = 100
  4. def load_image(img_path,size = (32,32)):
  5. label = tf.constant(1,tf.int8) if tf.strings.regex_full_match(img_path,".*automobile.*") \
  6. else tf.constant(0,tf.int8)
  7. img = tf.io.read_file(img_path)
  8. img = tf.image.decode_jpeg(img) #注意此处为jpeg格式
  9. img = tf.image.resize(img,size)/255.0
  10. return(img,label)

(2)tf.data

  1. #使用并行化预处理num_parallel_calls 和预存数据prefetch来提升性能
  2. ds_train = tf.data.Dataset.list_files("./data/cifar2/train/*/*.jpg") \
  3. .map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
  4. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  5. .prefetch(tf.data.experimental.AUTOTUNE)
  6. ds_test = tf.data.Dataset.list_files("./data/cifar2/test/*/*.jpg") \
  7. .map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
  8. .batch(BATCH_SIZE) \
  9. .prefetch(tf.data.experimental.AUTOTUNE)

tf.data.prefetch()#提前从数据集中取出若干数据放到内存中,这样可以使在gpu计算时,CPU执行数据预处理相关的指令

(3)抽取部分数据进行可视化

  1. %matplotlib inline
  2. %config InlineBackend.figure_format = 'svg'
  3. #在默认设置的matplotlib中图片分辨率不是很高,可以通过设置矢量图的方式来提高图片显示质量
  4. #查看部分样本
  5. from matplotlib import pyplot as plt
  6. plt.figure(figsize=(8,8))
  7. for i,(img,label) in enumerate(ds_train.unbatch().take(9)):
  8. ax=plt.subplot(3,3,i+1)
  9. ax.imshow(img.numpy())
  10. ax.set_title("label = %d"%label)
  11. ax.set_xticks([])
  12. ax.set_yticks([])
  13. plt.show()

使用fit方法构建训练流程

  1. history = model.fit(ds_train,epochs= 10,validation_data=ds_test,workers = 4)

(4)抽取数据进行推理

a.整个测试集

  1. model.predict(ds_test)

b.从一个batch中抽取前20张图片

  1. for x,y in ds_test.take(1):
  2. print(model.predict_on_batch(x[0:20]))