Source code for mzbsuite.skeletons.mzb_skeletons_helpers

# Module containing helper functions for the skeletonization scripts.

from pathlib import Path
from typing import List, Tuple, Union
import torch

import numpy as np
from matplotlib import pyplot as plt
from skimage.measure import label, regionprops


# %%
[docs]def paint_image(image: np.ndarray, mask: np.array, color: List[float]) -> np.ndarray: """ Given an input image, a binary mask indicating where to paint, and a color to use, returns a new image where the pixels within the mask are colored with the specified color. Parameters ---------- image (np.ndarray): Input image to paint. mask (np.array): Binary mask indicating where to paint. color (List[float]): List of 3 floats representing the RGB color to use. Returns ------- rgb_fi (np.ndarray): New image with painted-in mask. """ # Color the pixels within the mask with the specified color rgb_fi = image.copy() if len(rgb_fi.shape) == 2: rgb_fi = rgb_fi[:, :, np.newaxis] rgb_fi = np.concatenate((rgb_fi, rgb_fi, rgb_fi), axis=2) if np.max(rgb_fi) <= 1: rgb_fi = (rgb_fi * 255).astype(np.uint8) # Color the pixels within the mask with the specified color rgb_fi[mask > 0.75] = np.asarray( [ color[0] * mask[mask > 0.75], color[1] * mask[mask > 0.75], color[2] * mask[mask > 0.75], ] ).T # Return the new image return rgb_fi
# This probably needs to be merged with the above function! # make sure to use deal with torch vs numpy arrays
[docs]def paint_image_tensor( image: torch.Tensor, masks: torch.Tensor, color: List[float] ) -> torch.Tensor: """ Given an input image, a binary mask indicating where to paint, and a color to use, returns a new image where the pixels within the mask are colored with the specified color. Parameters ---------- image: torch.Tensor Input image to paint. mask: torch.Tensor Binary mask indicating where to paint. color: List[float] List of 3 floats representing the RGB color to use. Returns ------- rgb_body: torch.Tensor New image with painted pixels. """ # Make a copy of the input image rgb_body = image.clone() c = 0 for mask in masks: # Color the pixels within the mask with the specified color rgb_body[mask > 0.75] = torch.Tensor( [ color[c][0] * mask[mask > 0.75], color[c][1] * mask[mask > 0.75], color[c][2] * mask[mask > 0.75], ] ).permute((1, 0)) c += 1 # Return the new image return rgb_body
# %%
[docs]def get_intersections(skeleton: np.ndarray) -> List[Tuple[int, int]]: """ Given a skeletonised image, it will give the coordinates of the intersections of the skeleton. Parameters ---------- skeleton: np.ndarray Binary image of the skeleton Returns ------- intersections: list List of 2-tuples (x,y) containing the intersection coordinates """ # A biiiiiig list of valid intersections 2 3 4 # These are in the format shown to the right 1 C 5 # 8 7 6 validIntersection = [ [0, 1, 0, 1, 0, 0, 1, 0], [0, 0, 1, 0, 1, 0, 0, 1], [1, 0, 0, 1, 0, 1, 0, 0], [0, 1, 0, 0, 1, 0, 1, 0], [0, 0, 1, 0, 0, 1, 0, 1], [1, 0, 0, 1, 0, 0, 1, 0], [0, 1, 0, 0, 1, 0, 0, 1], [1, 0, 1, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0, 1], [0, 1, 0, 1, 0, 0, 0, 1], [0, 1, 0, 1, 0, 1, 0, 0], [0, 0, 0, 1, 0, 1, 0, 1], [1, 0, 1, 0, 0, 0, 1, 0], [1, 0, 1, 0, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0, 1, 0], [1, 0, 0, 0, 1, 0, 1, 0], [1, 0, 0, 1, 1, 1, 0, 0], [0, 0, 1, 0, 0, 1, 1, 1], [1, 1, 0, 0, 1, 0, 0, 1], [0, 1, 1, 1, 0, 0, 1, 0], [1, 0, 1, 1, 0, 0, 1, 0], [1, 0, 1, 0, 0, 1, 1, 0], [1, 0, 1, 1, 0, 1, 1, 0], [0, 1, 1, 0, 1, 0, 1, 1], [1, 1, 0, 1, 1, 0, 1, 0], [1, 1, 0, 0, 1, 0, 1, 0], [0, 1, 1, 0, 1, 0, 1, 0], [0, 0, 1, 0, 1, 0, 1, 1], [1, 0, 0, 1, 1, 0, 1, 0], [1, 0, 1, 0, 1, 1, 0, 1], [1, 0, 1, 0, 1, 1, 0, 0], [1, 0, 1, 0, 1, 0, 0, 1], [0, 1, 0, 0, 1, 0, 1, 1], [0, 1, 1, 0, 1, 0, 0, 1], [1, 1, 0, 1, 0, 0, 1, 0], [0, 1, 0, 1, 1, 0, 1, 0], [0, 0, 1, 0, 1, 1, 0, 1], [1, 0, 1, 0, 0, 1, 0, 1], [1, 0, 0, 1, 0, 1, 1, 0], [1, 0, 1, 1, 0, 1, 0, 0], ] image = skeleton.copy() intersections = [] for x in range(1, len(image) - 1): for y in range(1, len(image[x]) - 1): # If we have a white pixel if image[x][y] == 1: nei = neighbours(x, y, image) valid = True if nei in validIntersection: intersections.append((y, x)) # DO IT OUTSIDE AS INDEPENDENT STEP # # Filter intersections to make sure we don't count them twice or ones that are very close together # for point1 in intersections: # for point2 in intersections: # if ( # ((point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2) < 10**2 # ) and (point1 != point2): # intersections.remove(point2) # Remove duplicates intersections = list(set(intersections)) return intersections
# %%
[docs]def get_endpoints(skeleton: np.ndarray) -> List[Tuple[int, int]]: """ Given a skeletonised image, it will give the coordinates of the endpoints of the skeleton. Parameters ---------- skeleton: numpy.ndarray The skeletonised image to detect the endpoints of Returns ------- endpoints: list List of 2-tuples (x,y) containing the intersection coordinates """ # A biiiiiig list of valid endpoints 2 3 4 # These are in the format shown to the right 1 C 5 # 8 7 6 validEndpoints = [ [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 1], ] image = skeleton.copy() endpoints = [] for x in range(1, len(image) - 1): for y in range(1, len(image[x]) - 1): # If we have a white pixel if image[x][y] == 1: nei = neighbours(x, y, image) if nei in validEndpoints: endpoints.append((y, x)) # Remove duplicates if any endpoints = list(set(endpoints)) return endpoints
# %%
[docs]def neighbours(x, y, image): """ Return 8-neighbours of image point P1(x,y), in a clockwise order Parameters ---------- x: int x-coordinate of the point y: int y-coordinate of the point image: numpy.ndarray The image to find the neighbours of Returns ------- _: list List of 8-neighbours of the point in the image """ img = image x_1, y_1, x1, y1 = x - 1, y - 1, x + 1, y + 1 return [ img[x_1][y], img[x_1][y1], img[x][y1], img[x1][y1], img[x1][y], img[x1][y_1], img[x][y_1], img[x_1][y_1], ]
# %%
[docs]def traverse_graph( graph: dict, init: int, end_nodes: List[int], debug: bool = False ) -> List[List[int]]: """ Function to traverse a graph from a starting node to a list of end nodes, and return all possible paths as a list of lists. Parameters ---------- graph: dict The graph to traverse init: int The starting node ID end_nodes: list List of end nodes debug: bool Whether to print debug information Returns ------- all_paths: list List of lists containing all possible paths from init to end_nodes TODO: Maybe. * Make it work for graphs with multiple paths between nodes and ensure that a subset of paths can be visited multiple times * Make a test for it """ visited_ends = [] core_vis = [] all_paths = [] path = [] e = init if debug: print(f"start {e}") path.append(e) visited_ends.append(e) while True: neighs = graph[e] ends = [a for a in neighs if a in end_nodes] ends = [a for a in ends if a not in visited_ends] if ends: path.append(ends[0]) all_paths.append(path) if debug: print(f"appending path {all_paths}, restart e={init}") visited_ends.append(ends[0]) core_vis = [] e = init path = [] path.append(e) else: nex = [a for a in neighs if (a not in core_vis) and (a not in end_nodes)] if debug: print(f"nex {nex}") if len(nex) > 1: if nex[0] in path: e = nex[1] else: e = nex[0] else: if nex: e = nex[0] else: path.append(e) all_paths.append(path) return all_paths # track to avoid going back core_vis.append(e) path.append(e) if debug: print(f"nex {nex}, e {e}, core_vis {core_vis}") print(f"ends visited {visited_ends}") print(f"path so far {path}") if len(set(end_nodes).difference(set(visited_ends))) == 0: return all_paths
# %%
[docs]def segment_skel(skeleton, inter, conn=1): """ Custom function to segment a skeletonised image into individual branches. Each branch gets a unique ID. Parameters ---------- skeleton: numpy.ndarray The skeletonised image to segment inter: list List of 2-tuples (x,y) containing the intersection coordinates, as returned by the function find_intersections conn: int Connectivity of the skeleton. 1 for 4-connectivity, 2 for 8-connectivity Returns ------- skel_labels: numpy.ndarray The labelled skeleton image edge_attributes: dict Dictionary containing the attributes of each edge (branch) (for now, its size in pixels) skprops: dict Dictionary containing the skimage.regionprops of each branch """ zero_image = np.zeros_like(skeleton.copy()).astype(float) # np.zeros_like(ssub) zero_image[np.asarray(inter)[:, 1], np.asarray(inter)[:, 0]] = 1 zero_image[np.asarray(inter)[:, 1] + 1, np.asarray(inter)[:, 0]] = 1 zero_image[np.asarray(inter)[:, 1], np.asarray(inter)[:, 0] + 1] = 1 zero_image[np.asarray(inter)[:, 1] - 1, np.asarray(inter)[:, 0]] = 1 zero_image[np.asarray(inter)[:, 1], np.asarray(inter)[:, 0] - 1] = 1 if conn == 2: zero_image[np.asarray(inter)[:, 1] + 1, np.asarray(inter)[:, 0] + 1] = 1 zero_image[np.asarray(inter)[:, 1] + 1, np.asarray(inter)[:, 0] - 1] = 1 zero_image[np.asarray(inter)[:, 1] - 1, np.asarray(inter)[:, 0] + 1] = 1 zero_image[np.asarray(inter)[:, 1] - 1, np.asarray(inter)[:, 0] - 1] = 1 zero_image = np.clip(skeleton - zero_image, a_min=0, a_max=1) skel_labels = label(zero_image, connectivity=2) skprop = regionprops(skel_labels) edge_attributes = {} for i, r in enumerate(skprop): edge_attributes[i + 1] = r.area return skel_labels, edge_attributes, skprop
# %%
[docs]class Denormalize(object): """ Denormalize a tensor image with mean and standard deviation, for plotting purposes. """ def __init__(self, mean, std): """ Parameters ---------- mean: list List of mean values for each channel std: list List of standard deviation values for each channel """ self.mean = torch.Tensor(mean) self.std = torch.Tensor(std) def __call__(self, tensor): """ Loads the image and applies the transformation to it. Parameters ---------- tensor: torch.Tensor Tensor image of size (C, H, W) to be normalized. Returns ------- x_n: torch.Tensor Normalized image. """ channel_dim = np.where([(a == 3) or (a == 1) for a in tensor.shape])[0] if channel_dim == 2: x_n = tensor.mul_(self.std).add_(self.mean) return x_n elif channel_dim == 0: for t, m, s in zip(tensor, self.mean, self.std): x_n = t.mul_(s).add_(m) # The normalize code -> t.sub_(m).div_(s) return x_n