【pytorch笔记】数据集的读取与使用

热门标签

,
GentleCP

发表文章数:46

前言

本文主要总结一下使用pytorch过程中几种对数据集的读取与使用,主要领域涉及CVNLP

CV

pytorch中的CV数据集主要包含torchvision提供的预先处理好的数据集,例如MNIST,cifar10和现实生活中的图像数据。这里分别用MNISTkaggle猫狗分类数据为例。

MNIST

torchvision预处理好的数据集都在tochvision.dataset包中,要导入MNIST,使用以下代码:

from torchvision import datasets

下面的代码展示了导入MNIST按照训练集、测试集划分,并按照指定的batch_size返回数据

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def load_data(data_dir, batch_size):

    train_dataset = datasets.MNIST(root=data_dir,  # MNIST存储路径
                                   train=True,  # 是否是训练集
                                   transform=transforms.ToTensor(),  # 转换图像通道顺序,并标准化(将数据压缩到0~1)
                                   download=True)  # 如果目录下没有数据,则下载

    test_dataset = datasets.MNIST(root=data_dir,
                                  train=False,
                                  transform=transforms.ToTensor())

    # get data loader 
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

    test_loader = DataLoader(dataset=test_dataset,
                            batch_size=batch_size,
                            shuffle=False)

    return train_loader, test_loader

train_loader, test_loader = load_data('./data', 64)
for data, target in train_loader:
    print(data.size(),target.size())
    break

执行后程序会自动将MNIST原数据集下载到指定的data_dir下,返回一个dataset对象,打印这个对象信息得到如下图:
【pytorch笔记】数据集的读取与使用
接下来用DataLoader去读取该数据,用于返回每次提供batch_size的数据加载器。打印一个batch的数据结果如下:
【pytorch笔记】数据集的读取与使用

猫狗数据集

上面的数据集是pytorhc为我们处理好的常用数据集,但现实情况下我们通常要对文件(.jpg)形式的数据进行读取与使用,kaggle的猫狗分类数据集就是类似这种情况。

以猫狗分类数据中的train目录下数据为例,文件存储类似如下形式:
【pytorch笔记】数据集的读取与使用

对真实图片数据的读取主要借助torchvision.datasetsImageFolder函数,下面的代码展示了数据的读取:

假设train目录下为所有的图片数据,我先对所有数据按照训练集和验证集进行了划分,最后的目录结构是train/traintrain/val

def load_data(data_dir, batch_size, image_size):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(image_size),  # 对图片大小归一化,
        transforms.ToTensor(),  # 将数据压缩到0-1
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # 将数据归一化到-1,1
    ])

    train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'),
                                         transform=transform)
    val_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'),
                                       transform=transform)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=False)
    return train_loader, val_loader

其中transform主要是对图片数据的一系列操作,具体的含义百度即可找到,这里不多赘述。

NLP

NLP数据的读取较CV就要麻烦一些,因为自然语言文字不像像素那样可以直接转换成数字供计算机读取,需要先建立word2idid2word这样的词与id之间的映射关系,通常将语料库中所有的词取出,一一进行编号(不重复),得到类似下面的映射关系:

word2id: {'数据':1,'今天':2,'睡梦':3,...},id2word则是反向的映射,这个主要是为了快速查询,不一定要有。

除此之外,NLP的数据通常数据长度大小不一,有的段落较长,有的较短,类似不同形状的图片,需要进行截长补短

电影评论情感分类数据集

