Classification: finetune

In this notebook we illustrate how to re-train the models on user’s data. Specifically, we remap the last layer of the model to the desired classes, without modifying the model’s internal weights; this operation is called finetuning and is not as computationally intensive as re-training the full model. Regardless, this module greatly benefits from GPU compute, as long as the GPU(s) support CUDA and nvidia-smi is configured correctly.

This module uses two scripts: classification/main_prepare_learning_sets.py for preparing the data for training, and classification/main_classification_finetune.py, that need to be executed in that order.

The first step is to import the necessary libraries for main_prepare_learning_sets.py:

[6]:
import argparse
import shutil
import sys
import os
from pathlib import Path

import numpy as np
import pandas as pd
import yaml

from mzbsuite.utils import cfg_to_arguments

We need to declare the running parameters for the script,

[7]:
ROOT_DIR = Path.cwd()
MODEL="convnext-small-vtest-1"
LSET_FOLD=f"{ROOT_DIR}/data/mzb_example_data"

arguments = {
    "input_dir": "/data/shared/mzb-workflow/data/learning_sets/project_portable_flume/curated_learning_sets",
    "taxonomy_file": ROOT_DIR.parent.absolute() / "data/MZB_taxonomy.csv",
    "output_dir": ROOT_DIR.parent.absolute() / "data/mzb_example_data/aggregated_learning_sets",
    "save_model": ROOT_DIR.parent.absolute() / f"models/mzb-classification-models/{MODEL}",
    "config_file": ROOT_DIR.parent.absolute() / "configs/configuration_flume_datasets.yaml"
}

with open(str(arguments["config_file"]), "r") as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)

cfg["trcl_gpu_ids"] = None # this sets the number of available GPUs to zero, since this part of the module doesn't benefit from GPU compute.
cfg
[7]:
{'glob_random_seed': 222,
 'glob_root_folder': '/data/users/luca/mzb-workflow/mzb-workflow/',
 'glob_blobs_folder': '/data/users/luca/mzb-workflow/mzb-workflow/data/derived/blobs/',
 'glob_local_format': 'pdf',
 'model_logger': 'wandb',
 'impa_image_format': 'jpg',
 'impa_clip_areas': [2750, 4900],
 'impa_area_threshold': 5000,
 'impa_gaussian_blur': [21, 21],
 'impa_gaussian_blur_passes': 3,
 'impa_adaptive_threshold_block_size': 351,
 'impa_mask_postprocess_kernel': [11, 11],
 'impa_mask_postprocess_passes': 5,
 'impa_bounding_box_buffer': 200,
 'impa_save_clips_plus_features': True,
 'lset_class_cut': 'order',
 'lset_val_size': 0.1,
 'lset_taxonomy': '/data/users/luca/mzb-workflow/data/MZB_taxonomy.csv',
 'trcl_learning_rate': 0.001,
 'trcl_batch_size': 16,
 'trcl_weight_decay': 0,
 'trcl_step_size_decay': 5,
 'trcl_number_epochs': 10,
 'trcl_save_topk': 1,
 'trcl_num_classes': 8,
 'trcl_model_pretrarch': 'efficientnet-b2',
 'trcl_num_workers': 16,
 'trcl_wandb_project_name': 'mzb-classifiers',
 'trcl_logger': 'wandb',
 'trsk_learning_rate': 0.001,
 'trsk_batch_size': 32,
 'trsk_weight_decay': 0,
 'trsk_step_size_decay': 25,
 'trsk_number_epochs': 400,
 'trsk_save_topk': 1,
 'trsk_num_classes': 2,
 'trsk_model_pretrarch': 'mit_b2',
 'trsk_num_workers': 16,
 'trsk_wandb_project_name': 'mzb-skeletons',
 'trsk_logger': 'wandb',
 'infe_model_ckpt': 'last',
 'infe_num_classes': 8,
 'infe_image_glob': '*_rgb.jpg',
 'skel_class_exclude': 'errors',
 'skel_conv_rate': 131.6625,
 'skel_label_thickness': 3,
 'skel_label_buffer_on_preds': 25,
 'skel_label_clip_with_mask': False,
 'trcl_gpu_ids': None}

Convert these parameters to a dictionary:

