{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Image segmentation\n", "==================\n", "\n", "In this notebook we illustrate how to use the script `scripts/image_parsing/main_raw_to_clips.py` to segment (i.e. extract) clips containing a single organisms from large-pane images containing multiple organisms. \n", "\n", "As a first, step, import the necessary packages, including the custom functions of this repository `msbsuite.utils` (if you have trouble importing this package, refer back to the `Installation` section of the documentation). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import sys\n", "from pathlib import Path\n", "import cv2\n", "import numpy as np\n", "import pandas as pd\n", "import yaml\n", "from matplotlib import pyplot as plt\n", "from scipy import ndimage\n", "from skimage import feature, measure, morphology, segmentation\n", "from tqdm import tqdm\n", "\n", "from mzbsuite.utils import cfg_to_arguments\n", "\n", "from notebook.services.config import ConfigManager\n", "cm = ConfigManager().update('notebook', {'limit_output': 1000})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now need to declare the parameters to tell the script where to find the files and where to save its outputs. In this notebook, we pass these arguments as a dictionary to Python, rather than variables in a shell (`.sh`) script. \n", "\n", "You need to have downloaded the example dataset in order for this cell to compile properly. Alternatively you can change the file paths to the locations of folders of your own dataset on the `arguments = {}` block; the path is relative to where this notebook is located. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ROOT_DIR = Path(\"/data/shared/mzb-workflow/docs\")\n", "\n", "arguments = {\n", " \"input_dir\": ROOT_DIR.parent.absolute() / \"data/mzb_example_data/raw_img/\", \n", " \"output_dir\": ROOT_DIR.parent.absolute() / \"data/derived/mzb_example_data/\", \n", " \"save_full_mask_dir\": ROOT_DIR.parent.absolute() / \"data/derived/mzb_example_data/full_image_masks/\", \n", " \"config_file\": ROOT_DIR.parent.absolute() / \"configs/configuration_flume_datasets.yaml\", \n", "}\n", " \n", "with open(str(arguments[\"config_file\"]), \"r\") as f:\n", " cfg = yaml.load(f, Loader=yaml.FullLoader)\n", "\n", "cfg[\"trcl_gpu_ids\"] = None # this sets the number of available GPUs to zero, since this script doesn't benefit from GPU compute. \n", "cfg" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `cfg` variable should display a portion of the list of parameters in the configuration file. \n", "\n", "Now we use custom function `cfg_to_arguments` to parse the parameters we have just supplied and the parameters in the configuration file: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "args = cfg_to_arguments(arguments)\n", "cfg = cfg_to_arguments(cfg)\n", "print(str(cfg))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we check whether the output directories already exist, and if not create them. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define paths\n", "main_root = Path(args.input_dir)\n", "outdir = Path(args.output_dir)\n", "outdir.mkdir(parents=True, exist_ok=True)\n", "\n", "if args.save_full_mask_dir is not None:\n", " args.save_full_mask_dir = Path(args.save_full_mask_dir)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Parse the contents of the input folder and standardise filenames, and print how many images are going to be processed. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# get list of files to process\n", "files_proc = list(main_root.glob(f\"**/*.{cfg.impa_image_format}\"))\n", "# make sure weird capitalization doesn't cause issues\n", "files_proc.extend(list(main_root.glob(f\"**/*.{cfg.impa_image_format.upper()}\")))\n", "files_proc = [a for a in files_proc if \"mask\" not in str(a)]\n", "files_proc.sort()\n", "\n", "print(f\"Parsing {len(files_proc)} files\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If a clip area is defined, for instance if there is a reference scale in the same spot in all the images, this area is earmarked for exclusion in later processing. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if cfg.impa_clip_areas is not None:\n", " location_cutout = [int(a) for a in cfg.impa_clip_areas]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the `PLOTS` variable is True, then the script will print out a summary plot for each image and for each individual clip being generated. If you don't want plots being generated, change the value to False. If you would like to save each plot as a file, you can uncomment (i.e. remove `# `) the lines `plt.savefig()` in the loop below. \n", "\n", "> ⚠️ WARNING: this can be computationally intensive and can potentially crash the notebook if a large number number of outputs is generated! " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PLOTS = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define a normalisation function to flatten the pixel values of images (this helps with downstream processing). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define quick normalization function\n", "norm = lambda x: (x - np.min(x)) / (np.max(x) - np.min(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is the main loop that processes the images into clips, and will also produce figures if `PLOTS = True`. If `PLOTS = False`, the script will save a `.csv` file with information about each image and clips generated from it, as well as other information such as bounding box coordinates, pixel areas of the mask, etc. \n", "\n", "For further details about the logic fo this script please refer to the explanation in the section `Segmentation` under `Processing scripts` in the documentation. \n", "\n", "> ⚠️ WARNING: depending on the number of images and how many organisms are present, the processing time of the loop can be considerable. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ### ATTEMPT AT UPDATING THE FIGURE IN-PLACE, INSTEAD OF GENERATING NEW FIGURES ALL THE TIME... \n", "\n", "# import numpy as np\n", "# import matplotlib.pyplot as plt\n", "# from IPython.display import display, clear_output\n", "\n", "# fig = plt.figure()\n", "# ax = fig.add_subplot(1, 1, 1) \n", "\n", "# for i in range(21):\n", "# ax.set_xlim(0, 20)\n", " \n", "# ax.plot(i, 1, marker='x')\n", "# display(fig)\n", " \n", "# clear_output(wait = True)\n", "# plt.pause(0.5)\n", "\n", "# # ?display" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### EXPERIMENTING WITH clear_output... \n", "\n", "from random import uniform\n", "import time\n", "from IPython.display import display, clear_output\n", "\n", "def black_box():\n", " i = 1\n", " while i <= 5:\n", " clear_output(wait=True)\n", " display('Iteration '+str(i)+' Score: '+str(uniform(0, 1)))\n", " time.sleep(1)\n", " i += 1\n", "\n", "black_box()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "iterator = tqdm(files_proc, total=len(files_proc))\n", "for i, fo in enumerate(iterator):\n", " \n", " mask_props = []\n", "\n", "\n", " # get image path\n", " raw_image_in = fo\n", " full_path_raw_image_in = fo.resolve()\n", "\n", " # read image and convert to HSV\n", " img = cv2.imread(str(full_path_raw_image_in))[:, :, [2, 1, 0]]\n", "\n", " hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)\n", " im_t = hsv[:, :, 0].copy()\n", " im_t = (255 * norm(np.mean(hsv[:, :, :2], axis=2))).astype(np.uint8)\n", "\n", " # filter image with some iterations of gaussian blur\n", " for _ in range(cfg.impa_gaussian_blur_passes):\n", " im_t = cv2.GaussianBlur(im_t, tuple(cfg.impa_gaussian_blur), 0)\n", "\n", " # prepare for morphological reconstruction\n", " seed = np.copy(im_t)\n", " seed[1:-1, 1:-1] = im_t.min()\n", " mask = np.copy(im_t)\n", "\n", " # remove the background\n", " dil = morphology.reconstruction(seed, im_t, method=\"dilation\")\n", " im_t = (im_t - dil).astype(np.uint8)\n", "\n", " # adaptive local thresholding of foreground vs background\n", " # weighted cross correlation with gaussian filter\n", " ad_thresh = cv2.adaptiveThreshold(\n", " im_t,\n", " 255,\n", " cv2.ADAPTIVE_THRESH_GAUSSIAN_C,\n", " cv2.THRESH_BINARY,\n", " cfg.impa_adaptive_threshold_block_size,\n", " -2,\n", " )\n", " # additional global threhsold to remove foreground vs background\n", " t, thresh = cv2.threshold(im_t, 0, 255, cv2.THRESH_OTSU)\n", "\n", " # merge thresholds to globally get foreground masks\n", " # thresh = thresh | ad_thresh\n", " thresh = thresh + ad_thresh > 0\n", "\n", " # postprocess masking to remove small objects and fill holes\n", " kernel = np.ones(cfg.impa_mask_postprocess_kernel, np.uint8)\n", " for _ in range(cfg.impa_mask_postprocess_passes):\n", " thresh = cv2.morphologyEx(\n", " (255 * thresh).astype(np.uint8), cv2.MORPH_CLOSE, kernel\n", " )\n", " thresh = cv2.morphologyEx(\n", " (255 * thresh).astype(np.uint8), cv2.MORPH_OPEN, kernel\n", " )\n", " thresh = ndimage.binary_fill_holes(thresh)\n", "\n", " # cut out area related to measurement/color calibration widget\n", " if \"project_portable_flume\" in str(main_root):\n", " thresh[location_cutout[0] :, location_cutout[1] :] = 0\n", "\n", " # get labels of connected components\n", " labels = measure.label(thresh, connectivity=2, background=0)\n", "\n", " if PLOTS: \n", " full_image_thresh_fig, full_image_thresh_ax = plt.subplots(1, 4, figsize=(21, 9))\n", " full_image_thresh_ax[0].imshow(thresh)\n", " full_image_thresh_ax[0].title.set_text('global threshold')\n", " full_image_thresh_ax[1].imshow(ad_thresh)\n", " full_image_thresh_ax[1].title.set_text('adaptive threshold')\n", " full_image_thresh_ax[2].imshow(img)\n", " full_image_thresh_ax[2].title.set_text('original rgb')\n", " full_image_thresh_ax[3].imshow(labels)\n", " full_image_thresh_ax[3].title.set_text('labels')\n", " plt.show() \n", " # plt.savefig(\"test.png\")\n", "\n", " # Save the labels as a jpg for the full image\n", " if args.save_full_mask_dir is not None:\n", " args.save_full_mask_dir.mkdir(parents=True, exist_ok=True)\n", " cv2.imwrite(\n", " str(args.save_full_mask_dir / f\"labels_{fo.stem}.jpg\").lower(),\n", " (labels).astype(np.uint8),\n", " )\n", " if not cfg.impa_save_clips_plus_features:\n", " if args.verbose:\n", " print(\"skipping clip generation\")\n", " continue\n", "\n", " # get region properties\n", " rprop = measure.regionprops(labels)\n", " mask = np.ones(thresh.shape, dtype=\"uint8\")\n", "\n", " # init some stuff\n", " sub_df = pd.DataFrame([])\n", " c = 1\n", " # loop through identified regions and get some properties\n", " for label in range(len(rprop)): # np.unique(labels):\n", " \n", " clear_output(wait=True)\n", " \n", " reg_pro = rprop[label]\n", "\n", " # skip background\n", " if reg_pro.label == 0:\n", " continue\n", "\n", " # skip small objects\n", " if reg_pro.area < cfg.impa_area_threshold: # 5000 defauilt\n", " continue\n", "\n", " # get mask for current region of interest\n", " current_mask = np.zeros(thresh.shape)\n", " current_mask[labels == reg_pro.label] = 1\n", "\n", " # coordinates of bounding box corners for current region of interest\n", " (\n", " min_row,\n", " min_col,\n", " max_row,\n", " max_col,\n", " ) = reg_pro.bbox # cv2.boundingRect(approx)\n", " (x, y, w, h) = (min_col, min_row, max_col - min_col, max_row - min_row)\n", "\n", " # get the bounding box with some buffer\n", " (x_e, y_e, w_e, h_e) = (\n", " np.max((x - cfg.impa_bounding_box_buffer, 0)),\n", " np.max((y - cfg.impa_bounding_box_buffer, 0)),\n", " w + 2 * cfg.impa_bounding_box_buffer,\n", " h + 2 * cfg.impa_bounding_box_buffer,\n", " )\n", "\n", " if PLOTS: \n", " clip_crop_fig, clip_crop_ax = plt.subplots(1, 1, figsize=(10, 6))\n", " clip_crop_ax.imshow(img[:, :, [0, 1, 2]], aspect=\"auto\")\n", " rect = plt.Rectangle(\n", " (x_e, y_e), w_e, h_e, fc=\"none\", ec=\"black\", linewidth=2\n", " )\n", " clip_crop_ax.add_patch(rect)\n", " \n", " # clear_output(wait = True)\n", " display(full_image_thresh_fig)\n", " \n", " plt.show()\n", " # plt.savefig(f\"test_mask{c}.png\")\n", " # exit()\n", "\n", " # get the crop of the image and the mask\n", " crop = img[y_e : y_e + h_e, x_e : x_e + w_e, [2, 1, 0]]\n", " crop_hsv = hsv[y_e : y_e + h_e, x_e : x_e + w_e, :]\n", " crop_mask = current_mask[y_e : y_e + h_e, x_e : x_e + w_e]\n", " crop_im_t = im_t[y_e : y_e + h_e, x_e : x_e + w_e]\n", "\n", " im_crop_m = crop.reshape(-1, 3)[\n", " crop_mask.reshape(\n", " -1,\n", " ).astype(bool),\n", " :,\n", " ]\n", " hsv_crop_m = crop_hsv.reshape(-1, 3)[\n", " crop_mask.reshape(\n", " -1,\n", " ).astype(bool),\n", " :,\n", " ]\n", "\n", " # save actual image and mask crops\n", " # Avoid \"invalid value encountered in true_divide\" warning\n", " np.seterr(divide=\"ignore\", invalid=\"ignore\")\n", " cv2.imwrite(\n", " str(outdir / (f\"{fo.stem}_{c}_mask.{cfg.impa_image_format}\").lower()),\n", " (255 * crop_mask / crop_mask).astype(np.uint8),\n", " [cv2.IMWRITE_JPEG_QUALITY, 100],\n", " )\n", "\n", " # reactivate warnings\n", " np.seterr(divide=\"warn\", invalid=\"warn\")\n", "\n", " cv2.imwrite(\n", " str(outdir / (f\"{fo.stem}_{c}_rgb.{cfg.impa_image_format}\").lower()),\n", " crop,\n", " [cv2.IMWRITE_JPEG_QUALITY, 100],\n", " )\n", " # get average color of the crop\n", " # not really needed, aren't they\n", " # im_crop_cmean = str(np.mean(im_crop_m, axis=0))\n", " # hsv_crop_cmean = str(np.mean(hsv_crop_m, axis=0))\n", "\n", " # im_crop_std = str(np.std(im_crop_m, axis=0))\n", " # hsv_crop_std = str(np.std(hsv_crop_m, axis=0))\n", "\n", " mask = mask + current_mask * c\n", "\n", " if PLOTS:\n", " clip_fig, clip_ax = plt.subplots(1, 4, figsize=(10, 6))\n", " clip_ax[0].imshow(crop)\n", " clip_ax[0].title.set_text('crop')\n", " clip_ax[1].imshow(reg_pro.image) # crop_mask)\n", " clip_ax[1].title.set_text('binary mask')\n", " clip_ax[2].imshow(\n", " (\n", " crop * np.transpose(np.tile(crop_mask, (3, 1, 1)), (1, 2, 0))\n", " ).astype(np.uint8)\n", " )\n", " clip_ax[2].title.set_text('mask HSV')\n", " im_t_crop_m = crop_im_t.reshape(-1, 1)[\n", " crop_mask.reshape(\n", " -1,\n", " ).astype(bool),\n", " :,\n", " ]\n", " clip_ax[3].hist(im_t_crop_m, bins=50)\n", " clip_ax[3].title.set_text('colour histogram')\n", " # plt.pause(1)\n", " plt.show() \n", "\n", " sub_df = {}\n", " sub_df[\"input_file\"] = raw_image_in\n", " sub_df[\"species\"] = raw_image_in.name.split(\".\")[0]\n", " sub_df[\"png_mask_id\"] = c\n", " sub_df[\"reg_lab\"] = reg_pro.label\n", " sub_df[\"squareness\"] = w / float(h)\n", " # sub_df[\"average_color\"] = im_crop_cmean\n", " # sub_df[\"average_color_std\"] = im_crop_std\n", " # sub_df[\"average_hsv\"] = hsv_crop_cmean\n", " # sub_df[\"average_hsv_std\"] = hsv_crop_std\n", " sub_df[\"tight_bb\"] = f\"({x}, {y}, {w}, {h})\"\n", " sub_df[\"large_bb\"] = f\"({x_e}, {y_e}, {w_e}, {h_e})\"\n", " sub_df[\"ell_minor_axis\"] = reg_pro.minor_axis_length\n", " sub_df[\"ell_major_axis\"] = reg_pro.major_axis_length\n", " sub_df[\"bbox_area\"] = reg_pro.bbox_area\n", " sub_df[\"area_px\"] = reg_pro.area\n", " sub_df[\"mask_centroid\"] = str(reg_pro.centroid)\n", " sub_df = pd.DataFrame(data=sub_df, index=[0])\n", "\n", " mask_props.append(sub_df)\n", " c += 1\n", " \n", "if not PLOTS:\n", " if mask_props:\n", " mask_props = pd.concat(mask_props).reset_index().drop(columns=[\"index\"])\n", " mask_props.to_csv(outdir / \"_mask_properties.csv\")" ] } ], "metadata": { "kernelspec": { "display_name": "mzbfull2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }