Source code for mzbsuite.classification.mzb_classification_dataloader

import torch
import numpy as np

from torchvision import utils
from PIL import Image  # , ImageFilter, ImageDraw, ImageOps

from torch.utils.data import Dataset


[docs]class MZBLoader(Dataset): """ Class definition for the dataloader for the macrozoobenthos dataset. Parameters ---------- dir_dict : dict dictionary containing the paths to the folders of the dataset ls_inds : list indices of images to be used for the learning set, optional learning_set : str type of learning set to be used, optional, default: 'all' transforms : torchvision.transforms list of transformations to apply to blobs. Optional, default: None glob_pattern : str glob pattern to use for finding images. Optional, default: '*_rgb.*' """ def __init__( self, dir_dict, ls_inds=[], learning_set="all", transforms=None, glob_pattern="*_rgb.*", ): self.transforms = transforms self.imsize = 224 self.ls_inds = ls_inds self.learning_set = learning_set self.glob_pattern = glob_pattern self.img_paths, self.labels, self.inds = self.prepare_data( dir_dict, ls_inds=self.ls_inds ) def __len__(self): return len(self.img_paths) def __getitem__(self, idx): try: with open(self.img_paths[idx], "rb") as f: img = Image.open(f).convert("RGB") label = self.labels[idx] if isinstance(self.transforms, dict): tensor_image = self.transforms[str(label.item())](img) else: tensor_image = self.transforms(img) return tensor_image, label, idx except OSError: return ( torch.zeros((3, self.imsize, self.imsize)), torch.LongTensor([-1]).squeeze(), idx, )
[docs] @staticmethod def prepare_data( dir_dict: dict, ls_inds: list = [], glob_pattern: str = "*_rgb.*" ) -> tuple: """ Prepare data for training and testing, returns image paths, labels and indices Parameters ---------- dir_dict: dict dictionary with keys as class names and values as paths to images ls_inds: list list of indices to be used for training or testing glob_pattern: str glob pattern to use for finding images Returns ------- img_paths: list list of paths to images """ # this makes a one folder - one class connection, and prepares data arrays consequently img_paths = [] labels = [] for i, key in enumerate(dir_dict): if isinstance(dir_dict[key], list): img = [] for sub_dic in dir_dict[key]: img += list(sub_dic.glob(glob_pattern)) else: img = list(dir_dict[key].glob(glob_pattern)) img_paths.extend(img) labels.extend(len(img) * [i]) img_paths = np.asarray(img_paths, dtype=object) labels = np.asarray(labels) labels = torch.LongTensor(labels) if len(ls_inds) < 1: return img_paths, labels, ls_inds return img_paths[ls_inds], labels[ls_inds], ls_inds
[docs]class Denormalize(object): """De-normalizes an image given its mean and standard deviation.""" def __init__(self, mean, std): self.mean = torch.Tensor(mean) self.std = torch.Tensor(std) def __call__(self, tensor): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: Tensor: Normalized image. """ x_n = tensor.mul_(self.std).add_(self.mean) return x_n