nnUNet V2代码——图像增强(四)
本文阅读的nnU-Net V2图像增强有伽马校正、镜像
各个类内的各个函数的调用关系见前文nnUNet V2代码——图像增强(一)的BasicTransform类
安装batchgeneratorsv2,nnU-Net V2关于图像增强的代码都在这个库中,点击链接,将其clone到本地后,在命令行进入文件夹内,pip install -e . 即可(注意-e后有个点)。
本文目录
- 一 GammaTransform
- 1. __init__函数
- 2. get_parameters函数
- 3. _apply_to_image函数
- 二 MirrorTransform
- 1. __init__函数
- 2. get_parameters函数
- 3. _apply_to_image函数
一 GammaTransform
该类负责伽马校正,继承自ImageOnlyTransform类,只对image施加,seg不施加。
代码在batchgeneratorsv2 \ transforms \ intensity \ intensity.py文件中
1. __init__函数
变量名称 | 含义 |
---|---|
self.gamma | 伽马调整范围 |
self.p_invert_image | 反转图像 |
self.synchronize_channels | 各通道是否施加相同的低分辨率处理 |
self.p_per_channel | 各通道施加伽马调整的概率 |
self.p_retain_stats | 伽马校正前后图像的均值、方差是否保持一致 |
2. get_parameters函数
def get_parameters(self, **data_dict) -> dict:## 获取图像大小shape = data_dict['image'].shape## 确定哪些通道需要伽马校正apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]## 是否保持伽马校正前后图像的均值、方差是否保持一致,nnU-Net V2设置为一致retain_stats = torch.rand(len(apply_to_channel)) < self.p_retain_stats## 是否选择反转图像进行伽马校正,nnU-Net V2会进行两次伽马校正,一次反转,一次不反转invert_image = torch.rand(len(apply_to_channel)) < self.p_invert_image## 各通道同步应用相同的伽马校正if self.synchronize_channels:gamma = torch.Tensor([sample_scalar(self.gamma, image=data_dict['image'], channel=None)] * len(apply_to_channel))## 各通道应用各自的伽马校正else:gamma = torch.Tensor([sample_scalar(self.gamma, image=data_dict['image'], channel=c) for c in apply_to_channel])return {'apply_to_channel': apply_to_channel,'retain_stats': retain_stats,'invert_image': invert_image,'gamma': gamma}
sample_scalar函数见nnUNet V2代码——图像增强(一)的其余函数
3. _apply_to_image函数
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:## 遍历所有需要应用伽马矫正的通道for c, r, i, g in zip(params['apply_to_channel'], params['retain_stats'], params['invert_image'], params['gamma']):## 如果本次反转图像if i:img[c] *= -1## 如果本次保持校正前后图像均值、方差一致if r:mean = torch.mean(img[c])std = torch.std(img[c])## 记录最值minm = torch.min(img[c])rnge = torch.max(img[c]) - minm## 伽马校正img[c] = torch.pow(((img[c] - minm) / torch.clamp(rnge, min=1e-7)), g) * rnge + minmif r:mn_here = torch.mean(img[c])std_here = torch.std(img[c])img[c] -= mn_hereimg[c] *= (std / torch.clamp(std_here, min=1e-7))img[c] += mean## 反转回来if i:img[c] *= -1return img
二 MirrorTransform
该类负责图像镜像
代码在batchgeneratorsv2 \ transforms \ spatial\ mirroring.py文件中
1. __init__函数
## 允许镜像的轴
self.allowed_axes = allowed_axes
2. get_parameters函数
def get_parameters(self, **data_dict) -> dict:## 判断允许镜像的轴是否真的应用镜像,概率固定为0.5axes = [i for i in self.allowed_axes if torch.rand(1) < 0.5]return {'axes': axes}
3. _apply_to_image函数
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:## 没有镜像的轴,直接返回if len(params['axes']) == 0:return img## 这里+1是为了匹配具有通道维度的img## 计算允许镜像的轴时,维度是从0开始的## 而应用镜像的img的维度是从1开始的## img的0维度是通道维度axes = [i + 1 for i in params['axes']]## 调用flip镜像return torch.flip(img, axes)
_apply_to_segmentation函数、_apply_to_regr_target函数和_apply_to_image函数流程一致