【pytorch】图像基本操作

【pytorch】图像基本操作

为了方便进行数据的操作,pytorch团队提供了一个torchvision.transforms包,我们可以用transforms进行以下操作:
  • PIL.Image/numpy.ndarray与Tensor的相互转化;
  • 归一化;
  • 对PIL.Image进行裁剪、缩放等操作。

通常,在使用torchvision.transforms,我们通常使用transforms.Compose将transforms组合在一起。


PIL.Image/numpy.ndarray与Tensor的相互转换

PIL.Image/numpy.ndarray转化为Tensor,常常用在训练模型阶段的数据读取,而Tensor转化为PIL.Image/numpy.ndarray则用在验证模型阶段的数据输出。

我们可以使用 transforms.ToTensor() 将 PIL.Image/numpy.ndarray 数据进转化为torch.FloadTensor,并归一化到[0, 1.0]:

  • 取值范围为[0, 255]的PIL.Image,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloadTensor;
  • 形状为[H, W, C]的numpy.ndarray,转换成形状为[C, H, W],取值范围是[0, 1.0]的torch.FloadTensor。
而transforms.ToPILImage则是将Tensor转化为PIL.Image。如果,我们要将Tensor转化为numpy,只需要使用 .numpy() 即可。如下:
img_path = "./data/img_37.jpg"

# transforms.ToTensor()
transform1 = transforms.Compose([
    transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
    ]
)

##numpy.ndarray
img = cv2.imread(img_path)# 读取图像
img1 = transform1(img) # 归一化到 [0.0,1.0]
print("img1 = ",img1)
# 转化为numpy.ndarray并显示
img_1 = img1.numpy()*255
img_1 = img_1.astype('uint8')
img_1 = np.transpose(img_1, (1,2,0))
cv2.imshow('img_1', img_1)
cv2.waitKey()

##PIL
img = Image.open(img_path).convert('RGB') # 读取图像
img2 = transform1(img) # 归一化到 [0.0,1.0]
print("img2 = ",img2)
#转化为PILImage并显示
img_2 = transforms.ToPILImage()(img2).convert('RGB')
print("img_2 = ",img_2)
img_2.show() 

归一化

归一化对神经网络的训练是非常重要的,那么我们如何归一化到[-1.0, -1.0]呢?只需要将上面的transform1改为如下所示:

transform2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
    ]
)

(1)transforms.Compose就是将transforms组合在一起;

(2)transforms.Normalize使用如下公式进行归一化:

channel=(channel-mean)/std

这样一来,我们的数据中的每个值就变成了[-1,1]的数了。


PIL.Image的缩放裁剪等操作

此外,transforms还提供了裁剪,缩放等操作,以便进行数据增强。下面就看一个随机裁剪的例子,这个例子中,仍然使用 Compose 将 transforms 组合在一起,如下:

# transforms.RandomCrop()
transform4 = transforms.Compose([
    transforms.ToTensor(), 
    transforms.ToPILImage(),
    transforms.RandomCrop((300,300)),
    ]
)

img = Image.open(img_path).convert('RGB')
img3 = transform4(img)
img3.show()

相关代码可以查看:tfygg/pytorch-tutorials



最后,安利一下pytorch中文文档

在各方小伙伴的努力和支持下,pytorch中文文档第一版终于上线啦!!!(鼓掌)文档还有很多小瑕疵,但是大体可以放心使用了~我们遵循快速迭代的原则,所以赶紧上线第一版来接受广大开源社区的意见和建议。欢迎加入我们!

github项目地址:awfssv/pytorch-cn

中文翻译组QQ群:628478868

还有pytorch项目交流群:613523596

欢迎关注!

编辑于 2017-06-18

文章被以下专栏收录