import osimport loggingimport warningsimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torchvision import transformsfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom PIL import Imageimport numpy as npimport torch.utils.data as datafrom dataset_fine_02 import datasetfrom model.resnet_p2p_02 import ResNet50from eval_model_02 import eval_turnfrom ZYJ_utils.utils_02 import calc_map_k, label2onehot, calc_train_codes, batch_augment, generate_heatmap, aug_augmentfrom ZYJ_utils.utils_1 import NCESoftmaxLossfrom utils import show_cam_on_image# settings#################################################assert torch.cuda.is_available()os.environ['CUDA_VISIBLE_DEVICES'] = '0'device = torch.device("cuda")torch.backends.cudnn.benchmark = Truesavepath = './visualize/'weights_path = './training_checkpoint_bit_class/weights_0.pth'dataset_name = 'cub_bird'img_path = "bird3.jpg"labels = torch.tensor([0])batch_size = 1os.makedirs(savepath, exist_ok=True)ToPILImage = transforms.ToPILImage()MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)###########################################################def main(): logging.basicConfig( format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") # Dataset for testing################################# if dataset_name == 'cub_bird': classes = 200 data_dir = './datasets/cub_bird/' elif dataset_name == 'stanford_dog': classes = 120 data_dir = './datasets/stanford_dog/' elif dataset_name == 'aircraft': classes = 100 data_dir = './datasets/aircraft/' elif dataset_name == 'vegfru': classes = 292 data_dir = './datasets/vegfru/' else: print('undefined dataset ! ') base_set = dataset(dataset_name, root_dir=data_dir, train=True) print('basa_set', len(base_set)) dataloader = {} dataloader['base'] = data.DataLoader(base_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) ######################################################### ################################## # Initialize model ################################## net = ResNet50(32, classes, 0.7) net.load_state_dict(torch.load(weights_path, map_location=device), strict=False) net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.eval() # load_image######################################################## assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path).convert('RGB') img_np = transforms.RandomResizedCrop(224)(img) img_np = np.array(img_np, dtype=np.uint8) # (224, 224, 3) # [C, H, W] img_tensor, aug_sample, sample, i, j, h, w, hor_flip = aug_augment(img) # expand batch dimension # [C, H, W] -> [N, C, H, W] input_tensor = torch.unsqueeze(img_tensor, dim=0) aug_input_tensor = torch.unsqueeze(aug_sample, dim=0) sample_input_tensor = torch.unsqueeze(sample, dim=0) X = input_tensor.cuda() # torch.Size([1, 3, 223, 320]) aug_X = aug_input_tensor.cuda() # torch.Size([1, 3, 223, 320]) sample_X = sample_input_tensor.cuda() ########################################################## feature_map, y3, code_q, c3, alpha1, alpha2 = net(X) with torch.no_grad(): drop_images = batch_augment(X, labels, feature_map, mode='drop') with torch.no_grad(): crop_images, attention_maps = batch_augment(X, labels, feature_map, mode='crop') with torch.no_grad(): zoom_images = batch_augment(X, labels, feature_map, mode='zoom') raw_image = X.cpu() * STD + MEAN sample_X_image = sample_X.cpu() * STD + MEAN aug_raw_image = aug_X.cpu() * STD + MEAN crop_images = crop_images.cpu() * STD + MEAN zoom_images = zoom_images.cpu() * STD + MEAN drop_images = drop_images.cpu() * STD + MEAN rimg = ToPILImage(raw_image[0]) # pic should be 2/3 dimensional. Got 4 dimensions. sampleimg = ToPILImage(sample_X_image[0]) augrimg = ToPILImage(aug_raw_image[0]) crimg = ToPILImage(crop_images[0]) zoimg = ToPILImage(zoom_images[0]) drimg = ToPILImage(drop_images[0]) rimg.save(os.path.join(savepath, '_raw'+img_path)) sampleimg.save(os.path.join(savepath, 'ogimg'+img_path)) augrimg.save(os.path.join(savepath, '_augrimg'+img_path)) crimg.save(os.path.join(savepath, img_path+'_raw_atten_cropp.jpg')) zoimg.save(os.path.join(savepath, img_path+'_raw_atten_zoom.jpg')) drimg.save(os.path.join(savepath, img_path+'_raw_atten_drop.jpg')) print('ok')if __name__ == '__main__': main()