{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Skeletonization: supervised, inference\n", "======================================\n", "\n", "In this notebook we use the supervised module to extract length and head width using a model trained on manually annotated data. We will use the script `skeletons/main_supervised_skeletons_inference.py` to extract skeletons form the clips. \n", "\n", "We first import the necessary libraries: " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/luca/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import argparse\n", "import os\n", "import sys\n", "import torch\n", "import cv2\n", "\n", "from datetime import datetime\n", "from pathlib import Path\n", "from PIL import Image\n", "from matplotlib import pyplot as plt\n", "from PIL import Image\n", "from skimage.morphology import thin\n", "from torchvision import transforms\n", "from tqdm import tqdm\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import pytorch_lightning as pl\n", "import yaml\n", "\n", "from mzbsuite.skeletons.mzb_skeletons_pilmodel import MZBModel_skels\n", "from mzbsuite.skeletons.mzb_skeletons_helpers import paint_image_tensor, Denormalize\n", "from mzbsuite.utils import cfg_to_arguments, find_checkpoints\n", "\n", "# Set the thread layer used by MKL\n", "os.environ[\"MKL_THREADING_LAYER\"] = \"GNU\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to set up some running parameters for the script too: " ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'config_file': PosixPath('/data/shared/mzb-workflow/configs/configuration_flume_datasets.yaml'), 'input_dir': PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera'), 'input_type': 'external', 'input_model': PosixPath('/data/shared/mzb-workflow/models/mzb-skeleton-models/mit-b2-v1'), 'output_dir': PosixPath('/data/shared/mzb-workflow/results/mzb_example_data/skeletons/skeletons_supervised'), 'save_masks': PosixPath('/data/shared/mzb-workflow/data/derived/skeletons/skeletons_supervised')}\n" ] } ], "source": [ "ROOT_DIR = Path(\"/data/shared/mzb-workflow/docs\")\n", "MODEL = \"mit-b2-v1\"\n", "\n", "arguments = {\n", " \"config_file\": ROOT_DIR.parent.absolute() / \"configs/configuration_flume_datasets.yaml\",\n", " \"input_dir\": ROOT_DIR.parent.absolute() / \"data/mzb_example_data/training_dataset/val_set/ephemeroptera\",\n", " \"input_type\": \"external\", \n", " \"input_model\": ROOT_DIR.parent.absolute() / f\"models/mzb-skeleton-models/{MODEL}\", \n", " \"output_dir\": ROOT_DIR.parent.absolute() / \"results/mzb_example_data/skeletons/skeletons_supervised\",\n", " \"save_masks\": ROOT_DIR.parent.absolute() / \"data/derived/skeletons/skeletons_supervised/\", \n", "}\n", " \n", "with open(str(arguments[\"config_file\"]), \"r\") as f:\n", " cfg = yaml.load(f, Loader=yaml.FullLoader)\n", "\n", "# cfg[\"trcl_gpu_ids\"] = None\n", "print(arguments)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convert to a dictionary for the scripts to parse. " ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'glob_random_seed': 222, 'glob_root_folder': '/work/mzb-workflow/', 'glob_blobs_folder': '/work/mzb-workflow/data/derived/blobs/', 'glob_local_format': 'pdf', 'model_logger': 'wandb', 'impa_image_format': 'jpg', 'impa_clip_areas': [2750, 4900], 'impa_area_threshold': 5000, 'impa_gaussian_blur': [21, 21], 'impa_gaussian_blur_passes': 3, 'impa_adaptive_threshold_block_size': 351, 'impa_mask_postprocess_kernel': [11, 11], 'impa_mask_postprocess_passes': 5, 'impa_bounding_box_buffer': 200, 'impa_save_clips_plus_features': True, 'lset_class_cut': 'order', 'lset_val_size': 0.1, 'lset_taxonomy': '/work/mzb-workflow/data/MZB_taxonomy.csv', 'trcl_learning_rate': 0.001, 'trcl_batch_size': 16, 'trcl_weight_decay': 0, 'trcl_step_size_decay': 5, 'trcl_number_epochs': 10, 'trcl_save_topk': 1, 'trcl_num_classes': 8, 'trcl_model_pretrarch': 'convnext-small', 'trcl_num_workers': 16, 'trcl_wandb_project_name': 'mzb-classifiers', 'trcl_logger': 'wandb', 'trsk_learning_rate': 0.001, 'trsk_batch_size': 32, 'trsk_weight_decay': 0, 'trsk_step_size_decay': 25, 'trsk_number_epochs': 400, 'trsk_save_topk': 1, 'trsk_num_classes': 2, 'trsk_model_pretrarch': 'mit_b2', 'trsk_num_workers': 16, 'trsk_wandb_project_name': 'mzb-skeletons', 'trsk_logger': 'wandb', 'infe_model_ckpt': 'last', 'infe_num_classes': 8, 'infe_image_glob': '*_rgb.jpg', 'skel_class_exclude': 'errors', 'skel_conv_rate': 131.6625, 'skel_label_thickness': 3, 'skel_label_buffer_on_preds': 25, 'skel_label_clip_with_mask': False}\n" ] } ], "source": [ "# Transforms configurations dicts to argparse arguments\n", "args = cfg_to_arguments(arguments)\n", "cfg = cfg_to_arguments(cfg)\n", "print(str(cfg))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If there is a trained model available, load those weights, and set up model directories. " ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/ephemeroptera\n" ] } ], "source": [ "dirs = find_checkpoints(\n", " Path(args.input_model).parents[0],\n", " version=Path(args.input_model).name,\n", " log=cfg.infe_model_ckpt,\n", ")\n", "\n", "mod_path = dirs[0]\n", "\n", "model = MZBModel_skels()\n", "model.model = model.load_from_checkpoint(\n", " checkpoint_path=mod_path,\n", ")\n", "\n", "model.data_dir = Path(args.input_dir)\n", "model.im_folder = model.data_dir / \"ephemeroptera\"\n", "# model.bo_folder = model.data_dir / \"sk_body\"\n", "# model.he_folder = model.data_dir / \"sk_head\"\n", "\n", "print(model.im_folder)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set up additional parameters for model inference: " ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# reindex trn/val split (this step is unfortunately necessary to get the model to work) \n", "np.random.seed(12)\n", "N = len(list(model.im_folder.glob(\"*.jpg\")))\n", "# model.trn_inds = sorted(\n", "# list(np.random.choice(np.arange(N), size=int(0.8 * N), replace=False))\n", "# )\n", "# model.val_inds = sorted(list(set(np.arange(N)).difference(set(model.trn_inds))))\n", "model.eval()\n", "model.freeze()\n", "\n", "args.input_type = \"external\"\n", "dataloader = model.external_dataloader(args.input_dir)\n", "dataset_name = \"external\"\n", "\n", "im_fi = dataloader.dataset.img_paths\n", "\n", "pbar_cb = pl.callbacks.progress.TQDMProgressBar(refresh_rate=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can run the inference using the trained model. " ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/31_b1_ephemeroptera_01_clip_1_rgb.jpg')\n", " PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/31_b2_baetis_01_clip_1_rgb.jpg')\n", " PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/31_hf2_baetidae_01_clip_4_rgb.jpg')\n", " PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/31_ob_ephemeroptera_01_clip_6_rgb.jpg')\n", " PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/32_b2_baetis_01_clip_1_rgb.jpg')\n", " PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/32_bd_baetidae_01_clip_4_rgb.jpg')\n", " PosixPath('/data/shared/mzb-workflow/data/mzb_example_data/training_dataset/val_set/ephemeroptera/32_hf2_baetidae_01_clip_5_rgb.jpg')]\n", "0\n", "\n" ] } ], "source": [ "print(im_fi)\n", "print(N)\n", "\n", "print(dataloader.dataset)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicting DataLoader 0: 100%|██████████| 1/1 [00:06<00:00, 6.52s/it]\n" ] } ], "source": [ "trainer = pl.Trainer(\n", " precision=32,\n", " max_epochs=1,\n", " accelerator=\"gpu\" if torch.cuda.is_available() else \"cpu\",\n", " devices=1 if torch.cuda.is_available() else None,\n", " callbacks=[pbar_cb],\n", " enable_checkpointing=False,\n", " logger=False,\n", ")\n", "\n", "outs = trainer.predict(\n", " model=model, dataloaders=[dataloader], return_predictions=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now aggregate the predictions and refine the skeletons produced. " ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/7 [00:00 0) and (not cfg.skel_label_clip_with_mask):\n", " mask = np.ones_like(x[0, 0, ...])\n", " mask[-cfg.skel_label_buffer_on_preds :, :] = 0\n", " mask[: cfg.skel_label_buffer_on_preds, :] = 0\n", " mask[:, : cfg.skel_label_buffer_on_preds] = 0\n", " mask[:, -cfg.skel_label_buffer_on_preds :] = 0\n", "\n", " mask = Image.fromarray(mask)\n", " mask = np.array(\n", " transforms.Resize(\n", " (o_size[1], o_size[0]),\n", " interpolation=transforms.InterpolationMode.BILINEAR,\n", " )(mask)\n", " )\n", " refined_skel = [\n", " (thin(a) > 0).astype(float) * mask for a in refined_skel[0:2, ...] > 50\n", " ]\n", " elif cfg.skel_label_clip_with_mask:\n", " # load mask\n", " mask_insect = Image.open(\n", " cfg.glob_blobs_folder / ti.name[:-4] + \"_mask.jpg\"\n", " ).convert(\"RGB\")\n", " mask_insect = np.array(mask_insect)[:, :, 0] > 0\n", " mask_insect = Image.fromarray(mask_insect)\n", " mask_insect = np.array(\n", " transforms.Resize(\n", " (o_size[1], o_size[0]),\n", " interpolation=transforms.InterpolationMode.BILINEAR,\n", " )(mask_insect)\n", " )\n", " refined_skel = [\n", " (thin(a) > 0).astype(float) * mask_insect\n", " for a in refined_skel[0:2, ...] > 50\n", " ]\n", "\n", " else:\n", " # Refine the predicted skeleton image\n", " refined_skel = [\n", " (thin(a) > 0).astype(float) for a in refined_skel[0:2, ...] > 50\n", " ]\n", "\n", " refined_skel = [(255 * s).astype(np.uint8) for s in refined_skel]\n", "\n", " if args.save_masks:\n", " name = \"_\".join(ti.name.split(\"_\")[:-1])\n", " cv2.imwrite(\n", " str(args.save_masks / f\"{name}_body.jpg\"),\n", " refined_skel[0],\n", " [cv2.IMWRITE_JPEG_QUALITY, 100],\n", " )\n", " cv2.imwrite(\n", " str(args.save_masks / f\"{name}_head.jpg\"),\n", " refined_skel[1],\n", " [cv2.IMWRITE_JPEG_QUALITY, 100],\n", " )\n", "\n", " preds_size.append(\n", " pd.DataFrame(\n", " {\n", " \"clip_name\": \"_\".join(ti.name.split(\".\")[0].split(\"_\")[:-1]),\n", " \"nn_pred_body\": [np.sum(refined_skel[0] > 0)],\n", " \"nn_pred_head\": [np.sum(refined_skel[1] > 0)],\n", " }\n", " )\n", " )\n", "\n", "preds_size = pd.concat(preds_size)\n", "# out_dir = Path(\n", "# f\"{args.output_dir}_{dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M')}\"\n", "# )\n", "out_dir = Path(f\"{args.output_dir}\")\n", "\n", "out_dir.mkdir(exist_ok=True, parents=True)\n", "\n", "preds_size.to_csv(out_dir / f\"size_skel_supervised_model.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "None\n" ] } ], "source": [ "print(outs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }