transforms.Compose()
官方文档 : 点击跳转
作用
torchvision.transforms.Compose()
的作用是将多个图像转换操作组合在一起,它接受一个transforms
列表作为参数,该列表包含要组合的转换操作,使用方式类似如下:
from torchvision.transforms import transforms
from PIL import Imagemy_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
original_img = Image.open("./image.jpg")
img = my_transform(original_img)
-
上面例子中,我们让图像依次经过 RandomResizedCrop --> RandomHorizontalFlip --> ToTensor --> Normalize 的处理
-
我们在生成
transforms.Compose()
的对象 my_transform 之后,直接调用my_transform(original_img)
即可处理图像。 而一般类的使用,是需要用对象调用方法来实现功能的,比如: obj.add(a, b) 。 而我们可以直接调用my_transform
对象,不需调用方法,直接将 参数(数据) 传给对象,就能实现图像处理的功能,这是因为transforms.Compose
的内部使用了__call__
方法,我们继续先看下面的内部实现
内部实现
torchvision.transforms.Compose()
的内部实现如下:
class Compose(object):def __init__(self, transforms):self.transforms = transformsdef __call__(self, img):for t in self.transforms:img = t(img)return img
-
还不了解
__call__
方法的使用,请查看这里 -
之后,在学习自定义 transforms的时候,是需要重写
torchvision.transforms.Compose()
的,就需要在上面的内部实现代码基础上,做一些修改。
拓展延伸
1、 为什么我们能直接调用 模型对象,并传入参数: model(input)
,就直接实现 forward 方法中的功能呢 ?为什么不需要调用 forward 方法呢 ?
在nn.Module
类中,实现了__call__方法,然后,在 __call__ 方法中调用的 forward方法
def __call__(self, *input, **kwargs):return self.forward(*input, **kwargs)
所以,我们在生使用 model(input)
的时候,调用的是 nn.Module
类中的 __call__(input) 方法,然后在 __call__(input) 方法中,又调用的我们自己写的 forward 方法
2、torch.nn.Sequential()
和 torchvision.transforms.Compose()
的内部实现 以及使用方式是很类似的,都是将输入的一连串操作一个一个的迭代出来,并且按照顺序进行使用。
torch.nn.Sequential()
的内部实现:
class Sequential(Module):def __init__(self, *args):super(Sequential, self).__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)def forward(self, input):for module in self:input = module(input)return input