这里以电影评论情感分类的语料数据为例:

  • 训练集(train.txt)。包含2W条左右中文电影评论,其中正负向评论各1W条左右。

  • 验证集(validation.txt)。包含6K条左右中文电影评论,其中正负向评论各3K条左右。

  • 测试集(test.txt)。包含360条左右中文电影评论,其中正负向评论各180条左右

    文件内容如下图所示,这里的数据已经做了分词,实际情况中中文若没有分词需先进行中文分词。

    【pytorch笔记】数据集的读取与使用

  • word2id:下面的代码实现了对给定txt文件读取内容,形成word2id映射关系

    def get_word2id(word2id_path, data_paths, rebuild_word2id=False):
        """
        获得word2id,
        :param word2id_path: word2id文件存储路径
        :param data_paths: 数据集(训练、验证)集路径
        :param rebuild_word2id: 是否重建word2id
        :return:
        """
        print('loading word2id...')
        if rebuild_word2id or not os.path.exists(word2id_path) or os.path.getsize(word2id_path) == 0:
            word2id = {'_PAD_': 0}
    
            for path in data_paths:
                with open(path, encoding='utf-8') as f:
                    for line in tqdm(f):   # 1  死囚 爱 刽子手 女贼 爱 衙役 我们 爱 你们 难
                        words = line.strip().split()
                        for word in words[1:]:  # words[0] = 1(label)
                            if word not in word2id.keys():
                                word2id[word] = len(word2id)  # 给词编号
                with open(word2id_path, 'wb') as f:
                    pickle.dump(word2id, f)
    
        else:
            with open(word2id_path, 'rb') as f:
                word2id = pickle.load(f)
    
        return word2id
  • word2vec:光有每一个词对应的编号还不行,每个词需要表示成向量的形式,这部分内容需要先对词的数值表示有所理解,查询one-hot编码和word2vec相关资料即可,这里不多赘述,下面代码实现了利用预训练词向量模型实现词到向量的转化。

    预训练的词向量模型即大厂或一些公司提供的大型语料库下每个词的模型表示,例如开心:[0.3, 0.1,...,0.7]。这里使用了wiki_word2vec_50.bin(将单个词表示成size 50的向量)

    def get_word2vec(pretrain_word2vec_path, word2vec_path, word2id, rebuild_word2vec=False):
    """
    利用预训练词向量获得语料库中每个词对应向量
    :param word2vec_path:
    :param word2id:
    :param rebuild_word2vec:
    :return:word_vecs: np.array
    """
        print('loading word2vec...(it may takes long time at first time)')
        if rebuild_word2vec or not os.path.exists(word2vec_path) or os.path.getsize(word2vec_path) == 0:
            num_words = max(word2id.values()) + 1
            model = gensim.models.KeyedVectors.load_word2vec_format(pretrain_word2vec_path, binary=True)
            word2vec = np.array(np.random.uniform(-1., 1., [num_words, model.vector_size]))  # vec size:50
            for word in tqdm(word2id.keys()):
                # 查询每个单词在预训练模型中的向量
                try:
                    word2vec[word2id[word]] = model[word]   #  根据id获取向量,98:[0.35,0.34,...]
                except KeyError:
                    pass
    
                with open(word2vec_path, 'wb') as f:
                    pickle.dump(word2vec,f)
        else:
            with open(word2vec_path,'rb') as f:
                word2vec = pickle.load(f)
    
        return word2vec
  • 获取语料数据:即从原始的txt文件中读取数据,利用word2id将每个词转换成id形式,然后拼凑成一个列表,截长补短。此时返回的相当于类原始dataset

    得到的data类似这样:[5,90,28,...,0] ,由于我们word2vec的size是50,所以该列表包含词的数量也是50(0是补足的词)。

    def get_corpus(file_path, word2id, max_len = 50):
    """
    获取文本数据和标签
    :param file_path:
    :param word2id:
    :param max_len: 寻取语句的最大长度,取前max_len个词
    :return:
    """
    # 'pos':0,  # 正面评论
    # 'neg':1  # 负面评论
        print('loading corpus...')
        contents, labels = [],[]
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f):
                words  = line.strip().split()
                if not words:
                    # 空行忽略
                    continue
                label = int(words[0])  # 第一个是标签
                content = [word2id.get(word, 0) for word in words[1:]]
                # 多的截断,少的补足0
                if len(content) < max_len:
                    content.extend([0] * (max_len-len(content)))
                else:
                    content = content[:max_len]
                contents.append(content)
                labels.append(label)
    
        print('{} :\n length of contents:{}, classes:{}'.format(file_path, len(labels), Counter(labels)))
    
        contents = np.asarray(contents)
        labels = np.asarray(labels)
    
        return contents,labels
  • 生成data_loader:下面跟图像处理的部分就很像了,将上面生成的类datasets转换成tensor形式得到真正的dataset,用到了TensorDataset这个类。

    from torch.utils.data import TensorDataset
    def load_data(contents, labels, batch_size, shuffle = False):
    """
    构造数据集,返回 data_loader
    :param contents: 文本的word2vec表示
    :param labels:
    :return:
    """
        dataset = TensorDataset(torch.from_numpy(contents).type(torch.float),
                                torch.from_numpy(labels).type(torch.long))
        data_loader = DataLoader(dataset= dataset,
                                 batch_size=batch_size,
                                 shuffle=shuffle)
        return data_loader

到此就实现了NLP中数据的读取和使用。

总结

本文主要是笔记的形式,因此有些地方不会进行过的的解释,百度均可查到相应的用法,只是方便日后查询使用,毕竟网络上搜到的各个数据读取方法千奇百怪,还需要各种甄别,目前仅更新了我暂时用到的数据集读取方法,后续随着使用次数的增加会不断更新更多类型数据集的读取与使用方法。

标签:

未经本人允许不得转载!作者:GentleCP, 转载或复制请以 超链接形式 并注明出处 求索
原文地址:《【pytorch笔记】数据集的读取与使用》 发布于2020-06-16

分享到:
赞(2)

评论 抢沙发

评论前必须登录!

  注册



Vieu4.5主题
专业打造轻量级个人企业风格博客主题!专注于前端开发,全站响应式布局自适应模板。
切换注册

登录

忘记密码 ?

您也可以使用第三方帐号快捷登录

Q Q 登 录
微 博 登 录
切换登录

注册