素材巴巴 > 程序开发 >

pytorch读取csv数据集

程序开发 2023-09-21 11:16:30

在pytorch中已经包含了许多的内置数据集,我们可以很简单的调用其内置的,但是在现实的过程之中我们往往会使用自己的数据集。这就使得读取自己的数据集并进行训练会有很大的问题。

因此对于csv格式的数据集合,以下图为例

 

每一份csv文件为一个样本,对应的标签数据也是使用csv格式进行存储。

对于这种情况,我们可以使用重写dataset类来解决这个问题,利用迭代的方式依次读取对应的data和label。

代码如下:

class myDataSet(Dataset):def __init__(self, data_dir, label_dir, transform=None):""":param data_dir: 数据文件路径:param label_dir: 标签文件路径:param transform: transform操作"""self.transform = transform# 读文件夹下每个数据文件名称#os.listdir读取文件夹内的文件名称self.file_name = os.listdir(data_dir)# 读标签文件夹下的数据名称self.label_name = os.listdir(label_dir)self.data_path = []self.label_path = []#让每一个文件的路径拼接起来for index in range(len(self.file_name)):self.data_path.append(os.path.join(data_dir,self.file_name[index]))self.label_path.append(os.path.join(label_dir, self.label_name[index]))def __len__(self):# 返回数据集长度return len(self.file_name)def __getitem__(self, index):# 获取每一个数据#读取数据data = pd.read_csv(self.data_path[index],header=None)#读取标签label = pd.read_csv(self.label_path[index],header=None)if self.transform :data = self.transform(data)label = self.transform(label)#转成张量data = torch.tensor(data.values)label = torch.tensor(label.values)return data, label  # 返回数据和标签

重构dataset类之后,读取数据并使用dataloader进行数据的加载 

    data_dir = r"./data/Circle/BV/"label_dir = r"./data/Circle/DDL/"#读取数据集train_dataset = myDataSet(data_dir = data_dir,label_dir = label_dir,)#加载数据集train_iter = DataLoader(train_dataset)
 

成功加载数据集之后就可以构建自己的网络来进行训练。

ps:学生新手,如果有不足之处还希望大家多多批评指正。


标签:

上一篇: Maven学习—Maven环境搭建 下一篇:
素材巴巴 Copyright © 2013-2021 http://www.sucaibaba.com/. Some Rights Reserved. 备案号:备案中。