Source code for mzbsuite.skeletons.mzb_skeletons_dataloader

import torch
import numpy as np

# from torchvision import utils
from PIL import Image  # , ImageFilter, ImageDraw, ImageOps

from torch.utils.data import Dataset

from mzbsuite.skeletons.mzb_skeletons_helpers import Denormalize


[docs]class MZBLoader_skels(Dataset): """ Class definition for the dataloader for the macrozoobenthos skeletons dataset. Parameters ---------- im_folder : Path folder path of input images bo_folder : Path folder path of body length skeleton masks he_folder : Path folder path of head length skeleton masks ls_inds : list indices of images to be used for the learning set, optional learning_set : str type of learning set to be used, optional, default: 'all' transforms : torchvision.transforms optional, default: None """ def __init__( self, im_folder, bo_folder, he_folder, ls_inds=[], learning_set="all", transforms=None, ): self.transforms = transforms self.imsize = 224 self.ls_inds = ls_inds self.learning_set = learning_set self.im_folder = im_folder self.bo_folder = bo_folder self.he_folder = he_folder self.denorm = Denormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) self.img_paths, self.mbo_paths, self.mhe_paths, self.inds = self.prepare_data( self.im_folder, self.bo_folder, self.he_folder, ls_inds=self.ls_inds ) def __len__(self): return len(self.img_paths) def __getitem__(self, idx): f = self.img_paths[idx] img = Image.open(f).convert("RGB") # THIS IS A HACK seed = np.random.randint(123456789) # make a seed with numpy generator # THIS IS A HACK np.random.seed(seed) # apply this seed to img tranfsorms torch.manual_seed(seed) # needed for torchvision 0.7 if self.transforms is not None: tensor_image = self.transforms(img) if len(self.mbo_paths) > 0: f = self.mbo_paths[idx] mbo = Image.open(f).convert("RGB") # THIS IS A HACK np.random.seed(seed) # apply this seed to target tranfsorms torch.manual_seed(seed) # needed for torchvision 0.7 if self.transforms is not None: tensor_bmsk = self.transforms(mbo) tensor_bmsk = self.denorm(tensor_bmsk) tensor_bmsk = torch.gt(tensor_bmsk, 0).long() else: tensor_bmsk = torch.zeros_like(tensor_image) if len(self.mhe_paths) > 0: f = self.mhe_paths[idx] mhe = Image.open(f).convert("RGB") # THIS IS A HACK np.random.seed(seed) # apply this seed to target tranfsorms torch.manual_seed(seed) # needed for torchvision 0.7 if self.transforms is not None: tensor_hmsk = self.transforms(mhe) tensor_hmsk = self.denorm(tensor_hmsk) tensor_hmsk = torch.gt(tensor_hmsk, 0).long() else: tensor_hmsk = torch.zeros_like(tensor_image) if (len(self.mbo_paths) > 0) and (len(self.mhe_paths) > 0): # tensor_hmsk[tensor_hmsk != 0] += 1 tensor_mask = torch.stack((tensor_bmsk, tensor_hmsk), dim=0) else: tensor_mask = torch.zeros_like(tensor_image[0, ...]) # tensor_mask = torch.clamp(tensor_mask, min=0, max=2).long() # tensor_mask = torch.LongTensor(tensor_mask) # [None, ...] return tensor_image, tensor_mask, idx
[docs] @staticmethod def prepare_data(im_folder, bo_folder=None, he_folder=None, ls_inds=[]): """ Prepares the data for the dataloader, loads it and returns it as numpy arrays. """ # At least one mask folder needs to exist # assert bo_folder.is_dir() or he_folder.is_dir() # this makes a one folder - one class connection, and prepares data arrays consequently images = np.asarray(sorted(list(im_folder.glob("*.jpg")))) if bo_folder.is_dir(): mbody = np.asarray(sorted(list(bo_folder.glob("*.jpg")))) else: mbody = [] if he_folder.is_dir(): mhead = np.asarray(sorted(list(he_folder.glob("*.jpg")))) else: mhead = [] if len(ls_inds) > 0: return images[ls_inds], mbody[ls_inds], mhead[ls_inds], ls_inds else: return images, mbody, mhead, ls_inds