import numpy as np
from typing import Dict, Any, Tuple, List, Optional, Union
import mne
import pandas as pd
from scipy import signal as sp_signal
from scipy import signal as sp_signal
import numpy as np
[docs]
class DataProcessor:
[docs]
@staticmethod
def process_lfp_data(data, n_sessions, lfp_sfreq, event_of_interest, mod_start_event_id, normal_walking_event_id, gap_sample_length, epoch_sample_length, epoch_tmin, epoch_tmax, epoch_duration, event_dict, info, reject_criteria, config, verbose=False):
lfp_raw_list = []
epochs_list = []
events_list = []
all_lfp_data = []
all_lfp_data_dict = {}
# for s in range(n_sessions):
for subject_idx, subject in enumerate(data.keys()):
print(f'subject {subject_idx}: {subject}')
all_lfp_data_dict[subject] = {}
for session_idx, session in enumerate(data[subject].keys()):
all_lfp_data_dict[subject][session] = {}
session_data = data[subject][session]
# Extract events and lfp data of the subject/session
lfp_data = session_data['data_LFP'] # * 1e-6 # Convert microvolts to volts
# Handle events
events_KIN = DataProcessor.np_to_dict(session_data['events_KIN'])
events_before_trim, event_dict_before_trim = DataProcessor.create_events_array(events_KIN, lfp_sfreq)
# Trim the data and adjust the event onsets accordingly
lfp_data, events_after_trim = DataProcessor.trim_data(lfp_data, events_before_trim, lfp_sfreq)
lfp_duration = lfp_data.shape[1] / lfp_sfreq
n_samples = int(lfp_duration * lfp_sfreq)
all_lfp_data.append(lfp_data)
all_lfp_data_dict[subject][session] = lfp_data
# Update raw data after trimming
lfp_raw = mne.io.RawArray(lfp_data, info, verbose=40)
events_mod_start = events_after_trim[events_after_trim[:, 2] == event_dict_before_trim[event_of_interest]]
events_mod_start[:, 1] = subject_idx # mark the subject index
# Rename Gait Modulation Events
events_mod_start[:, 2] = mod_start_event_id
# Define normal walking events
normal_walking_events = DataProcessor.define_normal_walking_events(
normal_walking_event_id, events_mod_start,
gap_sample_length, epoch_sample_length, n_samples
)
events_mod_start[:, 1] = subject_idx # mark the session nr
normal_walking_events[:, 1] = subject_idx # mark the session nr
# Combine events and create epochs
events, epochs = DataProcessor.create_epochs_with_events(
lfp_raw,
events_mod_start,
normal_walking_events,
mod_start_event_id,
normal_walking_event_id,
epoch_tmin,
epoch_tmax,
event_dict
)
if verbose:
print(f"Total epochs: {len(epochs)}")
for cls in event_dict.keys():
print(f"{cls}: {len(epochs[cls])} epochs", end='; ')
epochs.events[:, 1] = subject_idx # mark the subject index
# Remove bad epochs
epochs.drop_bad(reject=reject_criteria)
# my_annot = mne.annotations_from_events(epochs.events, lfp_sfreq)
my_annot = mne.Annotations(
onset=(events[:, 0] - epoch_sample_length) / lfp_sfreq, # in seconds
duration=len(events) * [epoch_duration], # in seconds, too
description=events[:, 2]
)
lfp_raw.set_annotations(my_annot)
# lfp_raw.add_events(epochs.events)
lfp_raw_list.append(lfp_raw)
epochs_list.append(epochs)
events_list.append(events)
print("\n==========================================================")
epochs = mne.concatenate_epochs(epochs_list, verbose=40)
events = np.vstack(events_list)
events = events[np.argsort(events[:, 0])] # Sort by onset time
# Generate the channel locations
ch_locs = DataProcessor.generate_ch_locs(ch_names=lfp_raw.ch_names)
montage = mne.channels.make_dig_montage(ch_pos=ch_locs)
epochs.set_montage(montage)
return epochs, events, lfp_raw_list, all_lfp_data
[docs]
@staticmethod
def check_and_fix_lfp_chs_name(data: Dict[str, Any], fix_chs_name: Optional[Union[bool, List[str]]] = False) -> None:
"""
Check the LFP channel names in the data and optionally fix them.
Args:
data (Dict[str, Any]): The dataset containing LFP channel information.
fix_chs_name (Optional[Union[bool, List[str]]]): If True, fix the channel names to a default list. If a list is provided, use it to fix the channel names. Defaults to False.
"""
for patient_id, sessions in data.items():
print(f"Patient ID: {patient_id}")
for session_id, session_data in sessions.items():
labels = session_data['hdr_LFP']['labels']
if isinstance(labels, np.ndarray):
length = len(labels)
if length != 6:
print(f"\033[93m{session_id}: {length} channels found\033[0m")
else:
print(f"{session_id}: {length} channels found")
elif isinstance(labels, str):
print(f"\033[93m{session_id}: {labels}\033[0m")
if fix_chs_name:
new_labels = np.array(fix_chs_name)
data[patient_id][session_id]['hdr_LFP']['labels'] = new_labels
print(f"Fixed: {new_labels}")
print('-' * 50)
print('=' * 70)
[docs]
@staticmethod
def rename_keys(data, key_map):
"""Recursively renames keys in a dictionary or list in place.
Args:
data (dict, list, or any data structure): The data to process.
key_map (dict): A dictionary mapping old key names to new key names.
Returns:
None: Modifies the input data in place.
"""
if isinstance(data, dict):
keys_to_rename = [key for key in data.keys() if key.lower() in key_map] # Find keys to rename
for key in keys_to_rename:
new_key = key_map[key.lower()]
data[new_key] = data.pop(key) # Rename key
DataProcessor.rename_keys(data[new_key], key_map) # Recursive call on the value
# Recursively process other values safely
for key in list(data.keys()): # Iterate over a copy to avoid runtime issues
DataProcessor.rename_keys(data[key], key_map)
elif isinstance(data, list): # Handle lists
for item in data:
DataProcessor.rename_keys(item, key_map)
# return data
[docs]
@staticmethod
def sort_and_filter_events(data: Dict[str, Any], new_order: List[str]) -> Dict[str, Any]:
"""
Reorder events in the 'events_KIN' key of the provided data based on a new order. Can also be used to filter out events by excluding those not in the new order.
Parameters:
data (Dict[str, Dict[str, dict]]): The input data containing events to be reordered and filtered.
new_order (List[str]): The desired order of event labels. Events not in this list will be filtered out.
Raises:
KeyError: If 'events_KIN' is not found in any data_type.
"""
for subject, sessions in data.items():
for session, data_type in sessions.items():
if 'events_KIN' not in data_type:
raise KeyError("events_KIN not found in data_type.")
events_KIN_labels = data_type['events_KIN']['labels']
events_KIN_times = data_type['events_KIN']['times']
# Map labels to their indices for fast lookup
label_to_index = {label: idx for idx, label in enumerate(events_KIN_labels)}
# Get indices for sorting (only for labels that exist)
order_indices = [label_to_index[event] for event in new_order if event in label_to_index]
# Apply sorted order
data[subject][session]['events_KIN']['labels'] = events_KIN_labels[order_indices]
data[subject][session]['events_KIN']['times'] = events_KIN_times[order_indices]
[docs]
@staticmethod
def remove_nan_events(data: Dict[str, Any], sfreq: float):
"""
Remove NaN values from event times and generate corresponding sample indices.
Parameters:
events_kin (Dict[str, Any]): Dictionary containing event times and labels.
Expected keys are 'times' (numpy array of event times)
and 'label' (list of event labels).
sfreq (float): Sampling frequency.
Returns:
Tuple[np.ndarray, np.ndarray]: A tuple containing:
- events_kin_times_valid (np.ndarray): Array of valid event times with NaN values removed.
- events_kin_samples_valid (np.ndarray): Array of valid event times converted to sample indices.
"""
for subject, sessions in data.items():
for session, data_type in sessions.items():
events_KIN_labels = data_type['events_KIN']['labels']
events_KIN_times = data_type['events_KIN']['times']
# Map event labels to unique integer IDs
event_id = {label: idx + 1 for idx, label in enumerate(events_KIN_labels)}
# Remove NaN values in events_kin_times
valid_mask = ~np.isnan(events_KIN_times)
valid_mask = valid_mask.all(axis=0)
# Remove trials that contain at least one NaN value
events_kin_times_valid = events_KIN_times[:, valid_mask]
# Update the data with valid times
data[subject][session]['events_KIN']['times'] = events_kin_times_valid
# TODO: this method is not used in the current implementation as it operates on a single patietn/session and not on the entire data.
# def sort_and_filter_events(events_KIN, new_order: List[str]):
# """
# Reorder and filter events in the 'events_KIN' dictionary based on a new order.
# events_KIN (dict): The input dictionary containing 'label' and 'times' keys, where 'label' is a list of event labels and 'times' is a list of corresponding event times.
# Returns:
# dict: The updated 'events_KIN' dictionary with events reordered and filtered according to 'new_order'.
# """
# # Map labels to their indices for fast lookup
# label_to_index = {label: idx for idx, label in enumerate(events_KIN['label'])}
# # Get indices for sorting (only for labels that exist)
# order_indices = [label_to_index[event] for event in new_order if event in label_to_index]
# # Apply sorted order
# events_KIN['label'] = events_KIN['label'][order_indices]
# events_KIN['times'] = events_KIN['times'][order_indices]
# return events_KIN
[docs]
@staticmethod
def clean_data(
data: Dict[str, Any],
data_type: str,
key_map: Dict[str, str],
new_events_order: list,
sfreq: float,
fix_chs_names: Optional[Union[bool, List[str]]] = False,
verbose: bool = False
):
"""Cleans data by applying renaming, sorting/filtering, and removing NaNs.
Args:
data (dict): The dataset to clean.
data_type (str): The type of data being processed.
key_map (dict): Mapping of old keys to new keys.
new_events_order (list): Ordered list of events to keep.
sfreq (float): Sampling frequency for processing.
verbose (bool, optional): Whether to print debug information. Defaults to False.
"""
if verbose:
print(f"Cleaning data of type: {data_type}...")
DataProcessor.check_and_fix_lfp_chs_name(data, fix_chs_names)
DataProcessor.rename_keys(data, key_map)
DataProcessor.sort_and_filter_events(data, new_events_order)
DataProcessor.remove_nan_events(data, sfreq)
if verbose:
print("Data cleaning completed.")
# return data
[docs]
@staticmethod
def process_trials_and_events(data: Dict[str, Dict[str, dict]],
data_type: str,
sfreq: float,
config: dict,
verbose: bool = False) -> Tuple[Dict[str, List[np.ndarray]], Dict[str, np.ndarray], Dict[str, Dict[str, List[int]]]]:
"""
Prepares data and event indices for each subject across multiple sessions.
This function processes data such as LFP, IMU, EEG, etc by extracting trials from kinematic events and storing them in three dictionaries:
- subjects_data_dict: A dictionary where each subject's trials (as 2D NumPy arrays) are stored. Each trial array has the shape (n_channels x n_times), where n_channels is the number of channels or sensors.
- subjects_event_sample_idx_dict: A dictionary where each subject's event indices across trials are stored. The array shape is (n_trials x n_events). Each row corresponds to a specifc trial index and each column to a specific event label's index.
- subjects_session_trial_mapping: A dictionary where each subject's trials are mapped to their originating sessions. Format: {patient_id: {session_name: [trial_indices]}}
Args:
data (Dict[str, Dict[str, dict]]): Nested dictionary where the outer key is the subject name, and the inner dictionary contains session data for each subject.
data_type (str): Type of data to be processed. Must be one of the following: 'data_acc', 'data_EEG', 'data_EMG', 'data_giro', 'data_LFP'.
sfreq (float): Sampling frequency of the specified data (in Hz).
config (dict): Configuration dictionary used in the data processing.
verbose (bool, optional): Whether to print processing details for each subject and session. Defaults to True.
Returns:
Tuple[Dict[str, List[np.ndarray]], Dict[str, np.ndarray], Dict[str, Dict[str, List[int]]]]:
- subjects_data_dict: Dictionary of subjects with trials as 2D NumPy arrays.
- subjects_event_sample_idx_dict: Dictionary of subjects with event indices across trials.
- subjects_session_trial_mapping: Dictionary mapping each subject's trials to their originating sessions.
"""
subjects_data_dict = {} # Stores data for each trial
subjects_event_sample_idx_dict = {} # Stores event indices per trial
subjects_session_trial_mapping = {} # Stores session-to-trial mapping for each subject
for subject_idx, subject in enumerate(data.keys()):
if verbose:
print(f'subject {subject_idx}: {subject}')
subjects_data_dict[subject] = [] # List of data for trials
subjects_event_sample_idx_dict[subject] = [] # Initialize with an empty list
subjects_session_trial_mapping[subject] = {} # Initialize session mapping for this subject
current_trial_index = 0 # Track the current trial index across all sessions
for session_idx, session in enumerate(data[subject].keys()):
if verbose:
print(f' - session {session_idx}: {session}')
session_data = data[subject][session]
# Generate time indices in samples for valid times
events_kin_samples_valid = (session_data['events_KIN']['times'] * sfreq).astype(int)
# Extract trials based on event kinematic samples
data_specific_trials, valid_trials_mask = DataProcessor.create_trials(
events_kin_samples_valid,
session_data[data_type],
sfreq,
config,
session_data['pt'],
session_data['session']
)
# Track which trials belong to this session
session_trial_indices = []
n_trials_in_session = len(data_specific_trials)
for i in range(n_trials_in_session):
session_trial_indices.append(current_trial_index)
current_trial_index += 1
# Store the session-to-trial mapping
subjects_session_trial_mapping[subject][session] = session_trial_indices
subjects_data_dict[subject].extend(data_specific_trials) # Add trials to the subject's data
n_trials_per_session = events_kin_samples_valid.shape[1]
for session_trial_idx in range(n_trials_per_session):
# CRITICAL BUG FIX: Only process event indices for trials that were not skipped
if not valid_trials_mask[session_trial_idx]:
if verbose:
print(f" Skipping event processing for trial {session_trial_idx} (was skipped in create_trials)")
continue
DataProcessor.check_event_order(
events_kin_samples_valid,
session_data['events_KIN']['labels'],
trial_idx=session_trial_idx,
exclude_label_names=['VA_cross', 'min_vel', 'min_dist'])
session_event_indices = DataProcessor.get_trial_event_indices(
events_kin_samples_valid,
session_data['events_KIN']['labels'],
trial_idx=session_trial_idx
)
# Append the event indices for this trial to the event dictionary
subjects_event_sample_idx_dict[subject].append(session_event_indices)
# Convert the list of event indices to a numpy array
subjects_event_sample_idx_dict[subject] = np.array(subjects_event_sample_idx_dict[subject], dtype=np.int64)
return subjects_data_dict, subjects_event_sample_idx_dict, subjects_session_trial_mapping
# TODO: Transpose the shape of events_kin_samples to (n_trials, n_events) before passing it to this method for a more intuitive API.
[docs]
@staticmethod
def create_trials(events_kin_samples: np.ndarray,
data: np.ndarray,
sfreq: float,
config: Dict[str, Any],
subject_id: str,
session_id: str) -> Tuple[List[np.ndarray], List[bool]]:
"""
Extracts trial data based on event kinematic samples.
Args:
events_kin_samples (np.ndarray): A 2D numpy array containing kinematic event samples with shape (n_events, n_trials).
data (np.ndarray): A 2D numpy array containing measurement data (LFP, EEG, IMU, etc.) with shape (n_channels, n_times).
sfreq (float): Sampling frequency of the LFP data (Hz).
config (Dict[str, Any]): Configuration settings for padding and truncating the data.
subject_id (str): Subject identifier for debugging messages.
session_id (str): Session identifier for debugging messages.
Returns:
Tuple[List[np.ndarray], List[bool]]: A tuple containing:
- trials_data: List of 2D numpy arrays containing trial data with shape (n_channels, n_times).
- valid_trials_mask: List of booleans indicating which trials were successfully processed (True) vs skipped (False).
"""
n_events = events_kin_samples.shape[0]
n_trials = events_kin_samples.shape[1]
trials_data = []
valid_trials_mask = []
for trial in range(n_trials):
# Get trial start and stop times for this trial
trial_start_idx = events_kin_samples[0, trial]
trial_stop_idx = events_kin_samples[n_events-1, trial]
# # or trial_start_idx >= trial_stop_idx
if trial_start_idx > data.shape[1] or trial_stop_idx > data.shape[1]:
print(f"\033[93mSkipping trial {trial}: [{trial_start_idx}, {trial_stop_idx}] is outside LFP data range (0, {data.shape[1]}) -- Subject {subject_id}, session {session_id}.\033[0m")
valid_trials_mask.append(False)
continue
# Extract data for this trial
trial_data = data[:, trial_start_idx:trial_stop_idx]
# Append trial data to the list
trials_data.append(trial_data)
valid_trials_mask.append(True)
return trials_data, valid_trials_mask
[docs]
@staticmethod
def check_event_order(
events_kin_samples: np.ndarray,
events_kin_labels: np.ndarray | List,
trial_idx: int,
exclude_label_names: List[str] = []):
"""
Check if the values in each column of events_kin_samples are in ascending order for the given trial index.
Args:
events_kin_samples (np.ndarray): A 2D array of shape (n_events, n_trials) containing the sample indices of each event.
events_kin_labels (np.ndarray | List): A 1D array of shape (n_events,) or a list containing the event names.
trial_idx (int): The index of the trial to check.
Raises:
ValueError: If the values in the column are not in ascending order.
"""
n_events = events_kin_samples.shape[0]
exclude_labels_idx = np.where(np.isin(events_kin_labels, exclude_label_names))[0]
# Exclude values from event_idx_subsequent_diff with exclude_labels_idx
events_kin_samples = np.delete(events_kin_samples, exclude_labels_idx, axis=0)
# Check if the values in the column are in ascending order
event_idx_subsequent_diff = np.diff(events_kin_samples[:, trial_idx]) >= 0
if not np.all(event_idx_subsequent_diff):
print(f"\033[93mWarning: Values in trial {trial_idx} are not in ascending order.\033[0m")
[docs]
@staticmethod
def get_trial_event_indices(
events_kin_samples: np.ndarray,
events_kin_labels: np.ndarray | List,
trial_idx: int
) -> np.ndarray:
"""
Computes the sample indices of event occurrences relative to the start of a given trial.
Args:
events_kin_samples (np.ndarray): A 2D array of shape (n_events, n_trials) containing the sample indices of each event.
events_kin_labels (np.ndarray | List): A 1D array of shape (n_events,) or a list containing the event names.
trial_idx (int): The index of the trial for which to compute event indices.
Returns:
np.ndarray (int): Event indices relative to the start of the trial.
Notes:
- If an event occurs outside the trial boundaries, a random valid index
within the trial range is assigned.
- The returned indices are relative to the trial's start index.
"""
n_events = events_kin_samples.shape[0]
trial_start_idx = events_kin_samples[0, trial_idx]
trial_end_idx = events_kin_samples[n_events - 1, trial_idx]
trial_event_indices = []
for event_name_idx, event_name in enumerate(events_kin_labels):
event_idx = events_kin_samples[event_name_idx, trial_idx]
# Handle cases where the event index is outside trial boundaries
if event_idx < trial_start_idx or event_idx > trial_end_idx:
print(f"Event '{event_name}' at index {event_idx} is outside trial boundaries " f"[{trial_start_idx}-{trial_end_idx}]. Assigning a random valid index.")
# Compute event index relative to trial start
event_idx_relative_to_trial_start = event_idx - trial_start_idx
trial_event_indices.append(event_idx_relative_to_trial_start)
return np.array(trial_event_indices)
# lfp_raw_list = []
# epochs_list = []
# events_list = []
# all_lfp_data = []
# events_before_trim, event_dict_before_trim = DataProcessor.create_events_array(events_KIN, lfp_sfreq)
# # Trim the data and adjust the event onsets accordingly
# sessoin_lfp_data, events_after_trim = DataProcessor.trim_data(sessoin_lfp_data, events_before_trim, lfp_sfreq)
# lfp_duration = lfp_sessoin_lfp_datadata.shape[1] / lfp_sfreq
# n_samples = int(lfp_duration * lfp_sfreq)
# all_lfp_data.append(sessoin_lfp_data)
# all_lfp_data_dict[subject][session] = sessoin_lfp_data
# # Update raw data after trimming
# lfp_raw = mne.io.RawArray(sessoin_lfp_data, info, verbose=40)
# events_mod_start = events_after_trim[events_after_trim[:, 2] == event_dict_before_trim[event_of_interest]]
# events_mod_start[:, 1] = subject_idx # mark the subject index
# # Rename Gait Modulation Events
# events_mod_start[:, 2] = mod_start_event_id
# # Define normal walking events
# normal_walking_events = DataProcessor.define_normal_walking_events(
# normal_walking_event_id, events_mod_start,
# gap_sample_length, epoch_sample_length, n_samples
# )
# events_mod_start[:, 1] = subject_idx # mark the session nr
# normal_walking_events[:, 1] = subject_idx # mark the session nr
# # Combine events and create epochs
# events, epochs = DataProcessor.create_epochs_with_events(
# lfp_raw,
# events_mod_start,
# normal_walking_events,
# mod_start_event_id,
# normal_walking_event_id,
# epoch_tmin,
# epoch_tmax,
# event_dict
# )
# if verbose:
# print(f"Total epochs: {len(epochs)}")
# for cls in event_dict.keys():
# print(f"{cls}: {len(epochs[cls])} epochs", end='; ')
# epochs.events[:, 1] = subject_idx # mark the subject index
# # Remove bad epochs
# epochs.drop_bad(reject=reject_criteria)
# # my_annot = mne.annotations_from_events(epochs.events, lfp_sfreq)
# my_annot = mne.Annotations(
# onset=(events[:, 0] - epoch_sample_length) / lfp_sfreq, # in seconds
# duration=len(events) * [epoch_duration], # in seconds, too
# description=events[:, 2]
# )
# lfp_raw.set_annotations(my_annot)
# # lfp_raw.add_events(epochs.events)
# lfp_raw_list.append(lfp_raw)
# epochs_list.append(epochs)
# events_list.append(events)
# print("\n==========================================================")
# epochs = mne.concatenate_epochs(epochs_list, verbose=40)
# events = np.vstack(events_list)
# events = events[np.argsort(events[:, 0])] # Sort by onset time
# # Generate the channel locations
# ch_locs = DataProcessor.generate_ch_locs(ch_names=lfp_raw.ch_names)
# montage = mne.channels.make_dig_montage(ch_pos=ch_locs)
# epochs.set_montage(montage)
# return epochs, events, lfp_raw_list, all_lfp_data
[docs]
@staticmethod
def segment_and_label_trials(
trials: list[np.ndarray],
subjects_event_idx_dict: list[np.ndarray],
ch_names: list[str],
sfreq: int = 250,
window_size: float = 0.5,
overlap: float = 0.5,
expand_transition: float = 0.0,
discard_ambiguous: bool = False,
mod_start_idx: int = 2,
mod_end_idx: int = 6,
event_dict: dict[str, int] = None) -> mne.epochs.EpochsArray:
"""
Segments LFP trials into overlapping windows, assigns labels (0: normal, 1: modulation),
and stores the results in an MNE Epochs object.
Args:
trials (list of np.ndarray): List of LFP trials, each with shape (n_channels, n_samples).
subjects_event_idx_dict (list of np.ndarray): List of event indices for each trial.
ch_names (list of str): List of channel names.
sfreq (int): Sampling frequency of the signals.
window_size (float): Size of each segment in seconds.
overlap (float): Overlap fraction between consecutive windows.
expand_transition (float): Amount of time (seconds) to expand mod_start/mod_end.
discard_ambiguous (bool): Whether to remove windows that overlap both states.
mod_start_idx (int): Index for modulation start event.
mod_end_idx (int): Index for modulation end event.
event_dict (dict[str, int]): Dictionary mapping class names to labels.
Must contain exactly 2 classes for normal walking and modulation phases.
Returns:
mne.epochs.EpochsArray: MNE Epochs object containing the segmented data.
"""
epochs_data = []
epochs_labels = []
events_list = []
window_samples = int(window_size * sfreq) # Convert window size to samples
step_size = int(window_samples * (1 - overlap)) # Step size for overlapping windows
new_ch_names = DataProcessor.rename_lfp_channels(ch_names)
info = mne.create_info(ch_names=new_ch_names, sfreq=sfreq, ch_types="dbs")
# Validate event_dict
if event_dict is None:
raise ValueError("event_dict must be provided. Example: {'normal_walking': 0, 'modulation': 1}")
if len(event_dict) != 2:
raise ValueError(f"event_dict must contain exactly 2 classes, but got {len(event_dict)} classes")
# Extract the two class labels (assuming lower value is normal, higher is modulation)
sorted_items = sorted(event_dict.items(), key=lambda x: x[1])
normal_label = sorted_items[0][1]
modulation_label = sorted_items[1][1]
for trial_idx, trial_data in enumerate(trials):
n_channels, n_samples = trial_data.shape
# Get event indices for this trial
event_indices = subjects_event_idx_dict[trial_idx]
mod_start = event_indices[mod_start_idx] # mod_start
mod_end = event_indices[mod_end_idx] # mod_end
# Apply expansion if enabled
mod_start = max(0, mod_start - int(expand_transition * sfreq))
mod_end = min(n_samples, mod_end + int(expand_transition * sfreq))
# Segment the trial into overlapping windows
window_start = 0
while window_start + window_samples <= n_samples:
window_end = window_start + window_samples
window_data = trial_data[:, window_start:window_end]
# Determine label based on overlap with mod_start-mod_end
mod_overlap = max(0, min(window_end, mod_end) - max(window_start, mod_start))
mod_ratio = mod_overlap / window_samples
if mod_ratio > 0.5:
label = modulation_label # Modulation
else:
label = normal_label # Normal walking
if discard_ambiguous and 0 < mod_ratio < 0.5:
window_start += step_size
continue # Skip this window
# Store window data and event (window_start is now used correctly)
epochs_data.append(window_data)
epochs_labels.append(label)
events_list.append([window_start, trial_idx, label])
window_start += step_size
# Convert data to MNE Epochs format
epochs_data = np.array(epochs_data)
epochs_labels = np.array(epochs_labels)
events_array = np.array(events_list)
# shift the event times by the trial index
events_array[:, 0] = events_array[:, 0] + events_array[:, 1]
epochs = mne.EpochsArray(
epochs_data,
info,
events=events_array,
event_id=event_dict,
on_missing='raise',
)
# Optional: Generate the channel locations
ch_locs = DataProcessor.generate_ch_locs(ch_names=new_ch_names)
montage = mne.channels.make_dig_montage(ch_pos=ch_locs)
epochs.set_montage(montage)
# TODO: add annotations to the epochs object
return epochs
[docs]
@staticmethod
def generate_ch_locs(ch_names):
# ch_locs = {
# 'LFP_L03': [-1, -1, -1],
# 'LFP_L13': [-1, 0, 0],
# 'LFP_L02': [-1, 1, 1],
# 'LFP_R03': [1, -1, -1],
# 'LFP_R13': [1, 0, 0],
# 'LFP_R02': [1, 1, 1]}
ch_locs = {}
for i, ch in enumerate(ch_names):
if 'LFP_L' in ch:
# Example locations for left channels
ch_locs[ch] = [-1, i % 3, i // 3]
elif 'LFP_R' in ch:
# Example locations for right channels
ch_locs[ch] = [1, i % 3, i // 3]
return ch_locs
[docs]
@staticmethod
def process_events_kin(events_kin: Dict[str, Any]) -> None:
"""Process kinesthetic event data.
Args:
events_kin (Dict[str, Any]): Dictionary containing kinesthetic event data.
"""
if events_kin is None:
print("No events_KIN data found.")
return
labels = events_kin.get('labels', [])
times = events_kin.get('times', [])
print(f"Event labels: {labels}")
print(f"Event times: {times}")
[docs]
@staticmethod
def process_events_steps(events_steps: Dict[str, Any]) -> None:
"""Process step event data.
Args:
events_steps (Dict[str, Any]): Dictionary containing data
"""
if events_steps is None:
print("No events_steps data found.")
return
labels = events_steps.get('labels', [])
times = events_steps.get('times', [])
print(f"Step event labels: {labels}")
print(f"Step event times: {times}")
# NOTE: This method is not used in the current implementation
# @staticmethod
[docs]
def np_to_dict(data_structure: np.ndarray) -> Dict[str, Any]:
"""
Converts a numpy array-based data structure to a dictionary-like structure.
Parameters:
-----------
data_structure : np.ndarray
Numpy array-based data structure from which to extract metadata and events.
Returns:
--------
dict
A dictionary-like structure containing the extracted data from the input numpy array-based structure.
"""
extracted_data = {n: data_structure[n][0, 0] for n in data_structure.dtype.names}
return extracted_data
[docs]
@staticmethod
def convert_lfp_label(labels: np.ndarray) -> list:
"""
Convert long-form LFP labels to a more descriptive format.
Args:
labels (np.ndarray): Array of long-form LFP labels.
Returns:
list: List of converted labels in the format 'LFP_SIDE_DIGITDIGIT',
where SIDE is 'L' or 'R', and DIGIT represents the numeric
part extracted from each original label.
"""
text_to_num = {
'ZERO': '0',
'ONE': '1',
'TWO': '2',
'THREE': '3',
}
converted_labels = []
for label_array in labels:
label = label_array[0] # Extract the label string from the array
# Extracting and printing original label name
original_label_name = label[0]
print(f"Original Label Name: {original_label_name}")
parts = original_label_name.split('_')
if parts[2] == 'LEFT':
side = 'L'
else:
side = 'R'
numeric_part = f"{text_to_num[parts[0]]}{text_to_num[parts[1]]}"
converted_label = f"LFP_{side}{numeric_part}"
# Printing comparison of before and after renaming
print(f" Before renaming: {original_label_name}")
print(f" After renaming : {converted_label}")
converted_labels.append(converted_label)
return converted_labels
[docs]
@staticmethod
def rename_lfp_channels(labels: np.ndarray) -> list:
"""
Convert long-form LFP channels to a more descriptive format and ensure uniqueness.
Args:
labels (np.ndarray): Array of long-form LFP labels.
Returns:
list: List of converted labels in the format 'LFP_SIDE_DIGITDIGIT',
where SIDE is 'L' or 'R', and DIGIT represents the numeric
part extracted from each original channel.
"""
text_to_num = {
'ZERO': '0',
'ONE': '1',
'TWO': '2',
'THREE': '3',
}
converted_labels = []
label_count = {}
for label_array in labels:
label = label_array # Extract the label string from the array
parts = label.split('_')
side = 'L' if parts[2] == 'LEFT' else 'R'
numeric_part = f"{text_to_num[parts[0]]}-{text_to_num[parts[1]]}"
base_label = f"LFP_{side}{numeric_part}"
# Ensure uniqueness by adding a suffix if the label already exists
if base_label in label_count:
label_count[base_label] += 1
unique_label = f"{base_label}_{label_count[base_label]}"
else:
label_count[base_label] = 1
unique_label = base_label
converted_labels.append(unique_label)
return converted_labels
[docs]
@staticmethod
def remove_trials_with_short_labels(
patients_epochs: Dict[str, mne.EpochsArray],
subjects_lfp_data_dict: Dict[str, List[np.ndarray]] = None,
subjects_event_idx_dict: Dict[str, List[np.ndarray]] = None,
min_epochs_per_class: int = 1,
verbose: bool = True
) -> Tuple[Dict[str, mne.EpochsArray], Optional[Dict[str, List[np.ndarray]]], Optional[Dict[str, List[np.ndarray]]]]:
"""
Removes trials that contain epochs from only one class (unbalanced trials) from ALL related data structures.
This method ensures consistency across patients_epochs, subjects_lfp_data_dict, and subjects_event_idx_dict
by removing the same trials from all data structures simultaneously.
Parameters:
-----------
patients_epochs : Dict[str, mne.EpochsArray]
Dictionary of patient EpochsArray objects containing segmented trials
subjects_lfp_data_dict : Dict[str, List[np.ndarray]], optional
Dictionary of raw LFP trial data per subject (will be filtered if provided)
subjects_event_idx_dict : Dict[str, List[np.ndarray]], optional
Dictionary of event indices per trial per subject (will be filtered if provided)
min_epochs_per_class : int, optional
Minimum number of epochs required for each class within a trial (default: 1)
verbose : bool, optional
Whether to print filtering details (default: True)
Returns:
--------
Tuple containing:
- Dict[str, mne.EpochsArray]: Filtered patients_epochs
- Dict[str, List[np.ndarray]]: Filtered subjects_lfp_data_dict (or original if None provided)
- Dict[str, List[np.ndarray]]: Filtered subjects_event_idx_dict (or original if None provided)
"""
filtered_patients_epochs = {}
filtered_subjects_lfp_data_dict = subjects_lfp_data_dict.copy() if subjects_lfp_data_dict else None
filtered_subjects_event_idx_dict = subjects_event_idx_dict.copy() if subjects_event_idx_dict else None
overall_stats = {
'total_patients': len(patients_epochs),
'total_trials_removed': 0,
'total_epochs_removed': 0,
'patients_affected': 0
}
for patient_name, epochs in patients_epochs.items():
if verbose:
print(f"\nProcessing patient: {patient_name}")
print("=" * 50)
events = epochs.events.copy() # Shape: (n_epochs, 3) - [onset, trial_id, class_label]
# Get unique trial IDs
trial_ids = np.unique(events[:, 1])
# Find trials to keep and track filtered trials
trials_to_keep = []
filtered_trials = []
total_epochs_removed = 0
for trial_id in trial_ids:
# Get events for this trial
trial_events = events[events[:, 1] == trial_id]
trial_labels = trial_events[:, 2] # Extract class labels
# Count epochs per class in this trial
normal_count = np.sum(trial_labels == 0)
modulation_count = np.sum(trial_labels == 1)
total_epochs_in_trial = len(trial_events)
# Keep trial if it has minimum epochs for both classes
if (normal_count >= min_epochs_per_class and
modulation_count >= min_epochs_per_class):
trials_to_keep.append(trial_id)
else:
filtered_trials.append((trial_id, normal_count, modulation_count))
total_epochs_removed += total_epochs_in_trial
if verbose:
print(f"\033[91mFiltered out trial {trial_id}: normal_walking={normal_count}, modulation={modulation_count} epochs (total {total_epochs_in_trial} epochs removed)\033[0m")
if len(trials_to_keep) == 0:
raise ValueError(f"No balanced trials found for patient {patient_name} with the current criteria")
# Print summary of filtering for this patient
if filtered_trials and verbose:
print(f"\033[91mPatient {patient_name}: {len(filtered_trials)} trials filtered out of {len(trial_ids)}\033[0m")
print(f"\033[91mPatient {patient_name}: {total_epochs_removed} epochs removed\033[0m")
overall_stats['patients_affected'] += 1
elif verbose:
print(f"Patient {patient_name}: All {len(trial_ids)} trials kept (no filtering needed)")
# Update overall statistics
overall_stats['total_trials_removed'] += len(filtered_trials)
overall_stats['total_epochs_removed'] += total_epochs_removed
# Filter epochs to keep only those from balanced trials
epochs_to_keep = np.isin(events[:, 1], trials_to_keep)
# Create new EpochsArray with filtered data
filtered_data = epochs.get_data()[epochs_to_keep]
filtered_events = events[epochs_to_keep]
# Re-map trial IDs to be sequential starting from 0
unique_kept_trials = np.unique(filtered_events[:, 1])
trial_mapping = {old_id: new_id for new_id, old_id in enumerate(unique_kept_trials)}
for i, old_trial_id in enumerate(filtered_events[:, 1]):
filtered_events[i, 1] = trial_mapping[old_trial_id]
# Create new EpochsArray
filtered_epochs = mne.EpochsArray(
filtered_data,
epochs.info.copy(),
events=filtered_events,
event_id=epochs.event_id.copy(),
verbose=False
)
# Copy metadata if it exists
if hasattr(epochs, 'metadata') and epochs.metadata is not None:
filtered_epochs.metadata = epochs.metadata.iloc[epochs_to_keep].reset_index(drop=True)
filtered_patients_epochs[patient_name] = filtered_epochs
# Filter corresponding raw data structures if provided
if filtered_subjects_lfp_data_dict and patient_name in filtered_subjects_lfp_data_dict:
# Filter trials from subjects_lfp_data_dict
original_trials = filtered_subjects_lfp_data_dict[patient_name]
if len(original_trials) >= len(trial_ids): # Safety check
# Keep only trials that correspond to trials_to_keep
# Map back from new trial IDs to original indices
kept_trial_indices = [trial_id for trial_id in trials_to_keep]
filtered_trials_data = [original_trials[i] for i in kept_trial_indices if i < len(original_trials)]
filtered_subjects_lfp_data_dict[patient_name] = filtered_trials_data
if verbose and len(filtered_trials) > 0:
print(f"Updated subjects_lfp_data_dict for {patient_name}: {len(original_trials)} -> {len(filtered_trials_data)} trials")
if filtered_subjects_event_idx_dict and patient_name in filtered_subjects_event_idx_dict:
# Filter trials from subjects_event_idx_dict
original_events = filtered_subjects_event_idx_dict[patient_name]
if len(original_events) >= len(trial_ids): # Safety check
# Keep only events that correspond to trials_to_keep
kept_trial_indices = [trial_id for trial_id in trials_to_keep]
filtered_events_data = [original_events[i] for i in kept_trial_indices if i < len(original_events)]
filtered_subjects_event_idx_dict[patient_name] = filtered_events_data
if verbose and len(filtered_trials) > 0:
print(f"Updated subjects_event_idx_dict for {patient_name}: {len(original_events)} -> {len(filtered_events_data)} trials")
# Print overall summary
if verbose:
print(f"\n{'='*60}")
print(f"OVERALL FILTERING SUMMARY")
print(f"{'='*60}")
print(f"Total patients processed: {overall_stats['total_patients']}")
print(f"Patients affected by filtering: {overall_stats['patients_affected']}")
print(f"Total trials removed across all patients: {overall_stats['total_trials_removed']}")
print(f"Total epochs removed across all patients: {overall_stats['total_epochs_removed']}")
if overall_stats['total_trials_removed'] == 0:
print(f"\033[92mNo trials needed to be removed - all data is already balanced!\033[0m")
else:
print(f"\033[93mData structures have been synchronized - epochs, raw trials, and event indices are consistent\033[0m")
return filtered_patients_epochs, filtered_subjects_lfp_data_dict, filtered_subjects_event_idx_dict
# Vectorized version of create_events_array
[docs]
@staticmethod
def create_events_array(events_kin: Dict[str, Any], sfreq: float) -> Tuple[np.ndarray, Dict[str, int]]:
"""
Create an MNE-compatible events array from events_KIN data, handling NaN values.
Parameters:
-----------
events_kin : Dict[str, Any]
Dictionary containing events_KIN data with 'times' and 'label' keys.
sfreq : float
Sampling frequency of the LFP data.
Returns:
--------
mne_events : np.ndarray
MNE-compatible events array (shape: [n_events, 3]).
Each row contains [sample_index, 0, event_id].
event_id : Dict[str, int]
Dictionary mapping event labels to unique event IDs.
"""
events_kin_times = events_kin['times']
event_labels = [label[0] for label in events_kin['label'][0]]
# Map event labels to unique integer IDs
event_id = {label: idx + 1 for idx, label in enumerate(event_labels)}
# Handle NaN values in events_kin_times
valid_mask = ~np.isnan(events_kin_times)
events_kin_times_valid = events_kin_times[valid_mask]
# Generate time indices in samples for valid times
# try:
time_samples = (events_kin_times_valid * sfreq).astype(int)
# except TypeError as e:
# raise TypeError("Invalid type encountered in events_kin_times or sfreq. Ensure events_kin['times'] is a numeric array and sfreq is a numeric value.") from e
# except Exception as e:
# raise RuntimeError("Error occurred while converting times to samples.") from e
# Create indices for events and trials based on valid times
event_indices = np.nonzero(valid_mask)[0]
# Construct the MNE events array with sample indices
mne_events = np.column_stack([
time_samples,
np.zeros_like(event_indices),
np.array([event_id[event_labels[event_idx]] for event_idx in event_indices])
])
return mne_events, event_id
[docs]
@staticmethod
def crop_lfp_to_event_times(lfp_data, events_KIN, lfp_sfreq):
"""
Crop the LFP data to match the time span of the provided event times.
Parameters:
lfp_data (np.ndarray): 2D array of LFP data with shape (channels, samples).
events_KIN (dict): Dictionary containing event times.
lfp_sfreq (float): Sampling frequency of the LFP data.
Returns:
tuple: Tuple containing:
- first_non_zero_indices (np.ndarray): Indices of the first non-zero values in each row.
- cropped_lfp_data (np.ndarray): Cropped LFP data array.
"""
# Create a mask for non-zero values
mask = lfp_data != 0
# Find the indices of the first True in each row
first_non_zero_indices = np.argmax(mask, axis=1)
# Handle rows that have no valid non-zero elements
valid_rows = mask.any(axis=1)
first_non_zero_indices[~valid_rows] = -1 # Mark invalid rows with -1
# Index of the first non-zero value across all channels
max_index = np.max(first_non_zero_indices[valid_rows])
# Map from event time stamp[s] into the sample index of the LFP signals
ts_session_start = int(events_KIN['times'][0][0] * lfp_sfreq)
try:
# Crop the LFP signals based on the start and end of the first and last events, respectively.
ts_session_stop = int(events_KIN['times'][-1][-1] * lfp_sfreq)
if ts_session_start < ts_session_stop:
lfp_data = lfp_data[:, ts_session_start:ts_session_stop]
else:
print("Session start time is after session stop time. Check the events data.")
except Exception as e:
print(f"Error during cropping LFP data: {e}")
return first_non_zero_indices, lfp_data
[docs]
def trim_data(
lfp_data: np.ndarray,
events: np.ndarray,
sfreq: float,
threshold: float = 1e-6
) -> Tuple[np.ndarray, np.ndarray]:
"""
Trims the beginning of the LFP data by removing leading segments where the signal contains only zero or NaN values,
and adjusts the events' onsets accordingly. Only trims at the beginning if it contains no signal data.
Args:
lfp_data (np.ndarray): 2D array of LFP data with shape (n_channels, n_samples).
events (np.ndarray): 2D array of events with shape (n_events, 3), where the second column represents onsets.
sfreq (float): Sampling frequency of the LFP data.
threshold (float, optional): Value below which the data is considered as "no recorded data". Defaults to 1e-6.
Returns:
Tuple[np.ndarray, np.ndarray]: Tuple containing:
- Trimmed LFP data (2D array with shape (n_channels, trimmed_n_samples)).
- Adjusted events (2D array with shape (n_events, 3)).
Prints:
- Number of samples removed.
- Number of seconds removed.
- Number of samples shifted for the onsets.
"""
# Identify the indices where the data is not NaN or zero
non_zero_indices = np.any(np.abs(lfp_data) > threshold, axis=0)
# Find the start index of valid data
start_index = np.argmax(non_zero_indices)
# If the start index is 0, there is no need to trim
if start_index == 0:
print("No trimming needed as the beginning of signal is not flat.")
return lfp_data, events
# Trim the LFP data
trimmed_lfp_data = lfp_data[:, start_index:]
# Calculate the number of samples removed
samples_removed = start_index
seconds_removed = samples_removed / sfreq
# Adjust the events by shifting the onsets
adjusted_events = events.copy()
adjusted_events[:, 0] -= start_index # Shift event onsets
# Print the number of samples removed and shifted
print(f"Number of samples removed: {start_index}")
print(f"Number of seconds removed: {seconds_removed:.2f} seconds")
return trimmed_lfp_data, adjusted_events
[docs]
def define_normal_walking_events(normal_walking_event_id: int,
events_mod_start: np.ndarray,
gap_sample_length: int,
epoch_sample_length: int,
n_samples: int) -> np.ndarray:
"""
Defines normal walking events by creating intervals between modulation events
and constructing an array of event onsets for normal walking periods.
Args:
normal_walking_event_id (int): The event ID assigned to normal walking events.
events_mod_start (np.ndarray): Array of modulation event start times, where each row
contains the sample onset of a modulation event.
gap_sample_length (int): Length of the gap in samples to create before and after each modulation event.
epoch_sample_length (int): The length of the epochs in samples for normal walking events.
n_samples (int): Total number of samples in the signal.
Returns:
np.ndarray: Array containing the normal walking events. Each row contains three values:
- The onset of the normal walking event in samples.
- A dummy value (always zero).
- The event ID (as provided in `normal_walking_event_id`).
"""
# Calculate gap boundaries (before and after each modulation event)
gap_boundaries = np.column_stack(
(events_mod_start[:, 0] - gap_sample_length,
events_mod_start[:, 0] + gap_sample_length)
)
# Construct the output array (normal walking ranges)
normal_walking_ranges = np.vstack((
np.array([epoch_sample_length, gap_boundaries[0, 0]]), # First interval
np.column_stack((gap_boundaries[:-1, 1], gap_boundaries[1:, 0])), # Middle intervals
np.array([gap_boundaries[-1, 1], n_samples]) # Last interval
))
# Ensure no events are generated in gap areas by creating a mask
mask = normal_walking_ranges[:, 0] <= normal_walking_ranges[:, 1]
# Apply the mask to filter out invalid ranges
normal_walking_ranges = normal_walking_ranges[mask]
# Generate walking onsets by constructing intervals based on epoch_sample_length
walking_onsets = np.concatenate(
[np.arange(boundary[0], min(boundary[1], n_samples - epoch_sample_length) + epoch_sample_length, epoch_sample_length)
for boundary in normal_walking_ranges]
)
# Filter out any walking onsets that exceed n_samples
walking_onsets = walking_onsets[walking_onsets + epoch_sample_length <= n_samples]
# Create the normal walking event array with the provided event ID
normal_walking_events = np.column_stack((
walking_onsets.astype(int),
np.zeros_like(walking_onsets, dtype=int),
np.ones_like(walking_onsets, dtype=int) * normal_walking_event_id
))
return normal_walking_events
[docs]
def create_epochs_with_events(lfp_raw: mne.io.Raw,
events_mod_start: np.ndarray,
normal_walking_events: np.ndarray,
gait_modulation_event_id: int,
normal_walking_event_id: int,
epoch_tmin: float,
epoch_tmax: float,
event_dict: Dict[str, int]) -> Tuple[np.ndarray, mne.Epochs]:
"""
Combines modulation and normal walking events, sorts them by onset time, and creates MNE epochs.
Args:
lfp_raw (mne.io.Raw): The raw LFP signal data.
events_mod_start (np.ndarray): Array of modulation event start times.
normal_walking_events (np.ndarray): Array of normal walking events.
gait_modulation_event_id (int): Event ID for gait modulation events.
normal_walking_event_id (int): Event ID for normal walking events.
epoch_tmin (float): Start time for epochs (in seconds).
epoch_tmax (float): End time for epochs (in seconds).
event_dict (Dict[str, int]): Dictionary mapping event class names to event IDs.
Returns:
Tuple[np.ndarray, mne.Epochs]:
- Array containing the combined and sorted events.
- The MNE Epochs object containing the epoched data for both modulation and normal walking events.
"""
# Combine modulation and normal walking events
events = np.vstack((events_mod_start, normal_walking_events))
events = events[np.argsort(events[:, 0])] #TODO Sort events by onset time
# Create MNE Epochs object
epochs = mne.Epochs(
lfp_raw,
events,
event_dict,
tmin=epoch_tmin,
tmax=epoch_tmax,
baseline=None,
preload=True,
verbose=40
)
return events, epochs
[docs]
@staticmethod
def pad_data(trials, max_length, padding_value=0, position="end"):
"""Pad all trials to the specified max_length with the given padding value."""
padded_trials = []
for trial in trials:
pad_size = max(0, max_length - trial.shape[1]) # Calculate required padding based on trial length
if position == "end":
padded_trial = np.pad(trial, ((0, 0), (0, pad_size)), mode='constant', constant_values=padding_value)
else: # position == "start"
padded_trial = np.pad(trial, ((0, 0), (pad_size, 0)), mode='constant', constant_values=padding_value)
padded_trials.append(padded_trial)
return np.array(padded_trials)
[docs]
@staticmethod
def truncate_data(trials, target_length, position="end"):
"""Truncate all trials to the specified target_length."""
truncated_trials = []
for trial in trials:
if trial.shape[1] > target_length:
if position == "end":
truncated_trial = trial[:, :target_length]
else: # position == "start"
truncated_trial = trial[:, -target_length:]
truncated_trials.append(truncated_trial)
else:
truncated_trials.append(trial)
return np.array(truncated_trials)
[docs]
@staticmethod
def pad_or_truncate(trials: np.ndarray, config: Dict[str, Any], target_length: Optional[Union[int, float, str]] = None) -> List[np.ndarray]:
"""
Pads or truncates the given trials to the target length based on the provided configuration.
Parameters:
trials (List[np.ndarray]): List of trial data arrays to be padded or truncated.
target_length (Optional[Union[int, float, str]]): The target length to pad or truncate the trials to. If None, the target length will be determined from the config.
config (Dict[str, Any]): Configuration dictionary containing padding and truncation settings.
Returns:
List[np.ndarray]: List of trial data arrays after padding or truncation.
"""
# Apply padding if enabled
if config['data_preprocessing']['padding']['enabled']:
if not target_length:
target_length = config['data_preprocessing']['padding']['target_length']
if target_length == "max":
target_length = max([trial.shape[1] for trial in trials])
trials = DataProcessor.pad_data(
trials,
max_length=target_length,
padding_value=config['data_preprocessing']['padding']['padding_value'],
position=config['data_preprocessing']['padding']['padding_position']
)
# Apply truncation if enabled
if config['data_preprocessing']['truncation']['enabled']:
if not target_length:
target_length = config['data_preprocessing']['truncation']['target_length']
if target_length == "min":
target_length = min([trial.shape[1] for trial in trials])
trials = DataProcessor.truncate_data(
trials,
target_length=target_length,
position=config['data_preprocessing']['truncation']['truncation_position']
)
return trials
[docs]
@staticmethod
def determine_adaptive_filter_params(patients_epochs: Dict[str, mne.EpochsArray],
notch_freq: float = 50.0,
verbose: bool = True) -> Dict[str, any]:
"""
Determine optimal filtering parameters based on epoch characteristics.
This function analyzes the first patient's epochs to determine optimal filter
parameters that will be applied consistently across all patients.
Parameters:
-----------
patients_epochs : Dict[str, mne.EpochsArray]
Dictionary of patient EpochsArray objects to analyze
notch_freq : float, optional
Notch filter center frequency in Hz (default: 50.0)
verbose : bool, optional
Whether to print parameter determination details (default: True)
Returns:
--------
Dict[str, any]
Dictionary containing determined filter parameters:
- 'filter_order': int
- 'highpass_freq': float
- 'notch_width': float
- 'epoch_duration': float
- 'freq_resolution': float
- 'sfreq': float
- 'nyquist': float
- 'warnings': List[str]
"""
# Get parameters from first patient (should be same for all)
first_patient = next(iter(patients_epochs.values()))
data = first_patient.get_data()
sfreq = first_patient.info['sfreq']
n_epochs, n_channels, n_samples = data.shape
# Calculate epoch characteristics
epoch_duration = n_samples / sfreq
freq_resolution = sfreq / n_samples
nyquist = sfreq / 2
if verbose:
print(f"\n{'='*60}")
print(f"DETERMINING ADAPTIVE FILTER PARAMETERS")
print(f"{'='*60}")
print(f"EPOCH CHARACTERISTICS (consistent across all patients):")
print(f" - Duration: {epoch_duration:.2f} s ({n_samples} samples)")
print(f" - Sampling rate: {sfreq} Hz")
print(f" - Frequency resolution: {freq_resolution:.2f} Hz")
print(f" - Nyquist frequency: {nyquist} Hz")
# Determine filter order based on epoch duration
if epoch_duration >= 2.0:
filter_order = 4 # Standard order for long epochs
duration_category = "long"
elif epoch_duration >= 1.0:
filter_order = 3 # Reduced order for medium epochs
duration_category = "medium"
else:
filter_order = 2 # Minimal order for short epochs
duration_category = "short"
# Determine high-pass frequency based on epoch duration
if epoch_duration >= 2.0:
highpass_freq = 0.5 # Conservative for long epochs
elif epoch_duration >= 1.0:
highpass_freq = 1.0 # Moderate for medium epochs
else:
highpass_freq = 2.0 # Higher for short epochs to avoid instability
# Determine notch width based on frequency resolution
if freq_resolution <= 1.0:
notch_width = 2.0 # Narrow notch for good resolution
resolution_category = "high"
elif freq_resolution <= 2.0:
notch_width = 4.0 # Moderate notch for adequate resolution
resolution_category = "moderate"
else:
notch_width = 6.0 # Wide notch for poor resolution
resolution_category = "low"
# Check for potential issues and generate warnings
warnings = []
if epoch_duration < 1.0:
warnings.append(f"Short epochs ({epoch_duration:.2f}s) may cause edge artifacts")
if freq_resolution > 2.0:
warnings.append(f"Poor frequency resolution ({freq_resolution:.2f} Hz) for precise notch filtering")
if highpass_freq * 2 > freq_resolution * 3:
warnings.append(f"High-pass frequency ({highpass_freq} Hz) may be too high for epoch length")
if verbose:
if warnings:
print(f"\nANALYSIS WARNINGS:")
for warning in warnings:
print(f" - {warning}")
print(f"\nDETERMINED FILTER PARAMETERS:")
print(f" - Filter order: {filter_order} (optimized for {duration_category} epochs)")
print(f" - High-pass frequency: {highpass_freq} Hz")
print(f" - Notch width: {notch_width} Hz (optimized for {resolution_category} resolution)")
print(f" - Notch range: {notch_freq-notch_width/2:.1f}-{notch_freq+notch_width/2:.1f} Hz")
reasoning = f"""
PARAMETER SELECTION REASONING:
- Epoch duration ({epoch_duration:.2f}s) → filter_order={filter_order}
- Frequency resolution ({freq_resolution:.2f} Hz) → notch_width={notch_width} Hz
- Stability considerations → highpass_freq={highpass_freq} Hz
"""
print(reasoning)
return {
'filter_order': filter_order,
'highpass_freq': highpass_freq,
'notch_width': notch_width,
'epoch_duration': epoch_duration,
'freq_resolution': freq_resolution,
'sfreq': sfreq,
'nyquist': nyquist,
'warnings': warnings,
'duration_category': duration_category,
'resolution_category': resolution_category
}
@staticmethod
def _apply_single_filter(signal: np.ndarray,
sos: np.ndarray,
zero_phase: bool = True,
filter_name: str = "filter") -> np.ndarray:
"""
Apply a single filter to a signal with explicit phase control.
Parameters:
-----------
signal : np.ndarray
Input signal to filter
sos : np.ndarray
Second-order sections filter coefficients
zero_phase : bool, optional
Whether to apply zero-phase filtering (default: True)
filter_name : str, optional
Name of filter for error reporting (default: "filter")
Returns:
--------
np.ndarray
Filtered signal
Raises:
-------
Exception
If filtering fails, returns original signal and prints warning
"""
try:
if zero_phase:
# Forward-backward filtering: no phase distortion, double filter order
filtered_signal = sp_signal.sosfiltfilt(sos, signal)
else:
# Standard filtering: faster but introduces phase delays
filtered_signal = sp_signal.sosfilt(sos, signal)
return filtered_signal
except Exception as e:
print(f"Warning: {filter_name} failed - {str(e)}")
return signal # Return original signal if filtering fails
@staticmethod
def _apply_notch_filter(signal: np.ndarray,
notch_freq: float,
notch_width: float,
filter_order: int,
nyquist: float,
zero_phase: bool = True,
verbose: bool = False) -> np.ndarray:
"""
Apply notch filter for power line noise removal.
Frequency Domain: Removes narrow band around notch_freq
Phase Domain: Controlled by zero_phase parameter
"""
# Check if notch frequency is reasonable
low_cutoff = max((notch_freq - notch_width/2) / nyquist, 0.01)
high_cutoff = min((notch_freq + notch_width/2) / nyquist, 0.99)
if low_cutoff < high_cutoff and notch_freq < nyquist:
# Design bandstop filter (removes frequencies between low_cutoff and high_cutoff)
sos = sp_signal.butter(filter_order, [low_cutoff, high_cutoff],
btype='bandstop', output='sos')
return DataProcessor._apply_single_filter(signal, sos, zero_phase, "Notch filter")
else:
if verbose:
print(f"\nWarning: Notch filter skipped (invalid frequency range)")
return signal
@staticmethod
def _apply_highpass_filter(signal: np.ndarray,
highpass_freq: float,
filter_order: int,
nyquist: float,
zero_phase: bool = True,
verbose: bool = False) -> np.ndarray:
"""
Apply high-pass filter for drift and DC removal.
Frequency Domain: Removes frequencies below highpass_freq
Phase Domain: Controlled by zero_phase parameter
"""
highpass_freq_norm = highpass_freq / nyquist
if highpass_freq_norm < 0.99:
# Design high-pass filter (removes frequencies below cutoff)
sos = sp_signal.butter(filter_order, highpass_freq_norm,
btype='highpass', output='sos')
return DataProcessor._apply_single_filter(signal, sos, zero_phase, "High-pass filter")
else:
if verbose:
print(f"\nWarning: High-pass filter skipped (frequency too high)")
return signal
@staticmethod
def _apply_lowpass_filter(signal: np.ndarray,
lowpass_freq: float,
filter_order: int,
nyquist: float,
zero_phase: bool = True,
verbose: bool = False) -> np.ndarray:
"""
Apply low-pass filter for high-frequency noise removal.
Frequency Domain: Removes frequencies above lowpass_freq
Phase Domain: Controlled by zero_phase parameter
"""
lowpass_freq_norm = lowpass_freq / nyquist
if lowpass_freq_norm < 0.99:
# Design low-pass filter (removes frequencies above cutoff)
sos = sp_signal.butter(filter_order, lowpass_freq_norm,
btype='lowpass', output='sos')
return DataProcessor._apply_single_filter(signal, sos, zero_phase, "Low-pass filter")
else:
if verbose:
print(f"\nWarning: Low-pass filter skipped (frequency too high)")
return signal
[docs]
@staticmethod
def apply_adaptive_filtering_to_epochs(patients_epochs: Dict[str, mne.EpochsArray],
apply_notch: bool = True,
apply_highpass: bool = True,
apply_lowpass: bool = False,
zero_phase: bool = True,
notch_freq: float = 50.0,
notch_width: float = None,
highpass_freq: float = None,
lowpass_freq: float = 100.0,
filter_order: int = None,
padding_method: str = 'reflect',
verbose: bool = True) -> Dict[str, mne.EpochsArray]:
"""
Apply filtering to STN LFP epochs with explicit separation of frequency and phase filtering.
Now accepts pre-determined parameters or determines them automatically if None.
Use determine_adaptive_filter_params() to get optimal parameters first.
Parameters:
-----------
patients_epochs : Dict[str, mne.EpochsArray]
Dictionary of patient EpochsArray objects to be filtered
apply_notch : bool, optional
Whether to apply notch filter for power line noise (default: True)
apply_highpass : bool, optional
Whether to apply high-pass filter for drift removal (default: True)
apply_lowpass : bool, optional
Whether to apply low-pass filter (default: False, not recommended for HCTSA)
zero_phase : bool, optional
Whether to apply zero-phase filtering (forward-backward) vs standard filtering.
True: Uses sosfiltfilt() - preserves timing, no phase distortion (default: True)
False: Uses sosfilt() - faster but introduces phase delays
notch_freq : float, optional
Notch filter center frequency in Hz (default: 50.0)
notch_width : float, optional
Notch filter width in Hz. If None, will be determined automatically
highpass_freq : float, optional
High-pass filter frequency in Hz. If None, will be determined automatically
lowpass_freq : float, optional
Low-pass filter frequency in Hz (default: 100.0)
filter_order : int, optional
Filter order. If None, will be determined automatically
padding_method : str, optional
Method for edge padding ('reflect', 'zero', 'constant') (default: 'reflect')
verbose : bool, optional
Whether to print filtering information (default: True)
Returns:
--------
Dict[str, mne.EpochsArray]
Dictionary of filtered EpochsArray objects
"""
# Determine parameters if not provided
auto_params_needed = any(param is None for param in [filter_order, highpass_freq, notch_width])
if auto_params_needed:
params = DataProcessor.determine_adaptive_filter_params(patients_epochs, notch_freq, verbose=False)
# Use determined parameters for any None values
if filter_order is None:
filter_order = params['filter_order']
if highpass_freq is None:
highpass_freq = params['highpass_freq']
if notch_width is None:
notch_width = params['notch_width']
sfreq = params['sfreq']
nyquist = params['nyquist']
else:
# Use provided parameters
first_patient = next(iter(patients_epochs.values()))
sfreq = first_patient.info['sfreq']
nyquist = sfreq / 2
filtered_epochs = {}
if verbose:
print(f"Filtering {len(patients_epochs)} patients...")
print(f"FILTERING CONFIGURATION:")
print(f" FREQUENCY DOMAIN FILTERS:")
if apply_notch:
print(f" - Notch filter: {notch_freq-notch_width/2:.1f}-{notch_freq+notch_width/2:.1f} Hz (removes power line noise)")
if apply_highpass:
print(f" - High-pass filter: >{highpass_freq} Hz (removes DC/drift/movement artifacts)")
if apply_lowpass:
print(f" - Low-pass filter: <{lowpass_freq} Hz (removes high-frequency noise)")
print(f" TIME DOMAIN PROCESSING:")
print(f" - Phase filtering: {'Zero-phase (sosfiltfilt)' if zero_phase else 'Standard (sosfilt)'}")
print(f" - Filter order: {filter_order}")
print(f" - Edge padding: {padding_method}")
print()
for i, (patient_name, epochs) in enumerate(patients_epochs.items()):
# Get epoch parameters
data = epochs.get_data() # Shape: (n_epochs, n_channels, n_samples)
info = epochs.info.copy()
n_epochs, n_channels, n_samples = data.shape
# Show progress for each patient
if verbose:
print(f" {patient_name}: {n_epochs} epochs", end=" ")
# Apply filtering to each epoch and channel
filtered_data = data.copy()
# Calculate padding length (typically 3-5 times the filter order)
if padding_method != 'zero':
pad_length = min(filter_order * 5, n_samples // 4)
else:
pad_length = 0
for epoch_idx in range(n_epochs):
for ch_idx in range(n_channels):
signal_data = data[epoch_idx, ch_idx, :]
# Apply padding if requested
if pad_length > 0:
if padding_method == 'reflect':
# Reflect signal at boundaries
padded_signal = np.pad(signal_data, pad_length, mode='reflect')
elif padding_method == 'constant':
# Pad with edge values
padded_signal = np.pad(signal_data, pad_length, mode='edge')
else:
padded_signal = signal_data
# STEP 1: Apply notch filter for power line noise (FREQUENCY DOMAIN)
# Removes narrow frequency band around notch_freq (e.g., 50 Hz power line)
if apply_notch:
padded_signal = DataProcessor._apply_notch_filter(
padded_signal, notch_freq, notch_width, filter_order,
nyquist, zero_phase, verbose and i == 0
)
# STEP 2: Apply high-pass filter for drift removal (FREQUENCY DOMAIN)
# Removes frequencies below highpass_freq (DC offset, slow drifts, movement artifacts)
if apply_highpass:
padded_signal = DataProcessor._apply_highpass_filter(
padded_signal, highpass_freq, filter_order,
nyquist, zero_phase, verbose and i == 0
)
# STEP 3: Apply low-pass filter for noise removal (FREQUENCY DOMAIN, optional)
# Removes frequencies above lowpass_freq (high-frequency noise, aliasing)
if apply_lowpass:
padded_signal = DataProcessor._apply_lowpass_filter(
padded_signal, lowpass_freq, filter_order,
nyquist, zero_phase, verbose and i == 0
)
# Remove padding
if pad_length > 0:
filtered_signal = padded_signal[pad_length:-pad_length]
else:
filtered_signal = padded_signal
filtered_data[epoch_idx, ch_idx, :] = filtered_signal
# Create new EpochsArray with filtered data
filtered_epochs_obj = mne.EpochsArray(
filtered_data,
info,
events=epochs.events.copy() if hasattr(epochs, 'events') else None,
event_id=epochs.event_id.copy() if hasattr(epochs, 'event_id') else None,
verbose=False
)
filtered_epochs[patient_name] = filtered_epochs_obj
if verbose:
total_epochs = sum(len(epochs) for epochs in filtered_epochs.values())
print(f"Filtering complete: {total_epochs:,} epochs processed")
return filtered_epochs