"""
Module containing utility functions for mzbsuite
C 2023, M. Volpi, Swiss Data Science Center
"""
from torchvision import models
from pathlib import Path
# import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import (
mean_absolute_error,
mean_squared_error,
max_error,
median_absolute_error,
r2_score,
explained_variance_score,
)
[docs]def noneparse(value):
"""
Helper function to parse None values from YAML files
Parameters
----------
value: string
string to be parsed
Returns
-------
value: string or None
parsed string
"""
if value.lower() == "none":
return None
return value
[docs]class cfg_to_arguments(object):
"""
This class is used to convert a dictionary to an object and extend the argparser.
In the __init__ method, we iterate over the dictionary and add each key as an attribute to the object.
Input is a dictionary, output is an object, that mimicks the argparse object.
Example
-------
cfg = {'a': 1, 'b': 2}
args = cfg_to_arguments(cfg)
print(args.a) # 1
print(args.b) # 2
cfg can be from configs stored in YAML file, a JSON file, or a dictionary, whatever you prefer.
"""
def __init__(self, args):
"""
Parameters
----------
args: dict
dictionary of arguments
"""
for key in args:
setattr(self, key, args[key])
def __str__(self):
"""Prints the object as a string"""
return self.__dict__.__str__()
[docs]def regression_report(y_true, y_pred, PRINT=True):
"""
Helper function to print regression metrics. Taken and adapted from
https://github.com/scikit-learn/scikit-learn/issues/18454#issue-708338254
Parameters
----------
y_true: np.array
ground truth values
y_pred: np.array
predicted values
PRINT: bool
whether to print the metrics or not
Returns
-------
metrics: list
list of tuples with the name of the metric and its value
"""
error = y_true - y_pred
percentile = [5, 25, 50, 75, 95]
percentile_value = np.percentile(error, percentile)
metrics = [
("mean absolute error", mean_absolute_error(y_true, y_pred)),
("median absolute error", median_absolute_error(y_true, y_pred)),
("mean squared error", mean_squared_error(y_true, y_pred)),
("max error", max_error(y_true, y_pred)),
("r2 score", r2_score(y_true, y_pred)),
("explained variance score", explained_variance_score(y_true, y_pred)),
]
if PRINT:
print("Metrics for regression:")
for metric_name, metric_value in metrics:
print(f"{metric_name:>25s}: {metric_value: >20.3f}")
print("\nPercentiles:")
for p, pv in zip(percentile, percentile_value):
print(f"{p: 25d}: {pv:>20.3f}")
return metrics
[docs]def read_pretrained_model(architecture, n_class):
"""
Helper script to load models compactly from pytorch model zoo and prepare them for Hummingbird finetuning
Parameters
----------
architecture: str
name of the model to load
n_class: int
number of classes to finetune the model for
Returns
-------
model : pytorch model
model with last layer replaced with a linear layer with n_class outputs
"""
architecture = architecture.lower()
if architecture == "vgg":
model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
in_feat = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(
in_features=in_feat, out_features=n_class, bias=True
)
for param in model.features.parameters():
param.requires_grad = False
for param in model.classifier.parameters():
if np.any([a == 2 for a in param.shape]):
pass
else:
param.requires_grad = False
elif architecture == "resnet18":
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(
in_features=model.fc.in_features, out_features=n_class, bias=True
)
# Freeze base feature extraction trunk:
for param in model.parameters():
param.requires_grad = True
for param in model.fc.parameters():
param.requires_grad = True
elif architecture == "resnet50":
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(
in_features=model.fc.in_features, out_features=n_class, bias=True
)
# Freeze base feature extraction trunk:
for param in model.parameters():
param.requires_grad = True
# for param in model.fc.parameters():
# param.requires_grad = True
elif architecture == "densenet161":
model = models.densenet161(weights=models.DenseNet161_Weights.IMAGENET1K_V1)
model.classifier = nn.Linear(
in_features=model.classifier.in_features, out_features=n_class, bias=True
)
for param in model.parameters():
param.requires_grad = True
for param in model.classifier.parameters():
param.requires_grad = True
elif architecture == "mobilenet":
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.classifier[1] = nn.Linear(
in_features=model.classifier[1].in_features,
out_features=n_class,
bias=True,
)
for param in model.parameters():
param.requires_grad = False
for param in model.classifier[1].parameters():
param.requires_grad = True
elif architecture == "efficientnet-b2":
model = models.efficientnet_b2(
weights=models.EfficientNet_B2_Weights.IMAGENET1K_V1
)
model.classifier[1] = nn.Linear(
in_features=model.classifier[1].in_features,
out_features=n_class,
bias=True,
)
for param in model.parameters():
param.requires_grad = False
for param in model.classifier[1].parameters():
param.requires_grad = True
elif architecture == "efficientnet-b1":
model = models.efficientnet_b1(
weights=models.EfficientNet_B1_Weights.IMAGENET1K_V1
)
model.classifier[1] = nn.Linear(
in_features=model.classifier[1].in_features,
out_features=n_class,
bias=True,
)
for param in model.parameters():
param.requires_grad = False
for param in model.classifier[1].parameters():
param.requires_grad = True
elif architecture == "vit16":
model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
model.heads.head = nn.Linear(
in_features=model.heads.head.in_features,
out_features=n_class,
bias=True,
)
for param in model.parameters():
param.requires_grad = False
for param in model.heads.head.parameters():
param.requires_grad = True
elif architecture == "convnext-small":
model = models.convnext_small(
weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1
)
model.classifier[2] = nn.Linear(
in_features=model.classifier[2].in_features, out_features=n_class, bias=True
)
for param in model.parameters():
param.requires_grad = False
for param in model.classifier[2].parameters():
param.requires_grad = True
else:
raise OSError("Model not found")
return model
[docs]def find_checkpoints(dirs=Path("lightning_logs"), version=None, log="val"):
"""
Find the checkpoints for a given log
Parameters
----------
dirs: Path (default: Path("lightning_logs"))
path to the lightning_logs folder
version: str (default: None)
version of the log to use
Returns
-------
chkp: str
list of paths to checkpoints
"""
if version:
ch_sf = list(dirs.glob(f"{version}/checkpoints/*.ckpt"))
else: # pick last
ch_sp = [a.parents[1] for a in dirs.glob("**/*.ckpt")]
ch_sp.sort()
ch_sf = list(ch_sp[-1].glob("**/*.ckpt"))
chkp = [a for a in ch_sf if log in str(a.name)]
return chkp
from pytorch_lightning.callbacks import Callback
from datetime import datetime
import yaml
[docs]class SaveLogCallback(Callback):
"""
Callback to save the log of the training
TODO: will need to be updated to save the log of the training in more detail and in a more
structured way
"""
def __init__(self, model_folder):
# super().__init__()
self.model_folder = model_folder
# def on_train_start(self, trainer, pl_module):
# self.model_folder = self.model_folder / "checkpoints"
# # store locally some meta info, if file exists, append to it
# # this in each model folder
# flog = self.model_folder.parents[0] / "trn_date.yaml"
# flag = "a" if flog.is_file() else "w"
# with open(self.model_folder.parents[0] / "trn_date.yaml", flag) as f:
# yaml.safe_dump(
# {"train-date-start": datetime.now().strftime("%Y-%m-%d %H:%M:%S")}, f
# )
# # this is a global file containing all the training dates for all models
# flog = self.model_folder.parents[1] / "all_trn_date.yaml"
# flag = "a" if flog.is_file() else "w"
# with open(self.model_folder.parents[1] / "all_trn_date.yaml", flag) as f:
# yaml.safe_dump(
# {
# f"{self.model_folder.parents[1].name}": {
# "start": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# }
# },
# f,
# )
[docs] def on_train_end(self, trainer, pl_module):
"""
Save the end date of the training
"""
# store locally some meta info, if file exists, append to it
# this in each model folder
flog = self.model_folder.parents[0] / "trn_date.yaml"
flag = "a" if flog.is_file() else "w"
with open(self.model_folder.parents[0] / "trn_date.yaml", flag) as f:
yaml.safe_dump(
{"train-date-end": datetime.now().strftime("%Y-%m-%d %H:%M:%S")}, f
)
# this is a global file containing all the training dates for all models
flog = self.model_folder.parents[1] / "all_trn_date.yaml"
flag = "a" if flog.is_file() else "w"
with open(self.model_folder.parents[1] / "all_trn_date.yaml", flag) as f:
yaml.safe_dump(
{
f"{self.model_folder.parents[0].name}": datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
},
f,
)