Overall architecture

下面这张图的做法其实就是三个网络,层数递减,输入的分辨率也递减。最后的结果再concat到一起,过一层吧。然后看结果。中间没有任何的交织
Triple-UNet.jpgTriple-UNet2.jpg

What happen?

  1. 怎么对不同的数据都切成patch,然后最后还要组合在一起,然后concat其他俩网络的feature
    a. 每次只输入一个数据,保证分块后的patch最后经过各种操作后还可以组合在一起
    对于这个问题,研究了monai slide_window_inference
    fall_back_tuple() Typically used when user_provided is a tuple of window size provided by the user, default is defined by data, this function returns an updated user_provided with its non-positive
    components replaced by the corresponding components from default.
    dense_patch_slice() Store all slices in list, return all the slice(0:x) in 3 dimension.
    unravel_slice 就是类似返回当前sw_batch_size个索引,每个索引第一部分是在当前的batch中的位置,第二个None(channel),第三四五就是对应的维度切片。
    importance_map默认是全为1,计算过的位置就会加一个map
  2. 现在的一个新问题是,按照monai unet最后输出的是一个2通道的,那么最后咋把三个结果concat呢?
    a. monai UNet通道数最少得有(a, b)
  3. 数据输入设想,是每个数据都crop成(512, 512, 128)大小的数据。这样最浅层的网络只会降采样一次(按照monai UNet的设计,bottom层不会有降低和提高分辨率的操作)
  4. monai UNet网络结构的设计,所有的block都是会过两次卷积。
    a. encoder部分,就先经过步长为2的卷积,下采样,同时还会两倍通道数。然后再经过步长为1的卷积正常操作,这里也不改变通道的数量。
    b. decoder部分(这里和下图里的不太一样)是先concat然后再ConvTranse, 之后再正常卷积。
    c. 因为原文是第一次不降采样之后每次输入block之前都会降采样。这里是最后一个block不降采样,前面每次都降采样。
  5. S层(16, 32) M层(16, 32, 64) L层(16, 32, 64, 128, 256)
  6. 看一下别的模型怎么解决多输入这个问题的。
    a. 还没看,但是用自己的方法解决了。还是可以看看别的方法
  7. 多监督,把其余几个网络的结果进行深监督,然后最后一层赋予更大的权重(jiadong)
    a. 目前的效果不好
  8. 当前的验证集效果波动非常大,How to solve this
    a. coslr,确实写了,尝试降低下限试试
    b. 数据shuffle一下,训练的时候确实是shuffle了
    c. batch_size 太小了,用gn来解决
    d. 关于patch_size 必须是2的幂这件事
    e. 增加BCE+DiceLoss
    zx’ words:
    1. 用group normalization z所有的都用gn, 除了最后的聚合层
    2. 最下层concat一下
    3. 第三个维度强行resize成一个统一大小,把有label的地方crop出来再resize(确定crop是不是出了问题)

Results part

1. number of parameters

Model parameters
UNet( 4.8M
VNet(Standard) 45.6 M
TripleUNet(2,3,5) 5.1M


2. The experiments result record

1. UNet

id max epoch best val dice(epoch) test dice source
1 500 0.8284(312) 0.8255 .
2 with stable lr 0.6917
  1. roi_size = (256, 256, 64) # window size of validation
  2. sw_batch_size = 4 # the batch size to run window slices
  3. val_num = 30
  4. data = PulVessel_lightning(in_dir, batch_size=1, num_workers=4, val_num=val_num,
  5. cache=True, cache_rate=1, predict_num=1)
  6. max_epochs=500,
  7. self._model = UNet(
  8. spatial_dims=3,
  9. in_channels=1,
  10. out_channels=2,
  11. channels=(16, 32, 64, 128, 256),
  12. strides=(2, 2, 2, 2),
  13. num_res_units=2,
  14. norm=Norm.BATCH,
  15. )
  16. self.loss_function = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=False, include_background=False)
  17. self.post_pred = AsDiscrete(argmax=True, to_onehot=2)
  18. self.post_label = AsDiscrete(to_onehot=2)
  19. self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
  20. def configure_optimizers(self):
  21. optimizer = torch.optim.Adam(self._model.parameters(), 0.0001)
  22. lr_scheduler = CosineAnnealingLR(optimizer, eta_min=0.0001/100, T_max=500,
  23. last_epoch=-1, verbose=True)
  24. return [optimizer], [lr_scheduler]
  25. self.train_transforms = Compose(
  26. [
  27. LoadImaged(keys=["vol", "seg"]),
  28. EnsureChannelFirstd(keys=["vol", "seg"]),
  29. ScaleIntensityRanged(keys=['vol'], a_min=-900.0, a_max=200,
  30. b_min=0, b_max=1, clip=True),
  31. # RandSpatialCropd(keys=["vol", "seg"], roi_size=(1000, 1000, 48), random_size=False),
  32. RandCropByPosNegLabeld(keys=["vol", "seg"], label_key='seg', spatial_size=[256, 256, 64], pos=1,
  33. neg=1, num_samples=1),
  34. RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  35. RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  36. ToTensord(keys=["vol", "seg"]),
  37. ]
  38. )
  39. self.val_transforms = Compose(
  40. [
  41. LoadImaged(keys=["vol", "seg"]),
  42. EnsureChannelFirstd(keys=["vol", "seg"]),
  43. ScaleIntensityRanged(keys=['vol'], a_min=-900.0, a_max=200,
  44. b_min=0, b_max=1, clip=True),
  45. # RandSpatialCropd(keys=["vol", "seg"], roi_size=(1000, 1000, 48), random_size=False),
  46. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  47. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  48. ToTensord(keys=["vol", "seg"]),
  49. ]
  50. )
  51. self.test_transforms = Compose(
  52. [
  53. LoadImaged(keys=["vol", "seg"]),
  54. EnsureChannelFirstd(keys=["vol", "seg"]),
  55. ScaleIntensityRanged(keys=['vol'], a_min=-900.0, a_max=200,
  56. b_min=0, b_max=1, clip=True),
  57. # RandSpatialCropd(keys=["vol", "seg"], roi_size=(1000, 1000, 48), random_size=False),
  58. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  59. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  60. ToTensord(keys=["vol", "seg"]),
  61. ]
  62. )
  63. Acurrent epoch: 499 current mean dice: 0.8048
  64. best mean dice: 0.8284 at epoch: 312
  65. testThe mean prediction loss is 0.17454605525539768 The best performance is 0.1278761625289917
  66. the event folder:/hpc/data/home/bme/liujy3/code/PulVessel_Lightning/UNet/logs/default/version_0

