mzbsuite module

  • install it

  • main modules

Functions and docstrings

mzb_classification_dataloader

class mzbsuite.classification.mzb_classification_dataloader.Denormalize(mean, std)[source]

De-normalizes an image given its mean and standard deviation.

class mzbsuite.classification.mzb_classification_dataloader.MZBLoader(dir_dict, ls_inds=[], learning_set='all', transforms=None, glob_pattern='*_rgb.*')[source]

Class definition for the dataloader for the macrozoobenthos dataset.

Parameters:
  • dir_dict (dict) – dictionary containing the paths to the folders of the dataset

  • ls_inds (list) – indices of images to be used for the learning set, optional

  • learning_set (str) – type of learning set to be used, optional, default: ‘all’

  • transforms (torchvision.transforms) – list of transformations to apply to blobs. Optional, default: None

  • glob_pattern (str) – glob pattern to use for finding images. Optional, default: ‘_rgb.

static prepare_data(dir_dict: dict, ls_inds: list = [], glob_pattern: str = '*_rgb.*') tuple[source]

Prepare data for training and testing, returns image paths, labels and indices

Parameters:
  • dir_dict (dict) – dictionary with keys as class names and values as paths to images

  • ls_inds (list) – list of indices to be used for training or testing

  • glob_pattern (str) – glob pattern to use for finding images

Returns:

img_paths – list of paths to images

Return type:

list

mzb_classification_pilmodel

class mzbsuite.classification.mzb_classification_pilmodel.MZBModel(data_dir='data/learning_sets/', pretrained_network='resnet50', learning_rate=0.0001, batch_size=32, weight_decay=1e-08, num_workers_loader=4, step_size_decay=5, num_classes=8)[source]

pytorch lightning class definition and model setup

Parameters:
  • data_dir (str) – path to the directory containing the training and validation sets

  • pretrained_network (str) – name of the pretrained network to use

  • learning_rate (float) – learning rate for the optimizer

  • batch_size (int) – batch size for the training and validation dataloaders

  • weight_decay (float) – weight decay for the optimizer

  • num_workers_loader (int) – number of workers for the dataloaders

  • step_size_decay (int) – number of epochs after which the learning rate is decayed

  • num_classes (int) – number of classes to classify

configure_optimizers()[source]

optimiser config plus lr scheduler callback

external_dataloader(data_dir, glob_pattern='*_rgb.*')[source]

external data loader

forward(x)[source]

forward pass return unnormalised logits, normalise when needed

