由於語法渲染問題而影響閱讀體驗, 請移步博客閱讀~
本文GitPage地址

Transfer_Learning

  1. ##!/usr/local/bin/python3.6
  2. from tensorflow.python.keras.applications import ResNet50
  3. from tensorflow.python.keras.models import Sequential
  4. from tensorflow.python.keras.layers import Dense, Flatten, GlobalAveragePooling2D
  5. num_classes = 2
  6. resnet_weights_path = 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
  7. my_new_model = Sequential()
  8. my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=resnet_weights_path))
  9. my_new_model.add(Dense(num_classes, activation='softmax'))
  10. my_new_model.layers[0].trainable = False
  11. my_new_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
  12. ######## Fit Model ###
  13. from tensorflow.python.keras.applications.resnet50 import preprocess_input
  14. from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
  15. image_size = 224
  16. data_generator = ImageDataGenerator(preprocessing_function=preprocess_input)
  17. train_generator = data_generator.flow_from_directory(
  18. './train',
  19. target_size=(image_size, image_size),
  20. batch_size=24,
  21. class_mode='categorical')
  22. validation_generator = data_generator.flow_from_directory(
  23. './val',
  24. target_size=(image_size, image_size),
  25. class_mode='categorical')
  26. my_new_model.fit_generator(
  27. train_generator,
  28. steps_per_epoch=3,
  29. validation_data=validation_generator,
  30. validation_steps=1)

Enjoy~

本文由Python腳本GitHub/語雀自動更新

由於語法渲染問題而影響閱讀體驗, 請移步博客閱讀~
本文GitPage地址

GitHub: Karobben
Blog:Karobben
BiliBili:史上最不正經的生物狗