[8]:
# Transforms configurations dicts to argparse arguments
args = cfg_to_arguments(arguments)
cfg = cfg_to_arguments(cfg)
print(str(cfg))
{'glob_random_seed': 222, 'glob_root_folder': '/data/users/luca/mzb-workflow/mzb-workflow/', 'glob_blobs_folder': '/data/users/luca/mzb-workflow/mzb-workflow/data/derived/blobs/', 'glob_local_format': 'pdf', 'model_logger': 'wandb', 'impa_image_format': 'jpg', 'impa_clip_areas': [2750, 4900], 'impa_area_threshold': 5000, 'impa_gaussian_blur': [21, 21], 'impa_gaussian_blur_passes': 3, 'impa_adaptive_threshold_block_size': 351, 'impa_mask_postprocess_kernel': [11, 11], 'impa_mask_postprocess_passes': 5, 'impa_bounding_box_buffer': 200, 'impa_save_clips_plus_features': True, 'lset_class_cut': 'order', 'lset_val_size': 0.1, 'lset_taxonomy': '/data/users/luca/mzb-workflow/data/MZB_taxonomy.csv', 'trcl_learning_rate': 0.001, 'trcl_batch_size': 16, 'trcl_weight_decay': 0, 'trcl_step_size_decay': 5, 'trcl_number_epochs': 10, 'trcl_save_topk': 1, 'trcl_num_classes': 8, 'trcl_model_pretrarch': 'efficientnet-b2', 'trcl_num_workers': 16, 'trcl_wandb_project_name': 'mzb-classifiers', 'trcl_logger': 'wandb', 'trsk_learning_rate': 0.001, 'trsk_batch_size': 32, 'trsk_weight_decay': 0, 'trsk_step_size_decay': 25, 'trsk_number_epochs': 400, 'trsk_save_topk': 1, 'trsk_num_classes': 2, 'trsk_model_pretrarch': 'mit_b2', 'trsk_num_workers': 16, 'trsk_wandb_project_name': 'mzb-skeletons', 'trsk_logger': 'wandb', 'infe_model_ckpt': 'last', 'infe_num_classes': 8, 'infe_image_glob': '*_rgb.jpg', 'skel_class_exclude': 'errors', 'skel_conv_rate': 131.6625, 'skel_label_thickness': 3, 'skel_label_buffer_on_preds': 25, 'skel_label_clip_with_mask': False, 'trcl_gpu_ids': None}

We next check whether the target directories already exist, and if not create them:

[9]:
np.random.seed(cfg.glob_random_seed)

# root of raw clip data
root_data = Path(args.input_dir)
outdir = Path(args.output_dir)
outdir.mkdir(parents=True, exist_ok=True)

# target folders definition
target_trn = outdir / "trn_set/"
target_val = outdir / "val_set/"

# check if trn_set and val_set subfolders exist. If so, then interrupt the script.
# This is to make sure that no overwriting happens; prompt the user that they need to specify a different output directory.
if target_trn.exists() or target_val.exists():
    raise ValueError(
        # print in red and back to normal
        f"\033[91m Output directory {outdir} already exists. Please specify a different output directory.\033[0m"
    )

We now use the specified taxonomic rank in the lset_class_cut parameter in the configuration file to cut the provided phylogenetic tree, and reorganize the images in directories corresponding to the this rank. See the documentation for further details.

[10]:

# make dictionary to recode: key is current classification, value is target reclassification. # forward fill to get last valid entry and subset to desired column mzb_taxonomy = pd.read_csv(Path(args.taxonomy_file)) if "Unnamed: 0" in mzb_taxonomy.columns: mzb_taxonomy = mzb_taxonomy.drop(columns=["Unnamed: 0"]) mzb_taxonomy = mzb_taxonomy.ffill(axis=1) recode_order = dict( zip(mzb_taxonomy["query"], mzb_taxonomy[cfg.lset_class_cut].str.lower()) ) print(f"Cutting phylogenetic tree at: {cfg.lset_class_cut}")
Cutting phylogenetic tree at: order

Now we copy the images over into the new folder structure according to the taxonomy:

[11]:
# Move files to target folders for all files in the curated learning set
for s_fo in recode_order:
    target_folder = target_trn / recode_order[s_fo]
    target_folder.mkdir(exist_ok=True, parents=True)

    for file in list((root_data / s_fo).glob("*")):
        shutil.copy(file, target_folder)

# move out the validation set
# make a small val set, 10% or 1 file, what is possible...
size = cfg.lset_val_size
trn_folds = [a.name for a in sorted(list(target_trn.glob("*")))]

