Classification: finetune
In this notebook we illustrate how to re-train the models on user’s data. Specifically, we remap the last layer of the model to the desired classes, without modifying the model’s internal weights; this operation is called finetuning and is not as computationally intensive as re-training the full model. Regardless, this module greatly benefits from GPU compute, as long as the GPU(s) support CUDA and nvidia-smi is configured correctly.
This module uses two scripts: classification/main_prepare_learning_sets.py for preparing the data for training, and classification/main_classification_finetune.py, that need to be executed in that order.
The first step is to import the necessary libraries for main_prepare_learning_sets.py:
[6]:
import argparse
import shutil
import sys
import os
from pathlib import Path
import numpy as np
import pandas as pd
import yaml
from mzbsuite.utils import cfg_to_arguments
We need to declare the running parameters for the script,
[7]:
ROOT_DIR = Path.cwd()
MODEL="convnext-small-vtest-1"
LSET_FOLD=f"{ROOT_DIR}/data/mzb_example_data"
arguments = {
"input_dir": "/data/shared/mzb-workflow/data/learning_sets/project_portable_flume/curated_learning_sets",
"taxonomy_file": ROOT_DIR.parent.absolute() / "data/MZB_taxonomy.csv",
"output_dir": ROOT_DIR.parent.absolute() / "data/mzb_example_data/aggregated_learning_sets",
"save_model": ROOT_DIR.parent.absolute() / f"models/mzb-classification-models/{MODEL}",
"config_file": ROOT_DIR.parent.absolute() / "configs/configuration_flume_datasets.yaml"
}
with open(str(arguments["config_file"]), "r") as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg["trcl_gpu_ids"] = None # this sets the number of available GPUs to zero, since this part of the module doesn't benefit from GPU compute.
cfg
[7]:
{'glob_random_seed': 222,
'glob_root_folder': '/data/users/luca/mzb-workflow/mzb-workflow/',
'glob_blobs_folder': '/data/users/luca/mzb-workflow/mzb-workflow/data/derived/blobs/',
'glob_local_format': 'pdf',
'model_logger': 'wandb',
'impa_image_format': 'jpg',
'impa_clip_areas': [2750, 4900],
'impa_area_threshold': 5000,
'impa_gaussian_blur': [21, 21],
'impa_gaussian_blur_passes': 3,
'impa_adaptive_threshold_block_size': 351,
'impa_mask_postprocess_kernel': [11, 11],
'impa_mask_postprocess_passes': 5,
'impa_bounding_box_buffer': 200,
'impa_save_clips_plus_features': True,
'lset_class_cut': 'order',
'lset_val_size': 0.1,
'lset_taxonomy': '/data/users/luca/mzb-workflow/data/MZB_taxonomy.csv',
'trcl_learning_rate': 0.001,
'trcl_batch_size': 16,
'trcl_weight_decay': 0,
'trcl_step_size_decay': 5,
'trcl_number_epochs': 10,
'trcl_save_topk': 1,
'trcl_num_classes': 8,
'trcl_model_pretrarch': 'efficientnet-b2',
'trcl_num_workers': 16,
'trcl_wandb_project_name': 'mzb-classifiers',
'trcl_logger': 'wandb',
'trsk_learning_rate': 0.001,
'trsk_batch_size': 32,
'trsk_weight_decay': 0,
'trsk_step_size_decay': 25,
'trsk_number_epochs': 400,
'trsk_save_topk': 1,
'trsk_num_classes': 2,
'trsk_model_pretrarch': 'mit_b2',
'trsk_num_workers': 16,
'trsk_wandb_project_name': 'mzb-skeletons',
'trsk_logger': 'wandb',
'infe_model_ckpt': 'last',
'infe_num_classes': 8,
'infe_image_glob': '*_rgb.jpg',
'skel_class_exclude': 'errors',
'skel_conv_rate': 131.6625,
'skel_label_thickness': 3,
'skel_label_buffer_on_preds': 25,
'skel_label_clip_with_mask': False,
'trcl_gpu_ids': None}
Convert these parameters to a dictionary:
[8]:
# Transforms configurations dicts to argparse arguments
args = cfg_to_arguments(arguments)
cfg = cfg_to_arguments(cfg)
print(str(cfg))
{'glob_random_seed': 222, 'glob_root_folder': '/data/users/luca/mzb-workflow/mzb-workflow/', 'glob_blobs_folder': '/data/users/luca/mzb-workflow/mzb-workflow/data/derived/blobs/', 'glob_local_format': 'pdf', 'model_logger': 'wandb', 'impa_image_format': 'jpg', 'impa_clip_areas': [2750, 4900], 'impa_area_threshold': 5000, 'impa_gaussian_blur': [21, 21], 'impa_gaussian_blur_passes': 3, 'impa_adaptive_threshold_block_size': 351, 'impa_mask_postprocess_kernel': [11, 11], 'impa_mask_postprocess_passes': 5, 'impa_bounding_box_buffer': 200, 'impa_save_clips_plus_features': True, 'lset_class_cut': 'order', 'lset_val_size': 0.1, 'lset_taxonomy': '/data/users/luca/mzb-workflow/data/MZB_taxonomy.csv', 'trcl_learning_rate': 0.001, 'trcl_batch_size': 16, 'trcl_weight_decay': 0, 'trcl_step_size_decay': 5, 'trcl_number_epochs': 10, 'trcl_save_topk': 1, 'trcl_num_classes': 8, 'trcl_model_pretrarch': 'efficientnet-b2', 'trcl_num_workers': 16, 'trcl_wandb_project_name': 'mzb-classifiers', 'trcl_logger': 'wandb', 'trsk_learning_rate': 0.001, 'trsk_batch_size': 32, 'trsk_weight_decay': 0, 'trsk_step_size_decay': 25, 'trsk_number_epochs': 400, 'trsk_save_topk': 1, 'trsk_num_classes': 2, 'trsk_model_pretrarch': 'mit_b2', 'trsk_num_workers': 16, 'trsk_wandb_project_name': 'mzb-skeletons', 'trsk_logger': 'wandb', 'infe_model_ckpt': 'last', 'infe_num_classes': 8, 'infe_image_glob': '*_rgb.jpg', 'skel_class_exclude': 'errors', 'skel_conv_rate': 131.6625, 'skel_label_thickness': 3, 'skel_label_buffer_on_preds': 25, 'skel_label_clip_with_mask': False, 'trcl_gpu_ids': None}
We next check whether the target directories already exist, and if not create them:
[9]:
np.random.seed(cfg.glob_random_seed)
# root of raw clip data
root_data = Path(args.input_dir)
outdir = Path(args.output_dir)
outdir.mkdir(parents=True, exist_ok=True)
# target folders definition
target_trn = outdir / "trn_set/"
target_val = outdir / "val_set/"
# check if trn_set and val_set subfolders exist. If so, then interrupt the script.
# This is to make sure that no overwriting happens; prompt the user that they need to specify a different output directory.
if target_trn.exists() or target_val.exists():
raise ValueError(
# print in red and back to normal
f"\033[91m Output directory {outdir} already exists. Please specify a different output directory.\033[0m"
)
We now use the specified taxonomic rank in the lset_class_cut parameter in the configuration file to cut the provided phylogenetic tree, and reorganize the images in directories corresponding to the this rank. See the documentation for further details.
[10]:
# make dictionary to recode: key is current classification, value is target reclassification.
# forward fill to get last valid entry and subset to desired column
mzb_taxonomy = pd.read_csv(Path(args.taxonomy_file))
if "Unnamed: 0" in mzb_taxonomy.columns:
mzb_taxonomy = mzb_taxonomy.drop(columns=["Unnamed: 0"])
mzb_taxonomy = mzb_taxonomy.ffill(axis=1)
recode_order = dict(
zip(mzb_taxonomy["query"], mzb_taxonomy[cfg.lset_class_cut].str.lower())
)
print(f"Cutting phylogenetic tree at: {cfg.lset_class_cut}")
Cutting phylogenetic tree at: order
Now we copy the images over into the new folder structure according to the taxonomy:
[11]:
# Move files to target folders for all files in the curated learning set
for s_fo in recode_order:
target_folder = target_trn / recode_order[s_fo]
target_folder.mkdir(exist_ok=True, parents=True)
for file in list((root_data / s_fo).glob("*")):
shutil.copy(file, target_folder)
# move out the validation set
# make a small val set, 10% or 1 file, what is possible...
size = cfg.lset_val_size
trn_folds = [a.name for a in sorted(list(target_trn.glob("*")))]
Next, we split the validation set based on the proportion of total images specified by the lset_val_size parameter in the configuration file. We recommend at least 10% of the total images for each class.
[12]:
for s_fo in trn_folds:
target_folder = target_val / s_fo
target_folder.mkdir(exist_ok=True, parents=True)
list_class = list((target_trn / s_fo).glob("*"))
n_val_sam = np.max((1, np.ceil(0.1 * len(list_class))))
val_files = np.random.choice(list_class, int(n_val_sam))
for file in val_files:
try:
shutil.move(str(file), target_folder)
except:
print(f"{str(file)} into {target_folder}")
/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/errors/32_hf2_plecoptera_01_clip_8_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/errors
/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/errors/32_ob_leuctridae_01_clip_4_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/errors
/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/plecoptera/32_bd_plecoptera_01_clip_2_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/plecoptera
Now we have the training dataset ready for model training, with a training set and a validation set containing the same classes.
We move on to the model finetuning, using the script classification/main_classification_finetune.py. First we import some additional libraries from PyTorch;
[13]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.strategies.ddp import DDPStrategy
from mzbsuite.classification.mzb_classification_pilmodel import MZBModel
from mzbsuite.utils import cfg_to_arguments, SaveLogCallback
# Set the thread layer used by MKL
os.environ["MKL_THREADING_LAYER"] = "GNU" # this time we set the GPU computing layer to active
Before we can launch the training, we need to define a few special parameters, relating to finding the specified monitoring the model training progress over time:
[14]:
# Define checkpoints callbacks
# best model on validation
best_val_cb = pl.callbacks.ModelCheckpoint(
dirpath=args.save_model,
filename="best-val-{epoch}-{step}-{val_loss:.1f}",
monitor="val_loss",
mode="min",
save_top_k=cfg.trcl_save_topk,
)
# latest model in training
last_mod_cb = pl.callbacks.ModelCheckpoint(
dirpath=args.save_model,
filename="last-{step}",
every_n_train_steps=50,
save_top_k=cfg.trcl_save_topk,
)
# Define progress bar callback
pbar_cb = pl.callbacks.progress.TQDMProgressBar(refresh_rate=5)
# Define logger callback to log training date
trdatelog = SaveLogCallback(model_folder=args.save_model)
# Define model from config
model = MZBModel(
data_dir=args.input_dir,
pretrained_network=cfg.trcl_model_pretrarch,
learning_rate=cfg.trcl_learning_rate,
batch_size=cfg.trcl_batch_size,
weight_decay=cfg.trcl_weight_decay,
num_workers_loader=cfg.trcl_num_workers,
step_size_decay=cfg.trcl_step_size_decay,
num_classes=cfg.trcl_num_classes,
)
We now check wether a pre-trained model is available, and if there is load the weights from that model. Note that logging model progress requires either a Weights & Biases or Tensorflow account. See the documentation for more details.
[15]:
# Check if there is a model to load, if there is, load it and train from there
if args.save_model.is_dir():
if args.verbose:
print(f"Loading model from {args.save_model}")
try:
fmodel = list(args.save_model.glob("last-*.ckpt"))[0]
except:
print("No last-* model in folder, loading best model")
fmodel = list(
args.save_model.glob("best-val-epoch=*-step=*-val_loss=*.*.ckpt")
)[-1]
model = model.load_from_checkpoint(fmodel)
# Define logger and name of run
name_run = f"classifier-{cfg.trcl_model_pretrarch}" # f"{model.pretrained_network}"
cbacks = [pbar_cb, best_val_cb, last_mod_cb, trdatelog]
# Define logger, and use either wandb or tensorboard
if cfg.trcl_logger == "wandb":
logger = WandbLogger(
project=cfg.trcl_wandb_project_name, name=name_run if name_run else None
)
logger.watch(model, log="all")
elif cfg.trcl_logger == "tensorboard":
logger = TensorBoardLogger(
save_dir=args.save_model,
name=name_run if name_run else None,
log_graph=True,
)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: lpego (biodetect). Use `wandb login --relogin` to force relogin
./wandb/run-20231111_161213-1u2u0o5h
wandb: logging graph, to disable use `wandb.watch(log_graph=False)`
We are now finally ready to train our model!
[16]:
# instantiate trainer and train
trainer = pl.Trainer(
accelerator="auto", # cfg.trcl_num_gpus outdated
max_epochs=cfg.trcl_number_epochs,
strategy=DDPStrategy(
find_unused_parameters=False
), # TODO: check how to use in notebook
precision=16,
callbacks=cbacks,
logger=logger,
log_every_n_steps=1
# profiler="simple",
)
trainer.fit(model)
/home/luca/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/lightning_fabric/connector.py:555: UserWarning: 16 is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
---------------------------------------------------------------------------
MisconfigurationException Traceback (most recent call last)
/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb Cell 22 line 2
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> # instantiate trainer and train
----> <a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a> trainer = pl.Trainer(
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> accelerator="auto", # cfg.trcl_num_gpus outdated
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> max_epochs=cfg.trcl_number_epochs,
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> strategy=DDPStrategy(
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a> find_unused_parameters=False
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a> ), # TODO: check how to use in notebook
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a> precision=16,
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a> callbacks=cbacks,
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a> logger=logger,
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a> log_every_n_steps=1
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a> # profiler="simple",
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=12'>13</a> )
<a href='vscode-notebook-cell://ssh-remote%2Bbiodetectgpu.datascience.ch/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb#X30sdnNjb2RlLXJlbW90ZQ%3D%3D?line=14'>15</a> trainer.fit(model)
File ~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/utilities/argparse.py:69, in _defaults_from_env_vars.<locals>.insert_env_defaults(self, *args, **kwargs)
66 kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
68 # all args were already moved to kwargs
---> 69 return fn(self, **kwargs)
File ~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:398, in Trainer.__init__(self, accelerator, strategy, devices, num_nodes, precision, logger, callbacks, fast_dev_run, max_epochs, min_epochs, max_steps, min_steps, max_time, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, overfit_batches, val_check_interval, check_val_every_n_epoch, num_sanity_val_steps, log_every_n_steps, enable_checkpointing, enable_progress_bar, enable_model_summary, accumulate_grad_batches, gradient_clip_val, gradient_clip_algorithm, deterministic, benchmark, inference_mode, use_distributed_sampler, profiler, detect_anomaly, barebones, plugins, sync_batchnorm, reload_dataloaders_every_n_epochs, default_root_dir)
395 # init connectors
396 self._data_connector = _DataConnector(self)
--> 398 self._accelerator_connector = _AcceleratorConnector(
399 devices=devices,
400 accelerator=accelerator,
401 strategy=strategy,
402 num_nodes=num_nodes,
403 sync_batchnorm=sync_batchnorm,
404 benchmark=benchmark,
405 use_distributed_sampler=use_distributed_sampler,
406 deterministic=deterministic,
407 precision=precision,
408 plugins=plugins,
409 )
410 self._logger_connector = _LoggerConnector(self)
411 self._callback_connector = _CallbackConnector(self)
File ~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:173, in _AcceleratorConnector.__init__(self, devices, num_nodes, accelerator, strategy, plugins, precision, sync_batchnorm, benchmark, use_distributed_sampler, deterministic)
170 self.precision_plugin = self._check_and_init_precision()
172 # 6. Instantiate Strategy - Part 2
--> 173 self._lazy_init_strategy()
File ~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:577, in _AcceleratorConnector._lazy_init_strategy(self)
574 self.strategy._configure_launcher()
576 if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible:
--> 577 raise MisconfigurationException(
578 f"`Trainer(strategy={self._strategy_flag!r})` is not compatible with an interactive"
579 " environment. Run your code as a script, or choose one of the compatible strategies:"
580 f" `Fabric(strategy='dp'|'ddp_notebook')`."
581 " In case you are spawning processes yourself, make sure to include the Trainer"
582 " creation inside the worker function."
583 )
585 # TODO: should be moved to _check_strategy_and_fallback().
586 # Current test check precision first, so keep this check here to meet error order
587 if isinstance(self.accelerator, TPUAccelerator) and not isinstance(
588 self.strategy, (SingleTPUStrategy, XLAStrategy)
589 ):
MisconfigurationException: `Trainer(strategy=<pytorch_lightning.strategies.ddp.DDPStrategy object at 0x7f16508311b0>)` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible strategies: `Fabric(strategy='dp'|'ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.