Skeletonization: supervised, inference

In this notebook we use the supervised module to extract length and head width using a model trained on manually annotated data. We will use the script skeletons/main_supervised_skeletons_inference.py to extract skeletons form the clips.

We first import the necessary libraries:

[1]:
import argparse
import os
import sys
import torch
import cv2

from datetime import datetime
from pathlib import Path
from PIL import Image
from matplotlib import pyplot as plt
from PIL import Image
from skimage.morphology import thin
from torchvision import transforms
from tqdm import tqdm

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import yaml

from mzbsuite.skeletons.mzb_skeletons_pilmodel import MZBModel_skels
from mzbsuite.skeletons.mzb_skeletons_helpers import paint_image_tensor, Denormalize
from mzbsuite.utils import cfg_to_arguments, find_checkpoints

# Set the thread layer used by MKL
os.environ["MKL_THREADING_LAYER"] = "GNU"
/home/luca/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We need to set up some running parameters for the script too:

[27]:
ROOT_DIR = Path("/data/shared/mzb-workflow/docs")
MODEL = "mit-b2-v1"

arguments = {
    "config_file": ROOT_DIR.parent.absolute() / "configs/configuration_flume_datasets.yaml",
    "input_dir": ROOT_DIR.parent.absolute() / "data/mzb_example_data/training_dataset/val_set/ephemeroptera",
    "input_type": "external",
    "input_model": ROOT_DIR.parent.absolute() / f"models/mzb-skeleton-models/{MODEL}",
    "output_dir": ROOT_DIR.parent.absolute() / "results/mzb_example_data/skeletons/skeletons_supervised",
    "save_masks": ROOT_DIR.parent.absolute() / "data/derived/skeletons/skeletons_supervised/",
}

with open(str(arguments["config_file"]), "r") as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)

