nnUNet V2代码——图像增强(三)
本文阅读的nnU-Net V2图像增强有亮度调整、对比度调整、低分辨率调整
各个类内的各个函数的调用关系见前文nnUNet V2代码——图像增强(一)的BasicTransform类
安装batchgeneratorsv2,nnU-Net V2关于图像增强的代码都在这个库中,点击链接,将其clone到本地后,在命令行进入文件夹内,pip install -e . 即可(注意-e后有个点)。
本文目录
- 一 MultiplicativeBrightnessTransform
- 1. __init__函数
- 2. get_parameters函数
- 3. _apply_to_image函数
- 二 ContrastTransform
- 1. __init__函数
- 2. get_parameters函数
- 3. _apply_to_image函数
- 三 SimulateLowResolutionTransform类
- 1. __init__函数
- 2. get_parameters函数
- 3. _apply_to_image函数
一 MultiplicativeBrightnessTransform
该类包含亮度调整,继承自ImageOnlyTransform类,只对image施加,seg不施加。
代码在batchgeneratorsv2 \ transforms \ intensity \ brightness.py文件中
MultiplicativeBrightnessTransform代码比SpatialTransform类、GaussianNoiseTransform类、GaussianBlurTransform类简洁,但代码逻辑一致
1. __init__函数
定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用
2. get_parameters函数
def get_parameters(self, **data_dict) -> dict:## 获取image大小shape = data_dict['image'].shape## 确定哪些通道要施加亮度调整### self.p_per_channel = 1,nnU-Net V2对每个通道都施加亮度调整apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]## 各通道同步施加相同的亮度调整if self.synchronize_channels:multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=None)] * len(apply_to_channel))## 各通道各自施加自己的亮度调整else: ### self.synchronize_channels = False,nnU-Net V2不同步施加multipliers = torch.Tensor([sample_scalar(self.multiplier_range, image=data_dict['image'], channel=c) for c in apply_to_channel])## 收集参数后返回return {'apply_to_channel': apply_to_channel,'multipliers': multipliers}
sample_scalar函数见nnUNet V2代码——图像增强(一)的其余函数
3. _apply_to_image函数
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:## 没有要调整的通道,直接返回if len(params['apply_to_channel']) == 0:return img## 遍历施加亮度调整for c, m in zip(params['apply_to_channel'], params['multipliers']):img[c] *= mreturn img
二 ContrastTransform
该类负责对比度调整,继承自ImageOnlyTransform类,只对image施加,seg不施加。
代码在batchgeneratorsv2 \ transforms \ intensity \ contrast.py文件中
1. __init__函数
定义必要的类内变量,代码清晰,不做粘贴,变量在用到时再介绍作用
2. get_parameters函数
和MultiplicativeBrightnessTransform的get_parameters函数一致
3. _apply_to_image函数
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:if len(params['apply_to_channel']) == 0:return img## 遍历通道for i in range(len(params['apply_to_channel'])):c = params['apply_to_channel'][i]## 获取图像某一通道的平均值mean = img[c].mean()## 是否保留数值范围,nnU-Net V2设置self.preserve_range = Trueif self.preserve_range:minm = img[c].min()maxm = img[c].max()## 对比度调整img[c] -= meanimg[c] *= params['multipliers'][i]img[c] += mean## 是否保留数值范围if self.preserve_range:img[c].clamp_(minm, maxm)return img
三 SimulateLowResolutionTransform类
该类负责施加低分辨率,继承自ImageOnlyTransform类,只对image施加,seg不施加。
代码在batchgeneratorsv2 \ transforms \ spatial\ low_resolution.py文件中
1. __init__函数
变量名称 | 含义 |
---|---|
self.scale | 图像放缩范围 |
self.synchronize_channels | 各通道是否施加相同的低分辨率处理 |
self.synchronize_axes | 各轴是否施加相同的低分辨率处理 |
self.ignore_axes | 某轴不能施加低分辨率处理,与nnUNet V2代码——图像增强(一)的Convert3DTo2DTransform和Convert2DTo3DTransform有关 |
self.allowed_channels | 可能会施加低分辨率处理的通道 |
self.p_per_channel | 某通道施加低分辨率处理的概率 |
self.upmodes | 各维度采样方法 |
self.scale = scale
self.synchronize_channels = synchronize_channels
self.synchronize_axes = synchronize_axes
self.ignore_axes = ignore_axes
self.allowed_channels = allowed_channels
self.p_per_channel = p_per_channelself.upmodes = {1: 'linear',2: 'bilinear',3: 'trilinear'
}
2. get_parameters函数
def get_parameters(self, **data_dict) -> dict:shape = data_dict['image'].shape## nnU-Net V2设置为None,所有通道按概率施加低分辨率处理if self.allowed_channels is None:apply_to_channel = torch.where(torch.rand(shape[0]) < self.p_per_channel)[0]else:apply_to_channel = [i for i in self.allowed_channels if torch.rand(1) < self.p_per_channel]## nnU-Net V2设置为Falseif self.synchronize_channels:## nnU-Net V2设置为Trueif self.synchronize_axes:## 各通道、各轴施加相同的scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=None)] * (len(shape) - 1)] * len(apply_to_channel))else:## 各通道施加相同的,各轴施加各自的scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=None, dim=d) for d in range(len(shape) - 1)]] * len(apply_to_channel))else:if self.synchronize_axes:## 各轴施加相同的,各通道施加各自的scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=None)] * (len(shape) - 1) for c in apply_to_channel])else:## 各通道、各轴施加各自的scales = torch.Tensor([[sample_scalar(self.scale, image=data_dict['image'], channel=c, dim=d) for d in range(len(shape) - 1)] for c in apply_to_channel])## 对忽略的轴单独处理if len(scales) > 0:scales[:, self.ignore_axes] = 1return {'apply_to_channel': apply_to_channel,'scales': scales}
3. _apply_to_image函数
def _apply_to_image(self, img: torch.Tensor, **params) -> torch.Tensor:orig_shape = img.shape[1:]# 注释机翻:我们无法对这些内容进行批处理,因为每个通道的下采样 SHAP 值会有所不同。for c, s in zip(params['apply_to_channel'], params['scales']):## 按照放缩尺度确定下采样后的图像大小new_shape = [round(i * j.item()) for i, j in zip(orig_shape, s)]## 使用某一种最近邻插值进行采样downsampled = interpolate(img[c][None, None], new_shape, mode='nearest-exact')## 还原图像大小img[c] = interpolate(downsampled, orig_shape, mode=self.upmodes[img.ndim - 1])[0, 0]return img