{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Skeletonization unsupervised\n",
"========================\n",
"\n",
"In this notebook we use the script `skeletons/main_unsupervised_skeleton_estimation.py` to automatically extract the length of the organisms from clips. \n",
"\n",
"First, import the necessary libraries: "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"import sys\n",
"from pathlib import Path\n",
"from datetime import datetime\n",
"import argparse\n",
"\n",
"import cv2\n",
"import yaml\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"from matplotlib import pyplot as plt\n",
"from scipy.spatial import distance_matrix\n",
"from skimage.measure import label, regionprops\n",
"from skimage.morphology import dilation, disk, medial_axis, thin\n",
"from tqdm import tqdm\n",
"\n",
"from mzbsuite.skeletons.mzb_skeletons_helpers import (\n",
" get_endpoints,\n",
" get_intersections,\n",
" paint_image,\n",
" segment_skel,\n",
" traverse_graph,\n",
")\n",
"\n",
"from mzbsuite.utils import cfg_to_arguments#, noneparse"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we set up some running parameters for the scipt to know where to find input images and where to write outputs. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: '/data/shared/configs/configuration_flume_datasets.yaml'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/data/shared/mzb-workflow/docs/source/files/examples/skeletonization_unsupervised.ipynb Cell 4\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 1\u001b[0m ROOT_DIR \u001b[39m=\u001b[39m Path(\u001b[39m\"\u001b[39m\u001b[39m/data/shared/mzb-workflow\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 3\u001b[0m arguments \u001b[39m=\u001b[39m {\n\u001b[1;32m 4\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mconfig_file\u001b[39m\u001b[39m\"\u001b[39m: ROOT_DIR\u001b[39m.\u001b[39mparent\u001b[39m.\u001b[39mabsolute() \u001b[39m/\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mconfigs/configuration_flume_datasets.yaml\u001b[39m\u001b[39m\"\u001b[39m, \n\u001b[1;32m 5\u001b[0m \u001b[39m\"\u001b[39m\u001b[39minput_dir\u001b[39m\u001b[39m\"\u001b[39m: ROOT_DIR\u001b[39m.\u001b[39mparent\u001b[39m.\u001b[39mabsolute() \u001b[39m/\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mdata/derived/mzb_example_data/\u001b[39m\u001b[39m\"\u001b[39m, \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlist_of_files\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 9\u001b[0m }\n\u001b[0;32m---> 11\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39;49m(\u001b[39mstr\u001b[39;49m(arguments[\u001b[39m\"\u001b[39;49m\u001b[39mconfig_file\u001b[39;49m\u001b[39m\"\u001b[39;49m]), \u001b[39m\"\u001b[39;49m\u001b[39mr\u001b[39;49m\u001b[39m\"\u001b[39;49m) \u001b[39mas\u001b[39;00m f:\n\u001b[1;32m 12\u001b[0m cfg \u001b[39m=\u001b[39m yaml\u001b[39m.\u001b[39mload(f, Loader\u001b[39m=\u001b[39myaml\u001b[39m.\u001b[39mFullLoader)\n\u001b[1;32m 14\u001b[0m cfg[\u001b[39m\"\u001b[39m\u001b[39mtrcl_gpu_ids\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39m# this sets the number of available GPUs to zero, since this script doesn't benefit from GPU compute. \u001b[39;00m\n",
"File \u001b[0;32m~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/IPython/core/interactiveshell.py:284\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[39mif\u001b[39;00m file \u001b[39min\u001b[39;00m {\u001b[39m0\u001b[39m, \u001b[39m1\u001b[39m, \u001b[39m2\u001b[39m}:\n\u001b[1;32m 278\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 279\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mIPython won\u001b[39m\u001b[39m'\u001b[39m\u001b[39mt let you open fd=\u001b[39m\u001b[39m{\u001b[39;00mfile\u001b[39m}\u001b[39;00m\u001b[39m by default \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 280\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 281\u001b[0m \u001b[39m\"\u001b[39m\u001b[39myou can use builtins\u001b[39m\u001b[39m'\u001b[39m\u001b[39m open.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 282\u001b[0m )\n\u001b[0;32m--> 284\u001b[0m \u001b[39mreturn\u001b[39;00m io_open(file, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/data/shared/configs/configuration_flume_datasets.yaml'"
]
}
],
"source": [
"ROOT_DIR = Path(\"/data/shared/mzb-workflow/docs\")\n",
"\n",
"arguments = {\n",
" \"config_file\": ROOT_DIR.parent.absolute() / \"configs/configuration_flume_datasets.yaml\", \n",
" \"input_dir\": ROOT_DIR.parent.absolute() / \"data/derived/mzb_example_data/\", \n",
" \"output_dir\": ROOT_DIR.parent.absolute() / \"results/mzb_example_data/skeletons/skeletons_unsupervised/\", \n",
" \"save_masks\": ROOT_DIR.parent.absolute() / \"data/derived/mzb_example_data/skeletons/skeletons_unsupervised\", \n",
" \"list_of_files\": None\n",
" }\n",
" \n",
"with open(str(arguments[\"config_file\"]), \"r\") as f:\n",
" cfg = yaml.load(f, Loader=yaml.FullLoader)\n",
"\n",
"cfg[\"trcl_gpu_ids\"] = None # this sets the number of available GPUs to zero, since this script doesn't benefit from GPU compute. \n",
"cfg"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Convert to dictionary for Python: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Transforms configurations dicts to argparse arguments\n",
"args = cfg_to_arguments(arguments)\n",
"cfg = cfg_to_arguments(cfg)\n",
"print(str(cfg))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the toggle `PLOTS` we can make the script produce graphical output as it processes the clips (if `PLOTS = True`) or not; this is useful to debug potential issues with parametrization. \n",
"\n",
"We also check whether the output folders exist already, otherwise create them. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PLOTS = True\n",
"\n",
"if args.save_masks is not None:\n",
" args.save_masks = Path(f\"{args.save_masks}\")\n",
" args.save_masks.mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define some area bins to differentiate morphological operations downstream: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# setup some area-specific parameters for filtering\n",
"area_class = {\n",
" 0: {\"area\": [0, 10000], \"thinning\": 1, \"lmode\": \"skeleton\"},\n",
" 2: {\"area\": [10000, 15000], \"thinning\": 9, \"lmode\": \"skeleton\"},\n",
" 3: {\"area\": [15000, 20000], \"thinning\": 11, \"lmode\": \"skeleton\"},\n",
" 4: {\"area\": [20000, 50000], \"thinning\": 11, \"lmode\": \"skeleton\"},\n",
" 5: {\"area\": [50000, 100000], \"thinning\": 15, \"lmode\": \"skeleton\"},\n",
" 6: {\"area\": [100000, np.inf], \"thinning\": 20, \"lmode\": \"skeleton\"},\n",
"}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load in clips, excluding those predicted to be `error` by the DL model. \n",
"\n",
"> ⚠️ \n",
"> the path to the folder with the clips classified as error is currently hardcoded in the script! "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Load in all masks in the input directory\n",
"mask_list = sorted(\n",
" list(Path(args.input_dir).glob(f\"*_mask.{cfg.impa_image_format}\"))\n",
")\n",
"\n",
"if args.list_of_files is not None:\n",
" # select all files that are not predicted as \"error\" by the classification model\n",
" predictions = (\n",
" pd.read_csv(args.list_of_files).set_index(\"file\").sort_values(\"file\")\n",
" )\n",
" exclude = predictions[\n",
" predictions[cfg.skel_class_exclude] > 1.0 / cfg.infe_num_classes\n",
" ].index.to_list()\n",
" exclude = [\n",
" (\"_\".join(a.split(\"_\")[:-1]) + f\"_mask.{cfg.impa_image_format}\").lower()\n",
" for a in exclude\n",
" ]\n",
"else:\n",
" exclude = []\n",
"\n",
"# load in file names that are classified as error by our CNN\n",
"err_filenames = sorted(\n",
" list(\n",
" Path(\n",
" f\"{cfg.glob_root_folder}/data/learning_sets/project_portable_flume/curated_learning_sets/errors\"\n",
" ).glob(\"*.png\")\n",
" )\n",
")\n",
"exclude += [\n",
" (\"_\".join(a.name.split(\"_\")[:-1]) + f\"_mask.{cfg.impa_image_format}\").lower()\n",
" for a in err_filenames\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up output directories and files: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"files_to_skel = [a for a in mask_list if a.name.lower() not in exclude]\n",
"\n",
"# %%\n",
"out_dir = (\n",
" args.output_dir\n",
" / f\"{args.input_dir.name}_unsupervised_{datetime.now().strftime('%Y%m%d_%H%M')}\"\n",
")\n",
"out_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# %%\n",
"growing_df = []\n",
"# Load the image\n",
"# PLOTS = True\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following code block is the implementation of the skeletonization algorithm. Please refer to the documentation for details of the functioning. \n",
"\n",
"In a nutshell, it sues the configuration parameters provided before to apply a series of morphological operations on the binary mask of each organism's clip, subsequently thinning it into segment(s), eventually connecting and calculating the longest path through them, thus producing the skeleton, which should approximate well the length of the organism. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"iterator = tqdm(files_to_skel, total=len(files_to_skel))\n",
"# iterator = tqdm([args.input_dir / \"1_ob_mixed_difficutly_clip_32_mask.jpg\"])\n",
"for fo in iterator:\n",
" iterator.set_description(fo.name)\n",
" # read in mask and rgb, rgb only for plotting\n",
" mask_ = (cv2.imread(str(fo))[:, :, 0] / 255).astype(float)\n",
"\n",
" # Get needed filter size based on area\n",
" for aa in area_class:\n",
" if area_class[aa][\"area\"][0] < np.sum(mask_) < area_class[aa][\"area\"][1]:\n",
" dpar = area_class[aa][\"thinning\"]\n",
"\n",
" # Find the medial axis, threshold it and clean if multiple regions, keep largest\n",
" _, distance = medial_axis(mask_, return_distance=True)\n",
" mask_dist = distance > dpar\n",
" regs = label(mask_dist)\n",
" props = regionprops(regs)\n",
"\n",
" # keep only the largest region of the eroded mask\n",
" mask = regs == np.argmax([p.area for p in props if p.label > 0]) + 1\n",
"\n",
" # compute general skeleton by thinning the masks\n",
" skeleton = thin(mask, max_num_iter=None)\n",
"\n",
" # get coordinates of point that intersect or are ends of the skeleton segments\n",
" inter = get_intersections(skeleton=skeleton.astype(np.uint8))\n",
" endpo = get_endpoints(skeleton=skeleton.astype(np.uint8))\n",
"\n",
" if args.save_masks:\n",
" # save the skeletonized mask\n",
" cv2.imwrite(\n",
" str(args.save_masks / f\"{''.join(fo.name.split('.')[:-1])}_skel.jpg\"),\n",
" (255 * skeleton / np.max(skeleton)).astype(np.uint8),\n",
" )\n",
"\n",
" if PLOTS:\n",
" rgb_ = cv2.imread(str(fo)[:-8] + \"rgb.jpg\")[:, :, [2, 1, 0]].astype(\n",
" np.uint8\n",
" )\n",
" rgb_fi = paint_image(rgb_, skeleton, color=[255, 0, 0])\n",
" rgb_ma = paint_image(rgb_, mask, color=[255, 0, 255])\n",
"\n",
" if inter:\n",
" # then, deduplicate the intersections\n",
" skel_labels, edge_attributes, skprop = segment_skel(skeleton, inter, conn=1)\n",
" ds = distance_matrix(inter, inter) + 100 * np.eye(len(inter))\n",
" duplicates = np.where(ds < 3)[0]\n",
" try:\n",
" inter = [a for a in inter if a != inter[duplicates[0]]]\n",
" except:\n",
" pass\n",
" else:\n",
" skel_labels = []\n",
"\n",
" # case for which there are no segments (ie, only one)\n",
" if len(np.unique(skel_labels)) < 3:\n",
" sub_df = pd.DataFrame(\n",
" data={\n",
" \"clip_filename\": fo.name,\n",
" \"conv_rate_mm_px\": [cfg.skel_conv_rate],\n",
" \"skel_length\": [np.sum(skeleton)],\n",
" \"skel_length_mm\": [np.sum(skeleton) / cfg.skel_conv_rate],\n",
" \"segms\": [[0]],\n",
" \"area\": np.sum(mask_),\n",
" }\n",
" )\n",
" growing_df.append(sub_df)\n",
"\n",
" if PLOTS:\n",
" f, a = plt.subplots(1, 2)\n",
" a[0].imshow(rgb_fi)\n",
" a[1].imshow(rgb_ma)\n",
" plt.title(f\"Area: {np.sum(mask_)}\")\n",
"\n",
" else:\n",
" # remove nodes that are too close (less than 3px) and treat them as only one node\n",
" # skel_labels, edge_attributes, skprop = segment_skel(skeleton, inter, conn=1)\n",
" # ds = distance_matrix(inter, inter) + 100 * np.eye(len(inter))\n",
" # duplicates = np.where(ds < 3)[0]\n",
" # try:\n",
" # inter = [a for a in inter if a != inter[duplicates[0]]]\n",
" # except:\n",
" # pass\n",
"\n",
" if args.save_masks:\n",
" skel_masks_path = Path(args.save_masks)\n",
" skel_masks_path.mkdir(parents=True, exist_ok=True)\n",
" # save the skeletonized mask\n",
" cv2.imwrite(\n",
" str(\n",
" args.save_masks / f\"{''.join(fo.name.split('.')[:-1])}_skel.jpg\"\n",
" ),\n",
" (255 * skel_labels / np.max(skel_labels)).astype(np.uint8),\n",
" )\n",
"\n",
" # get the segments that touch each intersection, and make them neighbors\n",
" intersection_nodes = []\n",
" for coord in inter:\n",
" local_cut = skel_labels[\n",
" (coord[1] - 4) : (coord[1] + 5), (coord[0] - 4) : (coord[0] + 5)\n",
" ]\n",
" nodes_touch = np.unique(local_cut[local_cut != 0])\n",
" intersection_nodes.append(list(nodes_touch))\n",
"\n",
" # remove duplicates\n",
" k = sorted(intersection_nodes)\n",
" dedup = [k[i] for i in range(len(k)) if i == 0 or k[i] != k[i - 1]]\n",
" intersection_nodes = dedup\n",
"\n",
" # get the segments that touch each endpoint\n",
" dead_ends = []\n",
" for coord in endpo:\n",
" ends = skel_labels[\n",
" (coord[1] - 4) : (coord[1] + 5), (coord[0] - 4) : (coord[0] + 5)\n",
" ]\n",
" end_node = np.unique(ends[ends != 0])\n",
" dead_ends.append(list(end_node))\n",
" dead_ends = sorted(dead_ends)\n",
"\n",
" # build the graph of segments of the skeleton\n",
" graph = {}\n",
" for nod in np.unique(skel_labels[skel_labels > 0]):\n",
" nei = [a for a in intersection_nodes if nod in a]\n",
" nei = [item for sublist in nei for item in sublist]\n",
" graph[nod] = list(set(nei).difference([nod]))\n",
"\n",
" end_nodes = copy.deepcopy(dead_ends)\n",
"\n",
" # tf is this\n",
" end_nodes = [i for a in end_nodes for i in a]\n",
" all_paths = []\n",
" c = 0\n",
"\n",
" # traverse the graph for all end_nodes and get paths, append them to all_paths\n",
" for init in end_nodes[:1]:\n",
" p_i = traverse_graph(graph, init, end_nodes, debug=False)\n",
" all_paths.extend(p_i)\n",
"\n",
" # remove doubles\n",
" skel_cand = []\n",
" for sk in all_paths:\n",
" if sorted(sk) not in skel_cand:\n",
" skel_cand.append(sorted(sk))\n",
"\n",
" # measure path lengths and keep max one, that is the skel for you\n",
" sk_l = []\n",
" for sk in skel_cand:\n",
" cus = 0\n",
" for i in sk:\n",
" cus += edge_attributes[i]\n",
" sk_l.append(cus)\n",
"\n",
" # append to dataframe, some properties\n",
" sub_df = pd.DataFrame(\n",
" data={\n",
" \"clip_filename\": fo.name,\n",
" \"conv_rate_mm_px\": [cfg.skel_conv_rate],\n",
" \"skel_length\": [sk_l[np.argmax(sk_l)]],\n",
" \"skel_length_mm\": [sk_l[np.argmax(sk_l)] / cfg.skel_conv_rate],\n",
" \"segms\": [skel_cand[np.argmax(sk_l)]],\n",
" \"area\": np.sum(mask_),\n",
" }\n",
" )\n",
" growing_df.append(sub_df)\n",
"\n",
" if PLOTS:\n",
" f, a = plt.subplots(1, 3, figsize=(12, 12))\n",
" a[0].imshow(\n",
" paint_image(\n",
" skel_labels * 255,\n",
" dilation(skel_labels > 0, disk(3)),\n",
" [255, 0, 255],\n",
" )\n",
" )\n",
"\n",
" a[0].scatter(np.array(inter)[:, 0], np.array(inter)[:, 1])\n",
" a[0].scatter(np.array(endpo)[:, 0], np.array(endpo)[:, 1], marker=\"s\")\n",
" for i in np.unique(skel_labels[skel_labels > 0]):\n",
" a[0].text(\n",
" x=skprop[i - 1].centroid[1],\n",
" y=skprop[i - 1].centroid[0],\n",
" s=f\"{i}\",\n",
" color=\"white\",\n",
" )\n",
"\n",
" sel_skel = np.zeros_like(skel_labels)\n",
" for i in np.unique(skel_labels[skel_labels > 0]):\n",
" if i in skel_cand[np.argmax(sk_l)]:\n",
" sel_skel += dilation(skel_labels == i, disk(3))\n",
" sel_skel = sel_skel > 0\n",
"\n",
" a[1].imshow(paint_image(rgb_fi, sel_skel, [255, 0, 0]))\n",
" a[2].imshow(rgb_ma)\n",
"\n",
" a[0].title.set_text(f\"Area: {np.sum(mask_)}\")\n",
" a[1].title.set_text(f\"Sel Segm: {skel_cand[np.argmax(sk_l)]}\")\n",
" a[2].title.set_text(f\"Skel_lenght_px {sk_l[np.argmax(sk_l)]}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, save the `.csv` file. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"full_df = pd.concat(growing_df)\n",
"full_df.to_csv(out_dir / \"skeleton_attributes.csv\", index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}