import sys,os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
# sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)),'..'))
# sys.path.append(os.path.join(os.path.split(__file__)[0], '..'))
# sys.path.append(os.path.join(os.path.split(__file__)[0], '..'))
import cv2
import torch,os,glob
import numpy as np
import time,yaml
import pandas as pd
# from my_u2net import U2NETP as model
# from my_u2net import U2Net_50M as model
# from my_u2net import U2Net_20M as model
from torch.utils.data import DataLoader
from my_dataloader import MyFolder
from torchvision.transforms import transforms
from torch import nn
import torch.nn.functional as F
import torch.utils.data as data
from utils.my_RandomFlip import MyHorizontalFlip
from utils.my_RandomFlip import MyVerticalFlip
from torchvision.transforms import transforms
import os,glob
import os.path
import random
import numpy as np
from PIL import Image
import imageio
from collections import namedtuple
Cls = namedtuple('cls', ['name', 'id', 'color'])
Clss = [
Cls('backgroud', 0, (0, 0, 0)),
Cls('crop', 1, (255, 255, 255)),
Cls('weed', 2, (216, 67, 82)),
]
def gray_to_color(color_dict, gt):
'''
gt:shape:(h,w),LongTensor,索引值,最多可表示255个类
'''
colorize = np.zeros([len(color_dict),3],'uint8')
for cls,color in color_dict.items():
colorize[cls, :] = list(color)
ims = colorize[gt,:]
ims = ims.reshape([gt.shape[0],gt.shape[1],3])
return ims
def get_color_dic(nt=Clss):
'''
swift nametuple to color dict
:param nt: nametuple
:return:color_dict
'''
color_dict= {}
for cls in nt:
color_dict[cls.id] = cls.color
return color_dict
def pred_directly(img_path,crop_output_save_path):
img = Image.open(img_path)
img_name = os.path.splitext(img_path)[0].split('/')[-1]
img_width = int(0.5*img.size[0])
img_height = int(0.5*img.size[1])
resized_img = img.resize((img_width,img_height))
resized_img = np.array(resized_img)
all_yx = []
for j in range(0, img_height, 256):
for k in range(0, img_width, 256):
x_start, y_start = k, j
x_end, y_end = k + 512, j + 512
if img_width - x_start < 256:
continue
if img_width - y_start < 256:
continue
if (img_width - x_end <= 256) and (k != 0):
x_end = img_width
x_start = img_width - 512
if (img_height - y_end <= 256) and (j != 0):
y_end = img_height
y_start = img_height - 512
print('y_start:{},x_start:{}'.format(y_start,x_start))
all_yx.append([y_start,x_start,y_end,x_end])
for yx in all_yx:
with torch.no_grad():
little_patch = resized_img[yx[0]:yx[2],yx[1]:yx[3]]
little_patch = np.expand_dims(little_patch,0)
little_patch = torch.tensor(little_patch, dtype=torch.float32).cuda() / 255
little_patch = little_patch.permute(0, 3, 1, 2)[:, :3]
d0 = model(little_patch)
# d0 = torch.clamp(d0,0,1)
d0 = d0.permute(0,2,3,1)
d0 = d0[0,:,:,:]
d0 = d0.cpu().detach().numpy()
# out = cv2.cvtColor(d0,cv2.COLOR_RGB2BGR)
save_path = crop_output_save_path + '/' + img_name
if not os.path.exists(save_path):
os.mkdir(save_path)
# out.save(save_path + '/{}_{}.png'.format(yx[0],yx[1]))
# cv2.imwrite(save_path + '/{}_{}.png'.format(yx[0],yx[1]),d0)
np.save(save_path + '/{}_{}.npy'.format(yx[0],yx[1]),d0)
# cv2.imwrite(save_path + '/{}_{}.png'.format(yx[0],yx[1]),out)
# imageio.imwrite(save_path + '/{}_{}.png'.format(yx[0],yx[1]),out)
print('{}_{}_{}'.format(img_name,yx[0],yx[1]))
def merge_hot_pic(img_path,hot_pic_path,mask_save_path,scale):
img_name = os.path.splitext(img_path)[0].split('/')[-1]
hot_imgs = glob.glob(hot_pic_path + '/' + img_name + '/*.npy')
# hot_imgs = glob.glob(hot_pic_path + '/' + img_name + '/*.png')
ori_img = cv2.imread(img_path)
hot_pic_height = int(ori_img.shape[0] * scale)
hot_pic_width = int(ori_img.shape[1] * scale)
img_size = 512
to_img = np.zeros((hot_pic_height + 512 * 2, hot_pic_width + 512 * 2, 3))
to_img_mask = np.zeros((hot_pic_height + 512 * 2, hot_pic_width + 512 * 2, 3)) + 1e-8
for i, each_img_path in enumerate(hot_imgs):
img_y, img_x = int(os.path.splitext(each_img_path)[0].split('/')[-1].split('_')[0]), int(
os.path.splitext(each_img_path)[0].split('/')[-1].split('_')[1])
from_img = np.load(each_img_path)
# from_img = cv2.imread(each_img_path)
# from_img = cv2.cvtColor(from_img,cv2.COLOR_BGR2RGB)
if hot_pic_height < 512:
from_img = from_img[:hot_pic_height,:]
if hot_pic_width < 512:
from_img = from_img[:,:hot_pic_width]
from_img_height,from_img_width = from_img.shape[0],from_img.shape[1]
# from_img = from_img.astype(np.uint16)
roi = to_img[img_y + img_size:(img_y + img_size + from_img_height), img_x + img_size:(img_x + img_size + from_img_width)]
# little_merge_img = np.maximum(from_img,roi)
little_merge_img = from_img + roi
to_img_mask[img_y + img_size:img_y + img_size + from_img_height, img_x + img_size:img_x + img_size + from_img_width] += 1
to_img[img_y + img_size:img_y + img_size + from_img_height, img_x + img_size:img_x + img_size + from_img_width] = little_merge_img
to_img = to_img / to_img_mask
new_to_img = to_img[img_size:hot_pic_height + img_size, img_size:hot_pic_width + img_size]
ori_height,ori_width = ori_img.shape[0],ori_img.shape[1]
new_to_img = cv2.resize(new_to_img,(ori_width,ori_height))
#这里只是之前部分测试集做了pad再scale crop,所以需复原
if 'Roseau' in img_name and ori_height == 1024:
ori_scale_img1 = glob.glob('/media/totem_disk/totem/guozunhu/Competition/acre/Development_Dataset/ori_Roseau/test_Roseau/Mais/Images' + '/{}.png'.format(img_name))
ori_scale_img2 = glob.glob('/media/totem_disk/totem/guozunhu/Competition/acre/Development_Dataset/ori_Roseau/test_Roseau/Haricot/Images' + '/{}.png'.format(img_name))
ori_scale_img_path = (ori_scale_img1 + ori_scale_img2)[0]
ori_scale_img = Image.open(ori_scale_img_path)
new_to_img = new_to_img[(1024-ori_scale_img.size[1])//2:(1024-ori_scale_img.size[1])//2 + ori_scale_img.size[1],:,:]
index_label = np.argmax(new_to_img,axis=2)
color_dict = get_color_dic(Clss)
img = gray_to_color(color_dict, index_label)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
hotpic_save_path = mask_save_path + '/' + img_name + '.png'
cv2.imwrite(hotpic_save_path,img)
print(img_name)
if __name__ == '__main__':
root_path = os.path.dirname(os.path.dirname(__file__))
# 正常预测和合图
if False:
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import segmentation_models_pytorch as smp
model_name = 'Unetres34_bce_scale0.5_pad_fold0_l_softmax_finetune'
scale = 0.5
# model = smp.Unet(classes=3, activation='softmax2d')
weight = '/media/totem_disk/totem/guozunhu/Competition/acre/weight/fold_0/UnetRes34_best_iou.pth'
# model.load_state_dict(torch.load(weight))
model = torch.load(weight)
model.eval()
root_path = glob.glob('Development_Dataset/Test_Dev/*')
all_img_paths = []
for _ in root_path:
all_img_paths += glob.glob(_ + '/Haricot/Images/*')
all_img_paths += glob.glob(_ + '/Mais/Images/*')
# 小图的预测输出保存路径
crop_output_save_path = 'predict/{}/patch_hotpic'.format(model_name)
mask_save_path = 'predict/{}/merge_hotpic'.format(model_name)
os.makedirs(crop_output_save_path,exist_ok=True)
os.makedirs(mask_save_path,exist_ok=True)
for img_path in all_img_paths:
pred_directly(img_path,crop_output_save_path)
merge_hot_pic(img_path,crop_output_save_path,mask_save_path,scale)
# ttach
if True:
import ttach as tta
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import segmentation_models_pytorch as smp
model_name = 'Unetres34_bce_scale0.5_pad_fold0_l_softmax_finetune_tta'
scale = 0.5
weight = '/media/totem_disk/totem/guozunhu/Competition/acre/weight/fold_0/UnetRes34_best_iou.pth'
model = torch.load(weight)
model = tta.SegmentationTTAWrapper(model,tta.aliases.d4_transform(),merge_mode='mean')
model.eval()
root_path = glob.glob('Development_Dataset/Test_Dev/*')
all_img_paths = []
for _ in root_path:
all_img_paths += glob.glob(_ + '/Haricot/Images/*')
all_img_paths += glob.glob(_ + '/Mais/Images/*')
# 小图的预测输出保存路径
crop_output_save_path = 'predict/{}/patch_hotpic'.format(model_name)
mask_save_path = 'predict/{}/merge_hotpic'.format(model_name)
os.makedirs(crop_output_save_path,exist_ok=True)
os.makedirs(mask_save_path,exist_ok=True)
for img_path in all_img_paths:
pred_directly(img_path,crop_output_save_path)
merge_hot_pic(img_path,crop_output_save_path,mask_save_path,scale)