1. 图片数据集的原生构建方法
(1)预处理图像的function
import tensorflow as tffrom tensorflow.keras import datasets,layers,modelsBATCH_SIZE = 100def load_image(img_path,size = (32,32)):label = tf.constant(1,tf.int8) if tf.strings.regex_full_match(img_path,".*automobile.*") \else tf.constant(0,tf.int8)img = tf.io.read_file(img_path)img = tf.image.decode_jpeg(img) #注意此处为jpeg格式img = tf.image.resize(img,size)/255.0return(img,label)
(2)tf.data
#使用并行化预处理num_parallel_calls 和预存数据prefetch来提升性能ds_train = tf.data.Dataset.list_files("./data/cifar2/train/*/*.jpg") \.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \.prefetch(tf.data.experimental.AUTOTUNE)ds_test = tf.data.Dataset.list_files("./data/cifar2/test/*/*.jpg") \.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \.batch(BATCH_SIZE) \.prefetch(tf.data.experimental.AUTOTUNE)
tf.data.prefetch()#提前从数据集中取出若干数据放到内存中,这样可以使在gpu计算时,CPU执行数据预处理相关的指令
(3)抽取部分数据进行可视化
%matplotlib inline%config InlineBackend.figure_format = 'svg'#在默认设置的matplotlib中图片分辨率不是很高,可以通过设置矢量图的方式来提高图片显示质量#查看部分样本from matplotlib import pyplot as pltplt.figure(figsize=(8,8))for i,(img,label) in enumerate(ds_train.unbatch().take(9)):ax=plt.subplot(3,3,i+1)ax.imshow(img.numpy())ax.set_title("label = %d"%label)ax.set_xticks([])ax.set_yticks([])plt.show()
使用fit方法构建训练流程
history = model.fit(ds_train,epochs= 10,validation_data=ds_test,workers = 4)
(4)抽取数据进行推理
a.整个测试集
model.predict(ds_test)
b.从一个batch中抽取前20张图片
for x,y in ds_test.take(1):print(model.predict_on_batch(x[0:20]))