Next, we split the validation set based on the proportion of total images specified by the lset_val_size parameter in the configuration file. We recommend at least 10% of the total images for each class.

[12]:

for s_fo in trn_folds: target_folder = target_val / s_fo target_folder.mkdir(exist_ok=True, parents=True) list_class = list((target_trn / s_fo).glob("*")) n_val_sam = np.max((1, np.ceil(0.1 * len(list_class)))) val_files = np.random.choice(list_class, int(n_val_sam)) for file in val_files: try: shutil.move(str(file), target_folder) except: print(f"{str(file)} into {target_folder}")
/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/errors/32_hf2_plecoptera_01_clip_8_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/errors
/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/errors/32_ob_leuctridae_01_clip_4_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/errors
/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/plecoptera/32_bd_plecoptera_01_clip_2_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/plecoptera

Now we have the training dataset ready for model training, with a training set and a validation set containing the same classes.

We move on to the model finetuning, using the script classification/main_classification_finetune.py. First we import some additional libraries from PyTorch;

[13]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.strategies.ddp import DDPStrategy

from mzbsuite.classification.mzb_classification_pilmodel import MZBModel
from mzbsuite.utils import cfg_to_arguments, SaveLogCallback

# Set the thread layer used by MKL
os.environ["MKL_THREADING_LAYER"] = "GNU" # this time we set the GPU computing layer to active

Before we can launch the training, we need to define a few special parameters, relating to finding the specified monitoring the model training progress over time:

[14]:
# Define checkpoints callbacks
# best model on validation
best_val_cb = pl.callbacks.ModelCheckpoint(
    dirpath=args.save_model,
    filename="best-val-{epoch}-{step}-{val_loss:.1f}",
    monitor="val_loss",
    mode="min",
    save_top_k=cfg.trcl_save_topk,
)

# latest model in training
last_mod_cb = pl.callbacks.ModelCheckpoint(
    dirpath=args.save_model,
    filename="last-{step}",
    every_n_train_steps=50,
    save_top_k=cfg.trcl_save_topk,
)

# Define progress bar callback
pbar_cb = pl.callbacks.progress.TQDMProgressBar(refresh_rate=5)

# Define logger callback to log training date
trdatelog = SaveLogCallback(model_folder=args.save_model)

# Define model from config
model = MZBModel(
    data_dir=args.input_dir,
    pretrained_network=cfg.trcl_model_pretrarch,
    learning_rate=cfg.trcl_learning_rate,
    batch_size=cfg.trcl_batch_size,
    weight_decay=cfg.trcl_weight_decay,
    num_workers_loader=cfg.trcl_num_workers,
    step_size_decay=cfg.trcl_step_size_decay,
    num_classes=cfg.trcl_num_classes,
)

We now check wether a pre-trained model is available, and if there is load the weights from that model. Note that logging model progress requires either a Weights & Biases or Tensorflow account. See the documentation for more details.

[15]:
# Check if there is a model to load, if there is, load it and train from there
if args.save_model.is_dir():
    if args.verbose:
        print(f"Loading model from {args.save_model}")
    try:
        fmodel = list(args.save_model.glob("last-*.ckpt"))[0]
    except:
        print("No last-* model in folder, loading best model")
        fmodel = list(
            args.save_model.glob("best-val-epoch=*-step=*-val_loss=*.*.ckpt")
        )[-1]

    model = model.load_from_checkpoint(fmodel)

# Define logger and name of run
name_run = f"classifier-{cfg.trcl_model_pretrarch}"  # f"{model.pretrained_network}"
cbacks = [pbar_cb, best_val_cb, last_mod_cb, trdatelog]

# Define logger, and use either wandb or tensorboard
if cfg.trcl_logger == "wandb":
    logger = WandbLogger(
        project=cfg.trcl_wandb_project_name, name=name_run if name_run else None
    )
    logger.watch(model, log="all")

elif cfg.trcl_logger == "tensorboard":
    logger = TensorBoardLogger(
        save_dir=args.save_model,
        name=name_run if name_run else None,
        log_graph=True,
    )
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: lpego (biodetect). Use `wandb login --relogin` to force relogin
wandb version 0.16.0 is available! To upgrade, please run: $ pip install wandb --upgrade
Tracking run with wandb version 0.15.4
Run data is saved locally in ./wandb/run-20231111_161213-1u2u0o5h
wandb: logging graph, to disable use `wandb.watch(log_graph=False)`

