Image segmentation

In this notebook we illustrate how to use the script scripts/image_parsing/main_raw_to_clips.py to segment (i.e. extract) clips containing a single organisms from large-pane images containing multiple organisms.

As a first, step, import the necessary packages, including the custom functions of this repository msbsuite.utils (if you have trouble importing this package, refer back to the Installation section of the documentation).

[ ]:
import argparse
import sys
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
import yaml
from matplotlib import pyplot as plt
from scipy import ndimage
from skimage import feature, measure, morphology, segmentation
from tqdm import tqdm

from mzbsuite.utils import cfg_to_arguments

from notebook.services.config import ConfigManager
cm = ConfigManager().update('notebook', {'limit_output': 1000})

We now need to declare the parameters to tell the script where to find the files and where to save its outputs. In this notebook, we pass these arguments as a dictionary to Python, rather than variables in a shell (.sh) script.

You need to have downloaded the example dataset in order for this cell to compile properly. Alternatively you can change the file paths to the locations of folders of your own dataset on the arguments = {} block; the path is relative to where this notebook is located.

[ ]:
ROOT_DIR = Path("/data/shared/mzb-workflow/docs")

arguments = {
    "input_dir": ROOT_DIR.parent.absolute() / "data/mzb_example_data/raw_img/",
    "output_dir": ROOT_DIR.parent.absolute() / "data/derived/mzb_example_data/",
    "save_full_mask_dir": ROOT_DIR.parent.absolute() / "data/derived/mzb_example_data/full_image_masks/",
    "config_file": ROOT_DIR.parent.absolute() / "configs/configuration_flume_datasets.yaml",
}

with open(str(arguments["config_file"]), "r") as f:
    cfg = yaml.load(f, Loader=yaml.FullLoader)

cfg["trcl_gpu_ids"] = None # this sets the number of available GPUs to zero, since this script doesn't benefit from GPU compute.
cfg

The cfg variable should display a portion of the list of parameters in the configuration file.

Now we use custom function cfg_to_arguments to parse the parameters we have just supplied and the parameters in the configuration file:

[ ]:
args = cfg_to_arguments(arguments)
cfg = cfg_to_arguments(cfg)
print(str(cfg))

Now we check whether the output directories already exist, and if not create them.

[ ]:
# define paths
main_root = Path(args.input_dir)
outdir = Path(args.output_dir)
outdir.mkdir(parents=True, exist_ok=True)

if args.save_full_mask_dir is not None:
        args.save_full_mask_dir = Path(args.save_full_mask_dir)

Parse the contents of the input folder and standardise filenames, and print how many images are going to be processed.

[ ]:
# get list of files to process
files_proc = list(main_root.glob(f"**/*.{cfg.impa_image_format}"))
# make sure weird capitalization doesn't cause issues
files_proc.extend(list(main_root.glob(f"**/*.{cfg.impa_image_format.upper()}")))
files_proc = [a for a in files_proc if "mask" not in str(a)]
files_proc.sort()

print(f"Parsing {len(files_proc)} files")

If a clip area is defined, for instance if there is a reference scale in the same spot in all the images, this area is earmarked for exclusion in later processing.

[ ]:
if cfg.impa_clip_areas is not None:
    location_cutout = [int(a) for a in cfg.impa_clip_areas]

If the PLOTS variable is True, then the script will print out a summary plot for each image and for each individual clip being generated. If you don’t want plots being generated, change the value to False. If you would like to save each plot as a file, you can uncomment (i.e. remove #) the lines plt.savefig() in the loop below.

⚠️ WARNING: this can be computationally intensive and can potentially crash the notebook if a large number number of outputs is generated!

[ ]:
PLOTS = True

Define a normalisation function to flatten the pixel values of images (this helps with downstream processing).

[ ]:
# define quick normalization function
norm = lambda x: (x - np.min(x)) / (np.max(x) - np.min(x))

Below is the main loop that processes the images into clips, and will also produce figures if PLOTS = True. If PLOTS = False, the script will save a .csv file with information about each image and clips generated from it, as well as other information such as bounding box coordinates, pixel areas of the mask, etc.

For further details about the logic fo this script please refer to the explanation in the section Segmentation under Processing scripts in the documentation.

⚠️ WARNING: depending on the number of images and how many organisms are present, the processing time of the loop can be considerable.

[ ]:
# ### ATTEMPT AT UPDATING THE FIGURE IN-PLACE, INSTEAD OF GENERATING NEW FIGURES ALL THE TIME...

# import numpy as np
# import matplotlib.pyplot as plt
# from IPython.display import display, clear_output

# fig = plt.figure()
# ax = fig.add_subplot(1, 1, 1)

# for i in range(21):
#     ax.set_xlim(0, 20)

#     ax.plot(i, 1, marker='x')
#     display(fig)

#     clear_output(wait = True)
#     plt.pause(0.5)

# # ?display
[ ]:
### EXPERIMENTING WITH clear_output...

from random import uniform
import time
from IPython.display import display, clear_output

def black_box():
    i = 1
    while i <= 5:
        clear_output(wait=True)
        display('Iteration '+str(i)+' Score: '+str(uniform(0, 1)))
        time.sleep(1)
        i += 1

black_box()
[ ]:

iterator = tqdm(files_proc, total=len(files_proc)) for i, fo in enumerate(iterator): mask_props = [] # get image path raw_image_in = fo full_path_raw_image_in = fo.resolve() # read image and convert to HSV img = cv2.imread(str(full_path_raw_image_in))[:, :, [2, 1, 0]] hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) im_t = hsv[:, :, 0].copy() im_t = (255 * norm(np.mean(hsv[:, :, :2], axis=2))).astype(np.uint8) # filter image with some iterations of gaussian blur for _ in range(cfg.impa_gaussian_blur_passes): im_t = cv2.GaussianBlur(im_t, tuple(cfg.impa_gaussian_blur), 0) # prepare for morphological reconstruction seed = np.copy(im_t) seed[1:-1, 1:-1] = im_t.min() mask = np.copy(im_t) # remove the background dil = morphology.reconstruction(seed, im_t, method="dilation") im_t = (im_t - dil).astype(np.uint8) # adaptive local thresholding of foreground vs background # weighted cross correlation with gaussian filter ad_thresh = cv2.adaptiveThreshold( im_t, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, cfg.impa_adaptive_threshold_block_size, -2, ) # additional global threhsold to remove foreground vs background t, thresh = cv2.threshold(im_t, 0, 255, cv2.THRESH_OTSU) # merge thresholds to globally get foreground masks # thresh = thresh | ad_thresh thresh = thresh + ad_thresh > 0 # postprocess masking to remove small objects and fill holes kernel = np.ones(cfg.impa_mask_postprocess_kernel, np.uint8) for _ in range(cfg.impa_mask_postprocess_passes): thresh = cv2.morphologyEx( (255 * thresh).astype(np.uint8), cv2.MORPH_CLOSE, kernel ) thresh = cv2.morphologyEx( (255 * thresh).astype(np.uint8), cv2.MORPH_OPEN, kernel ) thresh = ndimage.binary_fill_holes(thresh) # cut out area related to measurement/color calibration widget if "project_portable_flume" in str(main_root): thresh[location_cutout[0] :, location_cutout[1] :] = 0 # get labels of connected components labels = measure.label(thresh, connectivity=2, background=0) if PLOTS: full_image_thresh_fig, full_image_thresh_ax = plt.subplots(1, 4, figsize=(21, 9)) full_image_thresh_ax[0].imshow(thresh) full_image_thresh_ax[0].title.set_text('global threshold') full_image_thresh_ax[1].imshow(ad_thresh) full_image_thresh_ax[1].title.set_text('adaptive threshold') full_image_thresh_ax[2].imshow(img) full_image_thresh_ax[2].title.set_text('original rgb') full_image_thresh_ax[3].imshow(labels) full_image_thresh_ax[3].title.set_text('labels') plt.show() # plt.savefig("test.png") # Save the labels as a jpg for the full image if args.save_full_mask_dir is not None: args.save_full_mask_dir.mkdir(parents=True, exist_ok=True) cv2.imwrite( str(args.save_full_mask_dir / f"labels_{fo.stem}.jpg").lower(), (labels).astype(np.uint8), ) if not cfg.impa_save_clips_plus_features: if args.verbose: print("skipping clip generation") continue # get region properties rprop = measure.regionprops(labels) mask = np.ones(thresh.shape, dtype="uint8") # init some stuff sub_df = pd.DataFrame([]) c = 1 # loop through identified regions and get some properties for label in range(len(rprop)): # np.unique(labels): clear_output(wait=True) reg_pro = rprop[label] # skip background if reg_pro.label == 0: continue # skip small objects if reg_pro.area < cfg.impa_area_threshold: # 5000 defauilt continue # get mask for current region of interest current_mask = np.zeros(thresh.shape) current_mask[labels == reg_pro.label] = 1 # coordinates of bounding box corners for current region of interest ( min_row, min_col, max_row, max_col, ) = reg_pro.bbox # cv2.boundingRect(approx) (x, y, w, h) = (min_col, min_row, max_col - min_col, max_row - min_row) # get the bounding box with some buffer (x_e, y_e, w_e, h_e) = ( np.max((x - cfg.impa_bounding_box_buffer, 0)), np.max((y - cfg.impa_bounding_box_buffer, 0)), w + 2 * cfg.impa_bounding_box_buffer, h + 2 * cfg.impa_bounding_box_buffer, ) if PLOTS: clip_crop_fig, clip_crop_ax = plt.subplots(1, 1, figsize=(10, 6)) clip_crop_ax.imshow(img[:, :, [0, 1, 2]], aspect="auto") rect = plt.Rectangle( (x_e, y_e), w_e, h_e, fc="none", ec="black", linewidth=2 ) clip_crop_ax.add_patch(rect) # clear_output(wait = True) display(full_image_thresh_fig) plt.show() # plt.savefig(f"test_mask{c}.png") # exit() # get the crop of the image and the mask crop = img[y_e : y_e + h_e, x_e : x_e + w_e, [2, 1, 0]] crop_hsv = hsv[y_e : y_e + h_e, x_e : x_e + w_e, :] crop_mask = current_mask[y_e : y_e + h_e, x_e : x_e + w_e] crop_im_t = im_t[y_e : y_e + h_e, x_e : x_e + w_e] im_crop_m = crop.reshape(-1, 3)[ crop_mask.reshape( -1, ).astype(bool), :, ] hsv_crop_m = crop_hsv.reshape(-1, 3)[ crop_mask.reshape( -1, ).astype(bool), :, ] # save actual image and mask crops # Avoid "invalid value encountered in true_divide" warning np.seterr(divide="ignore", invalid="ignore") cv2.imwrite( str(outdir / (f"{fo.stem}_{c}_mask.{cfg.impa_image_format}").lower()), (255 * crop_mask / crop_mask).astype(np.uint8), [cv2.IMWRITE_JPEG_QUALITY, 100], ) # reactivate warnings np.seterr(divide="warn", invalid="warn") cv2.imwrite( str(outdir / (f"{fo.stem}_{c}_rgb.{cfg.impa_image_format}").lower()), crop, [cv2.IMWRITE_JPEG_QUALITY, 100], ) # get average color of the crop # not really needed, aren't they # im_crop_cmean = str(np.mean(im_crop_m, axis=0)) # hsv_crop_cmean = str(np.mean(hsv_crop_m, axis=0)) # im_crop_std = str(np.std(im_crop_m, axis=0)) # hsv_crop_std = str(np.std(hsv_crop_m, axis=0)) mask = mask + current_mask * c if PLOTS: clip_fig, clip_ax = plt.subplots(1, 4, figsize=(10, 6)) clip_ax[0].imshow(crop) clip_ax[0].title.set_text('crop') clip_ax[1].imshow(reg_pro.image) # crop_mask) clip_ax[1].title.set_text('binary mask') clip_ax[2].imshow( ( crop * np.transpose(np.tile(crop_mask, (3, 1, 1)), (1, 2, 0)) ).astype(np.uint8) ) clip_ax[2].title.set_text('mask HSV') im_t_crop_m = crop_im_t.reshape(-1, 1)[ crop_mask.reshape( -1, ).astype(bool), :, ] clip_ax[3].hist(im_t_crop_m, bins=50) clip_ax[3].title.set_text('colour histogram') # plt.pause(1) plt.show() sub_df = {} sub_df["input_file"] = raw_image_in sub_df["species"] = raw_image_in.name.split(".")[0] sub_df["png_mask_id"] = c sub_df["reg_lab"] = reg_pro.label sub_df["squareness"] = w / float(h) # sub_df["average_color"] = im_crop_cmean # sub_df["average_color_std"] = im_crop_std # sub_df["average_hsv"] = hsv_crop_cmean # sub_df["average_hsv_std"] = hsv_crop_std sub_df["tight_bb"] = f"({x}, {y}, {w}, {h})" sub_df["large_bb"] = f"({x_e}, {y_e}, {w_e}, {h_e})" sub_df["ell_minor_axis"] = reg_pro.minor_axis_length sub_df["ell_major_axis"] = reg_pro.major_axis_length sub_df["bbox_area"] = reg_pro.bbox_area sub_df["area_px"] = reg_pro.area sub_df["mask_centroid"] = str(reg_pro.centroid) sub_df = pd.DataFrame(data=sub_df, index=[0]) mask_props.append(sub_df) c += 1 if not PLOTS: if mask_props: mask_props = pd.concat(mask_props).reset_index().drop(columns=["index"]) mask_props.to_csv(outdir / "_mask_properties.csv")