1、下载数据、解压数据
!curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip!unzip -q kagglecatsanddogs_3367a.zip!ls!ls PetImages
2、定义模型[无 fine-tune]
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
"""
vgg 19 block
"""
def make_block(x,size):
previous_block_activation = x #set aside residual
for i in range(2):
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size,3,padding="same" )(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPool2D(3,strides=2,padding = "same")(x)
residual = layers.Conv2D(size,1,strides=2,padding = "same")(previous_block_activation)
x = layers.add([x,residual])
return x
"""
构建 vgg19
"""
def make_model(input_shape,num_classes):
inputs = keras.Input(shape = input_shape)
# x = data_augmentation(inputs) # 数据预处理
# x = layers.experimental.preprocessing.Rescaling(1.0/255)(x)
x = layers.Conv2D(32,3,strides=2,padding = "same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(64,3,padding = "same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
for size in [128,256,512,728]:
x = make_block(x,size)
x = layers.SeparableConv2D(1024,3,padding = "same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(num_classes,activation="softmax")(x)
return keras.Model(inputs,outputs)
"""
create model
"""
image_size = (180, 180)
model = make_model(input_shape= image_size+(3,),num_classes=2)
keras.utils.plot_model(model,show_shapes=True)
3、数据预处理
# 删除不能被解析的图片
import os
num_skipped = 0
for folder_name in ("Cat", "Dog"):
folder_path = os.path.join("PetImages", folder_name)
for fname in os.listdir(folder_path):
fpath = os.path.join(folder_path, fname)
try:
fobj = open(fpath, "rb")
is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
finally:
fobj.close()
if not is_jfif:
num_skipped += 1
# Delete corrupted image
os.remove(fpath)
print("Deleted %d images" % num_skipped)
4、构建数据集
# https://keras.io/api/preprocessing/image/
# 创建数据集,并使用原始数据集进行训练,查看效果
#------超参数定义开始-------
image_size = (180,180)
batch_size = 32
#------超参数定义结束-------
# 读取文件夹下面所有的图片,并使用文件夹作为分类label
def make_dataset_from_directory(image_path,dataset_name):
"""
Then calling image_dataset_from_directory(main_directory, labels='inferred') will return a tf.data.Dataset that yields batches of images from the subdirectories class_a and class_b, together with labels 0 and 1 (0 corresponding to class_a and 1 corresponding to class_b).
Supported image formats: jpeg, png, bmp, gif. Animated gifs are truncated to the first frame.
默认分割比例为 8:2
"""
return tf.keras.preprocessing.image_dataset_from_directory(
image_path
,validation_split = 0.2
,subset = dataset_name
,seed = 1337
,image_size = image_size
,batch_size = batch_size
)
train_ds = make_dataset_from_directory("PetImages","training")
val_ds = make_dataset_from_directory("PetImages","validation")
train_ds = train_ds.prefetch(buffer_size=32)
val_ds = val_ds.prefetch(buffer_size=32)
5、训练网络
epochs = 5
callbacks = [
keras.callbacks.ModelCheckpoint("save_at_{epoch}.h5")
]
basic_model.compile(
optimizer = keras.optimizers.Adam(1e-2)
,loss = keras.losses.sparse_categorical_crossentropy
,metrics = ['accuracy']
)
basic_model.fit(
train_ds
,epochs = epochs
,callbacks = callbacks
,validation_data=val_ds
)