1、下载数据、解压数据

  1. !curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip
  2. !unzip -q kagglecatsanddogs_3367a.zip
  3. !ls
  4. !ls PetImages

2、定义模型[无 fine-tune]

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


"""
vgg 19 block
"""
def make_block(x,size):
  previous_block_activation = x  #set aside residual
  for i in range(2):
    x = layers.Activation("relu")(x)
    x = layers.SeparableConv2D(size,3,padding="same" )(x)
    x = layers.BatchNormalization()(x)

  x = layers.MaxPool2D(3,strides=2,padding = "same")(x)
  residual = layers.Conv2D(size,1,strides=2,padding = "same")(previous_block_activation)

  x = layers.add([x,residual])

  return x



"""
构建 vgg19
"""
def make_model(input_shape,num_classes):
  inputs = keras.Input(shape = input_shape)

  # x = data_augmentation(inputs) # 数据预处理
  # x = layers.experimental.preprocessing.Rescaling(1.0/255)(x)

  x = layers.Conv2D(32,3,strides=2,padding = "same")(inputs)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)

  x = layers.Conv2D(64,3,padding = "same")(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)

  for size in [128,256,512,728]:
    x = make_block(x,size)

  x = layers.SeparableConv2D(1024,3,padding = "same")(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)

  x = layers.GlobalAveragePooling2D()(x)

  x = layers.Dropout(0.5)(x)
  outputs = layers.Dense(num_classes,activation="softmax")(x)

  return keras.Model(inputs,outputs)


"""
create model
"""
image_size = (180, 180)
model = make_model(input_shape= image_size+(3,),num_classes=2)
keras.utils.plot_model(model,show_shapes=True)

3、数据预处理

# 删除不能被解析的图片
import os

num_skipped = 0
for folder_name in ("Cat", "Dog"):
    folder_path = os.path.join("PetImages", folder_name)
    for fname in os.listdir(folder_path):
        fpath = os.path.join(folder_path, fname)
        try:
            fobj = open(fpath, "rb")
            is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
        finally:
            fobj.close()

        if not is_jfif:
            num_skipped += 1
            # Delete corrupted image
            os.remove(fpath)

print("Deleted %d images" % num_skipped)

4、构建数据集

# https://keras.io/api/preprocessing/image/
# 创建数据集,并使用原始数据集进行训练,查看效果

#------超参数定义开始-------
image_size = (180,180)
batch_size = 32

#------超参数定义结束-------
# 读取文件夹下面所有的图片,并使用文件夹作为分类label
def make_dataset_from_directory(image_path,dataset_name):
  """
    Then calling image_dataset_from_directory(main_directory, labels='inferred') will return a tf.data.Dataset that yields batches of images from the subdirectories class_a and class_b, together with labels 0 and 1 (0 corresponding to class_a and 1 corresponding to class_b).
    Supported image formats: jpeg, png, bmp, gif. Animated gifs are truncated to the first frame.
    默认分割比例为 8:2
  """
  return tf.keras.preprocessing.image_dataset_from_directory(
      image_path
      ,validation_split = 0.2
      ,subset = dataset_name
      ,seed = 1337
      ,image_size = image_size
      ,batch_size = batch_size
  )

train_ds = make_dataset_from_directory("PetImages","training")
val_ds = make_dataset_from_directory("PetImages","validation")

train_ds = train_ds.prefetch(buffer_size=32)
val_ds = val_ds.prefetch(buffer_size=32)

5、训练网络

epochs = 5
callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.h5")
]
basic_model.compile(
    optimizer = keras.optimizers.Adam(1e-2)
    ,loss = keras.losses.sparse_categorical_crossentropy
    ,metrics = ['accuracy']
)

basic_model.fit(
    train_ds
    ,epochs = epochs
    ,callbacks = callbacks
    ,validation_data=val_ds
)