2. VNet

id max epoch best val dice(epoch) test dice source
1 500 0.8027(294) 0.78 VNet1
  1. self._model = VNet(
  2. spatial_dims=3,
  3. in_channels=1,
  4. out_channels=2,
  5. ) # standard VNet
  6. self.loss_function = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=False, include_background=False)
  7. self.post_pred = AsDiscrete(argmax=True, to_onehot=2)
  8. self.post_label = AsDiscrete(to_onehot=2)
  9. self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
  10. def configure_optimizers(self):
  11. optimizer = torch.optim.Adam(self._model.parameters(), 0.0001)
  12. lr_scheduler = CosineAnnealingLR(optimizer, eta_min=0.0001/100, T_max=500,
  13. last_epoch=-1, verbose=True)
  14. return [optimizer], [lr_scheduler]
  15. roi_size = (256, 256, 64) # window size of validation
  16. sw_batch_size = 4
  17. val_num = 30
  18. data = PulVessel_lightning(in_dir, batch_size=1, num_workers=4, val_num=val_num,
  19. cache=True, cache_rate=1, predict_num=1) # for bme cluster
  20. self.train_transforms = Compose(
  21. [
  22. LoadImaged(keys=["vol", "seg"]),
  23. EnsureChannelFirstd(keys=["vol", "seg"]),
  24. ScaleIntensityRanged(keys=['vol'], a_min=-900.0, a_max=200,
  25. b_min=0, b_max=1, clip=True),
  26. # RandSpatialCropd(keys=["vol", "seg"], roi_size=(1000, 1000, 48), random_size=False),
  27. RandCropByPosNegLabeld(keys=["vol", "seg"], label_key='seg', spatial_size=[256, 256, 64], pos=1,
  28. neg=1, num_samples=1),
  29. RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  30. RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  31. ToTensord(keys=["vol", "seg"]),
  32. ]
  33. )
  34. self.val_transforms = Compose(
  35. [
  36. LoadImaged(keys=["vol", "seg"]),
  37. EnsureChannelFirstd(keys=["vol", "seg"]),
  38. ScaleIntensityRanged(keys=['vol'], a_min=-900.0, a_max=200,
  39. b_min=0, b_max=1, clip=True),
  40. # RandSpatialCropd(keys=["vol", "seg"], roi_size=(1000, 1000, 48), random_size=False),
  41. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  42. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  43. ToTensord(keys=["vol", "seg"]),
  44. ]
  45. )
  46. self.test_transforms = Compose(
  47. [
  48. LoadImaged(keys=["vol", "seg"]),
  49. EnsureChannelFirstd(keys=["vol", "seg"]),
  50. ScaleIntensityRanged(keys=['vol'], a_min=-900.0, a_max=200,
  51. b_min=0, b_max=1, clip=True),
  52. # RandSpatialCropd(keys=["vol", "seg"], roi_size=(1000, 1000, 48), random_size=False),
  53. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  54. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  55. ToTensord(keys=["vol", "seg"]),
  56. ]
  57. )
  58. Best loss is 0.2252 at epoch 294
  59. dice:0.8027
  60. test: The mean prediction loss is 0.22441603291419246 The best performance is 0.15201616287231445
  61. the event floder:/hpc/data/home/bme/liujy3/code/PulVessel_Lightning/VNet/logs/default/version_6

3. TripleUNet( sdata->lpath, ldata->spath)

id max epoch best val dice(epoch) test dice source
1 350 0.5070(284) 0.57 TripleU
2 293 0.6308(272) 0.63(0.72best)

  1. ```python roi_size = (-1, -1, 64) # window size of validation sw_batch_size = 4 # the batch size to run window slices val_num = 30 # the number of the validation data

data = PulVessel_lightning(in_dir, batch_size=1, num_workers=4, val_num=val_num, cache=True, cache_rate=1, predict_num=1) # for hp # for bme clust

model

self._modelS = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,), num_res_units=2, norm=Norm.BATCH, ) self._modelM = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64), strides=(2, 2,), num_res_units=2, norm=Norm.BATCH, ) self._modelL = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, norm=Norm.BATCH, ) self.final_layer = Convolution( spatial_dims=3, in_channels=6, out_channels=2, strides=1, kernel_size=3, act=Act.PRELU, norm=Norm.BATCH, dropout=0.0, bias=True, conv_only=True, is_transposed=True, )

data

output_s = self._modelS(x) output_m = fake_slide_window(x, 1 / 2, self.sw_batch_size, self._modelM) output_l = fake_slide_window(x, 1 / 4, self.sw_batch_size, self._modelL)

self.train_transforms = Compose( [ LoadImaged(keys=[“vol”, “seg”]), EnsureChannelFirstd(keys=[“vol”, “seg”]), ScaleIntensityRanged(keys=[‘vol’], a_min=-900.0, a_max=200, b_min=0, b_max=1, clip=True), RandSpatialCropd(keys=[“vol”, “seg”], roi_size=(512, 512, 128), random_size=False),

  1. # RandCropByPosNegLabeld(keys=["vol", "seg"], label_key='seg', spatial_size=[512, 512, 128], pos=1,
  2. # neg=1, num_samples=4),
  3. RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  4. RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  5. ToTensord(keys=["vol", "seg"]),
  6. ]
  7. )

self.val_transforms = Compose( [ LoadImaged(keys=[“vol”, “seg”]), EnsureChannelFirstd(keys=[“vol”, “seg”]), ScaleIntensityRanged(keys=[‘vol’], a_min=-900.0, a_max=200, b_min=0, b_max=1, clip=True),

  1. # RandSpatialCropd(keys=["vol", "seg"], roi_size=(1000, 1000, 48), random_size=False),
  2. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  3. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  4. ToTensord(keys=["vol", "seg"]),
  5. ]

)

self.test_transforms = Compose( [ LoadImaged(keys=[“vol”, “seg”]), EnsureChannelFirstd(keys=[“vol”, “seg”]), ScaleIntensityRanged(keys=[‘vol’], a_min=-900.0, a_max=200, b_min=0, b_max=1, clip=True),

  1. # RandSpatialCropd(keys=["vol", "seg"], roi_size=(1000, 1000, 48), random_size=False),
  2. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=1),
  3. # RandFlipd(keys=["vol", "seg"], prob=0.5, spatial_axis=2),
  4. ToTensord(keys=["vol", "seg"]),
  5. ]

)

  1. ![image.png](https://cdn.nlark.com/yuque/0/2022/png/1699458/1647849175596-f373525d-84ad-4a63-9dda-275213824fc1.png#clientId=u8c32a884-48f6-4&crop=0&crop=0&crop=1&crop=1&from=paste&height=399&id=u0d2a1534&margin=%5Bobject%20Object%5D&name=image.png&originHeight=399&originWidth=1119&originalType=binary&ratio=1&rotation=0&showTitle=false&size=64880&status=done&style=none&taskId=uf801cff3-03db-4d5b-a67c-9ffc0d85e92&title=&width=1119)
  2. 2. 加入了多监督。loss =0.4* loss_t +0.2* loss_l +0.2* loss_m +0.2* loss_s
  3. dir:/hpc/data/home/bme/liujy3/code/TripNewUnet/TripleTest/logs/default/version2and version1)<br />![image.png](https://cdn.nlark.com/yuque/0/2022/png/1699458/1648199016065-267bcc4f-48e5-4397-80dd-2dc4717a6386.png#clientId=uac0acc4d-4ebc-4&crop=0&crop=0&crop=1&crop=1&from=paste&height=387&id=u1951fbe5&margin=%5Bobject%20Object%5D&name=image.png&originHeight=387&originWidth=1243&originalType=binary&ratio=1&rotation=0&showTitle=false&size=71816&status=done&style=none&taskId=u547aa280-985d-4d17-99d8-1f23f510729&title=&width=1243)
  4. <a name="dILOG"></a>
  5. ### 4. TripleUNet
  6. | id | max epoch | best val dice(epoch) | test dice | source |
  7. | --- | --- | --- | --- | --- |
  8. | 1 | 263 | 0.8639(216) | 0.8528 | [TripleU](https://www.yuque.com/tuoyu-hy28g/gar920/clgl92?inner=ASMqk) |
  9. | 2 | 237 | 0.8678(184) | 0.8513 | |
  10. | 3 | | 0.8572(245) | | |
  11. | 4 | | 0.8679(216) | | |
  12. 1.
  13. ```python
  14. s_channel=(16, 32),
  15. m_channel=(16, 32, 64),
  16. l_channel=(16, 32, 64, 128, 256),
  17. kernel_size: Union[Sequence[int], int] = 3,
  18. up_kernel_size: Union[Sequence[int], int] = 3,
  19. num_res_units: int = 2,
  20. #model
  21. # s_network
  22. self.down_s_1 = self._get_down_layer(1, self.s_channel[0], 2, norm=(Norm.GROUP, {'num_groups': 4}))
  23. self.bottom_s = self._get_bottom_layer(self.s_channel[0], self.s_channel[1], norm=(Norm.GROUP, {'num_groups': 8}))
  24. self.up_s_1 = self._get_up_layer(self.s_channel[0] + self.s_channel[1], 2, 2, norm=(Norm.GROUP, {'num_groups': 1}),
  25. is_top=True)
  26. # m_network
  27. self.down_m_1 = self._get_down_layer(1, self.m_channel[0], 2, norm=(Norm.GROUP, {'num_groups': 4})) # 1 -> 16
  28. self.down_m_2 = self._get_down_layer(self.m_channel[0], self.m_channel[1], 2,
  29. norm=(Norm.GROUP, {'num_groups': 8})) # 16 -> 32
  30. self.bottom_m = self._get_bottom_layer(self.m_channel[1], self.m_channel[2],
  31. norm=(Norm.GROUP, {'num_groups': 16})) # 32 -> 64
  32. self.up_m_1 = self._get_up_layer(self.m_channel[1] + self.m_channel[2], self.m_channel[0], 2,
  33. norm=(Norm.GROUP, {'num_groups': 8})) # 32+64 -> 16
  34. self.up_m_2 = self._get_up_layer(self.m_channel[0] * 2, 2, 2, norm=(Norm.GROUP, {'num_groups': 1}),
  35. is_top=True) # 16+16 -> 2
  36. # l_network
  37. self.down_l_1 = self._get_down_layer(1, self.l_channel[0], 2, norm=(Norm.GROUP, {'num_groups': 4})) # 1 -> 16
  38. self.down_l_2 = self._get_down_layer(self.l_channel[0], self.l_channel[1], 2,
  39. norm=(Norm.GROUP, {'num_groups': 8})) # 16 -> 32
  40. self.down_l_3 = self._get_down_layer(self.l_channel[1], self.l_channel[2], 2,
  41. norm=(Norm.GROUP, {'num_groups': 16})) # 32 -> 64
  42. self.down_l_4 = self._get_down_layer(self.l_channel[2], self.l_channel[3], 2,
  43. norm=(Norm.GROUP, {'num_groups': 32})) # 64-128
  44. self.bottom_l = self._get_bottom_layer(self.l_channel[3], self.l_channel[4],
  45. norm=(Norm.GROUP, {'num_groups': 32})) # 128-256
  46. self.up_l_1 = self._get_up_layer(self.l_channel[3] + self.l_channel[4], self.l_channel[2], 2,
  47. norm=(Norm.GROUP, {'num_groups': 16})) # 256+128 - 64
  48. self.up_l_2 = self._get_up_layer(self.l_channel[2] * 2, self.l_channel[1], 2,
  49. norm=(Norm.GROUP, {'num_groups': 8})) # 64+64 - 32
  50. self.up_l_3 = self._get_up_layer(self.l_channel[1] * 2, self.l_channel[0], 2,
  51. norm=(Norm.GROUP, {'num_groups': 4})) # 32+32 - 16
  52. self.up_l_4 = self._get_up_layer(self.l_channel[0] * 2, 2, 2, norm=(Norm.GROUP, {'num_groups': 1}),
  53. is_top=True) # 16+16 - 2
  54. self.final_layer = Convolution(
  55. spatial_dims=3,
  56. in_channels=6,
  57. out_channels=2,
  58. strides=1,
  59. kernel_size=3,
  60. act=Act.PRELU,
  61. norm=Norm.INSTANCE,
  62. dropout=0.0,
  63. bias=True,
  64. conv_only=True,
  65. is_transposed=True,
  66. )

