1. import os
    2. import logging
    3. import warnings
    4. import torch
    5. import torch.nn as nn
    6. import torch.nn.functional as F
    7. from torchvision import transforms
    8. from torch.utils.data import DataLoader
    9. from tqdm import tqdm
    10. from PIL import Image
    11. import numpy as np
    12. import torch.utils.data as data
    13. from dataset_fine_02 import dataset
    14. from model.resnet_p2p_02 import ResNet50
    15. from eval_model_02 import eval_turn
    16. from ZYJ_utils.utils_02 import calc_map_k, label2onehot, calc_train_codes, batch_augment, generate_heatmap, aug_augment
    17. from ZYJ_utils.utils_1 import NCESoftmaxLoss
    18. from utils import show_cam_on_image
    19. # settings#################################################
    20. assert torch.cuda.is_available()
    21. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    22. device = torch.device("cuda")
    23. torch.backends.cudnn.benchmark = True
    24. savepath = './visualize/'
    25. weights_path = './training_checkpoint_bit_class/weights_0.pth'
    26. dataset_name = 'cub_bird'
    27. img_path = "bird3.jpg"
    28. labels = torch.tensor([0])
    29. batch_size = 1
    30. os.makedirs(savepath, exist_ok=True)
    31. ToPILImage = transforms.ToPILImage()
    32. MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    33. STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    34. ###########################################################
    35. def main():
    36. logging.basicConfig(
    37. format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
    38. level=logging.INFO)
    39. warnings.filterwarnings("ignore")
    40. # Dataset for testing#################################
    41. if dataset_name == 'cub_bird':
    42. classes = 200
    43. data_dir = './datasets/cub_bird/'
    44. elif dataset_name == 'stanford_dog':
    45. classes = 120
    46. data_dir = './datasets/stanford_dog/'
    47. elif dataset_name == 'aircraft':
    48. classes = 100
    49. data_dir = './datasets/aircraft/'
    50. elif dataset_name == 'vegfru':
    51. classes = 292
    52. data_dir = './datasets/vegfru/'
    53. else:
    54. print('undefined dataset ! ')
    55. base_set = dataset(dataset_name, root_dir=data_dir, train=True)
    56. print('basa_set', len(base_set))
    57. dataloader = {}
    58. dataloader['base'] = data.DataLoader(base_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    59. #########################################################
    60. ##################################
    61. # Initialize model
    62. ##################################
    63. net = ResNet50(32, classes, 0.7)
    64. net.load_state_dict(torch.load(weights_path, map_location=device), strict=False)
    65. net.to(device)
    66. if torch.cuda.device_count() > 1:
    67. net = nn.DataParallel(net)
    68. net.eval()
    69. # load_image########################################################
    70. assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    71. img = Image.open(img_path).convert('RGB')
    72. img_np = transforms.RandomResizedCrop(224)(img)
    73. img_np = np.array(img_np, dtype=np.uint8) # (224, 224, 3)
    74. # [C, H, W]
    75. img_tensor, aug_sample, sample, i, j, h, w, hor_flip = aug_augment(img)
    76. # expand batch dimension
    77. # [C, H, W] -> [N, C, H, W]
    78. input_tensor = torch.unsqueeze(img_tensor, dim=0)
    79. aug_input_tensor = torch.unsqueeze(aug_sample, dim=0)
    80. sample_input_tensor = torch.unsqueeze(sample, dim=0)
    81. X = input_tensor.cuda() # torch.Size([1, 3, 223, 320])
    82. aug_X = aug_input_tensor.cuda() # torch.Size([1, 3, 223, 320])
    83. sample_X = sample_input_tensor.cuda()
    84. ##########################################################
    85. feature_map, y3, code_q, c3, alpha1, alpha2 = net(X)
    86. with torch.no_grad():
    87. drop_images = batch_augment(X, labels, feature_map, mode='drop')
    88. with torch.no_grad():
    89. crop_images, attention_maps = batch_augment(X, labels, feature_map, mode='crop')
    90. with torch.no_grad():
    91. zoom_images = batch_augment(X, labels, feature_map, mode='zoom')
    92. raw_image = X.cpu() * STD + MEAN
    93. sample_X_image = sample_X.cpu() * STD + MEAN
    94. aug_raw_image = aug_X.cpu() * STD + MEAN
    95. crop_images = crop_images.cpu() * STD + MEAN
    96. zoom_images = zoom_images.cpu() * STD + MEAN
    97. drop_images = drop_images.cpu() * STD + MEAN
    98. rimg = ToPILImage(raw_image[0]) # pic should be 2/3 dimensional. Got 4 dimensions.
    99. sampleimg = ToPILImage(sample_X_image[0])
    100. augrimg = ToPILImage(aug_raw_image[0])
    101. crimg = ToPILImage(crop_images[0])
    102. zoimg = ToPILImage(zoom_images[0])
    103. drimg = ToPILImage(drop_images[0])
    104. rimg.save(os.path.join(savepath, '_raw'+img_path))
    105. sampleimg.save(os.path.join(savepath, 'ogimg'+img_path))
    106. augrimg.save(os.path.join(savepath, '_augrimg'+img_path))
    107. crimg.save(os.path.join(savepath, img_path+'_raw_atten_cropp.jpg'))
    108. zoimg.save(os.path.join(savepath, img_path+'_raw_atten_zoom.jpg'))
    109. drimg.save(os.path.join(savepath, img_path+'_raw_atten_drop.jpg'))
    110. print('ok')
    111. if __name__ == '__main__':
    112. main()