Source code for scripts.skeletons.main_unsupervised_skeleton_estimation

# %% test skimage skeletonize
import copy
import sys
from pathlib import Path
from datetime import datetime
import argparse

import cv2
import yaml

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.spatial import distance_matrix
from skimage.measure import label, regionprops
from skimage.morphology import dilation, disk, medial_axis, thin
from tqdm import tqdm

from mzbsuite.skeletons.mzb_skeletons_helpers import (
    get_endpoints,
    get_intersections,
    paint_image,
    segment_skel,
    traverse_graph,
)

from mzbsuite.utils import cfg_to_arguments, noneparse


[docs]def main(args, cfg): """ Main function for skeleton estimation (body size) in the unsupervised setting. Parameters ---------- args : argparse.Namespace Arguments parsed from command line. Namely: - config_file: path to the configuration file - input_dir: path to the directory containing the masks - output_dir: path to the directory where to save the results - save_masks: path to the directory where to save the masks as jpg - list_of_files: path to the csv file containing the classification predictions - v (verbose): whether to print more info cfg : argparse.Namespace Arguments parsed from the configuration file. Returns ------- None. All is saved to disk at specified locations. """ PLOTS = False if args.save_masks is not None: args.save_masks = Path(f"{args.save_masks}") args.save_masks.mkdir(parents=True, exist_ok=True) # setup some area-specific parameters for filtering area_class = { 0: {"area": [0, 10000], "thinning": 1, "lmode": "skeleton"}, 2: {"area": [10000, 15000], "thinning": 9, "lmode": "skeleton"}, 3: {"area": [15000, 20000], "thinning": 11, "lmode": "skeleton"}, 4: {"area": [20000, 50000], "thinning": 11, "lmode": "skeleton"}, 5: {"area": [50000, 100000], "thinning": 15, "lmode": "skeleton"}, 6: {"area": [100000, np.inf], "thinning": 20, "lmode": "skeleton"}, } # Load in all masks in the input directory mask_list = sorted( list(Path(args.input_dir).glob(f"*_mask.{cfg.impa_image_format}")) ) if args.list_of_files is not None: # select all files that are not predicted as "error" by the classification model predictions = ( pd.read_csv(args.list_of_files).set_index("file").sort_values("file") ) exclude = predictions[ predictions[cfg.skel_class_exclude] > 1.0 / cfg.infe_num_classes ].index.to_list() exclude = [ ("_".join(a.split("_")[:-1]) + f"_mask.{cfg.impa_image_format}").lower() for a in exclude ] else: exclude = [] # load in file names that are classified as error by our CNN err_filenames = sorted( list( Path( f"{cfg.glob_root_folder}/data/learning_sets/project_portable_flume/curated_learning_sets/errors" ).glob("*.png") ) ) exclude += [ ("_".join(a.name.split("_")[:-1]) + f"_mask.{cfg.impa_image_format}").lower() for a in err_filenames ] files_to_skel = [a for a in mask_list if a.name.lower() not in exclude] # %% out_dir = ( args.output_dir / f"{args.input_dir.name}_unsupervised_{datetime.now().strftime('%Y%m%d_%H%M')}" ) out_dir.mkdir(parents=True, exist_ok=True) # %% growing_df = [] # Load the image # PLOTS = True iterator = tqdm(files_to_skel, total=len(files_to_skel)) # iterator = tqdm([args.input_dir / "1_ob_mixed_difficutly_clip_32_mask.jpg"]) for fo in iterator: iterator.set_description(fo.name) # read in mask and rgb, rgb only for plotting mask_ = (cv2.imread(str(fo))[:, :, 0] / 255).astype(float) # Get needed filter size based on area for aa in area_class: if area_class[aa]["area"][0] < np.sum(mask_) < area_class[aa]["area"][1]: dpar = area_class[aa]["thinning"] # Find the medial axis, threshold it and clean if multiple regions, keep largest _, distance = medial_axis(mask_, return_distance=True) mask_dist = distance > dpar regs = label(mask_dist) props = regionprops(regs) # keep only the largest region of the eroded mask mask = regs == np.argmax([p.area for p in props if p.label > 0]) + 1 # compute general skeleton by thinning the maks skeleton = thin(mask, max_num_iter=None) # get coordinates of point that intersect or are ends of the skeleton segments inter = get_intersections(skeleton=skeleton.astype(np.uint8)) endpo = get_endpoints(skeleton=skeleton.astype(np.uint8)) if args.save_masks: # save the skeletonized mask cv2.imwrite( str(args.save_masks / f"{''.join(fo.name.split('.')[:-1])}_skel.jpg"), (255 * skeleton / np.max(skeleton)).astype(np.uint8), ) if PLOTS: rgb_ = cv2.imread(str(fo)[:-8] + "rgb.jpg")[:, :, [2, 1, 0]].astype( np.uint8 ) rgb_fi = paint_image(rgb_, skeleton, color=[255, 0, 0]) rgb_ma = paint_image(rgb_, mask, color=[255, 0, 255]) if inter: # then, deduplicate the intersections skel_labels, edge_attributes, skprop = segment_skel(skeleton, inter, conn=1) ds = distance_matrix(inter, inter) + 100 * np.eye(len(inter)) duplicates = np.where(ds < 3)[0] try: inter = [a for a in inter if a != inter[duplicates[0]]] except: pass else: skel_labels = [] # case for which there are no segments (ie, only one) if len(np.unique(skel_labels)) < 3: sub_df = pd.DataFrame( data={ "clip_filename": fo.name, "conv_rate_mm_px": [cfg.skel_conv_rate], "skel_length": [np.sum(skeleton)], "skel_length_mm": [np.sum(skeleton) / cfg.skel_conv_rate], "segms": [[0]], "area": np.sum(mask_), } ) growing_df.append(sub_df) if PLOTS: f, a = plt.subplots(1, 2) a[0].imshow(rgb_fi) a[1].imshow(rgb_ma) plt.title(f"Area: {np.sum(mask_)}") else: # remove nodes that are too close (less than 3px) and treat them as only one node # skel_labels, edge_attributes, skprop = segment_skel(skeleton, inter, conn=1) # ds = distance_matrix(inter, inter) + 100 * np.eye(len(inter)) # duplicates = np.where(ds < 3)[0] # try: # inter = [a for a in inter if a != inter[duplicates[0]]] # except: # pass if args.save_masks: skel_masks_path = Path(args.save_masks) skel_masks_path.mkdir(parents=True, exist_ok=True) # save the skeletonized mask cv2.imwrite( str( args.save_masks / f"{''.join(fo.name.split('.')[:-1])}_skel.jpg" ), (255 * skel_labels / np.max(skel_labels)).astype(np.uint8), ) # get the segments that touch each intersection, and make them neighbors intersection_nodes = [] for coord in inter: local_cut = skel_labels[ (coord[1] - 4) : (coord[1] + 5), (coord[0] - 4) : (coord[0] + 5) ] nodes_touch = np.unique(local_cut[local_cut != 0]) intersection_nodes.append(list(nodes_touch)) # remove duplicates k = sorted(intersection_nodes) dedup = [k[i] for i in range(len(k)) if i == 0 or k[i] != k[i - 1]] intersection_nodes = dedup # get the segments that touch each endpoint dead_ends = [] for coord in endpo: ends = skel_labels[ (coord[1] - 4) : (coord[1] + 5), (coord[0] - 4) : (coord[0] + 5) ] end_node = np.unique(ends[ends != 0]) dead_ends.append(list(end_node)) dead_ends = sorted(dead_ends) # build the graph of segments of the skeleton graph = {} for nod in np.unique(skel_labels[skel_labels > 0]): nei = [a for a in intersection_nodes if nod in a] nei = [item for sublist in nei for item in sublist] graph[nod] = list(set(nei).difference([nod])) end_nodes = copy.deepcopy(dead_ends) # tf is this end_nodes = [i for a in end_nodes for i in a] all_paths = [] c = 0 # traverse the graph for all end_nodes and get paths, append them to all_paths for init in end_nodes[:1]: p_i = traverse_graph(graph, init, end_nodes, debug=False) all_paths.extend(p_i) # remove doubles skel_cand = [] for sk in all_paths: if sorted(sk) not in skel_cand: skel_cand.append(sorted(sk)) # measure path lenghts and keep max one, that is the skel for you sk_l = [] for sk in skel_cand: cus = 0 for i in sk: cus += edge_attributes[i] sk_l.append(cus) # append to dataframe, some propeties sub_df = pd.DataFrame( data={ "clip_filename": fo.name, "conv_rate_mm_px": [cfg.skel_conv_rate], "skel_length": [sk_l[np.argmax(sk_l)]], "skel_length_mm": [sk_l[np.argmax(sk_l)] / cfg.skel_conv_rate], "segms": [skel_cand[np.argmax(sk_l)]], "area": np.sum(mask_), } ) growing_df.append(sub_df) if PLOTS: f, a = plt.subplots(1, 3, figsize=(12, 12)) a[0].imshow( paint_image( skel_labels * 255, dilation(skel_labels > 0, disk(3)), [255, 0, 255], ) ) a[0].scatter(np.array(inter)[:, 0], np.array(inter)[:, 1]) a[0].scatter(np.array(endpo)[:, 0], np.array(endpo)[:, 1], marker="s") for i in np.unique(skel_labels[skel_labels > 0]): a[0].text( x=skprop[i - 1].centroid[1], y=skprop[i - 1].centroid[0], s=f"{i}", color="white", ) sel_skel = np.zeros_like(skel_labels) for i in np.unique(skel_labels[skel_labels > 0]): if i in skel_cand[np.argmax(sk_l)]: sel_skel += dilation(skel_labels == i, disk(3)) sel_skel = sel_skel > 0 a[1].imshow(paint_image(rgb_fi, sel_skel, [255, 0, 0])) a[2].imshow(rgb_ma) a[0].title.set_text(f"Area: {np.sum(mask_)}") a[1].title.set_text(f"Sel Segm: {skel_cand[np.argmax(sk_l)]}") a[2].title.set_text(f"Skel_lenght_px {sk_l[np.argmax(sk_l)]}") full_df = pd.concat(growing_df) full_df.to_csv(out_dir / "skeleton_attributes.csv", index=False)
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config_file", type=str, required=True) parser.add_argument("--input_dir", type=str, required=True) parser.add_argument("--list_of_files", type=noneparse, required=False, default=None) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--save_masks", type=str, required=True) parser.add_argument("--verbose", "-v", action="store_true") args = parser.parse_args() with open(args.config_file, "r") as f: cfg = yaml.load(f, Loader=yaml.FullLoader) cfg = cfg_to_arguments(cfg) args.input_dir = Path(args.input_dir) args.output_dir = Path(args.output_dir) sys.exit(main(args, cfg)) # %% some visualizations for debugging if 0: rgb_ = cv2.imread(str(fo)[:-8] + "rgb.png")[:, :, [2, 1, 0]].astype(np.uint8) rgb_fi = paint_image(rgb_, skeleton, color=[255, 0, 0]) rgb_ma = paint_image(rgb_, mask, color=[255, 0, 255]) labs = np.unique(skel_labels[skel_labels > 0]) for i in labs: plt.figure() plt.imshow( paint_image( rgb_, dilation(skel_labels == i, disk(3)), [i / len(labs) * 255, 0, 255] ) ) plt.title(i) plt.figure() # plt.imshow(paint_image(rgb_, dilation(skel_labels, disk(3)), [255, 0, 255])) plt.imshow(dilation(skel_labels, disk(3))) plt.scatter(np.array(inter)[:, 0], np.array(inter)[:, 1]) plt.scatter(np.array(endpo)[:, 0], np.array(endpo)[:, 1], marker="s") for i in np.unique(skel_labels[skel_labels > 0]): plt.text( x=skprop[i - 1].centroid[1], y=skprop[i - 1].centroid[0], s=f"{i}", color="white", ) f, a = plt.subplots(1, 3, figsize=(12, 12)) a[0].imshow( paint_image( skel_labels * 255, dilation(skel_labels > 0, disk(3)), [255, 0, 255] ) ) a[0].scatter(np.array(inter)[:, 0], np.array(inter)[:, 1]) a[0].scatter(np.array(endpo)[:, 0], np.array(endpo)[:, 1], marker="s") for i in np.unique(skel_labels[skel_labels > 0]): a[0].text( x=skprop[i - 1].centroid[1], y=skprop[i - 1].centroid[0], s=f"{i}", color="white", ) sel_skel = np.zeros_like(skel_labels) for i in np.unique(skel_labels[skel_labels > 0]): # if i in skel_cand[np.argmax(sk_l)]: sel_skel += dilation(skel_labels == i, disk(3)) sel_skel = sel_skel > 0 a[1].imshow(paint_image(rgb_fi, sel_skel, [255, 0, 0])) a[2].imshow(rgb_ma) a[0].title.set_text(f"Area: {np.sum(mask_)}") a[1].title.set_text(f"Sel Segm: {skel_cand[np.argmax(sk_l)]}") a[2].title.set_text(f"Skel_lenght_px {sk_l[np.argmax(sk_l)]}")