Showing results for 
Show  only  | Search instead for 
Did you mean: 

Object Oriented Dataset with Python and PyTorch - Part 2: Creation of Dataset Object

Xilinx Employee
Xilinx Employee
1 0 921

This is part two of the Object Oriented Dataset with Python and PyTorch blog series. For Part One, see here.

We defined the MyDataset class in Part One, now let's instantiate a MyDataset object.

This iterable object is the interface with raw data and will be very useful throughout the training process.

In [9]:
mydataset = MyDataset(isValSet_bool = None, raw_data_path = raw_data_path, norm = False, resize = True, newsize = (64, 64))

Below are some examples of usage of this object:

In [10]:
# Example of object operation.
# This invokes the method __getitem__ and gets the label from the 6-th sample
In [11]:
# This prints the comment after the class declaration
'Interface class to raw data, providing the total number of samples in the dataset and a preprocessed item'
In [12]:
# this invokes the method __len__
In [13]:
# This triggers the method __str__
Path of raw data is ./raw_data/data_images/<raw samples>

Importance of having an iterable object

During Training, batches of samples are presented to the model. The iterable mydataset is the key to having a light and high level code.

Two examples of usage of the iterable object are presented below.

Example one:

We can simply get the third sample tensor:

In [14]:
torch.Size([3, 64, 64])

Which is the same as

In [15]:
torch.Size([3, 64, 64])

Example two:

We can parse the images in the folder and remove those that are in black and white:

In [ ]:
# example of access to the dataset: create a new file of labels, removing those images in black&white

if os.path.exists(raw_data_path + '/'+ "labels_new.txt"):
    os.remove(raw_data_path + '/'+ "labels_new.txt")

with open(raw_data_path + '/'+ "labels_new.txt", "a") as myfile:
    for item, info in mydataset:
        if item != None:
            if item.shape[0]==1:
                # os.remove(raw_data_path + '/' + info.SampleName)
                print('C = {}; H = {}; W = {}; info = {}'.format(item.shape[0], item.shape[1], item.shape[2], info))
                #print(info.SampleName + ' ' + str(info.SampleLabel))
                myfile.write(info.SampleName + ' ' + str(info.SampleLabel) + '\n')        
In [ ]:
# find samples with unexpected format
with open(raw_data_path + '/'+ "labels.txt", "a") as myfile:
    for item, info in mydataset:
        if item != None:
            if item.shape[0]!=3:
                # os.remove(raw_data_path + '/' + info.SampleName)
                print('C = {}; H = {}; W = {}; info = {}'.format(item.shape[0], item.shape[1], item.shape[2], info))

Remember to update the cache once the label files are modified:

In [ ]:
if os.path.exists(raw_data_path + '/'+ "labels_new.txt"):
    os.rename(raw_data_path + '/'+ "labels.txt", raw_data_path + '/'+ "labels_orig.txt")
    os.rename(raw_data_path + '/'+ "labels_new.txt", raw_data_path + '/'+ "labels.txt")

def getSampleInfoList(raw_data_path):
    sample_list = []
    with open(str(raw_data_path) + '/labels.txt', "r") as f:
        reader = csv.reader(f, delimiter = ' ')
        for i, row in enumerate(reader):
            imgname = row[0]
            label = int(row[1])
            sample_list.append(DataInfoTuple(imgname, label))
    sample_list.sort(reverse=False, key=myFunc)
    return sample_list

del mydataset
mydataset = MyDataset(isValSet_bool = None, raw_data_path = '../../raw_data/data_images', norm = False)

You can read more about iterable databases in PyTorch here:


Mean value and standard deviation should be calculated from all samples tensors.

If the dataset is small, we can try to manipulate it directly in memory: torch.stack creates a stack of all samples tensors.

The iterable object mydataset allows for an elegant and concise code.
"view" preserves the three channels R, G, B and merges all remaining dimensions into one.
"mean" calculates the mean value for each channel along dimension 1.

See the note in the Appendix about dim usage.

In [16]:
imgs = torch.stack([img_t for img_t, _ in mydataset], dim = 3) 
In [17]:
#im_mean = imgs.view(3, -1).mean(dim=1).tolist()
im_mean = imgs.view(3, -1).mean(dim=1)
tensor([0.4735, 0.4502, 0.4002])
In [18]:
im_std = imgs.view(3, -1).std(dim=1).tolist()
[0.28131285309791565, 0.27447444200515747, 0.2874436378479004]
In [19]:
normalize = transforms.Normalize(mean=[0.4735, 0.4502, 0.4002], std=[0.28131, 0.27447, 0.28744])

# free memory
del imgs

Below we will build the dataset object again, this time with normalization:

In [21]:
mydataset = MyDataset(isValSet_bool = None, raw_data_path = raw_data_path, norm = True, resize = True, newsize = (64, 64))

Because of normalization, tensor values are shifted outside of the range 0..1 and will be clipped.

In [22]:
original ='../../raw_data/data_images/img_00009111.JPEG')

fig, axs = plt.subplots(1, 2, figsize=(10, 3))
axs[0].set_title('clipped tensor')
axs[1].set_title('original PIL image')
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Preprocessing with torchvision.transforms

Now that we have created our own transformation functions or objects (this was an exercise to speed up our learning curve), I would suggest using the Torch module torchvision.transforms:

"this module defines a set of composable function-like objects that can be passed as an argument to a dataset such as torchvision.CIFAR10, and that perform transformations on the data after it is loaded but before it is returned by __getitem__".

The possible transformations are listed below:

In [23]:
from torchvision import transforms

In this example we use the transform to do the following:

1) ToTensor - Convert from a PIL image to tensor and define the output layout as CxHxW
2) Normalize - Normalize the tensor

For the next steps, see Part Three of this series.