取消
显示结果 
显示  仅  | 搜索替代 
您的意思是: 

利用 Python 和 PyTorch 处理面向对象的数据集 - 第 2 部分:创建数据集对象

yolanda
Moderator
Moderator
0 0 136

BY Giovanni Guasti

注意:本论坛博客所有内容皆来源于Xilinx工程师,如需转载,请写明出处作者及赛灵思论坛链接并发邮件至cncrc@xilinx.com,未经Xilinx及著作权人许可,禁止用作商业用途 


本篇是利用 Python 和 PyTorch 处理面向对象的数据集系列博客的第 2 篇。如需阅读第 1 篇,请参阅此处。

我们在第 1 部分中已定义 MyDataset 类,现在,让我们来例化 MyDataset 对象。

此可迭代对象是与原始数据交互的接口,在整个训练过程中都有巨大作用。

输入 [9]:

 

mydataset = MyDataset(isValSet_bool = None, raw_data_path = raw_data_path, norm = False, resize = True, newsize = (64, 64))

 

以下是该对象的一些使用示例:

输入 [10]:

 

# 对象操作示例。
# 此操作用于调用 method __getitem__ 并从第 6 个样本获取标签
mydataset[6][1]

 

输出 [10]:

 

0

 

输入 [11]:

 

# 此操作用于在类声明后打印注释
MyDataset.__doc__

 

输出 [11]:

 

'Interface class to raw data, providing the total number of samples in the dataset and a preprocessed item'

 

输入 [12]:

 

# 此操作用于调用 method __len__
len(mydataset)

 

输出 [12]:

 

49100

 

输入 [13]:

 

# This triggers the method __str__
print(mydataset)
Path of raw data is ./raw_data/data_images/<raw samples>

 

 

可迭代对象的重要性

训练期间,将向模型提供多批次样本。可迭代的 mydataset 是获得高级轻量代码的关键。

以下提供了可迭代对象的 2 个使用示例。

示例 1:

我们可以直接获取第 3 个样本张量:

输入 [14]:

 

mydataset.__getitem__(3)[0].shape

 

输出 [14]:

 

torch.Size([3, 64, 64])

 

 

与以下操作作用相同

输入 [15]:

 

mydataset[3][0].shape

 

输出 [15]:

 

torch.Size([3, 64, 64])

 

 

示例 2:

我们可以对文件夹中的图像进行解析,并移除黑白图像:

输入 [ ]:

 

# 数据集访问示例:创建 1 个包含标签的新文件,移除黑白图像

if os.path.exists(raw_data_path + '/'+ "labels_new.txt"):
    os.remove(raw_data_path + '/'+ "labels_new.txt")

with open(raw_data_path + '/'+ "labels_new.txt", "a") as myfile:
    for item, info in mydataset:
        if item != None:
            if item.shape[0]==1:
                # os.remove(raw_data_path + '/' + info.SampleName)
                print('C = {}; H = {}; W = {}; info = {}'.format(item.shape[0], item.shape[1], item.shape[2], info))
            else:
                #print(info.SampleName + ' ' + str(info.SampleLabel))
                myfile.write(info.SampleName + ' ' + str(info.SampleLabel) + '\n')   

 

输入 [ ]:

 

# 查找具有非期望格式的样本
with open(raw_data_path + '/'+ "labels.txt", "a") as myfile:
    for item, info in mydataset:
        if item != None:
            if item.shape[0]!=3:
                # os.remove(raw_data_path + '/' + info.SampleName)
                print('C = {}; H = {}; W = {}; info = {}'.format(item.shape[0], item.shape[1], item.shape[2], info))

 

修改标签文件后,请务必更新缓存:

输入 [ ]:

 

if os.path.exists(raw_data_path + '/'+ "labels_new.txt"):
    os.rename(raw_data_path + '/'+ "labels.txt", raw_data_path + '/'+ "labels_orig.txt")
    os.rename(raw_data_path + '/'+ "labels_new.txt", raw_data_path + '/'+ "labels.txt")

