import sys,os
    
    sys.path.append(os.path.dirname(os.path.dirname(__file__)))
    # sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)),'..'))
    # sys.path.append(os.path.join(os.path.split(__file__)[0], '..'))
    # sys.path.append(os.path.join(os.path.split(__file__)[0], '..'))
    
    import cv2
    import torch,os,glob
    import numpy as np
    import time,yaml
    import pandas as pd
    # from my_u2net import U2NETP as model
    # from my_u2net import U2Net_50M as model
    # from my_u2net import U2Net_20M as model
    from torch.utils.data import DataLoader
    from my_dataloader import MyFolder
    from torchvision.transforms import transforms
    from torch import nn
    import torch.nn.functional as F
    import torch.utils.data as data
    
    from utils.my_RandomFlip import MyHorizontalFlip
    from utils.my_RandomFlip import MyVerticalFlip
    
    from torchvision.transforms import transforms
    
    import os,glob
    import os.path
    import random
    import numpy as np
    from PIL import Image
    import imageio
    from collections import namedtuple
    
    Cls = namedtuple('cls', ['name', 'id', 'color'])
    Clss = [
        Cls('backgroud', 0, (0, 0, 0)),
        Cls('crop', 1, (255, 255, 255)),
        Cls('weed', 2, (216, 67, 82)),
    ]
    
    def gray_to_color(color_dict, gt):
        '''
        gt:shape:(h,w),LongTensor,索引值,最多可表示255个类
        '''
        colorize = np.zeros([len(color_dict),3],'uint8')
        for cls,color in color_dict.items():
            colorize[cls, :] = list(color)
        ims = colorize[gt,:]
        ims = ims.reshape([gt.shape[0],gt.shape[1],3])
        return ims
    
    def get_color_dic(nt=Clss):
        '''
        swift nametuple to color dict
        :param nt: nametuple
        :return:color_dict
        '''
        color_dict= {}
        for cls in nt:
            color_dict[cls.id] = cls.color
        return color_dict
    
    
    def pred_directly(img_path,crop_output_save_path):
        img = Image.open(img_path)
        img_name = os.path.splitext(img_path)[0].split('/')[-1]
        img_width = int(0.5*img.size[0])
        img_height = int(0.5*img.size[1])
        resized_img = img.resize((img_width,img_height))
        resized_img = np.array(resized_img)
        all_yx = []
        for j in range(0, img_height, 256):
            for k in range(0, img_width, 256):
                x_start, y_start = k, j
                x_end, y_end = k + 512, j + 512
    
                if img_width - x_start < 256:
                    continue
                if img_width - y_start < 256:
                    continue
    
                if (img_width - x_end <= 256) and (k != 0):
                    x_end = img_width
                    x_start = img_width - 512
    
                if (img_height - y_end <= 256) and (j != 0):
                    y_end = img_height
                    y_start = img_height - 512
                    print('y_start:{},x_start:{}'.format(y_start,x_start))
                all_yx.append([y_start,x_start,y_end,x_end])
        for yx in all_yx:
            with torch.no_grad():
                little_patch = resized_img[yx[0]:yx[2],yx[1]:yx[3]]
                little_patch = np.expand_dims(little_patch,0)
                little_patch = torch.tensor(little_patch, dtype=torch.float32).cuda() / 255
                little_patch = little_patch.permute(0, 3, 1, 2)[:, :3]
                d0 = model(little_patch)
                # d0 = torch.clamp(d0,0,1)
                d0 = d0.permute(0,2,3,1)
                d0 = d0[0,:,:,:]
                d0 = d0.cpu().detach().numpy()
                # out = cv2.cvtColor(d0,cv2.COLOR_RGB2BGR)
    
                save_path = crop_output_save_path + '/' + img_name
                if not os.path.exists(save_path):
                    os.mkdir(save_path)
                # out.save(save_path + '/{}_{}.png'.format(yx[0],yx[1]))
                # cv2.imwrite(save_path + '/{}_{}.png'.format(yx[0],yx[1]),d0)
                np.save(save_path + '/{}_{}.npy'.format(yx[0],yx[1]),d0)
                # cv2.imwrite(save_path + '/{}_{}.png'.format(yx[0],yx[1]),out)
                # imageio.imwrite(save_path + '/{}_{}.png'.format(yx[0],yx[1]),out)
                print('{}_{}_{}'.format(img_name,yx[0],yx[1]))
    
    def merge_hot_pic(img_path,hot_pic_path,mask_save_path,scale):
        img_name = os.path.splitext(img_path)[0].split('/')[-1]
        hot_imgs = glob.glob(hot_pic_path + '/' + img_name + '/*.npy')
        # hot_imgs = glob.glob(hot_pic_path + '/' + img_name + '/*.png')
        ori_img = cv2.imread(img_path)
        hot_pic_height = int(ori_img.shape[0] * scale)
        hot_pic_width = int(ori_img.shape[1] * scale)
        img_size = 512
        to_img = np.zeros((hot_pic_height + 512 * 2, hot_pic_width + 512 * 2, 3))
        to_img_mask = np.zeros((hot_pic_height + 512 * 2, hot_pic_width + 512 * 2, 3)) + 1e-8
        for i, each_img_path in enumerate(hot_imgs):
            img_y, img_x = int(os.path.splitext(each_img_path)[0].split('/')[-1].split('_')[0]), int(
                os.path.splitext(each_img_path)[0].split('/')[-1].split('_')[1])
            from_img = np.load(each_img_path)
            # from_img = cv2.imread(each_img_path)
            # from_img = cv2.cvtColor(from_img,cv2.COLOR_BGR2RGB)
            if hot_pic_height < 512:
                from_img = from_img[:hot_pic_height,:]
            if hot_pic_width < 512:
                from_img = from_img[:,:hot_pic_width]
            from_img_height,from_img_width = from_img.shape[0],from_img.shape[1]
            # from_img = from_img.astype(np.uint16)
            roi = to_img[img_y + img_size:(img_y + img_size + from_img_height), img_x + img_size:(img_x + img_size + from_img_width)]
            # little_merge_img = np.maximum(from_img,roi)
            little_merge_img = from_img + roi
            to_img_mask[img_y + img_size:img_y + img_size + from_img_height, img_x + img_size:img_x + img_size + from_img_width] += 1
            to_img[img_y + img_size:img_y + img_size + from_img_height, img_x + img_size:img_x + img_size + from_img_width] = little_merge_img
        to_img = to_img / to_img_mask
        new_to_img = to_img[img_size:hot_pic_height + img_size, img_size:hot_pic_width + img_size]
        ori_height,ori_width = ori_img.shape[0],ori_img.shape[1]
        new_to_img = cv2.resize(new_to_img,(ori_width,ori_height))
        #这里只是之前部分测试集做了pad再scale crop,所以需复原
        if 'Roseau' in img_name and ori_height == 1024:
            ori_scale_img1 = glob.glob('/media/totem_disk/totem/guozunhu/Competition/acre/Development_Dataset/ori_Roseau/test_Roseau/Mais/Images' + '/{}.png'.format(img_name))
            ori_scale_img2 = glob.glob('/media/totem_disk/totem/guozunhu/Competition/acre/Development_Dataset/ori_Roseau/test_Roseau/Haricot/Images' + '/{}.png'.format(img_name))
            ori_scale_img_path = (ori_scale_img1 + ori_scale_img2)[0]
            ori_scale_img = Image.open(ori_scale_img_path)
            new_to_img = new_to_img[(1024-ori_scale_img.size[1])//2:(1024-ori_scale_img.size[1])//2 + ori_scale_img.size[1],:,:]
        index_label = np.argmax(new_to_img,axis=2)
        color_dict = get_color_dic(Clss)
        img = gray_to_color(color_dict, index_label)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        hotpic_save_path = mask_save_path + '/' + img_name + '.png'
        cv2.imwrite(hotpic_save_path,img)
    
        print(img_name)
    
    
    if __name__ == '__main__':
    
        root_path = os.path.dirname(os.path.dirname(__file__))
    
        # 正常预测和合图
        if False:
            os.environ['CUDA_VISIBLE_DEVICES'] = '3'
            import segmentation_models_pytorch as smp
            model_name = 'Unetres34_bce_scale0.5_pad_fold0_l_softmax_finetune'
            scale = 0.5
            # model = smp.Unet(classes=3, activation='softmax2d')
            weight = '/media/totem_disk/totem/guozunhu/Competition/acre/weight/fold_0/UnetRes34_best_iou.pth'
            # model.load_state_dict(torch.load(weight))
            model = torch.load(weight)
            model.eval()
    
            root_path = glob.glob('Development_Dataset/Test_Dev/*')
            all_img_paths = []
            for _ in root_path:
                all_img_paths += glob.glob(_ + '/Haricot/Images/*')
                all_img_paths += glob.glob(_ + '/Mais/Images/*')
            # 小图的预测输出保存路径
            crop_output_save_path = 'predict/{}/patch_hotpic'.format(model_name)
            mask_save_path = 'predict/{}/merge_hotpic'.format(model_name)
            os.makedirs(crop_output_save_path,exist_ok=True)
            os.makedirs(mask_save_path,exist_ok=True)
            for img_path in all_img_paths:
                pred_directly(img_path,crop_output_save_path)
                merge_hot_pic(img_path,crop_output_save_path,mask_save_path,scale)
    
        # ttach
        if True:
            import ttach as tta
            os.environ['CUDA_VISIBLE_DEVICES'] = '0'
            import segmentation_models_pytorch as smp
            model_name = 'Unetres34_bce_scale0.5_pad_fold0_l_softmax_finetune_tta'
            scale = 0.5
            weight = '/media/totem_disk/totem/guozunhu/Competition/acre/weight/fold_0/UnetRes34_best_iou.pth'
            model = torch.load(weight)
            model = tta.SegmentationTTAWrapper(model,tta.aliases.d4_transform(),merge_mode='mean')
            model.eval()
            root_path = glob.glob('Development_Dataset/Test_Dev/*')
            all_img_paths = []
            for _ in root_path:
                all_img_paths += glob.glob(_ + '/Haricot/Images/*')
                all_img_paths += glob.glob(_ + '/Mais/Images/*')
            # 小图的预测输出保存路径
            crop_output_save_path = 'predict/{}/patch_hotpic'.format(model_name)
            mask_save_path = 'predict/{}/merge_hotpic'.format(model_name)
            os.makedirs(crop_output_save_path,exist_ok=True)
            os.makedirs(mask_save_path,exist_ok=True)
            for img_path in all_img_paths:
                pred_directly(img_path,crop_output_save_path)
                merge_hot_pic(img_path,crop_output_save_path,mask_save_path,scale)