import os
import logging
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch.utils.data as data
from dataset_fine_02 import dataset
from model.resnet_p2p_02 import ResNet50
from eval_model_02 import eval_turn
from ZYJ_utils.utils_02 import calc_map_k, label2onehot, calc_train_codes, batch_augment, generate_heatmap, aug_augment
from ZYJ_utils.utils_1 import NCESoftmaxLoss
from utils import show_cam_on_image
# settings#################################################
assert torch.cuda.is_available()
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
savepath = './visualize/'
weights_path = './training_checkpoint_bit_class/weights_0.pth'
dataset_name = 'cub_bird'
img_path = "bird3.jpg"
labels = torch.tensor([0])
batch_size = 1
os.makedirs(savepath, exist_ok=True)
ToPILImage = transforms.ToPILImage()
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
###########################################################
def main():
logging.basicConfig(
format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
level=logging.INFO)
warnings.filterwarnings("ignore")
# Dataset for testing#################################
if dataset_name == 'cub_bird':
classes = 200
data_dir = './datasets/cub_bird/'
elif dataset_name == 'stanford_dog':
classes = 120
data_dir = './datasets/stanford_dog/'
elif dataset_name == 'aircraft':
classes = 100
data_dir = './datasets/aircraft/'
elif dataset_name == 'vegfru':
classes = 292
data_dir = './datasets/vegfru/'
else:
print('undefined dataset ! ')
base_set = dataset(dataset_name, root_dir=data_dir, train=True)
print('basa_set', len(base_set))
dataloader = {}
dataloader['base'] = data.DataLoader(base_set, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
#########################################################
##################################
# Initialize model
##################################
net = ResNet50(32, classes, 0.7)
net.load_state_dict(torch.load(weights_path, map_location=device), strict=False)
net.to(device)
if torch.cuda.device_count() > 1:
net = nn.DataParallel(net)
net.eval()
# load_image########################################################
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path).convert('RGB')
img_np = transforms.RandomResizedCrop(224)(img)
img_np = np.array(img_np, dtype=np.uint8) # (224, 224, 3)
# [C, H, W]
img_tensor, aug_sample, sample, i, j, h, w, hor_flip = aug_augment(img)
# expand batch dimension
# [C, H, W] -> [N, C, H, W]
input_tensor = torch.unsqueeze(img_tensor, dim=0)
aug_input_tensor = torch.unsqueeze(aug_sample, dim=0)
sample_input_tensor = torch.unsqueeze(sample, dim=0)
X = input_tensor.cuda() # torch.Size([1, 3, 223, 320])
aug_X = aug_input_tensor.cuda() # torch.Size([1, 3, 223, 320])
sample_X = sample_input_tensor.cuda()
##########################################################
feature_map, y3, code_q, c3, alpha1, alpha2 = net(X)
with torch.no_grad():
drop_images = batch_augment(X, labels, feature_map, mode='drop')
with torch.no_grad():
crop_images, attention_maps = batch_augment(X, labels, feature_map, mode='crop')
with torch.no_grad():
zoom_images = batch_augment(X, labels, feature_map, mode='zoom')
raw_image = X.cpu() * STD + MEAN
sample_X_image = sample_X.cpu() * STD + MEAN
aug_raw_image = aug_X.cpu() * STD + MEAN
crop_images = crop_images.cpu() * STD + MEAN
zoom_images = zoom_images.cpu() * STD + MEAN
drop_images = drop_images.cpu() * STD + MEAN
rimg = ToPILImage(raw_image[0]) # pic should be 2/3 dimensional. Got 4 dimensions.
sampleimg = ToPILImage(sample_X_image[0])
augrimg = ToPILImage(aug_raw_image[0])
crimg = ToPILImage(crop_images[0])
zoimg = ToPILImage(zoom_images[0])
drimg = ToPILImage(drop_images[0])
rimg.save(os.path.join(savepath, '_raw'+img_path))
sampleimg.save(os.path.join(savepath, 'ogimg'+img_path))
augrimg.save(os.path.join(savepath, '_augrimg'+img_path))
crimg.save(os.path.join(savepath, img_path+'_raw_atten_cropp.jpg'))
zoimg.save(os.path.join(savepath, img_path+'_raw_atten_zoom.jpg'))
drimg.save(os.path.join(savepath, img_path+'_raw_atten_drop.jpg'))
print('ok')
if __name__ == '__main__':
main()