loss =0.4 loss_t +0.2 loss_l +0.2 loss_m +0.2 loss_s
其余跟之前一样,就是加了GN,效果好了很多。下一个实验:证明GN的作用,再在新代码上用BN试试。
/hpc/data/home/bme/liujy3/code/mTU/mTU1/logs/default/version_0

2.
将1中的normalization改为了BN,测试看是不是GN的效果提升,但是并不是,效果仍然和上面一样好。现在就是在找原因。

3.
现在为了找出是不是fake slide window 里面每次的batchsize为1的问题,我已经改过一次了。但是效果仍然一样。
将自己写的网络实例化,然后传入fake slide window,但是效果一样

  1. import torch
  2. import torch.nn as nn
  3. from typing import Any, Sequence, Tuple, Union, Optional
  4. from monai.inferers.utils import _get_scan_interval
  5. from monai.data.utils import dense_patch_slices, compute_importance_map, get_valid_patch_size
  6. from monai.utils import BlendMode
  7. from monai.data.utils import dense_patch_slices
  8. from monai.networks.blocks.convolutions import Convolution, ResidualUnit
  9. from monai.networks.layers.factories import Act, Norm
  10. class mTriUNet(torch.nn.Module):
  11. def __init__(
  12. self,
  13. spatial_dims: int = 3,
  14. s_channel=(16, 32),
  15. m_channel=(16, 32, 64),
  16. l_channel=(16, 32, 64, 128, 256),
  17. kernel_size: Union[Sequence[int], int] = 3,
  18. up_kernel_size: Union[Sequence[int], int] = 3,
  19. num_res_units: int = 2,
  20. act: Union[Tuple, str] = Act.PRELU,
  21. norm: Union[Tuple, str] = Norm.BATCH,
  22. dropout: float = 0.0,
  23. bias: bool = True,
  24. dimensions: Optional[int] = None,
  25. ):
  26. super().__init__()
  27. self.dimensions = spatial_dims
  28. self.kernel_size = kernel_size
  29. self.up_kernel_size = up_kernel_size
  30. self.num_res_units = num_res_units
  31. self.act = act
  32. self.norm = norm
  33. self.dropout = dropout
  34. self.bias = bias
  35. self.s_channel = s_channel
  36. self.m_channel = m_channel
  37. self.l_channel = l_channel
  38. # s_network
  39. self.down_s_1 = self._get_down_layer(1, self.s_channel[0], 2, norm=self.norm)
  40. self.bottom_s = self._get_bottom_layer(self.s_channel[0], self.s_channel[1], norm=self.norm)
  41. self.up_s_1 = self._get_up_layer(self.s_channel[0] + self.s_channel[1], 2, 2, norm=self.norm,
  42. is_top=True)
  43. # m_network
  44. self.down_m_1 = self._get_down_layer(1, self.m_channel[0], 2, norm=self.norm) # 1 -> 16
  45. self.down_m_2 = self._get_down_layer(self.m_channel[0], self.m_channel[1], 2,
  46. norm=self.norm) # 16 -> 32
  47. self.bottom_m = self._get_bottom_layer(self.m_channel[1], self.m_channel[2],
  48. norm=self.norm) # 32 -> 64
  49. self.up_m_1 = self._get_up_layer(self.m_channel[1] + self.m_channel[2], self.m_channel[0], 2,
  50. norm=self.norm) # 32+64 -> 16
  51. self.up_m_2 = self._get_up_layer(self.m_channel[0] * 2, 2, 2, norm=self.norm,
  52. is_top=True) # 16+16 -> 2
  53. # l_network
  54. self.down_l_1 = self._get_down_layer(1, self.l_channel[0], 2, norm=self.norm) # 1 -> 16
  55. self.down_l_2 = self._get_down_layer(self.l_channel[0], self.l_channel[1], 2,
  56. norm=self.norm) # 16 -> 32
  57. self.down_l_3 = self._get_down_layer(self.l_channel[1], self.l_channel[2], 2,
  58. norm=self.norm) # 32 -> 64
  59. self.down_l_4 = self._get_down_layer(self.l_channel[2], self.l_channel[3], 2,
  60. norm=self.norm) # 64-128
  61. self.bottom_l = self._get_bottom_layer(self.l_channel[3], self.l_channel[4],
  62. norm=self.norm) # 128-256
  63. self.up_l_1 = self._get_up_layer(self.l_channel[3] + self.l_channel[4], self.l_channel[2], 2,
  64. norm=self.norm) # 256+128 - 64
  65. self.up_l_2 = self._get_up_layer(self.l_channel[2] * 2, self.l_channel[1], 2,
  66. norm=self.norm) # 64+64 - 32
  67. self.up_l_3 = self._get_up_layer(self.l_channel[1] * 2, self.l_channel[0], 2,
  68. norm=self.norm) # 32+32 - 16
  69. self.up_l_4 = self._get_up_layer(self.l_channel[0] * 2, 2, 2, norm=self.norm,
  70. is_top=True) # 16+16 - 2
  71. self.final_layer = Convolution(
  72. spatial_dims=3,
  73. in_channels=6,
  74. out_channels=2,
  75. strides=1,
  76. kernel_size=3,
  77. act=Act.PRELU,
  78. norm=Norm.INSTANCE,
  79. dropout=0.0,
  80. bias=True,
  81. conv_only=True,
  82. is_transposed=True,
  83. )
  84. def _get_down_layer(
  85. self,
  86. in_channels: int,
  87. out_channels: int,
  88. strides: int,
  89. norm: Union[Tuple, str] = Norm.INSTANCE,
  90. is_top=False
  91. ) -> nn.Module:
  92. """
  93. Args:
  94. in_channels: number of input channels.
  95. out_channels: number of output channels.
  96. strides: convolution stride.
  97. is_top: True if this is the top block.
  98. """
  99. mod: nn.Module
  100. if self.num_res_units > 0:
  101. mod = ResidualUnit(
  102. self.dimensions,
  103. in_channels,
  104. out_channels,
  105. strides=strides,
  106. kernel_size=self.kernel_size,
  107. subunits=self.num_res_units,
  108. act=self.act,
  109. norm=norm,
  110. dropout=self.dropout,
  111. bias=self.bias,
  112. )
  113. return mod
  114. mod = Convolution(
  115. self.dimensions,
  116. in_channels,
  117. out_channels,
  118. strides=strides,
  119. kernel_size=self.kernel_size,
  120. act=self.act,
  121. norm=norm,
  122. dropout=self.dropout,
  123. bias=self.bias,
  124. )
  125. return mod
  126. def _get_bottom_layer(
  127. self,
  128. in_channels: int,
  129. out_channels: int,
  130. norm: Union[Tuple, str] = Norm.INSTANCE
  131. ) -> nn.Module:
  132. """
  133. Args:
  134. in_channels: number of input channels.
  135. out_channels: number of output channels.
  136. """
  137. return self._get_down_layer(in_channels, out_channels, 1, norm, False)
  138. def _get_up_layer(
  139. self,
  140. in_channels: int,
  141. out_channels: int,
  142. strides: int,
  143. norm: Union[Tuple, str] = Norm.INSTANCE,
  144. is_top: bool = False
  145. ) -> nn.Module:
  146. """
  147. Args:
  148. in_channels: number of input channels.
  149. out_channels: number of output channels.
  150. strides: convolution stride.
  151. is_top: True if this is the top block.
  152. """
  153. conv: Union[Convolution, nn.Sequential]
  154. conv = Convolution(
  155. self.dimensions,
  156. in_channels,
  157. out_channels,
  158. strides=strides,
  159. kernel_size=self.up_kernel_size,
  160. act=self.act,
  161. norm=norm,
  162. dropout=self.dropout,
  163. bias=self.bias,
  164. conv_only=is_top and self.num_res_units == 0,
  165. is_transposed=True,
  166. )
  167. if self.num_res_units > 0:
  168. ru = ResidualUnit(
  169. self.dimensions,
  170. out_channels,
  171. out_channels,
  172. strides=1,
  173. kernel_size=self.kernel_size,
  174. subunits=1,
  175. act=self.act,
  176. norm=norm,
  177. dropout=self.dropout,
  178. bias=self.bias,
  179. last_conv_only=is_top,
  180. )
  181. conv = nn.Sequential(conv, ru)
  182. return conv
  183. def fake_slide_window(
  184. self,
  185. inputs: torch.Tensor,
  186. ratio,
  187. sw_batch_size: int,
  188. mode: Union[BlendMode, str] = BlendMode.CONSTANT,
  189. sigma_scale: Union[Sequence[float], float] = 0.125,
  190. device: Union[torch.device, str, None] = None,
  191. sw_device: Union[torch.device, str, None] = None,
  192. *args: Any,
  193. **kwargs: Any,
  194. ) -> torch.Tensor:
  195. num_spatial_dims = len(inputs.shape) - 2
  196. if device is None:
  197. device = inputs.device
  198. if sw_device is None:
  199. sw_device = inputs.device
  200. image_size = list(inputs.shape[2:])
  201. batch_size = inputs.shape[0]
  202. roi_size = []
  203. for i in image_size:
  204. i = i * ratio
  205. roi_size.append(int(i))
  206. # =0?
  207. overlap = 0
  208. scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
  209. # Store all slices in list
  210. slices = dense_patch_slices(image_size, roi_size, scan_interval)
  211. num_win = len(slices) # number of windows per image
  212. total_slices = num_win * batch_size # total number of windows
  213. # Create window-level importance map
  214. importance_map = compute_importance_map(
  215. get_valid_patch_size(image_size, roi_size), mode=mode, sigma_scale=sigma_scale, device=device
  216. )
  217. # Perform predictions
  218. output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)
  219. _initialized = False
  220. for slice_g in range(0, total_slices, sw_batch_size):
  221. slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
  222. unravel_slice = [
  223. [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
  224. for idx in slice_range
  225. ]
  226. window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
  227. if ratio == 1/2:
  228. seg_prob = self.m_network(window_data).to(device) # batched patch segmentation
  229. elif ratio == 1/4:
  230. seg_prob = self.s_network(window_data).to(device)
  231. if not _initialized: # init. buffer at the first iteration
  232. output_classes = seg_prob.shape[1]
  233. output_shape = [batch_size, output_classes] + list(image_size)
  234. # allocate memory to store the full output and the count for overlapping parts
  235. output_image = torch.zeros(output_shape, dtype=torch.float32, device=device)
  236. count_map = torch.zeros(output_shape, dtype=torch.float32, device=device)
  237. _initialized = True
  238. # store the result in the proper location of the full output. Apply weights from importance map.
  239. for idx, original_idx in zip(slice_range, unravel_slice):
  240. output_image[original_idx] += importance_map * seg_prob[idx - slice_g]
  241. count_map[original_idx] += importance_map
  242. # account for any overlapping sections
  243. output_image = output_image / count_map
  244. return output_image
  245. def s_network(self, x_s):
  246. down_out_s_1 = self.down_s_1(x_s)
  247. bottom_out_s = self.bottom_s(down_out_s_1)
  248. up_out_s_1 = self.up_s_1(torch.cat([down_out_s_1, bottom_out_s], dim=1))
  249. return up_out_s_1
  250. def m_network(self, x_m):
  251. down_out_m_1 = self.down_m_1(x_m)
  252. down_out_m_2 = self.down_m_2(down_out_m_1)
  253. bottom_out_m = self.bottom_m(down_out_m_2)
  254. up_out_m_1 = self.up_m_1(torch.cat([down_out_m_2, bottom_out_m], dim=1))
  255. up_out_m_2 = self.up_m_2(torch.cat([down_out_m_1, up_out_m_1], dim=1))
  256. return up_out_m_2
  257. def l_network(self, x_l):
  258. down_out_l_1 = self.down_l_1(x_l)
  259. down_out_l_2 = self.down_l_2(down_out_l_1)
  260. down_out_l_3 = self.down_l_3(down_out_l_2)
  261. down_out_l_4 = self.down_l_4(down_out_l_3)
  262. bottom_out_l = self.bottom_l(down_out_l_4)
  263. up_out_l_1 = self.up_l_1(torch.cat([down_out_l_4, bottom_out_l], dim=1))
  264. up_out_l_2 = self.up_l_2(torch.cat([down_out_l_3, up_out_l_1], dim=1))
  265. up_out_l_3 = self.up_l_3(torch.cat([down_out_l_2, up_out_l_2], dim=1))
  266. up_out_l_4 = self.up_l_4(torch.cat([down_out_l_1, up_out_l_3], dim=1))
  267. return up_out_l_4
  268. def forward(self, x):
  269. # input data
  270. # three-path
  271. # s (16, 32)
  272. output_s = self.fake_slide_window(x, 1/4, 1)
  273. # m (16, 32, 64)
  274. output_m = self.fake_slide_window(x, 1/2, 1)
  275. # l (16, 32, 64, 128, 256)
  276. output_l = self.l_network(x)
  277. # concat
  278. output = torch.cat((output_l, output_m, output_s), 1)
  279. # final layer + softmax?
  280. output = self.final_layer(output)
  281. return output, output_l, output_m, output_s

4.
我现在猜测是我的网络的问题,然后我用了原来的UNet和bn,然后用我新的切patch的方法,但是他妈的结果还是一样,bullshit。

  1. import torch
  2. import torch.nn as nn
  3. import pytorch_lightning as pl
  4. from typing import Any, Callable, List, Sequence, Tuple, Union
  5. from monai.networks.layers.factories import Act
  6. from monai.utils import BlendMode
  7. from monai.inferers.utils import _get_scan_interval
  8. from monai.data.utils import dense_patch_slices, compute_importance_map, get_valid_patch_size
  9. from monai.networks.layers import Norm
  10. from monai.networks.blocks.convolutions import Convolution, ResidualUnit
  11. from monai.networks.nets import UNet
  12. def recover_divisible_patch(inputs, image_size, slice_range, unravel_slice):
  13. # only for non-overlap
  14. device = inputs.device
  15. batch_size = 1
  16. output_classes = inputs.shape[1]
  17. output_shape = [batch_size, output_classes] + list(image_size)
  18. # allocate memory to store the full output and the count for overlapping parts
  19. output_image = torch.zeros(output_shape, dtype=torch.float32, device=device)
  20. # store the result in the proper location of the full output. Apply weights from importance map.
  21. for idx, original_idx in zip(slice_range, unravel_slice):
  22. output_image[original_idx] += inputs[idx]
  23. return output_image
  24. def divisible_patch(
  25. inputs: torch.Tensor,
  26. ratio,
  27. device: Union[torch.device, str, None] = None,
  28. *args: Any,
  29. **kwargs: Any,
  30. ):
  31. num_spatial_dims = len(inputs.shape) - 2
  32. if device is None:
  33. device = inputs.device
  34. image_size = list(inputs.shape[2:])
  35. batch_size = inputs.shape[0]
  36. roi_size = []
  37. for i in image_size:
  38. i = i * ratio
  39. roi_size.append(int(i))
  40. overlap = 0
  41. scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
  42. # Store all slices in list
  43. slices = dense_patch_slices(image_size, roi_size, scan_interval)
  44. num_win = len(slices) # number of windows per image
  45. total_slices = num_win * batch_size # total number of windows
  46. slice_range = range(0, total_slices)
  47. unravel_slice = [
  48. [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
  49. for idx in slice_range
  50. ]
  51. # This is end of cut data
  52. window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(device)
  53. return window_data, image_size, slice_range, unravel_slice
  54. class mswTripleUNet(torch.nn.Module):
  55. def __init__(self):
  56. super().__init__()
  57. self._modelS = UNet(
  58. spatial_dims=3,
  59. in_channels=1,
  60. out_channels=2,
  61. channels=(16, 32),
  62. strides=(2,),
  63. num_res_units=2,
  64. norm=Norm.BATCH,
  65. )
  66. self._modelM = UNet(
  67. spatial_dims=3,
  68. in_channels=1,
  69. out_channels=2,
  70. channels=(16, 32, 64),
  71. strides=(2, 2,),
  72. num_res_units=2,
  73. norm=Norm.BATCH,
  74. )
  75. self._modelL = UNet(
  76. spatial_dims=3,
  77. in_channels=1,
  78. out_channels=2,
  79. channels=(16, 32, 64, 128, 256),
  80. strides=(2, 2, 2, 2),
  81. num_res_units=2,
  82. norm=Norm.BATCH,
  83. )
  84. self.final_layer = Convolution(
  85. spatial_dims=3,
  86. in_channels=6,
  87. out_channels=2,
  88. strides=1,
  89. kernel_size=3,
  90. act=Act.PRELU,
  91. norm=Norm.BATCH,
  92. dropout=0.0,
  93. bias=True,
  94. conv_only=True,
  95. is_transposed=True,
  96. )
  97. # the batch size to run window slices
  98. self.sw_batch_size = 1
  99. def forward(self, x):
  100. # input data
  101. x_l = x
  102. x_m, image_size_m, slice_range_m, unravel_slice_m = divisible_patch(x, 1/2)
  103. x_s, image_size_s, slice_range_s, unravel_slice_s = divisible_patch(x, 1/4)
  104. # split into 3 part
  105. output_s = self._modelS(x_s)
  106. output_s = recover_divisible_patch(output_s, image_size_s, slice_range_s, unravel_slice_s)
  107. output_m = self._modelM(x_m)
  108. output_m = recover_divisible_patch(output_m, image_size_m, slice_range_m, unravel_slice_m)
  109. output_l = self._modelL(x_l)
  110. # concat
  111. output = torch.cat((output_s, output_m, output_l), 1)
  112. # final layer + softmax?
  113. output = self.final_layer(output)
  114. return output, output_l, output_m, output_s

5. TripeUNet(543)

id different max epoch best val dice(epoch) test dice source
1 5 4 3 228 0.8613(213) 0.8507
2 最后不叠加loss still 0.8680 335
3 中间融合(add+BN), 叠加loss 231 0.8583 213
4 中间融合(add+GN),叠加loss 250 0.8674 237 0.8552
5 中间融合(cat+GN),叠加loss(TU_ccatgn_new) 253 0.8714 187 0.8586
6 中间融合(cat_gn+GN),叠加loss 227 0.8619 216

第五:cat之后别忘了加normalization,灵感来自attention。

6. for other dataset

id dataset method max epoch best val dice(epoch) test dice source
1 Cornary 中间融合(cat+GN) 337 0.7605 334
2 roi_jy 中间融合(cat+GN) 243 0.8736 89
3
4
5
  1. cornary
  2. jingyang 采用的选择roi的方法,下面是他们的大小

the mean of s1 is 368.1923076923077, the min of s1 is 299, the max of s1 is 434
the mean of s2 is 256.46153846153845, the min of s2 is 199, the max of s2 is 322
the mean of s3 is 251.6769230769231, the min of s3 is 157, the max of s3 is 320
256 128 128
128 64 64
64 32 32
32 16 16
16 8 8

7. bottom重新组合

使中间S和M网络对应的尺寸都缩小一倍,使他们可以组合在一起。

  1. 缩小
    1. 可以使bottom带有stride(之前的bottom_layer都没有stride
    2. 让倒数第二层的步长翻倍(目前采用了这个,简单,对应的up也得翻倍
  2. 重组
    1. 一个问题是加上对batch_size的支持(先就为1这么跑吧 | id | dataset | method | max epoch | best val dice(epoch) | test dice | source | | —- | —- | —- | —- | —- | —- | —- | | 1 | roi_jy | 2a | 260 | 0.8709 255 | | | | 2 |
      |
      | | | | | | 3 | | | | | | | | 4 | | | | | | | | 5 | |
      | | | | |

这个问题就在于最后一维的块太小了,而且采用1b的方法最后的步长太大,很多的信息都没有采集到。
todo: 把roi数据裁大一点。
end: 把margin增大到了16,但是效果还是没有提升,反而val_Dice只有0.54,严重怀疑我的代码有问题。(TU_16Roi-52869)

问题,可能在于最后一层是prelu,考虑设置成sigmoid,这样不用在diceloss里算sigmoid了。至于算dice,一般是在argmax之后算。

8. Multi-Scale(jm)

  1. TU_multiscale. 数据大小分别是(256, 192, 192), (214, 160, 160), (192, 128, 128)。然后大网络->中网络,中网络->小网络。所有数据都是提前先切patch,每个数据都切了不少patch。(问题:这些patch都过一遍吗,还是每次就随机选择呢?应该还是要随机选择,肯定不可能全部过一遍)

    1. as the S the center.<br />v1: three path, and supervised persepctively.<br />v2: three path, crop, simple convlution, and finally supervised by the center part.
id dataset method max epoch patch strategy or number best val dice(epoch) test dice source
1 roi_jy v1
250(randsample) 0.7756(149)
2 roi_jy v2
125 0.7345(107)
3





4





5





9. Tu_trans

id dataset method max epoch patch strategy or number best val dice(epoch) test dice
source
1 roi_jy S2M2L 432
0.7612(387


2 roi_jy L2M2S doing 125



3 roi_jy Lonly2S 149
0.6607(90


4 roi_jy mTUV2L2S 500
0.82 0.77

5







10. 跳连 concat

id dataset method max epoch patch strategy or number best val dice(epoch) test dice
source
1 roi_jy v1 500
0.8410(325


2 roi_jy v2(with background) 500 0 0.8437


3 roi_jy






4 roi_jy






5







如果采用先训练大网络,然后训练其他网络的方法。那么在load进大网络参数后,还更新它吗?如果不更新的话,最后的loss还加权吗?(我觉得需要更新,因为目的是增加多尺度的内容
也可以采用freeze的方法,先freeze部分参数,但是得有初始值,然后