mzbsuite module
install it
main modules
Functions and docstrings
mzb_classification_dataloader
- class mzbsuite.classification.mzb_classification_dataloader.Denormalize(mean, std)[source]
De-normalizes an image given its mean and standard deviation.
- class mzbsuite.classification.mzb_classification_dataloader.MZBLoader(dir_dict, ls_inds=[], learning_set='all', transforms=None, glob_pattern='*_rgb.*')[source]
Class definition for the dataloader for the macrozoobenthos dataset.
- Parameters:
dir_dict (dict) – dictionary containing the paths to the folders of the dataset
ls_inds (list) – indices of images to be used for the learning set, optional
learning_set (str) – type of learning set to be used, optional, default: ‘all’
transforms (torchvision.transforms) – list of transformations to apply to blobs. Optional, default: None
glob_pattern (str) – glob pattern to use for finding images. Optional, default: ‘_rgb.’
- static prepare_data(dir_dict: dict, ls_inds: list = [], glob_pattern: str = '*_rgb.*') tuple[source]
Prepare data for training and testing, returns image paths, labels and indices
- Parameters:
dir_dict (dict) – dictionary with keys as class names and values as paths to images
ls_inds (list) – list of indices to be used for training or testing
glob_pattern (str) – glob pattern to use for finding images
- Returns:
img_paths – list of paths to images
- Return type:
list
mzb_classification_pilmodel
- class mzbsuite.classification.mzb_classification_pilmodel.MZBModel(data_dir='data/learning_sets/', pretrained_network='resnet50', learning_rate=0.0001, batch_size=32, weight_decay=1e-08, num_workers_loader=4, step_size_decay=5, num_classes=8)[source]
pytorch lightning class definition and model setup
- Parameters:
data_dir (str) – path to the directory containing the training and validation sets
pretrained_network (str) – name of the pretrained network to use
learning_rate (float) – learning rate for the optimizer
batch_size (int) – batch size for the training and validation dataloaders
weight_decay (float) – weight decay for the optimizer
num_workers_loader (int) – number of workers for the dataloaders
step_size_decay (int) – number of epochs after which the learning rate is decayed
num_classes (int) – number of classes to classify
- predict_step(batch, batch_idx, dataloader_idx: int | None = None)[source]
Step function called during
predict(). By default, it callsforward(). Override to add any processing logic.The
predict_step()is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWritershould be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)as predictions won’t be returned.Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- Parameters:
batch – Current batch.
batch_idx – Index of current batch.
dataloader_idx – Index of the current dataloader.
- Returns:
Predicted output
mzb_skeletons_dataloader
- class mzbsuite.skeletons.mzb_skeletons_dataloader.MZBLoader_skels(im_folder, bo_folder, he_folder, ls_inds=[], learning_set='all', transforms=None)[source]
Class definition for the dataloader for the macrozoobenthos skeletons dataset.
- Parameters:
im_folder (Path) – folder path of input images
bo_folder (Path) – folder path of body length skeleton masks
he_folder (Path) – folder path of head length skeleton masks
ls_inds (list) – indices of images to be used for the learning set, optional
learning_set (str) – type of learning set to be used, optional, default: ‘all’
transforms (torchvision.transforms) – optional, default: None
mzb_skeletons_pilmodel
- class mzbsuite.skeletons.mzb_skeletons_pilmodel.MZBModel_skels(data_dir='data/skel_segm/', pretrained_network='efficientnet-b2', learning_rate=0.0001, batch_size=32, weight_decay=1e-08, num_workers_loader=4, step_size_decay=5, num_classes=2)[source]
Pytorch Lightning Module for training the skeleton segmentation model.
- Parameters:
data_dir (str) – Path to the directory where the data is stored.
pretrained_network (str) – Name of the pretrained network to use.
learning_rate (float) – Learning rate for the optimizer.
batch_size (int) – Batch size for the dataloader.
weight_decay (float) – Weight decay for the optimizer.
num_workers_loader (int) – Number of workers for the dataloader.
step_size_decay (int) – Number of epochs after which the learning rate is decayed.
num_classes (int) – Number of classes to predict.
- predict_step(batch, batch_idx, dataloader_idx: int | None = None)[source]
custom predict iteration per batch, returning probabilities and labels
mzb_skeletons_helpers
- class mzbsuite.skeletons.mzb_skeletons_helpers.Denormalize(mean, std)[source]
Denormalize a tensor image with mean and standard deviation, for plotting purposes.
- mzbsuite.skeletons.mzb_skeletons_helpers.get_endpoints(skeleton: ndarray) List[Tuple[int, int]][source]
Given a skeletonised image, it will give the coordinates of the endpoints of the skeleton.
- Parameters:
skeleton (numpy.ndarray) – The skeletonised image to detect the endpoints of
- Returns:
endpoints – List of 2-tuples (x,y) containing the intersection coordinates
- Return type:
list
- mzbsuite.skeletons.mzb_skeletons_helpers.get_intersections(skeleton: ndarray) List[Tuple[int, int]][source]
Given a skeletonised image, it will give the coordinates of the intersections of the skeleton.
- Parameters:
skeleton (np.ndarray) – Binary image of the skeleton
- Returns:
intersections – List of 2-tuples (x,y) containing the intersection coordinates
- Return type:
list
- mzbsuite.skeletons.mzb_skeletons_helpers.neighbours(x, y, image)[source]
Return 8-neighbours of image point P1(x,y), in a clockwise order
- Parameters:
x (int) – x-coordinate of the point
y (int) – y-coordinate of the point
image (numpy.ndarray) – The image to find the neighbours of
- Returns:
_ – List of 8-neighbours of the point in the image
- Return type:
list
- mzbsuite.skeletons.mzb_skeletons_helpers.paint_image(image: ndarray, mask: array, color: List[float]) ndarray[source]
Given an input image, a binary mask indicating where to paint, and a color to use, returns a new image where the pixels within the mask are colored with the specified color.
- Parameters:
(np.ndarray) (image) –
(np.array) (mask) –
(List[float]) (color) –
- Returns:
rgb_fi (np.ndarray)
- Return type:
New image with painted-in mask.
- mzbsuite.skeletons.mzb_skeletons_helpers.paint_image_tensor(image: Tensor, masks: Tensor, color: List[float]) Tensor[source]
Given an input image, a binary mask indicating where to paint, and a color to use, returns a new image where the pixels within the mask are colored with the specified color.
- Parameters:
image (torch.Tensor) – Input image to paint.
mask (torch.Tensor) – Binary mask indicating where to paint.
color (List[float]) – List of 3 floats representing the RGB color to use.
- Returns:
rgb_body – New image with painted pixels.
- Return type:
torch.Tensor
- mzbsuite.skeletons.mzb_skeletons_helpers.segment_skel(skeleton, inter, conn=1)[source]
Custom function to segment a skeletonised image into individual branches. Each branch gets a unique ID.
- Parameters:
skeleton (numpy.ndarray) – The skeletonised image to segment
inter (list) – List of 2-tuples (x,y) containing the intersection coordinates, as returned by the function find_intersections
conn (int) – Connectivity of the skeleton. 1 for 4-connectivity, 2 for 8-connectivity
- Returns:
skel_labels (numpy.ndarray) – The labelled skeleton image
edge_attributes (dict) – Dictionary containing the attributes of each edge (branch) (for now, its size in pixels)
skprops (dict) – Dictionary containing the skimage.regionprops of each branch
- mzbsuite.skeletons.mzb_skeletons_helpers.traverse_graph(graph: dict, init: int, end_nodes: List[int], debug: bool = False) List[List[int]][source]
Function to traverse a graph from a starting node to a list of end nodes, and return all possible paths as a list of lists.
- Parameters:
graph (dict) – The graph to traverse
init (int) – The starting node ID
end_nodes (list) – List of end nodes
debug (bool) – Whether to print debug information
- Returns:
all_paths (list) – List of lists containing all possible paths from init to end_nodes
TODO
Maybe.
* Make it work for graphs with multiple paths between nodes and ensure that a subset of paths can be visited multiple times
* Make a test for it
mzbsuite.utils
Module containing utility functions for mzbsuite C 2023, M. Volpi, Swiss Data Science Center
- class mzbsuite.utils.SaveLogCallback(model_folder)[source]
Callback to save the log of the training
TODO: will need to be updated to save the log of the training in more detail and in a more structured way
- class mzbsuite.utils.cfg_to_arguments(args)[source]
This class is used to convert a dictionary to an object and extend the argparser. In the __init__ method, we iterate over the dictionary and add each key as an attribute to the object. Input is a dictionary, output is an object, that mimicks the argparse object.
Example
cfg = {‘a’: 1, ‘b’: 2}
args = cfg_to_arguments(cfg) print(args.a) # 1 print(args.b) # 2
cfg can be from configs stored in YAML file, a JSON file, or a dictionary, whatever you prefer.
- mzbsuite.utils.find_checkpoints(dirs=PosixPath('lightning_logs'), version=None, log='val')[source]
Find the checkpoints for a given log
- Parameters:
dirs (Path (default: Path("lightning_logs"))) – path to the lightning_logs folder
version (str (default: None)) – version of the log to use
- Returns:
chkp – list of paths to checkpoints
- Return type:
str
- mzbsuite.utils.noneparse(value)[source]
Helper function to parse None values from YAML files
- Parameters:
value (string) – string to be parsed
- Returns:
value – parsed string
- Return type:
string or None
- mzbsuite.utils.read_pretrained_model(architecture, n_class)[source]
Helper script to load models compactly from pytorch model zoo and prepare them for Hummingbird finetuning
- Parameters:
architecture (str) – name of the model to load
n_class (int) – number of classes to finetune the model for
- Returns:
model – model with last layer replaced with a linear layer with n_class outputs
- Return type:
pytorch model
- mzbsuite.utils.regression_report(y_true, y_pred, PRINT=True)[source]
Helper function to print regression metrics. Taken and adapted from https://github.com/scikit-learn/scikit-learn/issues/18454#issue-708338254
- Parameters:
y_true (np.array) – ground truth values
y_pred (np.array) – predicted values
PRINT (bool) – whether to print the metrics or not
- Returns:
metrics – list of tuples with the name of the metric and its value
- Return type:
list