{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Classification: finetune\n", "========================\n", "\n", "In this notebook we illustrate how to re-train the models on user's data. Specifically, we remap the last layer of the model to the desired classes, without modifying the model's internal weights; this operation is called finetuning and is not as computationally intensive as re-training the full model. \n", "Regardless, this module greatly benefits from GPU compute, as long as the GPU(s) support CUDA and `nvidia-smi` is configured correctly. \n", "\n", "This module uses two scripts: `classification/main_prepare_learning_sets.py` for preparing the data for training, and `classification/main_classification_finetune.py`, that need to be executed in that order. \n", "\n", "The first step is to import the necessary libraries for `main_prepare_learning_sets.py`: " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import shutil\n", "import sys\n", "import os\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import yaml\n", "\n", "from mzbsuite.utils import cfg_to_arguments" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to declare the running parameters for the script, " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'glob_random_seed': 222,\n", " 'glob_root_folder': '/data/users/luca/mzb-workflow/mzb-workflow/',\n", " 'glob_blobs_folder': '/data/users/luca/mzb-workflow/mzb-workflow/data/derived/blobs/',\n", " 'glob_local_format': 'pdf',\n", " 'model_logger': 'wandb',\n", " 'impa_image_format': 'jpg',\n", " 'impa_clip_areas': [2750, 4900],\n", " 'impa_area_threshold': 5000,\n", " 'impa_gaussian_blur': [21, 21],\n", " 'impa_gaussian_blur_passes': 3,\n", " 'impa_adaptive_threshold_block_size': 351,\n", " 'impa_mask_postprocess_kernel': [11, 11],\n", " 'impa_mask_postprocess_passes': 5,\n", " 'impa_bounding_box_buffer': 200,\n", " 'impa_save_clips_plus_features': True,\n", " 'lset_class_cut': 'order',\n", " 'lset_val_size': 0.1,\n", " 'lset_taxonomy': '/data/users/luca/mzb-workflow/data/MZB_taxonomy.csv',\n", " 'trcl_learning_rate': 0.001,\n", " 'trcl_batch_size': 16,\n", " 'trcl_weight_decay': 0,\n", " 'trcl_step_size_decay': 5,\n", " 'trcl_number_epochs': 10,\n", " 'trcl_save_topk': 1,\n", " 'trcl_num_classes': 8,\n", " 'trcl_model_pretrarch': 'efficientnet-b2',\n", " 'trcl_num_workers': 16,\n", " 'trcl_wandb_project_name': 'mzb-classifiers',\n", " 'trcl_logger': 'wandb',\n", " 'trsk_learning_rate': 0.001,\n", " 'trsk_batch_size': 32,\n", " 'trsk_weight_decay': 0,\n", " 'trsk_step_size_decay': 25,\n", " 'trsk_number_epochs': 400,\n", " 'trsk_save_topk': 1,\n", " 'trsk_num_classes': 2,\n", " 'trsk_model_pretrarch': 'mit_b2',\n", " 'trsk_num_workers': 16,\n", " 'trsk_wandb_project_name': 'mzb-skeletons',\n", " 'trsk_logger': 'wandb',\n", " 'infe_model_ckpt': 'last',\n", " 'infe_num_classes': 8,\n", " 'infe_image_glob': '*_rgb.jpg',\n", " 'skel_class_exclude': 'errors',\n", " 'skel_conv_rate': 131.6625,\n", " 'skel_label_thickness': 3,\n", " 'skel_label_buffer_on_preds': 25,\n", " 'skel_label_clip_with_mask': False,\n", " 'trcl_gpu_ids': None}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ROOT_DIR = Path.cwd()\n", "MODEL=\"convnext-small-vtest-1\"\n", "LSET_FOLD=f\"{ROOT_DIR}/data/mzb_example_data\"\n", "\n", "arguments = {\n", " \"input_dir\": \"/data/shared/mzb-workflow/data/learning_sets/project_portable_flume/curated_learning_sets\", \n", " \"taxonomy_file\": ROOT_DIR.parent.absolute() / \"data/MZB_taxonomy.csv\", \n", " \"output_dir\": ROOT_DIR.parent.absolute() / \"data/mzb_example_data/aggregated_learning_sets\", \n", " \"save_model\": ROOT_DIR.parent.absolute() / f\"models/mzb-classification-models/{MODEL}\", \n", " \"config_file\": ROOT_DIR.parent.absolute() / \"configs/configuration_flume_datasets.yaml\"\n", "}\n", "\n", "with open(str(arguments[\"config_file\"]), \"r\") as f:\n", " cfg = yaml.load(f, Loader=yaml.FullLoader)\n", "\n", "cfg[\"trcl_gpu_ids\"] = None # this sets the number of available GPUs to zero, since this part of the module doesn't benefit from GPU compute. \n", "cfg" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convert these parameters to a dictionary: " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'glob_random_seed': 222, 'glob_root_folder': '/data/users/luca/mzb-workflow/mzb-workflow/', 'glob_blobs_folder': '/data/users/luca/mzb-workflow/mzb-workflow/data/derived/blobs/', 'glob_local_format': 'pdf', 'model_logger': 'wandb', 'impa_image_format': 'jpg', 'impa_clip_areas': [2750, 4900], 'impa_area_threshold': 5000, 'impa_gaussian_blur': [21, 21], 'impa_gaussian_blur_passes': 3, 'impa_adaptive_threshold_block_size': 351, 'impa_mask_postprocess_kernel': [11, 11], 'impa_mask_postprocess_passes': 5, 'impa_bounding_box_buffer': 200, 'impa_save_clips_plus_features': True, 'lset_class_cut': 'order', 'lset_val_size': 0.1, 'lset_taxonomy': '/data/users/luca/mzb-workflow/data/MZB_taxonomy.csv', 'trcl_learning_rate': 0.001, 'trcl_batch_size': 16, 'trcl_weight_decay': 0, 'trcl_step_size_decay': 5, 'trcl_number_epochs': 10, 'trcl_save_topk': 1, 'trcl_num_classes': 8, 'trcl_model_pretrarch': 'efficientnet-b2', 'trcl_num_workers': 16, 'trcl_wandb_project_name': 'mzb-classifiers', 'trcl_logger': 'wandb', 'trsk_learning_rate': 0.001, 'trsk_batch_size': 32, 'trsk_weight_decay': 0, 'trsk_step_size_decay': 25, 'trsk_number_epochs': 400, 'trsk_save_topk': 1, 'trsk_num_classes': 2, 'trsk_model_pretrarch': 'mit_b2', 'trsk_num_workers': 16, 'trsk_wandb_project_name': 'mzb-skeletons', 'trsk_logger': 'wandb', 'infe_model_ckpt': 'last', 'infe_num_classes': 8, 'infe_image_glob': '*_rgb.jpg', 'skel_class_exclude': 'errors', 'skel_conv_rate': 131.6625, 'skel_label_thickness': 3, 'skel_label_buffer_on_preds': 25, 'skel_label_clip_with_mask': False, 'trcl_gpu_ids': None}\n" ] } ], "source": [ "# Transforms configurations dicts to argparse arguments\n", "args = cfg_to_arguments(arguments)\n", "cfg = cfg_to_arguments(cfg)\n", "print(str(cfg))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We next check whether the target directories already exist, and if not create them: " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "np.random.seed(cfg.glob_random_seed)\n", "\n", "# root of raw clip data\n", "root_data = Path(args.input_dir)\n", "outdir = Path(args.output_dir)\n", "outdir.mkdir(parents=True, exist_ok=True)\n", "\n", "# target folders definition\n", "target_trn = outdir / \"trn_set/\"\n", "target_val = outdir / \"val_set/\"\n", "\n", "# check if trn_set and val_set subfolders exist. If so, then interrupt the script.\n", "# This is to make sure that no overwriting happens; prompt the user that they need to specify a different output directory.\n", "if target_trn.exists() or target_val.exists():\n", " raise ValueError(\n", " # print in red and back to normal\n", " f\"\\033[91m Output directory {outdir} already exists. Please specify a different output directory.\\033[0m\"\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now use the specified taxonomic rank in the `lset_class_cut` parameter in the configuration file to cut the provided phylogenetic tree, and reorganize the images in directories corresponding to the this rank. \n", "See the documentation for further details. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cutting phylogenetic tree at: order\n" ] } ], "source": [ "\n", "# make dictionary to recode: key is current classification, value is target reclassification.\n", "# forward fill to get last valid entry and subset to desired column\n", "mzb_taxonomy = pd.read_csv(Path(args.taxonomy_file))\n", "if \"Unnamed: 0\" in mzb_taxonomy.columns:\n", " mzb_taxonomy = mzb_taxonomy.drop(columns=[\"Unnamed: 0\"])\n", "mzb_taxonomy = mzb_taxonomy.ffill(axis=1)\n", "recode_order = dict(\n", " zip(mzb_taxonomy[\"query\"], mzb_taxonomy[cfg.lset_class_cut].str.lower())\n", ")\n", "\n", "print(f\"Cutting phylogenetic tree at: {cfg.lset_class_cut}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we copy the images over into the new folder structure according to the taxonomy: " ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Move files to target folders for all files in the curated learning set\n", "for s_fo in recode_order:\n", " target_folder = target_trn / recode_order[s_fo]\n", " target_folder.mkdir(exist_ok=True, parents=True)\n", " \n", " for file in list((root_data / s_fo).glob(\"*\")):\n", " shutil.copy(file, target_folder)\n", "\n", "# move out the validation set\n", "# make a small val set, 10% or 1 file, what is possible...\n", "size = cfg.lset_val_size\n", "trn_folds = [a.name for a in sorted(list(target_trn.glob(\"*\")))]\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we split the validation set based on the proportion of total images specified by the `lset_val_size` parameter in the configuration file. We recommend at least 10% of the total images for each class. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/errors/32_hf2_plecoptera_01_clip_8_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/errors\n", "/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/errors/32_ob_leuctridae_01_clip_4_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/errors\n", "/data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/trn_set/plecoptera/32_bd_plecoptera_01_clip_2_rgb.jpg into /data/users/luca/mzb-workflow/data/mzb_example_data/aggregated_learning_sets/val_set/plecoptera\n" ] } ], "source": [ "\n", "for s_fo in trn_folds:\n", " target_folder = target_val / s_fo\n", " target_folder.mkdir(exist_ok=True, parents=True)\n", "\n", " list_class = list((target_trn / s_fo).glob(\"*\"))\n", " n_val_sam = np.max((1, np.ceil(0.1 * len(list_class))))\n", "\n", " val_files = np.random.choice(list_class, int(n_val_sam))\n", "\n", " for file in val_files:\n", " try:\n", " shutil.move(str(file), target_folder)\n", " except:\n", " print(f\"{str(file)} into {target_folder}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have the training dataset ready for model training, with a training set and a validation set containing the same classes. \n", "\n", "We move on to the model finetuning, using the script `classification/main_classification_finetune.py`. First we import some additional libraries from PyTorch; " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger\n", "from pytorch_lightning.strategies.ddp import DDPStrategy\n", "\n", "from mzbsuite.classification.mzb_classification_pilmodel import MZBModel\n", "from mzbsuite.utils import cfg_to_arguments, SaveLogCallback\n", "\n", "# Set the thread layer used by MKL\n", "os.environ[\"MKL_THREADING_LAYER\"] = \"GNU\" # this time we set the GPU computing layer to active" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we can launch the training, we need to define a few special parameters, relating to finding the specified monitoring the model training progress over time: " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Define checkpoints callbacks\n", "# best model on validation\n", "best_val_cb = pl.callbacks.ModelCheckpoint(\n", " dirpath=args.save_model,\n", " filename=\"best-val-{epoch}-{step}-{val_loss:.1f}\",\n", " monitor=\"val_loss\",\n", " mode=\"min\",\n", " save_top_k=cfg.trcl_save_topk,\n", ")\n", "\n", "# latest model in training\n", "last_mod_cb = pl.callbacks.ModelCheckpoint(\n", " dirpath=args.save_model,\n", " filename=\"last-{step}\",\n", " every_n_train_steps=50,\n", " save_top_k=cfg.trcl_save_topk,\n", ")\n", "\n", "# Define progress bar callback\n", "pbar_cb = pl.callbacks.progress.TQDMProgressBar(refresh_rate=5)\n", "\n", "# Define logger callback to log training date\n", "trdatelog = SaveLogCallback(model_folder=args.save_model)\n", "\n", "# Define model from config\n", "model = MZBModel(\n", " data_dir=args.input_dir,\n", " pretrained_network=cfg.trcl_model_pretrarch,\n", " learning_rate=cfg.trcl_learning_rate,\n", " batch_size=cfg.trcl_batch_size,\n", " weight_decay=cfg.trcl_weight_decay,\n", " num_workers_loader=cfg.trcl_num_workers,\n", " step_size_decay=cfg.trcl_step_size_decay,\n", " num_classes=cfg.trcl_num_classes,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now check wether a pre-trained model is available, and if there is load the weights from that model. Note that logging model progress requires either a [Weights & Biases](https://wandb.ai/) or [Tensorflow](https://www.tensorflow.org/) account. See the documentation for more details. " ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlpego\u001b[0m (\u001b[33mbiodetect\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "wandb version 0.16.0 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.4" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in ./wandb/run-20231111_161213-1u2u0o5h" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run classifier-efficientnet-b2 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/biodetect/mzb-classifiers" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/biodetect/mzb-classifiers/runs/1u2u0o5h" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n" ] } ], "source": [ "# Check if there is a model to load, if there is, load it and train from there\n", "if args.save_model.is_dir():\n", " if args.verbose:\n", " print(f\"Loading model from {args.save_model}\")\n", " try:\n", " fmodel = list(args.save_model.glob(\"last-*.ckpt\"))[0]\n", " except:\n", " print(\"No last-* model in folder, loading best model\")\n", " fmodel = list(\n", " args.save_model.glob(\"best-val-epoch=*-step=*-val_loss=*.*.ckpt\")\n", " )[-1]\n", "\n", " model = model.load_from_checkpoint(fmodel)\n", "\n", "# Define logger and name of run\n", "name_run = f\"classifier-{cfg.trcl_model_pretrarch}\" # f\"{model.pretrained_network}\"\n", "cbacks = [pbar_cb, best_val_cb, last_mod_cb, trdatelog]\n", "\n", "# Define logger, and use either wandb or tensorboard\n", "if cfg.trcl_logger == \"wandb\":\n", " logger = WandbLogger(\n", " project=cfg.trcl_wandb_project_name, name=name_run if name_run else None\n", " )\n", " logger.watch(model, log=\"all\")\n", "\n", "elif cfg.trcl_logger == \"tensorboard\":\n", " logger = TensorBoardLogger(\n", " save_dir=args.save_model,\n", " name=name_run if name_run else None,\n", " log_graph=True,\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now finally ready to train our model! " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/luca/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/lightning_fabric/connector.py:555: UserWarning: 16 is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!\n", " rank_zero_warn(\n", "Using 16bit Automatic Mixed Precision (AMP)\n" ] }, { "ename": "MisconfigurationException", "evalue": "`Trainer(strategy=)` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible strategies: `Fabric(strategy='dp'|'ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mMisconfigurationException\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/data/users/luca/mzb-workflow/notebooks/classification_finetune.ipynb Cell 22\u001b[0m line \u001b[0;36m2\n\u001b[1;32m 1\u001b[0m \u001b[39m# instantiate trainer and train\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m trainer \u001b[39m=\u001b[39m pl\u001b[39m.\u001b[39;49mTrainer(\n\u001b[1;32m 3\u001b[0m accelerator\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mauto\u001b[39;49m\u001b[39m\"\u001b[39;49m, \u001b[39m# cfg.trcl_num_gpus outdated\u001b[39;49;00m\n\u001b[1;32m 4\u001b[0m max_epochs\u001b[39m=\u001b[39;49mcfg\u001b[39m.\u001b[39;49mtrcl_number_epochs,\n\u001b[1;32m 5\u001b[0m strategy\u001b[39m=\u001b[39;49mDDPStrategy(\n\u001b[1;32m 6\u001b[0m find_unused_parameters\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m\n\u001b[1;32m 7\u001b[0m ), \u001b[39m# TODO: check how to use in notebook\u001b[39;49;00m\n\u001b[1;32m 8\u001b[0m precision\u001b[39m=\u001b[39;49m\u001b[39m16\u001b[39;49m,\n\u001b[1;32m 9\u001b[0m callbacks\u001b[39m=\u001b[39;49mcbacks,\n\u001b[1;32m 10\u001b[0m logger\u001b[39m=\u001b[39;49mlogger,\n\u001b[1;32m 11\u001b[0m log_every_n_steps\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m\n\u001b[1;32m 12\u001b[0m \u001b[39m# profiler=\"simple\",\u001b[39;49;00m\n\u001b[1;32m 13\u001b[0m )\n\u001b[1;32m 15\u001b[0m trainer\u001b[39m.\u001b[39mfit(model)\n", "File \u001b[0;32m~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/utilities/argparse.py:69\u001b[0m, in \u001b[0;36m_defaults_from_env_vars..insert_env_defaults\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 66\u001b[0m kwargs \u001b[39m=\u001b[39m \u001b[39mdict\u001b[39m(\u001b[39mlist\u001b[39m(env_variables\u001b[39m.\u001b[39mitems()) \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(kwargs\u001b[39m.\u001b[39mitems()))\n\u001b[1;32m 68\u001b[0m \u001b[39m# all args were already moved to kwargs\u001b[39;00m\n\u001b[0;32m---> 69\u001b[0m \u001b[39mreturn\u001b[39;00m fn(\u001b[39mself\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", "File \u001b[0;32m~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:398\u001b[0m, in \u001b[0;36mTrainer.__init__\u001b[0;34m(self, accelerator, strategy, devices, num_nodes, precision, logger, callbacks, fast_dev_run, max_epochs, min_epochs, max_steps, min_steps, max_time, limit_train_batches, limit_val_batches, limit_test_batches, limit_predict_batches, overfit_batches, val_check_interval, check_val_every_n_epoch, num_sanity_val_steps, log_every_n_steps, enable_checkpointing, enable_progress_bar, enable_model_summary, accumulate_grad_batches, gradient_clip_val, gradient_clip_algorithm, deterministic, benchmark, inference_mode, use_distributed_sampler, profiler, detect_anomaly, barebones, plugins, sync_batchnorm, reload_dataloaders_every_n_epochs, default_root_dir)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[39m# init connectors\u001b[39;00m\n\u001b[1;32m 396\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_data_connector \u001b[39m=\u001b[39m _DataConnector(\u001b[39mself\u001b[39m)\n\u001b[0;32m--> 398\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_accelerator_connector \u001b[39m=\u001b[39m _AcceleratorConnector(\n\u001b[1;32m 399\u001b[0m devices\u001b[39m=\u001b[39;49mdevices,\n\u001b[1;32m 400\u001b[0m accelerator\u001b[39m=\u001b[39;49maccelerator,\n\u001b[1;32m 401\u001b[0m strategy\u001b[39m=\u001b[39;49mstrategy,\n\u001b[1;32m 402\u001b[0m num_nodes\u001b[39m=\u001b[39;49mnum_nodes,\n\u001b[1;32m 403\u001b[0m sync_batchnorm\u001b[39m=\u001b[39;49msync_batchnorm,\n\u001b[1;32m 404\u001b[0m benchmark\u001b[39m=\u001b[39;49mbenchmark,\n\u001b[1;32m 405\u001b[0m use_distributed_sampler\u001b[39m=\u001b[39;49muse_distributed_sampler,\n\u001b[1;32m 406\u001b[0m deterministic\u001b[39m=\u001b[39;49mdeterministic,\n\u001b[1;32m 407\u001b[0m precision\u001b[39m=\u001b[39;49mprecision,\n\u001b[1;32m 408\u001b[0m plugins\u001b[39m=\u001b[39;49mplugins,\n\u001b[1;32m 409\u001b[0m )\n\u001b[1;32m 410\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_logger_connector \u001b[39m=\u001b[39m _LoggerConnector(\u001b[39mself\u001b[39m)\n\u001b[1;32m 411\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_callback_connector \u001b[39m=\u001b[39m _CallbackConnector(\u001b[39mself\u001b[39m)\n", "File \u001b[0;32m~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:173\u001b[0m, in \u001b[0;36m_AcceleratorConnector.__init__\u001b[0;34m(self, devices, num_nodes, accelerator, strategy, plugins, precision, sync_batchnorm, benchmark, use_distributed_sampler, deterministic)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mprecision_plugin \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_check_and_init_precision()\n\u001b[1;32m 172\u001b[0m \u001b[39m# 6. Instantiate Strategy - Part 2\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_lazy_init_strategy()\n", "File \u001b[0;32m~/mambaforge/envs/mzbsuite/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:577\u001b[0m, in \u001b[0;36m_AcceleratorConnector._lazy_init_strategy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstrategy\u001b[39m.\u001b[39m_configure_launcher()\n\u001b[1;32m 576\u001b[0m \u001b[39mif\u001b[39;00m _IS_INTERACTIVE \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstrategy\u001b[39m.\u001b[39mlauncher \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstrategy\u001b[39m.\u001b[39mlauncher\u001b[39m.\u001b[39mis_interactive_compatible:\n\u001b[0;32m--> 577\u001b[0m \u001b[39mraise\u001b[39;00m MisconfigurationException(\n\u001b[1;32m 578\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m`Trainer(strategy=\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_strategy_flag\u001b[39m!r}\u001b[39;00m\u001b[39m)` is not compatible with an interactive\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 579\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m environment. Run your code as a script, or choose one of the compatible strategies:\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 580\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m `Fabric(strategy=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mdp\u001b[39m\u001b[39m'\u001b[39m\u001b[39m|\u001b[39m\u001b[39m'\u001b[39m\u001b[39mddp_notebook\u001b[39m\u001b[39m'\u001b[39m\u001b[39m)`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 581\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m In case you are spawning processes yourself, make sure to include the Trainer\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 582\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m creation inside the worker function.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 583\u001b[0m )\n\u001b[1;32m 585\u001b[0m \u001b[39m# TODO: should be moved to _check_strategy_and_fallback().\u001b[39;00m\n\u001b[1;32m 586\u001b[0m \u001b[39m# Current test check precision first, so keep this check here to meet error order\u001b[39;00m\n\u001b[1;32m 587\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39maccelerator, TPUAccelerator) \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(\n\u001b[1;32m 588\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstrategy, (SingleTPUStrategy, XLAStrategy)\n\u001b[1;32m 589\u001b[0m ):\n", "\u001b[0;31mMisconfigurationException\u001b[0m: `Trainer(strategy=)` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible strategies: `Fabric(strategy='dp'|'ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function." ] } ], "source": [ "\n", "# instantiate trainer and train\n", "trainer = pl.Trainer(\n", " accelerator=\"auto\", # cfg.trcl_num_gpus outdated\n", " max_epochs=cfg.trcl_number_epochs,\n", " strategy=DDPStrategy(\n", " find_unused_parameters=False\n", " ), # TODO: check how to use in notebook\n", " precision=16,\n", " callbacks=cbacks,\n", " logger=logger,\n", " log_every_n_steps=1\n", " # profiler=\"simple\",\n", ")\n", "\n", "trainer.fit(model)" ] } ], "metadata": { "kernelspec": { "display_name": "mzbfull", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }