This is part two of the Object Oriented Dataset with Python and PyTorch blog series. For Part One, see here.
We defined the MyDataset class in Part One, now let's instantiate a MyDataset object.
This iterable object is the interface with raw data and will be very useful throughout the training process.
mydataset = MyDataset(isValSet_bool = None, raw_data_path = raw_data_path, norm = False, resize = True, newsize = (64, 64))
Below are some examples of usage of this object:
# Example of object operation.
# This invokes the method __getitem__ and gets the label from the 6-th sample
mydataset[6][1]
0
# This prints the comment after the class declaration
MyDataset.__doc__
'Interface class to raw data, providing the total number of samples in the dataset and a preprocessed item'
# this invokes the method __len__
len(mydataset)
49100
# This triggers the method __str__
print(mydataset)
Path of raw data is ./raw_data/data_images/<raw samples>
During Training, batches of samples are presented to the model. The iterable mydataset is the key to having a light and high level code.
Two examples of usage of the iterable object are presented below.
Example one:
We can simply get the third sample tensor:
mydataset.__getitem__(3)[0].shape
torch.Size([3, 64, 64])
Which is the same as
mydataset[3][0].shape
torch.Size([3, 64, 64])
Example two:
We can parse the images in the folder and remove those that are in black and white:
# example of access to the dataset: create a new file of labels, removing those images in black&white
if os.path.exists(raw_data_path + '/'+ "labels_new.txt"):
os.remove(raw_data_path + '/'+ "labels_new.txt")
with open(raw_data_path + '/'+ "labels_new.txt", "a") as myfile:
for item, info in mydataset:
if item != None:
if item.shape[0]==1:
# os.remove(raw_data_path + '/' + info.SampleName)
print('C = {}; H = {}; W = {}; info = {}'.format(item.shape[0], item.shape[1], item.shape[2], info))
else:
#print(info.SampleName + ' ' + str(info.SampleLabel))
myfile.write(info.SampleName + ' ' + str(info.SampleLabel) + '\n')
# find samples with unexpected format
with open(raw_data_path + '/'+ "labels.txt", "a") as myfile:
for item, info in mydataset:
if item != None:
if item.shape[0]!=3:
# os.remove(raw_data_path + '/' + info.SampleName)
print('C = {}; H = {}; W = {}; info = {}'.format(item.shape[0], item.shape[1], item.shape[2], info))
Remember to update the cache once the label files are modified:
if os.path.exists(raw_data_path + '/'+ "labels_new.txt"):
os.rename(raw_data_path + '/'+ "labels.txt", raw_data_path + '/'+ "labels_orig.txt")
os.rename(raw_data_path + '/'+ "labels_new.txt", raw_data_path + '/'+ "labels.txt")
@functools.lru_cache(1)
def getSampleInfoList(raw_data_path):
sample_list = []
with open(str(raw_data_path) + '/labels.txt', "r") as f:
reader = csv.reader(f, delimiter = ' ')
for i, row in enumerate(reader):
imgname = row[0]
label = int(row[1])
sample_list.append(DataInfoTuple(imgname, label))
sample_list.sort(reverse=False, key=myFunc)
return sample_list
del mydataset
mydataset = MyDataset(isValSet_bool = None, raw_data_path = '../../raw_data/data_images', norm = False)
len(mydataset)
You can read more about iterable databases in PyTorch here:
Mean value and standard deviation should be calculated from all samples tensors.
If the dataset is small, we can try to manipulate it directly in memory: torch.stack creates a stack of all samples tensors.
The iterable object mydataset allows for an elegant and concise code.
"view" preserves the three channels R, G, B and merges all remaining dimensions into one.
"mean" calculates the mean value for each channel along dimension 1.
See the note in the Appendix about dim usage.
imgs = torch.stack([img_t for img_t, _ in mydataset], dim = 3)
#im_mean = imgs.view(3, -1).mean(dim=1).tolist()
im_mean = imgs.view(3, -1).mean(dim=1)
im_mean
tensor([0.4735, 0.4502, 0.4002])
im_std = imgs.view(3, -1).std(dim=1).tolist()
im_std
[0.28131285309791565, 0.27447444200515747, 0.2874436378479004]
normalize = transforms.Normalize(mean=[0.4735, 0.4502, 0.4002], std=[0.28131, 0.27447, 0.28744])
# free memory
del imgs
Below we will build the dataset object again, this time with normalization:
mydataset = MyDataset(isValSet_bool = None, raw_data_path = raw_data_path, norm = True, resize = True, newsize = (64, 64))
Because of normalization, tensor values are shifted outside of the range 0..1 and will be clipped.
original = Image.open('../../raw_data/data_images/img_00009111.JPEG')
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
axs[0].set_title('clipped tensor')
axs[0].imshow(mydataset[5][0].permute(1,2,0))
axs[1].set_title('original PIL image')
axs[1].imshow(original)
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Now that we have created our own transformation functions or objects (this was an exercise to speed up our learning curve), I would suggest using the Torch module torchvision.transforms:
"this module defines a set of composable function-like objects that can be passed as an argument to a dataset such as torchvision.CIFAR10, and that perform transformations on the data after it is loaded but before it is returned by __getitem__".
The possible transformations are listed below:
from torchvision import transforms
dir(transforms)
['CenterCrop', 'ColorJitter', 'Compose', 'FiveCrop', 'Grayscale', 'Lambda', 'LinearTransformation', 'Normalize', 'Pad', 'RandomAffine', 'RandomApply', 'RandomChoice', 'RandomCrop', 'RandomErasing', 'RandomGrayscale', 'RandomHorizontalFlip', 'RandomOrder', 'RandomPerspective', 'RandomResizedCrop', 'RandomRotation', 'RandomSizedCrop', 'RandomVerticalFlip', 'Resize', 'Scale', 'TenCrop', 'ToPILImage', 'ToTensor', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'functional', 'transforms']
In this example we use the transform to do the following:
1) ToTensor - Convert from a PIL image to tensor and define the output layout as CxHxW
2) Normalize - Normalize the tensor
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.