Pytorch 构建自定义数据集,Dataset和Dataloader代码模板

自定义数据集,torch.utils.data.Dataset和Dataloader使用案例。

Pytorch 构建自定义数据集,Dataset和Dataloader代码模板


扫码关注:torchnlp,一起精进成长 ...


Pytorch 构建自定义数据集,Dataset和Dataloader代码模板
通过本文可以了解:Dataset和DataLoader构建自定义数据集和迭代器,有代码模板示例。
作 者丨程旭源
学习笔记
Pytorch 构建自定义数据集,Dataset和Dataloader代码模板


我们做深度学习训练时,加载和处理数据集往往要花大量的时间和精力,本文将结合torch.utils.data中的Dataset和DataLoader介绍如何自定义一个数据集加载模块,并给出代码模板。

  • Dataset:负责根据index读取相应数据并执行预处理(负责处理索引index到样本sample映射的一个类);
  • Dataloader:最顶层的抽象,通过index找出一条数据出来。

深度学习中使用Dataset和Dataloader类的流程:

  • 定义Dataset并实例化;
  • 使用Dataloader加载数据;
  • 循环迭代使用Dataloader加载的数据进行训练或者验证;

Part1Dataset构建数据集

自定义Dataset的基本模板:

from torch.utils.data import Dataset, DataLoader

 
class ExampleDataset(Dataset):

    def __init__(self, flag='train'):
        assert flag in ['train''test''valid']
        self.flag = flag
        # 也可以把数据作为一个参数传递给类,__init__(self, data);
        # self.data = data
        self.data = self.__load_data__()
    
    def __getitem__(self, index):
        # 根据索引返回数据
        # data = self.preprocess(self.data[index]) # 如果需要预处理数据的话
        return self.data[index]
    
    def __len__(self):
        # 返回数据的长度
        return len(self.data)
    
    # 以下可以不在此定义。
    
    # 如果不是直接传入数据data,这里定义一个加载数据的方法
    def __load_data__(self, csv_paths: list):
        # 假如从 csv_paths 中加载数据,可能要遍历文件夹读取文件等,这里忽略
        # 可以拆分训练和验证集并返回train_X, train_Y, valid_X, valid_Y
        pass
      
    def preprocess(self, data):
        # 将data 做一些预处理
        pass


在我们继承Dataset构建自定义数据集时,一般要有这三个函魔法函数,并根据自己的数据集做响应的修改即可:

  • def _ _ init_ _ :

初始化,把数据作为一个参数传给类

  • def _ _ getitem_ _:

根据索引获取样本对(x,y) 索引为(0,len(dataset)-1),根据数据集长度从0开始的索引序列;模型通过这个函数获取一对样本对

  • def _ _ len_ _:

表示数据集的长度,最终训练时用到的数据集的样本个数。


构建好了自定义数据集,如何加载数据做训练和验证呢?

Part2创建迭代器

我们要使用 DataLoader 创建迭代器,同时支持单进程或多进程加载映射样式(map-style)和迭代式(iterable-style)的数据集。可以设置分批读取数据,设置batch_size和是否打乱shuffle等参数。
DataLoader类有几个重要参数简单说明:
1、dataset:要读取的数据集,要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。
2、batch_size:根据具体情况设置即可,比如16/32/64等。
3、shuffle:一般在训练数据中会采用。
4、collate_fn:是用来处理不同情况下的输入dataset的封装,自定义的数据读取。
5、batch_sampler:它batch_size、shuffle等参数是互斥的,一般采用默认。把batch_size个RandomSampler类对象封装成一个,这样就实现了随机选取一个batch的目的。
6、sampler:和shuffle是互斥的,一般默认即可。
7、num_workers:这个参数必须大于等于0,0的话表示数据导入在主进程中进行,大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。
8、pin_memory:(bool, optional),If True, the data loader will copy tensors into CUDA pinned memory before returning them. 处理数据拷贝到GPU的问题。
9、timeout,用来设置数据读取的超时时间,超过这个时间还没读取到数据就会报错。

Dataloader 代码示例

上面我们创建了自定义数据集 ExampleDataset,并根据flag设定训练集、测试集、验证集,这里我们创建DataLoader迭代器,并加载数据:

# 数据集实例
train_dataset = ExampleDataset(flag='train')
valid_dataset = ExampleDataset(flag='valid')

# 数据迭代器
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)

加载数据后,数据和标签是tuple元组的形式,我们可以使用enumerate访问可遍历的数组对象即可:

# 使用enumerate访问可遍历的数组对象
for step, (input, target) in enumerate(train_dataloader):
    print('step is :', step)
    # data, label = input, target
    print('data is {}, label is {}'.format(data, label))

# 或者这样使用
for idx, item in enumerate(train_dataloader):
    print('idx:', idx)
    data, label = item
    print('data:', data)
    print('label:', label)

在使用Pytorch做深度学习项目的时候,必然要使用到Dataset和Dataloader这两个类处理数据,这里给了一般性的数据处理模板和流程,希望对你有用!

THE END
Pytorch 构建自定义数据集,Dataset和Dataloader代码模板
Pytorch 构建自定义数据集,Dataset和Dataloader代码模板

Python高级工程师竟然这样写代码?优雅、简洁、易读!

2023-02-27

Pytorch 构建自定义数据集,Dataset和Dataloader代码模板

深度学习项目,代码结构、风格和习惯,让自己的代码更Pythonic!

2022-12-07

Pytorch 构建自定义数据集,Dataset和Dataloader代码模板

如何搭建一个智能对话机器人?行业应用和问答技术梳理

2023-02-02

Pytorch 构建自定义数据集,Dataset和Dataloader代码模板

开源对话机器人:Rasa3安装和基础入门

2023-02-27

Pytorch 构建自定义数据集,Dataset和Dataloader代码模板


Pytorch 构建自定义数据集,Dataset和Dataloader代码模板星标公众号精彩不错过

Pytorch 构建自定义数据集,Dataset和Dataloader代码模板


Pytorch 构建自定义数据集,Dataset和Dataloader代码模板

ID:torchnlp
◆◆◆◆◆◆◆◆◆◆◆
善利万物而不争

Pytorch 构建自定义数据集,Dataset和Dataloader代码模板

点赞”是喜欢,“在看、分享”是真爱Pytorch 构建自定义数据集,Dataset和Dataloader代码模板

<

原创文章。转载请注明: 作者:meixi 网址: https://www.icnma.com
Like (0)
meixi管理
Previous 18/01/2023 19:54
Next 07/03/2023 21:53

猜你想看