We are now finally ready to train our model!

[16]:

# instantiate trainer and train trainer = pl.Trainer( accelerator="auto", # cfg.trcl_num_gpus outdated max_epochs=cfg.trcl_number_epochs, strategy=DDPStrategy( find_unused_parameters=False ), # TODO: check how to use in notebook precision=16, callbacks=cbacks, logger=logger, log_every_n_steps=1 # profiler="simple", ) trainer.fit(model)
/home/luca/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/lightning_fabric/connector.py:555: UserWarning: 16 is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
---------------------------------------------------------------------------
MisconfigurationException                 Traceback (most recent call last)
/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb Cell 22 line 2
      <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> # instantiate trainer and train
----> <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a> trainer = pl.Trainer(
      <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>     accelerator="auto",  # cfg.trcl_num_gpus outdated
      <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>     max_epochs=cfg.trcl_number_epochs,
      <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>     strategy=DDPStrategy(
      <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a>         find_unused_parameters=False
      <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a>     ),  # TODO: check how to use in notebook
      <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a>     precision=16,
      <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>     callbacks=cbacks,
     <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a>     logger=logger,
     <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>     log_every_n_steps=1
     <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a>     # profiler="simple",
     <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=12'>13</a> )
     <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=14'>15</a> trainer.fit(model)

File ~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/utilities/argparse.py:69, in _defaults_from_env_vars.<locals>.insert_env_defaults(self, *args, **kwargs)
     66 kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
     68 # all args were already moved to kwargs
---> 69 return fn(self, **kwargs)

File ~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:398, in Trainer.__init__(self, accelerator, strategy, devices, num_nodes, precision, logger, callbacks, fast_dev_run, max_epochs, min_epochs, max_steps, min_steps, max_time, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, overfit_batches, val_check_interval, check_val_every_n_epoch, num_sanity_val_steps, log_every_n_steps, enable_checkpointing, enable_progress_bar, enable_model_summary, accumulate_grad_batches, gradient_clip_val, gradient_clip_algorithm, deterministic, benchmark, inference_mode, use_distributed_sampler, profiler, detect_anomaly, barebones, plugins, sync_batchnorm, reload_dataloaders_every_n_epochs, default_root_dir)
    395 # init connectors
    396 self._data_connector = _DataConnector(self)
--> 398 self._accelerator_connector = _AcceleratorConnector(
    399     devices=devices,
    400     accelerator=accelerator,
    401     strategy=strategy,
    402     num_nodes=num_nodes,
    403     sync_batchnorm=sync_batchnorm,
    404     benchmark=benchmark,
    405     use_distributed_sampler=use_distributed_sampler,
    406     deterministic=deterministic,
    407     precision=precision,
    408     plugins=plugins,
    409 )
    410 self._logger_connector = _LoggerConnector(self)
    411 self._callback_connector = _CallbackConnector(self)

File ~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:173, in _AcceleratorConnector.__init__(self, devices, num_nodes, accelerator, strategy, plugins, precision, sync_batchnorm, benchmark, use_distributed_sampler, deterministic)
    170 self.precision_plugin = self._check_and_init_precision()
    172 # 6. Instantiate Strategy - Part 2
--> 173 self._lazy_init_strategy()

File ~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:577, in _AcceleratorConnector._lazy_init_strategy(self)
    574 self.strategy._configure_launcher()
    576 if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible:
--> 577     raise MisconfigurationException(
    578         f"`Trainer(strategy={self._strategy_flag!r})` is not compatible with an interactive"
    579         " environment. Run your code as a script, or choose one of the compatible strategies:"
    580         f" `Fabric(strategy='dp'|'ddp_notebook')`."
    581         " In case you are spawning processes yourself, make sure to include the Trainer"
    582         " creation inside the worker function."
    583     )
    585 # TODO: should be moved to _check_strategy_and_fallback().
    586 # Current test check precision first, so keep this check here to meet error order
    587 if isinstance(self.accelerator, TPUAccelerator) and not isinstance(
    588     self.strategy, (SingleTPUStrategy, XLAStrategy)
    589 ):

MisconfigurationException: `Trainer(strategy=<pytorch_lightning.strategies.ddp.DDPStrategy object at 0x7f16508311b0>)` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible strategies: `Fabric(strategy='dp'|'ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.