# cfg["trcl_gpu_ids"] = None
print(arguments)
{'config_file': PosixPath('/data/shared/mzb-workflow/configs/configuration_flume_datasets.yaml'), 'input_dir': PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera'), 'input_type': 'external', 'input_model': PosixPath('/data/shared/mzb-workflow/models/mzb-skeleton-models/mit-b2-v1'), 'output_dir': PosixPath('/data/shared/mzb-workflow/results/mzb_example_data/skeletons/skeletons_supervised'), 'save_masks': PosixPath('/data/shared/mzb-workflow/data/derived/skeletons/skeletons_supervised')}

Convert to a dictionary for the scripts to parse.

[28]:
# Transforms configurations dicts to argparse arguments
args = cfg_to_arguments(arguments)
cfg = cfg_to_arguments(cfg)
print(str(cfg))
{'glob_random_seed': 222, 'glob_root_folder': '/work/mzb-workflow/', 'glob_blobs_folder': '/work/mzb-workflow/data/derived/blobs/', 'glob_local_format': 'pdf', 'model_logger': 'wandb', 'impa_image_format': 'jpg', 'impa_clip_areas': [2750, 4900], 'impa_area_threshold': 5000, 'impa_gaussian_blur': [21, 21], 'impa_gaussian_blur_passes': 3, 'impa_adaptive_threshold_block_size': 351, 'impa_mask_postprocess_kernel': [11, 11], 'impa_mask_postprocess_passes': 5, 'impa_bounding_box_buffer': 200, 'impa_save_clips_plus_features': True, 'lset_class_cut': 'order', 'lset_val_size': 0.1, 'lset_taxonomy': '/work/mzb-workflow/data/MZB_taxonomy.csv', 'trcl_learning_rate': 0.001, 'trcl_batch_size': 16, 'trcl_weight_decay': 0, 'trcl_step_size_decay': 5, 'trcl_number_epochs': 10, 'trcl_save_topk': 1, 'trcl_num_classes': 8, 'trcl_model_pretrarch': 'convnext-small', 'trcl_num_workers': 16, 'trcl_wandb_project_name': 'mzb-classifiers', 'trcl_logger': 'wandb', 'trsk_learning_rate': 0.001, 'trsk_batch_size': 32, 'trsk_weight_decay': 0, 'trsk_step_size_decay': 25, 'trsk_number_epochs': 400, 'trsk_save_topk': 1, 'trsk_num_classes': 2, 'trsk_model_pretrarch': 'mit_b2', 'trsk_num_workers': 16, 'trsk_wandb_project_name': 'mzb-skeletons', 'trsk_logger': 'wandb', 'infe_model_ckpt': 'last', 'infe_num_classes': 8, 'infe_image_glob': '*_rgb.jpg', 'skel_class_exclude': 'errors', 'skel_conv_rate': 131.6625, 'skel_label_thickness': 3, 'skel_label_buffer_on_preds': 25, 'skel_label_clip_with_mask': False}

If there is a trained model available, load those weights, and set up model directories.

[29]:
dirs = find_checkpoints(
    Path(args.input_model).parents[0],
    version=Path(args.input_model).name,
    log=cfg.infe_model_ckpt,
)

mod_path = dirs[0]

model = MZBModel_skels()
model.model = model.load_from_checkpoint(
    checkpoint_path=mod_path,
)

model.data_dir = Path(args.input_dir)
model.im_folder = model.data_dir / "ephemeroptera"
# model.bo_folder = model.data_dir / "sk_body"
# model.he_folder = model.data_dir / "sk_head"

print(model.im_folder)
/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/ephemeroptera

Set up additional parameters for model inference:

[30]:
# reindex trn/val split (this step is unfortunately necessary to get the model to work)
np.random.seed(12)
N = len(list(model.im_folder.glob("*.jpg")))
# model.trn_inds = sorted(
#     list(np.random.choice(np.arange(N), size=int(0.8 * N), replace=False))
# )
# model.val_inds = sorted(list(set(np.arange(N)).difference(set(model.trn_inds))))
model.eval()
model.freeze()

args.input_type = "external"
dataloader = model.external_dataloader(args.input_dir)
dataset_name = "external"

im_fi = dataloader.dataset.img_paths

pbar_cb = pl.callbacks.progress.TQDMProgressBar(refresh_rate=5)

Now we can run the inference using the trained model.

[31]:
print(im_fi)
print(N)

print(dataloader.dataset)
[PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/31_b1_ephemeroptera_01_clip_1_rgb.jpg')
 PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/31_b2_baetis_01_clip_1_rgb.jpg')
 PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/31_hf2_baetidae_01_clip_4_rgb.jpg')
 PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/31_ob_ephemeroptera_01_clip_6_rgb.jpg')
 PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/32_b2_baetis_01_clip_1_rgb.jpg')
 PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/32_bd_baetidae_01_clip_4_rgb.jpg')
 PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/32_hf2_baetidae_01_clip_5_rgb.jpg')]
0
<mzbsuite.skeletons.mzb_skeletons_dataloader.MZBLoader_skels object at 0x7f6a9d0c6b00>
[32]:
trainer = pl.Trainer(
    precision=32,
    max_epochs=1,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1 if torch.cuda.is_available() else None,
    callbacks=[pbar_cb],
    enable_checkpointing=False,
    logger=False,
)

outs = trainer.predict(
    model=model, dataloaders=[dataloader], return_predictions=True
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Predicting DataLoader 0: 100%|██████████| 1/1 [00:06<00:00,  6.52s/it]

We now aggregate the predictions and refine the skeletons produced.

[34]:
# aggregate predictions
p = []
gt = []
for out in outs:
    p.append(out[0].numpy())
    gt.append(out[1].numpy())
pc = np.concatenate(p)
gc = np.concatenate(gt)

# %%
# nn body preds
preds_size = []

print("Neural network predictions done, refining and saving skeletons...")

for i, ti in tqdm(enumerate(im_fi), total=len(im_fi)):
    im = Image.open(ti).convert("RGB")

    # get original size of image for resizing predictions
    o_size = im.size

    # get predictions
    x = model.transform_ts(im)
    x = x[np.newaxis, ...]
    with torch.set_grad_enabled(False):
        p = torch.sigmoid(model(x)).cpu().numpy().squeeze()

    refined_skel = np.concatenate((p, np.zeros_like(p[0:1, ...])), axis=0)
    refined_skel = Image.fromarray(
        (255 * np.transpose(refined_skel, (1, 2, 0))).astype(np.uint8)
    )

    refined_skel = transforms.Resize(
        (o_size[1], o_size[0]),
        interpolation=transforms.InterpolationMode.BILINEAR,
    )(refined_skel)
    refined_skel = np.transpose(np.asarray(refined_skel), (2, 0, 1))

    # mask out the edges of the image
    if (cfg.skel_label_buffer_on_preds > 0) and (not cfg.skel_label_clip_with_mask):
        mask = np.ones_like(x[0, 0, ...])
        mask[-cfg.skel_label_buffer_on_preds :, :] = 0
        mask[: cfg.skel_label_buffer_on_preds, :] = 0
        mask[:, : cfg.skel_label_buffer_on_preds] = 0
        mask[:, -cfg.skel_label_buffer_on_preds :] = 0

        mask = Image.fromarray(mask)
        mask = np.array(
            transforms.Resize(
                (o_size[1], o_size[0]),
                interpolation=transforms.InterpolationMode.BILINEAR,
            )(mask)
        )
        refined_skel = [
            (thin(a) > 0).astype(float) * mask for a in refined_skel[0:2, ...] > 50
        ]
    elif cfg.skel_label_clip_with_mask:
        # load mask
        mask_insect = Image.open(
            cfg.glob_blobs_folder / ti.name[:-4] + "_mask.jpg"
        ).convert("RGB")
        mask_insect = np.array(mask_insect)[:, :, 0] > 0
        mask_insect = Image.fromarray(mask_insect)
        mask_insect = np.array(
            transforms.Resize(
                (o_size[1], o_size[0]),
                interpolation=transforms.InterpolationMode.BILINEAR,
            )(mask_insect)
        )
        refined_skel = [
            (thin(a) > 0).astype(float) * mask_insect
            for a in refined_skel[0:2, ...] > 50
        ]

    else:
        # Refine the predicted skeleton image
        refined_skel = [
            (thin(a) > 0).astype(float) for a in refined_skel[0:2, ...] > 50
        ]

    refined_skel = [(255 * s).astype(np.uint8) for s in refined_skel]

    if args.save_masks:
        name = "_".join(ti.name.split("_")[:-1])
        cv2.imwrite(
            str(args.save_masks / f"{name}_body.jpg"),
            refined_skel[0],
            [cv2.IMWRITE_JPEG_QUALITY, 100],
        )
        cv2.imwrite(
            str(args.save_masks / f"{name}_head.jpg"),
            refined_skel[1],
            [cv2.IMWRITE_JPEG_QUALITY, 100],
        )

    preds_size.append(
        pd.DataFrame(
            {
                "clip_name": "_".join(ti.name.split(".")[0].split("_")[:-1]),
                "nn_pred_body": [np.sum(refined_skel[0] > 0)],
                "nn_pred_head": [np.sum(refined_skel[1] > 0)],
            }
        )
    )

preds_size = pd.concat(preds_size)
# out_dir = Path(
#     f"{args.output_dir}_{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M')}"
# )
out_dir = Path(f"{args.output_dir}")

out_dir.mkdir(exist_ok=True, parents=True)

preds_size.to_csv(out_dir / f"size_skel_supervised_model.csv", index=False)
  0%|          | 0/7 [00:00<?, ?it/s]100%|██████████| 7/7 [00:19<00:00,  2.84s/it]
[8]:
print(outs)
None
[ ]: