1. import torch
    2. from torch import nn, optim, autograd
    3. import numpy as np
    4. import visdom
    5. from torch.nn import functional as F
    6. from matplotlib import pyplot as plt
    7. import random
    8. h_dim = 400
    9. batchsz = 512
    10. viz = visdom.Visdom()
    11. class Generator(nn.Module):
    12. def __init__(self):
    13. super(Generator, self).__init__()
    14. self.net = nn.Sequential(
    15. nn.Linear(2, h_dim),
    16. nn.ReLU(True),
    17. nn.Linear(h_dim, h_dim),
    18. nn.ReLU(True),
    19. nn.Linear(h_dim, h_dim),
    20. nn.ReLU(True),
    21. nn.Linear(h_dim, 2),
    22. )
    23. def forward(self, z):
    24. output = self.net(z)
    25. return output
    26. class Discriminator(nn.Module):
    27. def __init__(self):
    28. super(Discriminator, self).__init__()
    29. self.net = nn.Sequential(
    30. nn.Linear(2, h_dim),
    31. nn.ReLU(True),
    32. nn.Linear(h_dim, h_dim),
    33. nn.ReLU(True),
    34. nn.Linear(h_dim, h_dim),
    35. nn.ReLU(True),
    36. nn.Linear(h_dim, 1),
    37. nn.Sigmoid()
    38. )
    39. def forward(self, x):
    40. output = self.net(x)
    41. return output.view(-1)
    42. def data_generator():
    43. scale = 2.
    44. centers = [
    45. (1, 0),
    46. (-1, 0),
    47. (0, 1),
    48. (0, -1),
    49. (1. / np.sqrt(2), 1. / np.sqrt(2)),
    50. (1. / np.sqrt(2), -1. / np.sqrt(2)),
    51. (-1. / np.sqrt(2), 1. / np.sqrt(2)),
    52. (-1. / np.sqrt(2), -1. / np.sqrt(2))
    53. ]
    54. centers = [(scale * x, scale * y) for x, y in centers]
    55. while True:
    56. dataset = []
    57. for i in range(batchsz):
    58. point = np.random.randn(2) * .02
    59. center = random.choice(centers)
    60. point[0] += center[0]
    61. point[1] += center[1]
    62. dataset.append(point)
    63. dataset = np.array(dataset, dtype='float32')
    64. dataset /= 1.414 # stdev
    65. yield dataset
    66. # for i in range(100000//25):
    67. # for x in range(-2, 3):
    68. # for y in range(-2, 3):
    69. # point = np.random.randn(2).astype(np.float32) * 0.05
    70. # point[0] += 2 * x
    71. # point[1] += 2 * y
    72. # dataset.append(point)
    73. #
    74. # dataset = np.array(dataset)
    75. # print('dataset:', dataset.shape)
    76. # viz.scatter(dataset, win='dataset', opts=dict(title='dataset', webgl=True))
    77. #
    78. # while True:
    79. # np.random.shuffle(dataset)
    80. #
    81. # for i in range(len(dataset)//batchsz):
    82. # yield dataset[i*batchsz : (i+1)*batchsz]
    83. def generate_image(D, G, xr, epoch):
    84. """
    85. Generates and saves a plot of the true distribution, the generator, and the
    86. critic.
    87. """
    88. N_POINTS = 128
    89. RANGE = 3
    90. plt.clf()
    91. points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    92. points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    93. points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    94. points = points.reshape((-1, 2))
    95. # (16384, 2)
    96. # print('p:', points.shape)
    97. # draw contour
    98. with torch.no_grad():
    99. points = torch.Tensor(points).cuda() # [16384, 2]
    100. disc_map = D(points).cpu().numpy() # [16384]
    101. x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    102. cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())
    103. plt.clabel(cs, inline=1, fontsize=10)
    104. # plt.colorbar()
    105. # draw samples
    106. with torch.no_grad():
    107. z = torch.randn(batchsz, 2).cuda() # [b, 2]
    108. samples = G(z).cpu().numpy() # [b, 2]
    109. plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')
    110. plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')
    111. viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
    112. def weights_init(m):
    113. if isinstance(m, nn.Linear):
    114. # m.weight.data.normal_(0.0, 0.02)
    115. nn.init.kaiming_normal_(m.weight)
    116. m.bias.data.fill_(0)
    117. def gradient_penalty(D, xr, xf):
    118. """
    119. :param D:
    120. :param xr:
    121. :param xf:
    122. :return:
    123. """
    124. LAMBDA = 0.3
    125. # only constrait for Discriminator
    126. xf = xf.detach()
    127. xr = xr.detach()
    128. # [b, 1] => [b, 2]
    129. alpha = torch.rand(batchsz, 1).cuda()
    130. alpha = alpha.expand_as(xr)
    131. interpolates = alpha * xr + ((1 - alpha) * xf)
    132. interpolates.requires_grad_()
    133. disc_interpolates = D(interpolates)
    134. gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
    135. grad_outputs=torch.ones_like(disc_interpolates),
    136. create_graph=True, retain_graph=True, only_inputs=True)[0]
    137. gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    138. return gp
    139. def main():
    140. torch.manual_seed(23)
    141. np.random.seed(23)
    142. G = Generator().cuda()
    143. D = Discriminator().cuda()
    144. G.apply(weights_init)
    145. D.apply(weights_init)
    146. optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
    147. optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))
    148. data_iter = data_generator()
    149. print('batch:', next(data_iter).shape)
    150. viz.line([[0,0]], [0], win='loss', opts=dict(title='loss',
    151. legend=['D', 'G']))
    152. for epoch in range(50000):
    153. # 1. train discriminator for k steps
    154. for _ in range(5):
    155. x = next(data_iter)
    156. xr = torch.from_numpy(x).cuda()
    157. # [b]
    158. predr = (D(xr))
    159. # max log(lossr)
    160. lossr = - (predr.mean())
    161. # [b, 2]
    162. z = torch.randn(batchsz, 2).cuda()
    163. # stop gradient on G
    164. # [b, 2]
    165. xf = G(z).detach()
    166. # [b]
    167. predf = (D(xf))
    168. # min predf
    169. lossf = (predf.mean())
    170. # gradient penalty
    171. gp = gradient_penalty(D, xr, xf)
    172. loss_D = lossr + lossf + gp
    173. optim_D.zero_grad()
    174. loss_D.backward()
    175. # for p in D.parameters():
    176. # print(p.grad.norm())
    177. optim_D.step()
    178. # 2. train Generator
    179. z = torch.randn(batchsz, 2).cuda()
    180. xf = G(z)
    181. predf = (D(xf))
    182. # max predf
    183. loss_G = - (predf.mean())
    184. optim_G.zero_grad()
    185. loss_G.backward()
    186. optim_G.step()
    187. if epoch % 100 == 0:
    188. viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
    189. generate_image(D, G, xr, epoch)
    190. print(loss_D.item(), loss_G.item())
    191. if __name__ == '__main__':
    192. main()
    1. import torch
    2. from torch import nn, optim, autograd
    3. import numpy as np
    4. import visdom
    5. from torch.nn import functional as F
    6. from matplotlib import pyplot as plt
    7. import random
    8. h_dim = 400
    9. batchsz = 512
    10. viz = visdom.Visdom()
    11. class Generator(nn.Module):
    12. def __init__(self):
    13. super(Generator, self).__init__()
    14. self.net = nn.Sequential(
    15. nn.Linear(2, h_dim),
    16. nn.ReLU(True),
    17. nn.Linear(h_dim, h_dim),
    18. nn.ReLU(True),
    19. nn.Linear(h_dim, h_dim),
    20. nn.ReLU(True),
    21. nn.Linear(h_dim, 2),
    22. )
    23. def forward(self, z):
    24. output = self.net(z)
    25. return output
    26. class Discriminator(nn.Module):
    27. def __init__(self):
    28. super(Discriminator, self).__init__()
    29. self.net = nn.Sequential(
    30. nn.Linear(2, h_dim),
    31. nn.ReLU(True),
    32. nn.Linear(h_dim, h_dim),
    33. nn.ReLU(True),
    34. nn.Linear(h_dim, h_dim),
    35. nn.ReLU(True),
    36. nn.Linear(h_dim, 1),
    37. nn.Sigmoid()
    38. )
    39. def forward(self, x):
    40. output = self.net(x)
    41. return output.view(-1)
    42. def data_generator():
    43. scale = 2.
    44. centers = [
    45. (1, 0),
    46. (-1, 0),
    47. (0, 1),
    48. (0, -1),
    49. (1. / np.sqrt(2), 1. / np.sqrt(2)),
    50. (1. / np.sqrt(2), -1. / np.sqrt(2)),
    51. (-1. / np.sqrt(2), 1. / np.sqrt(2)),
    52. (-1. / np.sqrt(2), -1. / np.sqrt(2))
    53. ]
    54. centers = [(scale * x, scale * y) for x, y in centers]
    55. while True:
    56. dataset = []
    57. for i in range(batchsz):
    58. point = np.random.randn(2) * .02
    59. center = random.choice(centers)
    60. point[0] += center[0]
    61. point[1] += center[1]
    62. dataset.append(point)
    63. dataset = np.array(dataset, dtype='float32')
    64. dataset /= 1.414 # stdev
    65. yield dataset
    66. # for i in range(100000//25):
    67. # for x in range(-2, 3):
    68. # for y in range(-2, 3):
    69. # point = np.random.randn(2).astype(np.float32) * 0.05
    70. # point[0] += 2 * x
    71. # point[1] += 2 * y
    72. # dataset.append(point)
    73. #
    74. # dataset = np.array(dataset)
    75. # print('dataset:', dataset.shape)
    76. # viz.scatter(dataset, win='dataset', opts=dict(title='dataset', webgl=True))
    77. #
    78. # while True:
    79. # np.random.shuffle(dataset)
    80. #
    81. # for i in range(len(dataset)//batchsz):
    82. # yield dataset[i*batchsz : (i+1)*batchsz]
    83. def generate_image(D, G, xr, epoch):
    84. """
    85. Generates and saves a plot of the true distribution, the generator, and the
    86. critic.
    87. """
    88. N_POINTS = 128
    89. RANGE = 3
    90. plt.clf()
    91. points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    92. points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    93. points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    94. points = points.reshape((-1, 2))
    95. # (16384, 2)
    96. # print('p:', points.shape)
    97. # draw contour
    98. with torch.no_grad():
    99. points = torch.Tensor(points).cuda() # [16384, 2]
    100. disc_map = D(points).cpu().numpy() # [16384]
    101. x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    102. cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())
    103. plt.clabel(cs, inline=1, fontsize=10)
    104. # plt.colorbar()
    105. # draw samples
    106. with torch.no_grad():
    107. z = torch.randn(batchsz, 2).cuda() # [b, 2]
    108. samples = G(z).cpu().numpy() # [b, 2]
    109. plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.')
    110. plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+')
    111. viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
    112. def weights_init(m):
    113. if isinstance(m, nn.Linear):
    114. # m.weight.data.normal_(0.0, 0.02)
    115. nn.init.kaiming_normal_(m.weight)
    116. m.bias.data.fill_(0)
    117. def gradient_penalty(D, xr, xf):
    118. """
    119. :param D:
    120. :param xr:
    121. :param xf:
    122. :return:
    123. """
    124. LAMBDA = 0.3
    125. # only constrait for Discriminator
    126. xf = xf.detach()
    127. xr = xr.detach()
    128. # [b, 1] => [b, 2]
    129. alpha = torch.rand(batchsz, 1).cuda()
    130. alpha = alpha.expand_as(xr)
    131. interpolates = alpha * xr + ((1 - alpha) * xf)
    132. interpolates.requires_grad_()
    133. disc_interpolates = D(interpolates)
    134. gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
    135. grad_outputs=torch.ones_like(disc_interpolates),
    136. create_graph=True, retain_graph=True, only_inputs=True)[0]
    137. gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    138. return gp
    139. def main():
    140. torch.manual_seed(23)
    141. np.random.seed(23)
    142. G = Generator().cuda()
    143. D = Discriminator().cuda()
    144. G.apply(weights_init)
    145. D.apply(weights_init)
    146. optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
    147. optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))
    148. data_iter = data_generator()
    149. print('batch:', next(data_iter).shape)
    150. viz.line([[0,0]], [0], win='loss', opts=dict(title='loss',
    151. legend=['D', 'G']))
    152. for epoch in range(50000):
    153. # 1. train discriminator for k steps
    154. for _ in range(5):
    155. x = next(data_iter)
    156. xr = torch.from_numpy(x).cuda()
    157. # [b]
    158. predr = (D(xr))
    159. # max log(lossr)
    160. lossr = - (predr.mean())
    161. # [b, 2]
    162. z = torch.randn(batchsz, 2).cuda()
    163. # stop gradient on G
    164. # [b, 2]
    165. xf = G(z).detach()
    166. # [b]
    167. predf = (D(xf))
    168. # min predf
    169. lossf = (predf.mean())
    170. # gradient penalty
    171. gp = gradient_penalty(D, xr, xf)
    172. loss_D = lossr + lossf + gp
    173. optim_D.zero_grad()
    174. loss_D.backward()
    175. # for p in D.parameters():
    176. # print(p.grad.norm())
    177. optim_D.step()
    178. # 2. train Generator
    179. z = torch.randn(batchsz, 2).cuda()
    180. xf = G(z)
    181. predf = (D(xf))
    182. # max predf
    183. loss_G = - (predf.mean())
    184. optim_G.zero_grad()
    185. loss_G.backward()
    186. optim_G.step()
    187. if epoch % 100 == 0:
    188. viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
    189. generate_image(D, G, xr, epoch)
    190. print(loss_D.item(), loss_G.item())
    191. if __name__ == '__main__':
    192. main()