predict_step(batch, batch_idx, dataloader_idx: int | None = None)[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
Parameters:
  • batch – Current batch.

  • batch_idx – Index of current batch.

  • dataloader_idx – Index of the current dataloader.

Returns:

Predicted output

test_step(batch, batch_idx, print_log: str = 'tst')[source]

Test iteration, per batch. return validation function call

train_dataloader(shuffle=True)[source]

training data loader

training_step(batch, batch_idx)[source]

training iteration, per batch

val_dataloader()[source]

validation data loader

validation_step(batch, batch_idx, print_log: str = 'val')[source]

validation iteration, per batch

mzb_skeletons_dataloader

class mzbsuite.skeletons.mzb_skeletons_dataloader.MZBLoader_skels(im_folder, bo_folder, he_folder, ls_inds=[], learning_set='all', transforms=None)[source]

Class definition for the dataloader for the macrozoobenthos skeletons dataset.

Parameters:
  • im_folder (Path) – folder path of input images

  • bo_folder (Path) – folder path of body length skeleton masks

  • he_folder (Path) – folder path of head length skeleton masks

  • ls_inds (list) – indices of images to be used for the learning set, optional

  • learning_set (str) – type of learning set to be used, optional, default: ‘all’

  • transforms (torchvision.transforms) – optional, default: None

static prepare_data(im_folder, bo_folder=None, he_folder=None, ls_inds=[])[source]

Prepares the data for the dataloader, loads it and returns it as numpy arrays.

mzb_skeletons_pilmodel

class mzbsuite.skeletons.mzb_skeletons_pilmodel.MZBModel_skels(data_dir='data/skel_segm/', pretrained_network='efficientnet-b2', learning_rate=0.0001, batch_size=32, weight_decay=1e-08, num_workers_loader=4, step_size_decay=5, num_classes=2)[source]

Pytorch Lightning Module for training the skeleton segmentation model.

Parameters:
  • data_dir (str) – Path to the directory where the data is stored.

  • pretrained_network (str) – Name of the pretrained network to use.

  • learning_rate (float) – Learning rate for the optimizer.

  • batch_size (int) – Batch size for the dataloader.

  • weight_decay (float) – Weight decay for the optimizer.

  • num_workers_loader (int) – Number of workers for the dataloader.

  • step_size_decay (int) – Number of epochs after which the learning rate is decayed.

  • num_classes (int) – Number of classes to predict.

configure_optimizers()[source]

define the optimizer and the learning rate scheduler

external_dataloader(data_dir)[source]

def of custom test dataloader

forward(x)[source]

forward pass of the model, returning logits

predict_step(batch, batch_idx, dataloader_idx: int | None = None)[source]

custom predict iteration per batch, returning probabilities and labels

set_learning_splits()[source]

set the learning splits for training and validation

test_step(batch, batch_idx, print_log: str = 'tst')[source]

test iteration per batch

train_dataloader(shuffle=True)[source]

definition of train dataloader

train_ts_augm_dataloader()[source]

def of a dataloader for training data using test-time data augmentation

training_step(batch, batch_idx)[source]

training iteration per batch

tst_dataloader()[source]

def of custom test dataloader

val_dataloader()[source]

” definition of custom val dataloader

validation_step(batch, batch_idx)[source]

validation iteration per batch

mzb_skeletons_helpers

class mzbsuite.skeletons.mzb_skeletons_helpers.Denormalize(mean, std)[source]

Denormalize a tensor image with mean and standard deviation, for plotting purposes.

mzbsuite.skeletons.mzb_skeletons_helpers.get_endpoints(skeleton: ndarray) List[Tuple[int, int]][source]

Given a skeletonised image, it will give the coordinates of the endpoints of the skeleton.

Parameters:

skeleton (numpy.ndarray) – The skeletonised image to detect the endpoints of

Returns:

endpoints – List of 2-tuples (x,y) containing the intersection coordinates

Return type:

list

mzbsuite.skeletons.mzb_skeletons_helpers.get_intersections(skeleton: ndarray) List[Tuple[int, int]][source]

Given a skeletonised image, it will give the coordinates of the intersections of the skeleton.

Parameters:

skeleton (np.ndarray) – Binary image of the skeleton

Returns:

intersections – List of 2-tuples (x,y) containing the intersection coordinates

Return type:

list

mzbsuite.skeletons.mzb_skeletons_helpers.neighbours(x, y, image)[source]

Return 8-neighbours of image point P1(x,y), in a clockwise order

Parameters:
  • x (int) – x-coordinate of the point

  • y (int) – y-coordinate of the point

  • image (numpy.ndarray) – The image to find the neighbours of

Returns:

_ – List of 8-neighbours of the point in the image

Return type:

list

mzbsuite.skeletons.mzb_skeletons_helpers.paint_image(image: ndarray, mask: array, color: List[float]) ndarray[source]

Given an input image, a binary mask indicating where to paint, and a color to use, returns a new image where the pixels within the mask are colored with the specified color.

Parameters:
  • (np.ndarray) (image) –

  • (np.array) (mask) –

  • (List[float]) (color) –

Returns:

rgb_fi (np.ndarray)

Return type:

New image with painted-in mask.

mzbsuite.skeletons.mzb_skeletons_helpers.paint_image_tensor(image: Tensor, masks: Tensor, color: List[float]) Tensor[source]

Given an input image, a binary mask indicating where to paint, and a color to use, returns a new image where the pixels within the mask are colored with the specified color.

Parameters:
  • image (torch.Tensor) – Input image to paint.

  • mask (torch.Tensor) – Binary mask indicating where to paint.

  • color (List[float]) – List of 3 floats representing the RGB color to use.

Returns:

rgb_body – New image with painted pixels.

Return type:

torch.Tensor

mzbsuite.skeletons.mzb_skeletons_helpers.segment_skel(skeleton, inter, conn=1)[source]

Custom function to segment a skeletonised image into individual branches. Each branch gets a unique ID.

Parameters:
  • skeleton (numpy.ndarray) – The skeletonised image to segment

  • inter (list) – List of 2-tuples (x,y) containing the intersection coordinates, as returned by the function find_intersections

  • conn (int) – Connectivity of the skeleton. 1 for 4-connectivity, 2 for 8-connectivity

Returns:

  • skel_labels (numpy.ndarray) – The labelled skeleton image

  • edge_attributes (dict) – Dictionary containing the attributes of each edge (branch) (for now, its size in pixels)

  • skprops (dict) – Dictionary containing the skimage.regionprops of each branch

mzbsuite.skeletons.mzb_skeletons_helpers.traverse_graph(graph: dict, init: int, end_nodes: List[int], debug: bool = False) List[List[int]][source]

Function to traverse a graph from a starting node to a list of end nodes, and return all possible paths as a list of lists.

Parameters:
  • graph (dict) – The graph to traverse

  • init (int) – The starting node ID

  • end_nodes (list) – List of end nodes

  • debug (bool) – Whether to print debug information

Returns:

  • all_paths (list) – List of lists containing all possible paths from init to end_nodes

  • TODO

  • Maybe.

  • * Make it work for graphs with multiple paths between nodes and ensure that a subset of paths can be visited multiple times

  • * Make a test for it

mzbsuite.utils

Module containing utility functions for mzbsuite C 2023, M. Volpi, Swiss Data Science Center

class mzbsuite.utils.SaveLogCallback(model_folder)[source]

Callback to save the log of the training

TODO: will need to be updated to save the log of the training in more detail and in a more structured way

on_train_end(trainer, pl_module)[source]

Save the end date of the training

class mzbsuite.utils.cfg_to_arguments(args)[source]

This class is used to convert a dictionary to an object and extend the argparser. In the __init__ method, we iterate over the dictionary and add each key as an attribute to the object. Input is a dictionary, output is an object, that mimicks the argparse object.

Example

cfg = {‘a’: 1, ‘b’: 2}

args = cfg_to_arguments(cfg) print(args.a) # 1 print(args.b) # 2

cfg can be from configs stored in YAML file, a JSON file, or a dictionary, whatever you prefer.

mzbsuite.utils.find_checkpoints(dirs=PosixPath('lightning_logs'), version=None, log='val')[source]

Find the checkpoints for a given log

Parameters:
  • dirs (Path (default: Path("lightning_logs"))) – path to the lightning_logs folder

  • version (str (default: None)) – version of the log to use

Returns:

chkp – list of paths to checkpoints

Return type:

str

mzbsuite.utils.noneparse(value)[source]

Helper function to parse None values from YAML files

Parameters:

value (string) – string to be parsed

Returns:

value – parsed string

Return type:

string or None

mzbsuite.utils.read_pretrained_model(architecture, n_class)[source]

Helper script to load models compactly from pytorch model zoo and prepare them for Hummingbird finetuning

Parameters:
  • architecture (str) – name of the model to load

  • n_class (int) – number of classes to finetune the model for

Returns:

model – model with last layer replaced with a linear layer with n_class outputs

Return type:

pytorch model

mzbsuite.utils.regression_report(y_true, y_pred, PRINT=True)[source]

Helper function to print regression metrics. Taken and adapted from https://github.com/scikit-learn/scikit-learn/issues/18454#issue-708338254

Parameters:
  • y_true (np.array) – ground truth values

  • y_pred (np.array) – predicted values

  • PRINT (bool) – whether to print the metrics or not

Returns:

metrics – list of tuples with the name of the metric and its value

Return type:

list