cancel
Showing results for 
Show  only  | Search instead for 
Did you mean: 

Object Oriented Dataset with Python and PyTorch - Part 1: Raw Data and Dataset

gguasti
Xilinx Employee
Xilinx Employee
1 0 1,473

A very common problem in Machine Learning is deciding how best to interface with data.

In this article we present an elegant method to interface with, organize, and then eventually transform the data (preprocessing). We will then cover how to properly feed the model during training procedures.

The PyTorch framework will assist us with this goal. We will also write a few classes from scratch. PyTorch natively provides more complete classes but creating our own class helps us to learn faster.

Credit should be given to the books and articles I have listed in the Reference section of Part Three, in particular to Deep Learning with PyTorch.

Part 1: Raw Data and Dataset

I call all samples that are non yet organized "raw data". I like to define a "dataset" as something ready to be used, that is the raw data with labels and those basic functions and code that allow for an easy use of the raw data information.

Here we want to interface with a simple example of raw data: a folder containing images and labels. However this method can be extended to samples of any nature (it could be pictures, sounds recordings, videos, etc.) and a file of labels.

Each row in the label file describes one sample and related label using the format below:


file_sample_1 label1
file_sample_2 label2
file_sample_3 label3
(...)

When we are ready to make some basic queries (how many samples we have, return the sample number xx, preprocess each sample,etc.) we can claim that we have created, from the raw data set, a dataset.

This method is based on Object Oriented programming and the creation of a "class" dedicated to data handling.

It might seem like overkill for a simple set of images and labels (and as a matter of fact this use case is often handled by creating independent folders for Training, Validating, and Testing). However if we choose a standard interface method, the same method can be reused for multiple different cases saving us time in the future.

 

Handling data in Python

In Python everything is an object: an integer number, a list, a dictionary.

There are several reasons to build a "dataset" object with some standard properties and methods. From my perspective the elegance of the code would be enough to justify this choice, but I understand that this is a matter of tastes. Portability, speed, and code modularity are probably the most important reasons.

I have found other interesting features and advantages of Object Oriented coding and in particular of classes in many examples and coding books which I have  summarized in the list below:

• Classes provide for inheritance
• Inheritance provides for reuse
• Inheritance provides for extension of a data type
• Inheritance allows for polymorphism
• Inheritance is a unique feature of object orientation

In [1]:
import torch
from torchvision import transforms
to_tensor = transforms.ToTensor()

from collections import namedtuple
import functools
import copy
import csv
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import os
import datetime
import torch.optim as optim
 

In our example, all raw samples are stored in a folder. The folder's address is declared in the variable raw_data_path.

In [2]:
raw_data_path = './raw_data/data_images'
 

Building blocks

There are a few functions and classes needed for our dataset interface. The dataset itself is an Object so we will create the MyDataset class with all of the important methods and variables.

We first need to read the label file; then we might want to apply some transformations to the samples, both in its original shape (in this case a PIL image) or in the final tensor format.

The following function is needed to read the label file once and then create a named tuple containing all of the sample names and all related labels.

In memory caching improves performances, but if the label file changes please remember to update the cache content.

In [113]:
DataInfoTuple = namedtuple('Sample','SampleName, SampleLabel')
def myFunc(e):
  return e.SampleLabel

# in memory caching decorator: ref https://dbader.org/blog/python-memoization
@functools.lru_cache(1)
def getSampleInfoList(raw_data_path):
    sample_list = []
    with open(str(raw_data_path) + '/labels.txt', mode = 'r') as f:
        reader = csv.reader(f, delimiter = ' ')
        for i, row in enumerate(reader):
            imgname = row[0]
            label = int(row[1])
            sample_list.append(DataInfoTuple(imgname, label))
    sample_list.sort(reverse=False, key=myFunc)
    # print("DataInfoTouple: samples list length = {}".format(len(sample_list)))
    return sample_list
 

The following class can be useful if we need to transform the PIL image directly.

This class has only one method for the moment: resize. The resize method is able to change the original size of the PIL image and resample it. If we need additional preprocessing (flip, crop, rotate etc.) we will add methods in this class.

Once the PIL image has been pre-processed, it is then converted to a tensor. A further processing step can happen to the tensor.

In the following example we will see both transformations:

In [4]:
class PilTransform():
    """generic transformation of a pil image"""
    def resize(self, img, **kwargs):
        img = img.resize(( kwargs.get('width'), kwargs.get('height')), resample=Image.NEAREST)
        return img
# creation of the object pil_transform, having all powers inherited by the class PilTransform
pil_transform = PilTransform()
 

Below is an example of the class PilTransform in action:

In [5]:
path = raw_data_path + "/img_00000600.JPEG"
print(path)
im1 = Image.open(path, mode='r')  
plt.imshow(im1)
 
./raw_data/data_images/img_00000600.JPEG
Out[5]:
<matplotlib.image.AxesImage at 0x121046f5588>
 
gguasti_0-1600262347734.png

 

In [6]:
im2 = pil_transform.resize(im1, width=128, height=128)
# im2.show()
plt.imshow(im2)
Out[6]:
<matplotlib.image.AxesImage at 0x12104b36358>
 
gguasti_1-1600262347730.png

Finally we define the class realizing the interface to raw data.

The class MyDataset mainly provides two methods:

  • __len__ which gives the number of raw samples.
  • __getitem__ which makes the object iterable and returns the requested preprocessed sample in a tensor shape.

__getitem__  steps:

1) open the sample from the file.
2) preprocess the sample in its original format.
3) transform the sample to a tensor.
4) preprocess the sample in the tensor format.

Preprocessing is added here just as an example.

This class is also ready to operate the tensor normalization (average and standard deviation) that will be useful for a faster training process.

Please note that a PIL image is made up of integers in the range 0-255, while the tensor is a matrix with floating point numbers in the range 0-1.


This class returns a list of two things: in position [0] the tensor, and in position [1] the named couple with fields SampleName and SampleLabel.

In [109]:
class MyDataset():
    """Interface class to raw data, providing the total number of samples in the dataset and a preprocessed item"""
    def __init__(self,
                 isValSet_bool = None,
                 raw_data_path = './',
                 SampleInfoList = DataInfoTuple,norm = False,
                 resize = False,
                 newsize = (32, 32)
                ):
        self.raw_data_path = raw_data_path
        self.SampleInfoList = copy.copy(getSampleInfoList(self.raw_data_path))
        self.isValSet_bool = isValSet_bool
        self.norm = norm
        self.resize = resize
        self.newsize = newsize
        
    def __str__(self):
        return 'Path of raw data is ' + self.raw_data_path + '/' + '<raw samples>'
    def __len__(self):
        return len(self.SampleInfoList)
    def __getitem__(self, ndx):
        SampleInfoList_tup = self.SampleInfoList[ndx]
        filepath = self.raw_data_path + '/' + str(SampleInfoList_tup.SampleName)
        if os.path.exists(filepath):
            img = Image.open(filepath)
            # PIL image preprocess (examples)
            #resize
            if self.resize:
                width, height = img.size
                if (width >= height) & (self.newsize[0] >= self.newsize[1]):
                    img = pil_transform.resize(img, width=self.newsize[0], height=self.newsize[1])
                elif (width >= height) & (self.newsize[0] < self.newsize[1]):
                    img = pil_transform.resize(img, width=self.newsize[1], height=self.newsize[0])
                elif (width < height) & (self.newsize[0] <= self.newsize[1]):
                    img = pil_transform.resize(img, width=self.newsize[0], height=self.newsize[1])
                elif (width < height) & (self.newsize[0] > self.newsize[1]):
                    img = pil_transform.resize(img, width=self.newsize[1], height=self.newsize[0])
                else:
                    print("ERROR")

            # from pil image to tensor
            img_t = to_tensor(img)     
            
            # tensor preprocess (examples)
            #rotation    
            ratio = img_t.shape[1]/img_t.shape[2]
            if ratio > 1:
                img_t = torch.rot90(img_t, 1, [1, 2])
            #normalization requires the knowledge of all tensors
            if self.norm:
                img_t = normalize(img_t)
        
            #return img_t, SampleInfoList_tup  
            return img_t, SampleInfoList_tup.SampleLabel  
        else:
            print('[WARNING] file {} does not exist'.format(str(SampleInfoList_tup.SampleName)))
            return None

For the next steps, see Part Two of this series.