Pytorch数据读取(Dataset, DataLoader, DataLoaderIter)

Pytorch数据读取(Dataset, DataLoader, DataLoaderIter)

的Pytorch的数据读取非常方便, 可以很容易地实现多线程数据预读. 我个人认为编程难度比TF小很多,而且灵活性也更高. (TF需要把文件名封装成list, 传入string_input_producer, 这样可以得到一个queue; 然后把这个queue给一个WholeFileReader.read(); 再把read()回来的value用decode_jpeg()解码; 然后再用一系列处理去clip, flip等等...)


Pytorch的数据读取主要包含三个类:

  1. Dataset
  2. DataLoader
  3. DataLoaderIter

这三者大致是一个依次封装的关系: 1.被装进2., 2.被装进3.


一. torch.utils.data.Dataset

是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:

  1. __getitem__()
  2. __len__()

第一个最为重要, 即每次怎么读数据. 以图片为例:

    def __getitem__(self, index):
        img_path, label = self.data[index].img_path, self.data[index].label
        img = Image.open(img_path)

        return img, label

值得一提的是, pytorch还提供了很多常用的transform, 在torchvision.transforms 里面, 本文中不多介绍, 我常用的有Resize , RandomCrop , Normalize , ToTensor (这个极为重要, 可以把一个PIL或numpy图片转为torch.Tensor, 但是好像对numpy数组的转换比较受限, 所以这里建议在__getitem__()里面用PIL来读图片, 而不是用skimage.io).

第二个比较简单, 就是返回整个数据集的长度:

    def __len__(self):
        return len(self.data)

二. torch.utils.data.DataLoader

类定义为:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

可以看到, 主要参数有这么几个:

  1. dataset : 即上面自定义的dataset.
  2. collate_fn: 这个函数用来打包batch, 后面详细讲.
  3. num_worker: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据啦.

这个类其实就是下面将要讲的DataLoaderIter的一个框架, 一共干了两件事: 1.定义了一堆成员变量, 到时候赋给DataLoaderIter, 2.然后有一个__iter__() 函数, 把自己 "装进" DataLoaderIter 里面.

def __iter__(self):
        return DataLoaderIter(self)


三. torch.utils.data.dataloader.DataLoaderIter

上面提到, DataLoaderIter就是DataLoaderIter的一个框架, 用来传给DataLoaderIter 一堆参数, 并把自己装进DataLoaderIter 里.

其实到这里就可以满足大多数训练的需求了, 比如

class CustomDataset(Dataset):
   # 自定义自己的dataset

dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)

for data in dataloader:
   # training...

在for 循环里, 总共有三点操作:

  1. 调用了dataloader__iter__() 方法, 产生了一个DataLoaderIter
  2. 反复调用DataLoaderIter__next__()来得到batch, 具体操作就是, 多次调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch. 中间还会涉及到shuffle , 以及sample 的方法等, 这里就不多说了.
  3. 当数据读完后, __next__()抛出一个StopIteration异常, for循环结束, dataloader 失效.

四. 又一层封装...

其实上面三个类已经可以搞定了, 但是我觉得这还不太符合我的需求, 就又写了一个类, 仅供参考

class DataProvider:
    def __init__(self, batch_size, is_cuda):
        self.batch_size = batch_size
        self.dataset = Dataset_triple(self.batch_size,
                                      transform_=transforms.Compose(
                                     [transforms.Scale([224, 224]),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])]),
                                      )
        self.is_cuda = is_cuda  # 是否将batch放到gpu上
        self.dataiter = None
        self.iteration = 0  # 当前epoch的batch数
        self.epoch = 0  # 统计训练了多少个epoch

    def build(self):
        dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, drop_last=True)
        self.dataiter = DataLoaderIter(dataloader)

    def next(self):
        if self.dataiter is None:
            self.build()
        try:
            batch = self.dataiter.next()
            self.iteration += 1

            if self.is_cuda:
                batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
            return batch

        except StopIteration:  # 一个epoch结束后reload
            self.epoch += 1
            self.build()
            self.iteration = 1  # reset and return the 1st batch

            batch = self.dataiter.next()
            if self.is_cuda:
                batch = [batch[0].cuda(), batch[1].cuda(), batch[2].cuda()]
            return batch
编辑于 2017-12-09