import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.metrics import MeanAbsoluteError, Accuracy, Precision, Recall, AUC # MeanSquaredError
import tensorflow as tf
from tensorflow.keras.metrics import Metric
# from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix, roc_auc_score, roc_curve, auc
from abc import ABC, abstractmethod
import yaml
from gaitmod.utils.utils import load_config
[docs]
class BaseModel(ABC):
def __init__(self, config_path=None, **kwargs):
self.config = load_config(config_path) if config_path else {}
self.model_type = kwargs.get("model_type", self.config['model']['model_type'].lower())
self.model_type_suffix = self.config['model'].get("model_type_suffix", "")
self.metrics = self.initialize_metrics()
self.model = None
[docs]
def save(self, model_path):
"""Save the underlying Keras model."""
if self.model is not None:
self.model.save(model_path)
else:
raise ValueError("The model is not built or trained yet.")
[docs]
def train(self, X, y, train_idx, test_idx, callbacks=None):
"""
Trains a model for a specific fold.
Args:
model: The model to be trained.
X: Input features.
y: Target values.
train_idx: Training indices.
test_idx: Testing indices.
callbacks: Callbacks for LSTM training (optional).
Returns:
Dictionary containing predictions and true values for the fold.
"""
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = y[train_idx], y[test_idx]
# Train the model
if self.model_type == 'lstm':
self.fit(X_train, y_train, callbacks)
else:
self.fit(X_train, y_train)
y_pred = self.predict(X_test)
return {'y_test': y_test, 'y_pred': y_pred}
[docs]
@staticmethod
def load(model_path, model_type, config_path, **kwargs):
"""Load a Keras model and wrap it in a concrete subclass of BaseModel."""
loaded_keras_model = load_model(model_path)
# Import the model type dynamically to avoid circular imports
if model_type == "lstm":
from gaitmod.models.classification_models import ClassificationLSTMModel # Delayed import
base_model = ClassificationLSTMModel(model_type=model_type, config_path=config_path, **kwargs)
else:
raise ValueError(f"Unsupported model type: {model_type}")
base_model.model = loaded_keras_model
return base_model
[docs]
def initialize_metrics(self):
# Fetch metrics configuration from the config file
metric_config = self.config.get("evaluation", {}).get("metrics", {})
# Ensure metric_config is a dictionary
if not isinstance(metric_config, dict):
metric_config = {}
# List to store actual metric objects
metrics_list = []
# if "regression" in self.model_type_suffix.lower():
# if metric_config.get("mse", True):
# metrics_list.append(self.MyMeanSquaredError(name="mse"))
# if metric_config.get("mae", True):
# metrics_list.append(self.MyMeanAbsoluteError(name="mae"))
# if metric_config.get("r2_score", True):
# metrics_list.append(self.R2Score(name="r2_score"))
# elif "classification" in self.model_type_suffix.lower():
# # Metrics for classification tasks
# if metric_config.get("accuracy", True):
# metrics_list.append(Accuracy(name="accuracy"))
# if metric_config.get("precision", True):
# metrics_list.append(Precision(name="precision"))
# if metric_config.get("recall", True):
# metrics_list.append(Recall(name="recall"))
# if metric_config.get("auc", True):
# metrics_list.append(AUC(name="auc"))
# else:
# raise ValueError(f"Unsupported model type suffix: {self.model_type_suffix}")
return metrics_list
[docs]
def evaluate(self, results, fold=None):
"""
Evaluate metrics based on true labels and predictions.
Parameters:
- results: Dictionary containing predictions (y_pred) and true values (y_test or y_true) for the fold.
# - y_true: Ground truth labels, usually y_test
# - y_pred: Predicted labels or values
- fold: Fold number to print in the logs. If None, the fold number is not printed.
"""
# TODO: improve hardcoding of metrics
y_true = results['y_test']
y_pred = results['y_pred']
# Mean of true values
y_mean = np.mean(y_true)
# Residual sum of squares
ss_res = np.sum((y_true - y_pred) ** 2)
# Total sum of squares
ss_tot = np.sum((y_true - y_mean) ** 2)
# Calculate R^2
r2_score = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
# Return all metrics
results.update({
# "y_test": y_true,
# "y_pred": y_pred,
"mse": np.mean((y_true - y_pred) ** 2),
"mae": np.mean(np.abs(y_true - y_pred)),
"r2": r2_score
})
# Log metrics for the fold
if fold:
metrics = self.metrics or ['mse', 'mae', 'r2']
metrics_str = ", ".join(
f"{metric.upper()}: {results[metric]:.4f}"
for metric in metrics
)
print(f"Fold {fold} | {metrics_str}")
return results
# #TODO: this method does not support serialization!
# def evaluate(self, y_true, y_pred):
# """
# Evaluate metrics based on true labels and predictions.
# Parameters:
# - y_true: Ground truth labels
# - y_pred: Predicted labels or values
# """
# results = {}
# for metric in self.metrics:
# metric.update_state(y_true, y_pred)
# result = metric.result().numpy()
# results[metric.name] = result
# metric.reset_state()
# return results
[docs]
class R2Score(Metric):
def __init__(self, name="r2_score", **kwargs):
super().__init__(name=name, **kwargs)
self.ssr = self.add_weight(name="ssr", initializer="zeros")
self.sst = self.add_weight(name="sst", initializer="zeros")
[docs]
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, self.dtype)
y_pred = tf.cast(y_pred, self.dtype)
residuals = y_true - y_pred
mean_true = tf.reduce_mean(y_true)
ssr = tf.reduce_sum(tf.square(residuals))
sst = tf.reduce_sum(tf.square(y_true - mean_true))
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, self.dtype)
ssr = tf.reduce_sum(sample_weight * tf.square(residuals))
sst = tf.reduce_sum(sample_weight * tf.square(y_true - mean_true))
self.ssr.assign_add(ssr)
self.sst.assign_add(sst)
[docs]
def result(self):
return 1.0 - (self.ssr / (self.sst + tf.keras.backend.epsilon()))
[docs]
def reset_state(self):
self.ssr.assign(0.0)
self.sst.assign(0.0)
[docs]
class MyMeanSquaredError(Metric):
def __init__(self, name="my_mse", **kwargs):
super().__init__(name=name, **kwargs)
[docs]
def update_state(self, y_true, y_pred, sample_weight=None):
# Compute the MSE for the entire data
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
self.mse = tf.reduce_mean(tf.square(y_true - y_pred))
[docs]
def result(self):
return self.mse
[docs]
def reset_state(self):
# Reset the metric state
self.mse = 0.0
[docs]
class MyMeanAbsoluteError(Metric):
def __init__(self, name="my_mae", **kwargs):
super().__init__(name=name, **kwargs)
[docs]
def update_state(self, y_true, y_pred, sample_weight=None):
# Compute the MAE for the entire data
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.cast(y_pred, tf.float32)
self.mae = tf.reduce_mean(tf.abs(y_true - y_pred))
[docs]
def result(self):
return self.mae
[docs]
def reset_state(self):
# Reset the metric state
self.mae = 0.0