@functools.lru_cache(1)
def getSampleInfoList(raw_data_path):
    sample_list = []
    with open(str(raw_data_path) + '/labels.txt', "r") as f:
        reader = csv.reader(f, delimiter = ' ')
        for i, row in enumerate(reader):
            imgname = row[0]
            label = int(row[1])
            sample_list.append(DataInfoTuple(imgname, label))
    sample_list.sort(reverse=False, key=myFunc)
    return sample_list

del mydataset
mydataset = MyDataset(isValSet_bool = None, raw_data_path = '../../raw_data/data_images', norm = False)
len(mydataset)

 

您可通过以下链接阅读了解有关 PyTorch 中的可迭代数据库的更多信息:

https://pytorch.org/docs/stable/data.html

归一化

应对所有样本张量计算平均值和标准差。

如果数据集较小,可以尝试在内存中对其进行直接操作:使用 torch.stack 即可创建 1 个包含所有样本张量的栈。

可迭代对象 mydataset 支持简洁精美的代码。
使用“view”即可保留 R、G 和 B 这 3 个通道,并将其余所有维度合并为 1 个维度。
使用“mean”即可计算维度 1 的每个通道的平均值。

请参阅附件中有关 dim 使用的说明。

输入 [16]:

 

imgs = torch.stack([img_t for img_t, _ in mydataset], dim = 3) 

 

输入 [17]:

 

#im_mean = imgs.view(3, -1).mean(dim=1).tolist()
im_mean = imgs.view(3, -1).mean(dim=1)
im_mean

 

输出 [17]:

 

tensor([0.4735, 0.4502, 0.4002])

 

输入 [18]:

 

im_std = imgs.view(3, -1).std(dim=1).tolist()
im_std

 

输出 [18]:

 

[0.28131285309791565, 0.27447444200515747, 0.2874436378479004]

 

输入 [19]:

 

normalize = transforms.Normalize(mean=[0.4735, 0.4502, 0.4002], std=[0.28131, 0.27447, 0.28744])

# free memory
del imgs

 

下面,我们将再次构建数据集对象,但这次将对此对象进行归一化:

输入 [21]:

 

mydataset = MyDataset(isValSet_bool = None, raw_data_path = raw_data_path, norm = True, resize = True, newsize = (64, 64))

 

由于采用了归一化,因此张量值被转换至范围 0..1 之内,并进行剪切操作。

输入 [22]:

 

original = Image.open('../../raw_data/data_images/img_00009111.JPEG')

fig, axs = plt.subplots(1, 2, figsize=(10, 3))
axs[0].set_title('clipped tensor')
axs[0].imshow(mydataset[5][0].permute(1,2,0))
axs[1].set_title('original PIL image')
axs[1].imshow(original)
plt.show()
将输入数据剪切到含 RGB 数据的 imshow 的有效范围内,以 [0..1] 表示浮点值,或者以 [0..255] 表示整数值。

 

1.png

 

使用 torchvision.transforms 进行预处理

现在,我们已经创建了自己的变换函数或对象(原本用作为加速学习曲线的练习),我建议使用 Torch 模块 torchvision.transforms:

“此模块定义了一组可组合式类函数对象,这些对象可作为实参传递到数据集(如 torchvision.CIFAR10),并在加载数据后 __getitem__ 返回数据之前,对数据执行变换”。

以下列出了可能的变换:

输入 [23]:

 

from torchvision import transforms
dir(transforms)

 

输出 [23]:

 

['CenterCrop',
 'ColorJitter',
 'Compose',
 'FiveCrop',
 'Grayscale',
 'Lambda',
 'LinearTransformation',
 'Normalize',
 'Pad',
 'RandomAffine',
 'RandomApply',
 'RandomChoice',
 'RandomCrop',
 'RandomErasing',
 'RandomGrayscale',
 'RandomHorizontalFlip',
 'RandomOrder',
 'RandomPerspective',
 'RandomResizedCrop',
 'RandomRotation',
 'RandomSizedCrop',
 'RandomVerticalFlip',
 'Resize',
 'Scale',
 'TenCrop',
 'ToPILImage',
 'ToTensor',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 'functional',
 'transforms']

 

 

在此示例中,我们使用变换来执行了以下操作:

1) ToTensor - 从 PIL 图像转换为张量,并将输出格式定义为 CxHxW
2) Normalize - 将张量归一化

如需了解后续步骤,请参阅本系列的 3 部分