实战使用
data_process.py
import numpy as npimport osimport tifffile as tiffimport randomimport mathimport torchfrom torch.utils.data import Datasetclass trainset(Dataset):def __init__(self,name_list,coordinate_list,noise_img_all,stack_index):self.name_list = name_listself.coordinate_list=coordinate_listself.noise_img_all = noise_img_allself.stack_index = stack_indexdef __getitem__(self, index):#fn = self.images[index]stack_index = self.stack_index[index]noise_img = self.noise_img_all[stack_index]single_coordinate = self.coordinate_list[self.name_list[index]]init_h = single_coordinate['init_h']end_h = single_coordinate['end_h']init_w = single_coordinate['init_w']end_w = single_coordinate['end_w']init_s = single_coordinate['init_s']end_s = single_coordinate['end_s']input = noise_img[init_s:end_s:2, init_h:end_h, init_w:end_w]target = noise_img[init_s + 1:end_s:2, init_h:end_h, init_w:end_w]input=torch.from_numpy(np.expand_dims(input, 0))target=torch.from_numpy(np.expand_dims(target, 0))return input, targetdef __len__(self):return len(self.name_list)class testset(Dataset):def __init__(self,name_list,coordinate_list,noise_img):self.name_list = name_listself.coordinate_list=coordinate_listself.noise_img = noise_imgdef __getitem__(self, index):#fn = self.images[index]single_coordinate = self.coordinate_list[self.name_list[index]]init_h = single_coordinate['init_h']end_h = single_coordinate['end_h']init_w = single_coordinate['init_w']end_w = single_coordinate['end_w']init_s = single_coordinate['init_s']end_s = single_coordinate['end_s']noise_patch = self.noise_img[init_s:end_s, init_h:end_h, init_w:end_w]noise_patch=torch.from_numpy(np.expand_dims(noise_patch, 0))#target = self.target[index]return noise_patch,single_coordinatedef __len__(self):return len(self.name_list)def train_preprocess_lessMemoryMulStacks(args):img_h = args.img_himg_w = args.img_wimg_s2 = args.img_s*2gap_h = args.gap_hgap_w = args.gap_wgap_s2 = args.gap_s*2im_folder = args.datasets_path + '//' + args.datasets_foldername_list = []coordinate_list={}stack_index = []noise_im_all = []ind = 0;print('\033[1;31mImage list for training -----> \033[0m')stack_num = len(list(os.walk(im_folder, topdown=False))[-1][-1])print('Total number -----> ', stack_num)for im_name in list(os.walk(im_folder, topdown=False))[-1][-1]:print(im_name)im_dir = im_folder+ '//' + im_namenoise_im = tiff.imread(im_dir)if noise_im.shape[0]>args.select_img_num:noise_im = noise_im[0:args.select_img_num,:,:]gap_s2 = get_gap_s(args, noise_im, stack_num)# print('noise_im shape -----> ',noise_im.shape)# print('noise_im max -----> ',noise_im.max())# print('noise_im min -----> ',noise_im.min())noise_im = (noise_im-noise_im.min()).astype(np.float32)/args.normalize_factornoise_im_all.append(noise_im)whole_w = noise_im.shape[2]whole_h = noise_im.shape[1]whole_s = noise_im.shape[0]# print('int((whole_h-img_h+gap_h)/gap_h) -----> ',int((whole_h-img_h+gap_h)/gap_h))# print('int((whole_w-img_w+gap_w)/gap_w) -----> ',int((whole_w-img_w+gap_w)/gap_w))# print('int((whole_s-img_s2+gap_s2)/gap_s2) -----> ',int((whole_s-img_s2+gap_s2)/gap_s2))for x in range(0,int((whole_h-img_h+gap_h)/gap_h)):for y in range(0,int((whole_w-img_w+gap_w)/gap_w)):for z in range(0,int((whole_s-img_s2+gap_s2)/gap_s2)):single_coordinate={'init_h':0, 'end_h':0, 'init_w':0, 'end_w':0, 'init_s':0, 'end_s':0}init_h = gap_h*xend_h = gap_h*x + img_hinit_w = gap_w*yend_w = gap_w*y + img_winit_s = gap_s2*zend_s = gap_s2*z + img_s2single_coordinate['init_h'] = init_hsingle_coordinate['end_h'] = end_hsingle_coordinate['init_w'] = init_wsingle_coordinate['end_w'] = end_wsingle_coordinate['init_s'] = init_ssingle_coordinate['end_s'] = end_s# noise_patch1 = noise_im[init_s:end_s,init_h:end_h,init_w:end_w]patch_name = args.datasets_folder+'_'+im_name.replace('.tif','')+'_x'+str(x)+'_y'+str(y)+'_z'+str(z)# train_raw.append(noise_patch1.transpose(1,2,0))name_list.append(patch_name)# print(' single_coordinate -----> ',single_coordinate)coordinate_list[patch_name] = single_coordinatestack_index.append(ind)ind = ind + 1;return name_list, noise_im_all, coordinate_list, stack_index
train.py
实例化trainset类train_dataDataLoader
# start trainingfor epoch in range(0, opt.n_epochs):train_data = trainset(name_list, coordinate_list, noise_img_all,stack_index)trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)for iteration, (input, target) in enumerate(trainloader):input=input.cuda()target = target.cuda()real_A=inputreal_B=targetreal_A = Variable(real_A)#print('real_A shape -----> ', real_A.shape)#print('real_B shape -----> ',real_B.shape)fake_B = denoise_generator(real_A)L1_loss = L1_pixelwise(fake_B, real_B)L2_loss = L2_pixelwise(fake_B, real_B)
import osimport torchimport torch.nn as nnfrom torch.autograd import Variablefrom torch.utils.data import DataLoaderimport argparseimport timeimport datetimefrom network import Network_3D_Unetfrom data_process import train_preprocess_lessMemoryMulStacks, trainsetfrom utils import save_yaml#############################################################################################################################################parser = argparse.ArgumentParser()parser.add_argument("--n_epochs", type=int, default=40, help="number of training epochs")parser.add_argument('--GPU', type=str, default='0,1', help="the index of GPU you will use for computation")parser.add_argument('--batch_size', type=int, default=2, help="batch size")parser.add_argument('--img_w', type=int, default=150, help="the width of image patch")parser.add_argument('--img_h', type=int, default=150, help="the height of image patch")parser.add_argument('--img_s', type=int, default=150, help="the length of image patch")parser.add_argument('--lr', type=float, default=0.00005, help='initial learning rate')parser.add_argument("--b1", type=float, default=0.5, help="Adam: bata1")parser.add_argument("--b2", type=float, default=0.999, help="Adam: bata2")parser.add_argument('--normalize_factor', type=int, default=1, help='normalize factor')parser.add_argument('--fmap', type=int, default=16, help='number of feature maps')parser.add_argument('--output_dir', type=str, default='./results', help="output directory")parser.add_argument('--datasets_folder', type=str, default='train', help="A folder containing files for training")parser.add_argument('--datasets_path', type=str, default='datasets', help="dataset root path")parser.add_argument('--pth_path', type=str, default='pth', help="pth file root path")parser.add_argument('--select_img_num', type=int, default=100000, help='select the number of images used for training')parser.add_argument('--train_datasets_size', type=int, default=4000, help='datasets size for training')opt = parser.parse_args()# default image gap is 0.5*image_dim# opt.gap_s (image gap) is the distance between two adjacent patchesopt.gap_s=int(opt.img_s*0.5)opt.gap_w=int(opt.img_w*0.5)opt.gap_h=int(opt.img_h*0.5)opt.ngpu=str(opt.GPU).count(',')+1print('\033[1;31mTraining parameters -----> \033[0m')print(opt)########################################################################################################################if not os.path.exists(opt.output_dir):os.mkdir(opt.output_dir)current_time = opt.datasets_folder+'_'+datetime.datetime.now().strftime("%Y%m%d%H%M")output_path = opt.output_dir + '/' + current_timepth_path = 'pth//'+ current_timeif not os.path.exists(pth_path):os.mkdir(pth_path)yaml_name = pth_path+'//para.yaml'save_yaml(opt, yaml_name)os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.GPU)batch_size = opt.batch_sizelr = opt.lrname_list, noise_img_all, coordinate_list, stack_index = train_preprocess_lessMemoryMulStacks(opt)# print('name_list -----> ',name_list)########################################################################################################################L1_pixelwise = torch.nn.L1Loss()L2_pixelwise = torch.nn.MSELoss()denoise_generator = Network_3D_Unet(in_channels = 1,out_channels = 1,f_maps=opt.fmap,final_sigmoid = True)if torch.cuda.is_available():denoise_generator = denoise_generator.cuda()denoise_generator = nn.DataParallel(denoise_generator, device_ids=range(opt.ngpu))print('\033[1;31mUsing {} GPU for training -----> \033[0m'.format(torch.cuda.device_count()))L2_pixelwise.cuda()L1_pixelwise.cuda()########################################################################################################################optimizer_G = torch.optim.Adam(denoise_generator.parameters(),lr=opt.lr, betas=(opt.b1, opt.b2))########################################################################################################################cuda = True if torch.cuda.is_available() else FalseTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensorprev_time = time.time()########################################################################################################################time_start=time.time()# start trainingfor epoch in range(0, opt.n_epochs):train_data = trainset(name_list, coordinate_list, noise_img_all,stack_index)trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)for iteration, (input, target) in enumerate(trainloader):input=input.cuda()target = target.cuda()real_A=inputreal_B=targetreal_A = Variable(real_A)#print('real_A shape -----> ', real_A.shape)#print('real_B shape -----> ',real_B.shape)fake_B = denoise_generator(real_A)L1_loss = L1_pixelwise(fake_B, real_B)L2_loss = L2_pixelwise(fake_B, real_B)################################################################################################################optimizer_G.zero_grad()# Total lossTotal_loss = 0.5*L1_loss + 0.5*L2_lossTotal_loss.backward()optimizer_G.step()################################################################################################################batches_done = epoch * len(trainloader) + iterationbatches_left = opt.n_epochs * len(trainloader) - batches_donetime_left = datetime.timedelta(seconds=int(batches_left * (time.time() - prev_time)))prev_time = time.time()if iteration%1 == 0:time_end=time.time()print('\r[Epoch %d/%d] [Batch %d/%d] [Total loss: %.2f, L1 Loss: %.2f, L2 Loss: %.2f] [ETA: %s] [Time cost: %.2d s] '% (epoch+1,opt.n_epochs,iteration+1,len(trainloader),Total_loss.item(),L1_loss.item(),L2_loss.item(),time_left,time_end-time_start), end=' ')if (iteration+1)%len(trainloader) == 0:print('\n', end=' ')################################################################################################################# save modelif (iteration + 1) % (len(trainloader)) == 0:model_save_name = pth_path + '//E_' + str(epoch+1).zfill(2) + '_Iter_' + str(iteration+1).zfill(4) + '.pth'if isinstance(denoise_generator, nn.DataParallel):torch.save(denoise_generator.module.state_dict(), model_save_name) # parallelelse:torch.save(denoise_generator.state_dict(), model_save_name) # not parallel
