https://www.tensorflow.org/tutorials/generative/cyclegan#loss_functions

Tensorflow官方文档对CycleGAN的解释很清晰

  1. OUTPUT_CHANNELS = 3
  2. generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
  3. generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
  4. discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
  5. discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

Loss

adversarial loss

  1. LAMBDA = 10
  2. loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

BinaryCrossentropy()
the cross-entropy loss between true labels and predicted labels

  1. bce = tf.keras.losses.BinaryCrossentropy()
  2. loss = bce([0., 0., 1., 1.], [1., 1., 1., 0.])
  3. print('Loss: ', loss.numpy()) # Loss: 11.522857
  4. ----------------------------------------------------
  5. __call__(
  6. y_true, y_pred, sample_weight=None
  7. )

image.png

  1. def discriminator_loss(real, generated):
  2. real_loss = loss_obj(tf.ones_like(real), real)
  3. generated_loss = loss_obj(tf.zeros_like(generated), generated)
  4. total_disc_loss = real_loss + generated_loss
  5. return total_disc_loss * 0.5
  6. def generator_loss(generated):
  7. return loss_obj(tf.ones_like(generated), generated)

cycle-consistency loss

  1. def calc_cycle_loss(real_image, cycled_image):
  2. loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  3. return LAMBDA * loss1

Indentity loss

  1. def identity_loss(real_image, same_image):
  2. loss = tf.reduce_mean(tf.abs(real_image - same_image))
  3. return LAMBDA * 0.5 * loss

Optimizers

  1. generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
  2. generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
  3. discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
  4. discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

Training

  1. # Generator G translates X -> Y<br /> # Generator F translates Y -> X

Horse and zebr : real_x , real_y

数据输入到生成器

fake_y = generate_g( real_x )
cycle_x = generate_f(fake_y)

fake_x = generate_f( real_y )
cycle_y = generate_g( fake_x )

same_y = generate_g( real_y )
same_x = generate_f( real_x )

数据输入到鉴别器

disc_real_x = dicriminator_x( real_x )
disc_real_y = dicriminator_y( real_y )

disc_fake_x = dicriminator_x( fake_x )
disc_fake_y = dicriminator_y( fake_y )

adversarial loss

Generator

生成器希望骗过鉴别器,也就是生成器生成的图像经过鉴别器后为真概率越大越好
wish to maximize log(D(G(z)))

gen_g_loss = generator_loss( disc_fake_y )
gen_f_loss = generator_loss( disc_fake_x )

Discriminator

鉴别器希望找出所有的生成器的结果,对真实输入都能判定为真

we want to maximize
Tensorflow2 -CycleGAN - 图2

disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

CycleLoss

total_cycle_loss = calc_cycle_loss(real_x, cycle_x) + calc_cycle_loss(real_y, cycle_y)

Total generator loss

total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

  1. EPOCHS = 40
  2. @tf.function
  3. def train_step(real_x, real_y):
  4. # persistent is set to True because the tape is used more than
  5. # once to calculate the gradients.
  6. with tf.GradientTape(persistent=True) as tape:
  7. # Generator G translates X -> Y
  8. # Generator F translates Y -> X.
  9. fake_y = generator_g(real_x, training=True)
  10. cycled_x = generator_f(fake_y, training=True)
  11. fake_x = generator_f(real_y, training=True)
  12. cycled_y = generator_g(fake_x, training=True)
  13. # same_x and same_y are used for identity loss.
  14. same_x = generator_f(real_x, training=True)
  15. same_y = generator_g(real_y, training=True)
  16. disc_real_x = discriminator_x(real_x, training=True)
  17. disc_real_y = discriminator_y(real_y, training=True)
  18. disc_fake_x = discriminator_x(fake_x, training=True)
  19. disc_fake_y = discriminator_y(fake_y, training=True)
  20. # calculate the loss
  21. gen_g_loss = generator_loss(disc_fake_y)
  22. gen_f_loss = generator_loss(disc_fake_x)
  23. total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
  24. # Total generator loss = adversarial loss + cycle loss
  25. total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
  26. total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
  27. disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
  28. disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  29. # Calculate the gradients for generator and discriminator
  30. generator_g_gradients = tape.gradient(total_gen_g_loss,
  31. generator_g.trainable_variables)
  32. generator_f_gradients = tape.gradient(total_gen_f_loss,
  33. generator_f.trainable_variables)
  34. discriminator_x_gradients = tape.gradient(disc_x_loss,
  35. discriminator_x.trainable_variables)
  36. discriminator_y_gradients = tape.gradient(disc_y_loss,
  37. discriminator_y.trainable_variables)
  38. # Apply the gradients to the optimizer
  39. generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
  40. generator_g.trainable_variables))
  41. generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
  42. generator_f.trainable_variables))
  43. discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
  44. discriminator_x.trainable_variables))
  45. discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
  46. discriminator_y.trainable_variables))