Source code for gaitmod.utils.feature_extractor

import numpy as np
from typing import Dict, Any

[docs] class FeatureExtractor:
[docs] @staticmethod def extract_band_psd(epochs, freq_bands): """ Extracts PSD features for specified frequency bands from the given MNE epochs without averaging across channels or frequencies. Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq). Returns: - psd_dict: A dictionary where each key corresponds to a frequency band, and values are arrays of shape (n_epochs, n_channels, n_frequencies) representing the raw PSD values for each band. """ psd_dict = {band: [] for band in freq_bands} # Initialize the dictionary for each band # Compute PSD using the `compute_psd` function from MNE's Epochs object psds, freqs = epochs.compute_psd(fmin=min([f[0] for f in freq_bands.values()]), fmax=max([f[1] for f in freq_bands.values()])).get_data(return_freqs=True) # Extract PSD for each band and channel for band, (low, high) in freq_bands.items(): idx_band = np.logical_and(freqs >= low, freqs <= high) # Find frequency indices within this band # Extract raw PSD values for each frequency in the band for each epoch and channel psd_dict[band] = psds[:, :, idx_band] # No mean over frequencies, retain raw PSD values return psd_dict # (n_epochs, n_channels, n_frequencies)
[docs] @staticmethod def extract_band_power(epochs, freq_bands): """ Extracts band power features from each channel of the given MNE epochs without averaging across channels. Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq). Returns: - band_power_dict: A dictionary where each key corresponds to a frequency band, and values are lists of features (1 for each epoch and channel) for that specific band. """ band_power_dict = {band: [] for band in freq_bands} # Initialize the dictionary for each band # Compute PSD using the `compute_psd` function from MNE's Epochs object psds, freqs = epochs.compute_psd(fmin=min([f[0] for f in freq_bands.values()]), fmax=max([f[1] for f in freq_bands.values()])).get_data(return_freqs=True) # Convert power spectral density (psd) to decibels psds_db = 10 * np.log10(psds) # Extract band power for each band and channel for band, (low, high) in freq_bands.items(): idx_band = np.logical_and(freqs >= low, freqs <= high) # Find frequency indices within this band # For each epoch, and for each channel, calculate the mean power for this band band_power_dict[band] = psds_db[:, :, idx_band].mean(axis=-1) # Mean over frequency range (axis=-1) return band_power_dict
[docs] @staticmethod def extract_psd_and_band_power(epochs, freq_bands, fmin, fmax): """ Extracts power spectral density (PSD) and band power features from each channel of the given MNE epochs (without averaging across channels or epochs). Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq). Returns: - psd_array: A NumPy array of shape (n_epochs, n_channels, n_frequencies) containing the raw PSD values. - band_power_array: A NumPy array of shape (n_epochs, n_channels, n_bands) containing the mean band power features for each frequency band across epochs and channels. """ n_epochs = epochs.get_data(copy=True).shape[0] n_channels = epochs.get_data(copy=True).shape[1] n_bands = len(freq_bands) # Initialize a NumPy array to store the band power values band_power_array = np.zeros((n_epochs, n_channels, n_bands)) # Compute PSD using the `compute_psd` function from MNE's Epochs object psds, freqs = epochs.compute_psd(method='welch', fmin=min([f[0] for f in freq_bands.values()]), fmax=max([f[1] for f in freq_bands.values()])).get_data(return_freqs=True) # Convert power spectral density (psd) to decibels psds_db = 10 * np.log10(psds) # Extract band power for each band and channel for i, (band, (low, high)) in enumerate(freq_bands.items()): idx_band = np.logical_and(freqs >= low, freqs <= high) # Find frequency indices within this band # For each epoch and each channel, calculate the mean power for this band band_power_array[:, :, i] = psds_db[:, :, idx_band].mean(axis=-1) # Mean over frequency range (axis=-1) return psds, freqs, band_power_array
[docs] @staticmethod def flatten_features(psds, band_power): """ Flatten the features and combine PSD and band power. Parameters: - psds: 3D array of PSD values - band_power: 3D array of band power values Returns: - Combined flattened features as a 2D array """ psds_flat = psds.reshape(psds.shape[0], -1) band_power_flat = band_power.reshape(band_power.shape[0], -1) return np.concatenate((psds_flat, band_power_flat), axis=1)
[docs] @staticmethod def extract_windowed_stat_features(trials, methods=['mean'], window_size=100, step_size=50, verbose=False): features_dict = {method: [] for method in methods} for method in methods: trial_features = [] for trial_idx, trial in enumerate(trials): trial_windows = [] for start_idx in range(0, trial.shape[1] - window_size + 1, step_size): window = trial[:, start_idx:start_idx + window_size] # Compute the required statistic based on method if method == 'mean': window_stat = np.mean(window, axis=1) elif method == 'std': window_stat = np.std(window, axis=1) elif method == 'median': window_stat = np.median(window, axis=1) trial_windows.append(window_stat) if verbose: print(f"Method: {method}, Trial {trial_idx}, Number of windows: {len(trial_windows)}") trial_features.append(np.array(trial_windows).T) # Add computed features for this method to the dictionary features_dict[method] = np.array(trial_features) return features_dict
[docs] @staticmethod def extract_epoched_stat_features(epochs, methods=['mean', 'std', 'median']): statistics_dict = { 'epochs': {}, 'channels': {}, 'times': {} } data = epochs.get_data(copy=False) # shape (n_epochs, n_channels, n_times) # Compute statistics along each axis for axis, axis_name in zip([0, 1, 2], ['epochs', 'channels', 'times']): for method in methods: if method == 'mean': stat = np.mean(data, axis=axis) elif method == 'std': stat = np.std(data, axis=axis) elif method == 'median': stat = np.median(data, axis=axis) statistics_dict[axis_name][method] = stat return statistics_dict
[docs] @staticmethod def reshape_lfp_data(lfp_data, mode="flat_time"): """ Reform the LFP data based on the specified mode. Parameters: ---------- lfp_data : np.ndarray The input LFP data with shape (trials, channels, times). mode : str The reshaping mode. Options: - "flat_time": Reshape to (trials * times, channels). - "flat_channel": Reshape to (trials, times * channels). Returns: ------- reshaped_data : np.ndarray The reshaped data based on the selected mode. """ n_trials, n_channels, n_times = lfp_data.shape if mode == "flat_time": # Reshape to (trials * times, channels) reshaped_data = lfp_data.transpose(0, 2, 1).reshape(-1, n_channels) # Flatten time dimension elif mode == "flat_channel": # Reshape to (trials, times * channels) reshaped_data = lfp_data.transpose(0, 2, 1).reshape(n_trials, -1) # Flatten channel dimension else: raise ValueError("Invalid mode. Use 'flat_time' or 'flat_channel'.") return reshaped_data
# @staticmethod # def compute_overall_psd(epochs, fmin=1, fmax=50): # """ # Computes the overall PSD across all channels and epochs from the MNE epochs object. # Parameters: # - epochs: mne.Epochs object containing the LFP data. # - fmin: Minimum frequency for PSD computation (default is 1 Hz). # - fmax: Maximum frequency for PSD computation (default is 50 Hz). # Returns: # - psds_db: PSD values in decibels. # - freqs: Corresponding frequency values. # """ # # Compute PSD across all frequencies from fmin to fmax # psds, freqs = epochs.compute_psd(fmin=fmin, fmax=fmax).get_data(return_freqs=True) # # Convert PSD to decibels (optional) # psds_db = 10 * np.log10(psds) # return psds_db, freqs # Return both the psds and the corresponding frequencies