Source code for gaitmod.models.feature_extraction
import numpy as np
import scipy.stats
import antropy as ant
[docs]
class FeatureExtractor2:
def __init__(self, sfreq, features_config):
"""
Initializes the FeatureExtractor with the specified feature extraction options.
Parameters:
- sfreq (float): The sampling frequency.
- features_config (dict): Configuration dictionary specifying which features to extract.
"""
self.sfreq = sfreq
self.features_config = features_config
self.feature_idx_map = {}
self.feature_names = None
[docs]
def extract_features(self, epochs, feature_handling="flatten_chs"):
data = epochs.get_data(copy=True)
n_epochs, n_channels, n_samples = data.shape
feature_list = []
current_index = 0
### RESET FEATURE INDEX MAP
self.feature_idx_map = {}
### TIME-DOMAIN FEATURES
time_features = []
time_cfg = self.features_config.get('time_features', {})
for feature_name, func in zip(
['mean', 'std', 'median', 'skew', 'kurtosis', 'rms'],
[
lambda x: np.mean(x, axis=2, keepdims=True),
lambda x: np.std(x, axis=2, keepdims=True),
lambda x: np.median(x, axis=2, keepdims=True),
lambda x: np.expand_dims(scipy.stats.skew(x, axis=2), axis=2),
lambda x: np.expand_dims(scipy.stats.kurtosis(x, axis=2), axis=2),
lambda x: np.sqrt(np.mean(x ** 2, axis=2, keepdims=True))
]
):
if time_cfg.get(feature_name, False):
feat = func(data)
time_features.append(feat)
n_features = feat.shape[2]
### UPDATE FEATURE INDEX MAP
if feature_handling == "flatten_chs":
self.feature_idx_map[f'time_features_{feature_name}'] = (
current_index, current_index + n_features * n_channels
)
current_index += n_features * n_channels
elif feature_handling == "average_chs":
self.feature_idx_map[f'time_features_{feature_name}'] = (
current_index, current_index + n_features
)
current_index += n_features
elif feature_handling == "separate_chs":
for ch in range(n_channels):
self.feature_idx_map[f'time_features_{feature_name}_ch{ch}'] = (
current_index, current_index + n_features
)
current_index += n_features
if time_features:
feature_list.append(np.concatenate(time_features, axis=2))
### FREQUENCY-DOMAIN FEATURES
if self.features_config.get('freq_features', False):
freq_features = []
freq_bands = {
# "delta": (0.5, 4),
# "theta": (4, 8),
# "alpha": (8, 12),
# "beta": (20, 30),
# "gamma": (30, 100),
"all": (0.5, 100)
}
psd, freqs = epochs.compute_psd(
method='multitaper',
fmin=min([f[0] for f in freq_bands.values()]),
fmax=max([f[1] for f in freq_bands.values()]),
verbose='WARNING',
).get_data(return_freqs=True)
for feature_name in ['psd_raw', 'psd_band_mean', 'psd_band_std', 'spectral_entropy']:
if not self.features_config['freq_features'].get(feature_name, False):
continue
for band_name, (fmin, fmax) in freq_bands.items():
band_mask = (freqs >= fmin) & (freqs < fmax)
band_psd = psd[:, :, band_mask]
if feature_name == 'psd_raw':
feat = band_psd
elif feature_name == 'psd_band_mean':
feat = np.mean(band_psd, axis=2, keepdims=True)
elif feature_name == 'psd_band_std':
feat = np.std(band_psd, axis=2, keepdims=True)
elif feature_name == 'spectral_entropy':
feat = np.apply_along_axis(
ant.spectral_entropy, 2, band_psd, self.sfreq, method='welch'
)
feat = np.expand_dims(feat, axis=2)
freq_features.append(feat)
n_features = feat.shape[2] if feat.ndim == 3 else 1
feature_key = f"{band_name}_{feature_name}"
### UPDATE FEATURE INDEX MAP
if feature_handling == "flatten_chs":
self.feature_idx_map[f'freq_features_{feature_key}'] = (
current_index, current_index + n_features * n_channels
)
current_index += n_features * n_channels
elif feature_handling == "average_chs":
self.feature_idx_map[f'freq_features_{feature_key}'] = (
current_index, current_index + n_features
)
current_index += n_features
elif feature_handling == "separate_chs":
for ch in range(n_channels):
self.feature_idx_map[f'freq_features_{feature_key}_ch{ch}'] = (
current_index, current_index + n_features
)
current_index += n_features
if freq_features:
feature_list.append(np.concatenate(freq_features, axis=2))
### CONCATENATE ALL FEATURES
all_features = np.concatenate(feature_list, axis=2) if feature_list else np.empty((n_epochs, n_channels, 0))
### HANDLE MULTIPLE CHANNELS BASED ON SELECTED STRATEGY
if feature_handling == "flatten_chs":
feature_matrix = all_features.reshape(n_epochs, -1)
elif feature_handling == "average_chs":
feature_matrix = np.mean(all_features, axis=1)
elif feature_handling == "separate_chs":
feature_matrix = all_features
else:
raise ValueError(f"Invalid feature_handling mode: {feature_handling}")
return feature_matrix, self.feature_idx_map
[docs]
def extract_features_with_labels(self, epochs, feature_handling="flatten_chs"):
"""
Extracts features and labels from an MNE Epochs object.
Parameters:
- epochs (mne.Epochs): The epochs object containing LFP data and labels.
- feature_handling (str): The strategy to handle multi-channel features.
Returns:
- X (np.ndarray): Extracted features of shape (n_epochs, n_features).
- y (np.ndarray): Corresponding labels of shape (n_epochs,).
"""
# Extract features using the existing method
X, feature_idx_map = self.extract_features(epochs, feature_handling)
# Extract labels from the MNE Epochs object
y = epochs.events[:, -1]
return X, y, feature_idx_map
[docs]
def select_feature(self, feature_matrix, feature_name, feature_handling="flatten_chs"):
"""
Extracts a specific feature slice from the feature matrix using the feature name.
Parameters:
- feature_matrix (np.ndarray): The full feature matrix returned by extract_features.
- feature_name (str): The name of the feature to extract.
- feature_handling (str): How the features are handled across channels.
Options: 'flatten_chs', 'average_chs', 'separate_chs'.
Returns:
- np.ndarray: The sliced feature matrix for the specified feature name.
"""
if feature_name not in self.feature_idx_map:
raise ValueError(f"Feature '{feature_name}' not found in the feature index map.")
start_idx, end_idx = self.feature_idx_map[feature_name]
if feature_handling == "flatten_chs":
# Feature matrix shape: (n_epochs, n_channels * n_features)
# Feat1_chs1,
# Feat1_chs2,
# Feat1_chs3,
# Feat2_chs1,
# Feat2_chs2,
# Feat2_chs3,
# ...
return feature_matrix[:, start_idx:end_idx]
elif feature_handling == "average_chs":
# Feature matrix shape: (n_epochs, n_features)
return feature_matrix[:, start_idx:end_idx]
elif feature_handling == "separate_chs":
# Feature matrix shape: (n_epochs, n_channels, n_features)
return feature_matrix[:, :, start_idx:end_idx]
else:
raise ValueError(f"Unknown feature_handling strategy: '{feature_handling}'")