实战使用

戴琼海老师组3D-Unet无监督去噪

data_process.py

  1. import numpy as np
  2. import os
  3. import tifffile as tiff
  4. import random
  5. import math
  6. import torch
  7. from torch.utils.data import Dataset
  8. class trainset(Dataset):
  9. def __init__(self,name_list,coordinate_list,noise_img_all,stack_index):
  10. self.name_list = name_list
  11. self.coordinate_list=coordinate_list
  12. self.noise_img_all = noise_img_all
  13. self.stack_index = stack_index
  14. def __getitem__(self, index):
  15. #fn = self.images[index]
  16. stack_index = self.stack_index[index]
  17. noise_img = self.noise_img_all[stack_index]
  18. single_coordinate = self.coordinate_list[self.name_list[index]]
  19. init_h = single_coordinate['init_h']
  20. end_h = single_coordinate['end_h']
  21. init_w = single_coordinate['init_w']
  22. end_w = single_coordinate['end_w']
  23. init_s = single_coordinate['init_s']
  24. end_s = single_coordinate['end_s']
  25. input = noise_img[init_s:end_s:2, init_h:end_h, init_w:end_w]
  26. target = noise_img[init_s + 1:end_s:2, init_h:end_h, init_w:end_w]
  27. input=torch.from_numpy(np.expand_dims(input, 0))
  28. target=torch.from_numpy(np.expand_dims(target, 0))
  29. return input, target
  30. def __len__(self):
  31. return len(self.name_list)
  32. class testset(Dataset):
  33. def __init__(self,name_list,coordinate_list,noise_img):
  34. self.name_list = name_list
  35. self.coordinate_list=coordinate_list
  36. self.noise_img = noise_img
  37. def __getitem__(self, index):
  38. #fn = self.images[index]
  39. single_coordinate = self.coordinate_list[self.name_list[index]]
  40. init_h = single_coordinate['init_h']
  41. end_h = single_coordinate['end_h']
  42. init_w = single_coordinate['init_w']
  43. end_w = single_coordinate['end_w']
  44. init_s = single_coordinate['init_s']
  45. end_s = single_coordinate['end_s']
  46. noise_patch = self.noise_img[init_s:end_s, init_h:end_h, init_w:end_w]
  47. noise_patch=torch.from_numpy(np.expand_dims(noise_patch, 0))
  48. #target = self.target[index]
  49. return noise_patch,single_coordinate
  50. def __len__(self):
  51. return len(self.name_list)
  52. def train_preprocess_lessMemoryMulStacks(args):
  53. img_h = args.img_h
  54. img_w = args.img_w
  55. img_s2 = args.img_s*2
  56. gap_h = args.gap_h
  57. gap_w = args.gap_w
  58. gap_s2 = args.gap_s*2
  59. im_folder = args.datasets_path + '//' + args.datasets_folder
  60. name_list = []
  61. coordinate_list={}
  62. stack_index = []
  63. noise_im_all = []
  64. ind = 0;
  65. print('\033[1;31mImage list for training -----> \033[0m')
  66. stack_num = len(list(os.walk(im_folder, topdown=False))[-1][-1])
  67. print('Total number -----> ', stack_num)
  68. for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:
  69. print(im_name)
  70. im_dir = im_folder+ '//' + im_name
  71. noise_im = tiff.imread(im_dir)
  72. if noise_im.shape[0]>args.select_img_num:
  73. noise_im = noise_im[0:args.select_img_num,:,:]
  74. gap_s2 = get_gap_s(args, noise_im, stack_num)
  75. # print('noise_im shape -----> ',noise_im.shape)
  76. # print('noise_im max -----> ',noise_im.max())
  77. # print('noise_im min -----> ',noise_im.min())
  78. noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factor
  79. noise_im_all.append(noise_im)
  80. whole_w = noise_im.shape[2]
  81. whole_h = noise_im.shape[1]
  82. whole_s = noise_im.shape[0]
  83. # print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))
  84. # print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))
  85. # print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))
  86. for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):
  87. for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):
  88. for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):
  89. single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}
  90. init_h = gap_h*x
  91. end_h = gap_h*x + img_h
  92. init_w = gap_w*y
  93. end_w = gap_w*y + img_w
  94. init_s = gap_s2*z
  95. end_s = gap_s2*z + img_s2
  96. single_coordinate['init_h'] = init_h
  97. single_coordinate['end_h'] = end_h
  98. single_coordinate['init_w'] = init_w
  99. single_coordinate['end_w'] = end_w
  100. single_coordinate['init_s'] = init_s
  101. single_coordinate['end_s'] = end_s
  102. # noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]
  103. patch_name = args.datasets_folder+'_'+im_name.replace('.tif','')+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)
  104. # train_raw.append(noise_patch1.transpose(1,2,0))
  105. name_list.append(patch_name)
  106. # print(' single_coordinate -----> ',single_coordinate)
  107. coordinate_list[patch_name] = single_coordinate
  108. stack_index.append(ind)
  109. ind = ind + 1;
  110. return name_list, noise_im_all, coordinate_list, stack_index

train.py

实例化trainset类train_data
DataLoader

  1. # start training
  2. for epoch in range(0, opt.n_epochs):
  3. train_data = trainset(name_list, coordinate_list, noise_img_all,stack_index)
  4. trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
  5. for iteration, (input, target) in enumerate(trainloader):
  6. input=input.cuda()
  7. target = target.cuda()
  8. real_A=input
  9. real_B=target
  10. real_A = Variable(real_A)
  11. #print('real_A shape -----> ', real_A.shape)
  12. #print('real_B shape -----> ',real_B.shape)
  13. fake_B = denoise_generator(real_A)
  14. L1_loss = L1_pixelwise(fake_B, real_B)
  15. L2_loss = L2_pixelwise(fake_B, real_B)
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. from torch.utils.data import DataLoader
  6. import argparse
  7. import time
  8. import datetime
  9. from network import Network_3D_Unet
  10. from data_process import train_preprocess_lessMemoryMulStacks, trainset
  11. from utils import save_yaml
  12. #############################################################################################################################################
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument("--n_epochs", type=int, default=40, help="number of training epochs")
  15. parser.add_argument('--GPU', type=str, default='0,1', help="the index of GPU you will use for computation")
  16. parser.add_argument('--batch_size', type=int, default=2, help="batch size")
  17. parser.add_argument('--img_w', type=int, default=150, help="the width of image patch")
  18. parser.add_argument('--img_h', type=int, default=150, help="the height of image patch")
  19. parser.add_argument('--img_s', type=int, default=150, help="the length of image patch")
  20. parser.add_argument('--lr', type=float, default=0.00005, help='initial learning rate')
  21. parser.add_argument("--b1", type=float, default=0.5, help="Adam: bata1")
  22. parser.add_argument("--b2", type=float, default=0.999, help="Adam: bata2")
  23. parser.add_argument('--normalize_factor', type=int, default=1, help='normalize factor')
  24. parser.add_argument('--fmap', type=int, default=16, help='number of feature maps')
  25. parser.add_argument('--output_dir', type=str, default='./results', help="output directory")
  26. parser.add_argument('--datasets_folder', type=str, default='train', help="A folder containing files for training")
  27. parser.add_argument('--datasets_path', type=str, default='datasets', help="dataset root path")
  28. parser.add_argument('--pth_path', type=str, default='pth', help="pth file root path")
  29. parser.add_argument('--select_img_num', type=int, default=100000, help='select the number of images used for training')
  30. parser.add_argument('--train_datasets_size', type=int, default=4000, help='datasets size for training')
  31. opt = parser.parse_args()
  32. # default image gap is 0.5*image_dim
  33. # opt.gap_s (image gap) is the distance between two adjacent patches
  34. opt.gap_s=int(opt.img_s*0.5)
  35. opt.gap_w=int(opt.img_w*0.5)
  36. opt.gap_h=int(opt.img_h*0.5)
  37. opt.ngpu=str(opt.GPU).count(',')+1
  38. print('\033[1;31mTraining parameters -----> \033[0m')
  39. print(opt)
  40. ########################################################################################################################
  41. if not os.path.exists(opt.output_dir):
  42. os.mkdir(opt.output_dir)
  43. current_time = opt.datasets_folder+'_'+datetime.datetime.now().strftime("%Y%m%d%H%M")
  44. output_path = opt.output_dir + '/' + current_time
  45. pth_path = 'pth//'+ current_time
  46. if not os.path.exists(pth_path):
  47. os.mkdir(pth_path)
  48. yaml_name = pth_path+'//para.yaml'
  49. save_yaml(opt, yaml_name)
  50. os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.GPU)
  51. batch_size = opt.batch_size
  52. lr = opt.lr
  53. name_list, noise_img_all, coordinate_list, stack_index = train_preprocess_lessMemoryMulStacks(opt)
  54. # print('name_list -----> ',name_list)
  55. ########################################################################################################################
  56. L1_pixelwise = torch.nn.L1Loss()
  57. L2_pixelwise = torch.nn.MSELoss()
  58. denoise_generator = Network_3D_Unet(in_channels = 1,
  59. out_channels = 1,
  60. f_maps=opt.fmap,
  61. final_sigmoid = True)
  62. if torch.cuda.is_available():
  63. denoise_generator = denoise_generator.cuda()
  64. denoise_generator = nn.DataParallel(denoise_generator, device_ids=range(opt.ngpu))
  65. print('\033[1;31mUsing {} GPU for training -----> \033[0m'.format(torch.cuda.device_count()))
  66. L2_pixelwise.cuda()
  67. L1_pixelwise.cuda()
  68. ########################################################################################################################
  69. optimizer_G = torch.optim.Adam(denoise_generator.parameters(),
  70. lr=opt.lr, betas=(opt.b1, opt.b2))
  71. ########################################################################################################################
  72. cuda = True if torch.cuda.is_available() else False
  73. Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
  74. prev_time = time.time()
  75. ########################################################################################################################
  76. time_start=time.time()
  77. # start training
  78. for epoch in range(0, opt.n_epochs):
  79. train_data = trainset(name_list, coordinate_list, noise_img_all,stack_index)
  80. trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
  81. for iteration, (input, target) in enumerate(trainloader):
  82. input=input.cuda()
  83. target = target.cuda()
  84. real_A=input
  85. real_B=target
  86. real_A = Variable(real_A)
  87. #print('real_A shape -----> ', real_A.shape)
  88. #print('real_B shape -----> ',real_B.shape)
  89. fake_B = denoise_generator(real_A)
  90. L1_loss = L1_pixelwise(fake_B, real_B)
  91. L2_loss = L2_pixelwise(fake_B, real_B)
  92. ################################################################################################################
  93. optimizer_G.zero_grad()
  94. # Total loss
  95. Total_loss = 0.5*L1_loss + 0.5*L2_loss
  96. Total_loss.backward()
  97. optimizer_G.step()
  98. ################################################################################################################
  99. batches_done = epoch * len(trainloader) + iteration
  100. batches_left = opt.n_epochs * len(trainloader) - batches_done
  101. time_left = datetime.timedelta(seconds=int(batches_left * (time.time() - prev_time)))
  102. prev_time = time.time()
  103. if iteration%1 == 0:
  104. time_end=time.time()
  105. print('\r[Epoch %d/%d] [Batch %d/%d] [Total loss: %.2f, L1 Loss: %.2f, L2 Loss: %.2f] [ETA: %s] [Time cost: %.2d s] '
  106. % (
  107. epoch+1,
  108. opt.n_epochs,
  109. iteration+1,
  110. len(trainloader),
  111. Total_loss.item(),
  112. L1_loss.item(),
  113. L2_loss.item(),
  114. time_left,
  115. time_end-time_start
  116. ), end=' ')
  117. if (iteration+1)%len(trainloader) == 0:
  118. print('\n', end=' ')
  119. ################################################################################################################
  120. # save model
  121. if (iteration + 1) % (len(trainloader)) == 0:
  122. model_save_name = pth_path + '//E_' + str(epoch+1).zfill(2) + '_Iter_' + str(iteration+1).zfill(4) + '.pth'
  123. if isinstance(denoise_generator, nn.DataParallel):
  124. torch.save(denoise_generator.module.state_dict(), model_save_name) # parallel
  125. else:
  126. torch.save(denoise_generator.state_dict(), model_save_name) # not parallel