Modules
- class gaitmod.BaseModel(config_path=None, **kwargs)[source]
Bases:
ABC- class MyMeanAbsoluteError(name='my_mae', **kwargs)[source]
Bases:
Metric- reset_state()[source]
Reset all of the metric state variables.
This function is called between epochs/steps, when a metric is evaluated during training.
- class MyMeanSquaredError(name='my_mse', **kwargs)[source]
Bases:
Metric- reset_state()[source]
Reset all of the metric state variables.
This function is called between epochs/steps, when a metric is evaluated during training.
- class R2Score(name='r2_score', **kwargs)[source]
Bases:
Metric- reset_state()[source]
Reset all of the metric state variables.
This function is called between epochs/steps, when a metric is evaluated during training.
- evaluate(results, fold=None)[source]
Evaluate metrics based on true labels and predictions.
Parameters: - results: Dictionary containing predictions (y_pred) and true values (y_test or y_true) for the fold. # - y_true: Ground truth labels, usually y_test # - y_pred: Predicted labels or values - fold: Fold number to print in the logs. If None, the fold number is not printed.
- static load(model_path, model_type, config_path, **kwargs)[source]
Load a Keras model and wrap it in a concrete subclass of BaseModel.
- train(X, y, train_idx, test_idx, callbacks=None)[source]
Trains a model for a specific fold.
- Parameters:
model – The model to be trained.
X – Input features.
y – Target values.
train_idx – Training indices.
test_idx – Testing indices.
callbacks – Callbacks for LSTM training (optional).
- Returns:
Dictionary containing predictions and true values for the fold.
- class gaitmod.ClassificationLSTMModel(model_type='lstm', config_path=None)[source]
Bases:
ClassificationModels
- class gaitmod.CustomGridSearchCV(estimator, param_grid, *, scoring=None, n_jobs=None, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score=nan, return_train_score=False)[source]
Bases:
GridSearchCVNot used for now
- fit(X, y=None, groups=None, **fit_params)[source]
Run fit with all sets of parameters.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training vector, where n_samples is the number of samples and n_features is the number of features.
y (array-like of shape (n_samples, n_output) or (n_samples,), default=None) – Target relative to X for classification or regression; None for unsupervised learning.
groups (array-like of shape (n_samples,), default=None) – Group labels for the samples used while splitting the dataset into train/test set. Only used in conjunction with a “Group” cv instance (e.g.,
GroupKFold).**fit_params (dict of str -> object) –
Parameters passed to the fit method of the estimator.
If a fit parameter is an array-like whose length is equal to num_samples then it will be split across CV groups along with X and y. For example, the sample_weight parameter is split because len(sample_weights) = len(X).
- Returns:
self – Instance of fitted estimator.
- Return type:
- class gaitmod.CustomTrainingLogger(fold=0)[source]
Bases:
Callback- on_epoch_begin(epoch, logs=None)[source]
Called at the start of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.
- on_epoch_end(epoch, logs=None)[source]
Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.
- class gaitmod.DataProcessor[source]
Bases:
object- static apply_adaptive_filtering_to_epochs(patients_epochs, apply_notch=True, apply_highpass=True, apply_lowpass=False, zero_phase=True, notch_freq=50.0, notch_width=None, highpass_freq=None, lowpass_freq=100.0, filter_order=None, padding_method='reflect', verbose=True)[source]
Apply filtering to STN LFP epochs with explicit separation of frequency and phase filtering.
Now accepts pre-determined parameters or determines them automatically if None. Use determine_adaptive_filter_params() to get optimal parameters first.
- Return type:
Parameters:
- patients_epochsDict[str, mne.EpochsArray]
Dictionary of patient EpochsArray objects to be filtered
- apply_notchbool, optional
Whether to apply notch filter for power line noise (default: True)
- apply_highpassbool, optional
Whether to apply high-pass filter for drift removal (default: True)
- apply_lowpassbool, optional
Whether to apply low-pass filter (default: False, not recommended for HCTSA)
- zero_phasebool, optional
Whether to apply zero-phase filtering (forward-backward) vs standard filtering. True: Uses sosfiltfilt() - preserves timing, no phase distortion (default: True) False: Uses sosfilt() - faster but introduces phase delays
- notch_freqfloat, optional
Notch filter center frequency in Hz (default: 50.0)
- notch_widthfloat, optional
Notch filter width in Hz. If None, will be determined automatically
- highpass_freqfloat, optional
High-pass filter frequency in Hz. If None, will be determined automatically
- lowpass_freqfloat, optional
Low-pass filter frequency in Hz (default: 100.0)
- filter_orderint, optional
Filter order. If None, will be determined automatically
- padding_methodstr, optional
Method for edge padding (‘reflect’, ‘zero’, ‘constant’) (default: ‘reflect’)
- verbosebool, optional
Whether to print filtering information (default: True)
Returns:
: Dict[str, mne.EpochsArray]
Dictionary of filtered EpochsArray objects
- static check_and_fix_lfp_chs_name(data, fix_chs_name=False)[source]
Check the LFP channel names in the data and optionally fix them.
- static check_event_order(events_kin_samples, events_kin_labels, trial_idx, exclude_label_names=[])[source]
Check if the values in each column of events_kin_samples are in ascending order for the given trial index.
- Parameters:
events_kin_samples (np.ndarray) – A 2D array of shape (n_events, n_trials) containing the sample indices of each event.
events_kin_labels (np.ndarray | List) – A 1D array of shape (n_events,) or a list containing the event names.
trial_idx (int) – The index of the trial to check.
- Raises:
ValueError – If the values in the column are not in ascending order.
- static clean_data(data, data_type, key_map, new_events_order, sfreq, fix_chs_names=False, verbose=False)[source]
Cleans data by applying renaming, sorting/filtering, and removing NaNs.
- Parameters:
data (dict) – The dataset to clean.
data_type (str) – The type of data being processed.
key_map (dict) – Mapping of old keys to new keys.
new_events_order (list) – Ordered list of events to keep.
sfreq (float) – Sampling frequency for processing.
verbose (bool, optional) – Whether to print debug information. Defaults to False.
- static convert_lfp_label(labels)[source]
Convert long-form LFP labels to a more descriptive format.
- Parameters:
labels (np.ndarray) – Array of long-form LFP labels.
- Returns:
- List of converted labels in the format ‘LFP_SIDE_DIGITDIGIT’,
where SIDE is ‘L’ or ‘R’, and DIGIT represents the numeric part extracted from each original label.
- Return type:
- create_epochs_with_events(events_mod_start, normal_walking_events, gait_modulation_event_id, normal_walking_event_id, epoch_tmin, epoch_tmax, event_dict)[source]
Combines modulation and normal walking events, sorts them by onset time, and creates MNE epochs.
- Parameters:
lfp_raw (mne.io.Raw) – The raw LFP signal data.
events_mod_start (np.ndarray) – Array of modulation event start times.
normal_walking_events (np.ndarray) – Array of normal walking events.
gait_modulation_event_id (int) – Event ID for gait modulation events.
normal_walking_event_id (int) – Event ID for normal walking events.
epoch_tmin (float) – Start time for epochs (in seconds).
epoch_tmax (float) – End time for epochs (in seconds).
event_dict (Dict[str, int]) – Dictionary mapping event class names to event IDs.
- Returns:
Array containing the combined and sorted events.
The MNE Epochs object containing the epoched data for both modulation and normal walking events.
- Return type:
Tuple[np.ndarray, mne.Epochs]
- static create_events_array(events_kin, sfreq)[source]
Create an MNE-compatible events array from events_KIN data, handling NaN values.
Parameters:
- events_kinDict[str, Any]
Dictionary containing events_KIN data with ‘times’ and ‘label’ keys.
- sfreqfloat
Sampling frequency of the LFP data.
Returns:
: mne_events : np.ndarray
MNE-compatible events array (shape: [n_events, 3]). Each row contains [sample_index, 0, event_id].
- event_idDict[str, int]
Dictionary mapping event labels to unique event IDs.
- static create_trials(events_kin_samples, data, sfreq, config, subject_id, session_id)[source]
Extracts trial data based on event kinematic samples. :type events_kin_samples:
ndarray:param events_kin_samples: A 2D numpy array containing kinematic event samples with shape (n_events, n_trials). :type events_kin_samples: np.ndarray :type data:ndarray:param data: A 2D numpy array containing measurement data (LFP, EEG, IMU, etc.) with shape (n_channels, n_times). :type data: np.ndarray :type sfreq:float:param sfreq: Sampling frequency of the LFP data (Hz). :type sfreq: float :type config:Dict[str,Any] :param config: Configuration settings for padding and truncating the data. :type config: Dict[str, Any] :type subject_id:str:param subject_id: Subject identifier for debugging messages. :type subject_id: str :type session_id:str:param session_id: Session identifier for debugging messages. :type session_id: str- Returns:
- A tuple containing:
trials_data: List of 2D numpy arrays containing trial data with shape (n_channels, n_times).
valid_trials_mask: List of booleans indicating which trials were successfully processed (True) vs skipped (False).
- Return type:
Tuple[List[np.ndarray], List[bool]]
- static crop_lfp_to_event_times(lfp_data, events_KIN, lfp_sfreq)[source]
Crop the LFP data to match the time span of the provided event times.
Parameters: lfp_data (np.ndarray): 2D array of LFP data with shape (channels, samples). events_KIN (dict): Dictionary containing event times. lfp_sfreq (float): Sampling frequency of the LFP data.
Returns: tuple: Tuple containing:
first_non_zero_indices (np.ndarray): Indices of the first non-zero values in each row.
cropped_lfp_data (np.ndarray): Cropped LFP data array.
- define_normal_walking_events(events_mod_start, gap_sample_length, epoch_sample_length, n_samples)[source]
Defines normal walking events by creating intervals between modulation events and constructing an array of event onsets for normal walking periods.
- Parameters:
normal_walking_event_id (int) – The event ID assigned to normal walking events.
events_mod_start (np.ndarray) – Array of modulation event start times, where each row contains the sample onset of a modulation event.
gap_sample_length (int) – Length of the gap in samples to create before and after each modulation event.
epoch_sample_length (int) – The length of the epochs in samples for normal walking events.
n_samples (int) – Total number of samples in the signal.
- Returns:
- Array containing the normal walking events. Each row contains three values:
The onset of the normal walking event in samples.
A dummy value (always zero).
The event ID (as provided in normal_walking_event_id).
- Return type:
np.ndarray
- static determine_adaptive_filter_params(patients_epochs, notch_freq=50.0, verbose=True)[source]
Determine optimal filtering parameters based on epoch characteristics.
This function analyzes the first patient’s epochs to determine optimal filter parameters that will be applied consistently across all patients.
Parameters:
- patients_epochsDict[str, mne.EpochsArray]
Dictionary of patient EpochsArray objects to analyze
- notch_freqfloat, optional
Notch filter center frequency in Hz (default: 50.0)
- verbosebool, optional
Whether to print parameter determination details (default: True)
Returns:
: Dict[str, any]
Dictionary containing determined filter parameters: - ‘filter_order’: int - ‘highpass_freq’: float - ‘notch_width’: float - ‘epoch_duration’: float - ‘freq_resolution’: float - ‘sfreq’: float - ‘nyquist’: float - ‘warnings’: List[str]
- static get_trial_event_indices(events_kin_samples, events_kin_labels, trial_idx)[source]
Computes the sample indices of event occurrences relative to the start of a given trial.
- Parameters:
events_kin_samples (np.ndarray) – A 2D array of shape (n_events, n_trials) containing the sample indices of each event.
events_kin_labels (np.ndarray | List) – A 1D array of shape (n_events,) or a list containing the event names.
trial_idx (int) – The index of the trial for which to compute event indices.
- Returns:
Event indices relative to the start of the trial.
- Return type:
np.ndarray (int)
Notes
If an event occurs outside the trial boundaries, a random valid index within the trial range is assigned.
The returned indices are relative to the trial’s start index.
- np_to_dict()[source]
Converts a numpy array-based data structure to a dictionary-like structure.
Parameters:
- data_structurenp.ndarray
Numpy array-based data structure from which to extract metadata and events.
Returns:
: dict
A dictionary-like structure containing the extracted data from the input numpy array-based structure.
- static pad_data(trials, max_length, padding_value=0, position='end')[source]
Pad all trials to the specified max_length with the given padding value.
- static pad_or_truncate(trials, config, target_length=None)[source]
Pads or truncates the given trials to the target length based on the provided configuration.
- Parameters:
trials (List[np.ndarray]) – List of trial data arrays to be padded or truncated.
target_length (Optional[Union[int, float, str]]) – The target length to pad or truncate the trials to. If None, the target length will be determined from the config.
config (Dict[str, Any]) – Configuration dictionary containing padding and truncation settings.
- Returns:
List of trial data arrays after padding or truncation.
- Return type:
List[np.ndarray]
- static process_lfp_data(data, n_sessions, lfp_sfreq, event_of_interest, mod_start_event_id, normal_walking_event_id, gap_sample_length, epoch_sample_length, epoch_tmin, epoch_tmax, epoch_duration, event_dict, info, reject_criteria, config, verbose=False)[source]
- static process_trials_and_events(data, data_type, sfreq, config, verbose=False)[source]
Prepares data and event indices for each subject across multiple sessions.
- This function processes data such as LFP, IMU, EEG, etc by extracting trials from kinematic events and storing them in three dictionaries:
subjects_data_dict: A dictionary where each subject’s trials (as 2D NumPy arrays) are stored. Each trial array has the shape (n_channels x n_times), where n_channels is the number of channels or sensors.
subjects_event_sample_idx_dict: A dictionary where each subject’s event indices across trials are stored. The array shape is (n_trials x n_events). Each row corresponds to a specifc trial index and each column to a specific event label’s index.
subjects_session_trial_mapping: A dictionary where each subject’s trials are mapped to their originating sessions. Format: {patient_id: {session_name: [trial_indices]}}
- Parameters:
data (Dict[str, Dict[str, dict]]) – Nested dictionary where the outer key is the subject name, and the inner dictionary contains session data for each subject.
data_type (str) – Type of data to be processed. Must be one of the following: ‘data_acc’, ‘data_EEG’, ‘data_EMG’, ‘data_giro’, ‘data_LFP’.
sfreq (float) – Sampling frequency of the specified data (in Hz).
config (dict) – Configuration dictionary used in the data processing.
verbose (bool, optional) – Whether to print processing details for each subject and session. Defaults to True.
- Returns:
subjects_data_dict: Dictionary of subjects with trials as 2D NumPy arrays.
subjects_event_sample_idx_dict: Dictionary of subjects with event indices across trials.
subjects_session_trial_mapping: Dictionary mapping each subject’s trials to their originating sessions.
- Return type:
Tuple[Dict[str, List[np.ndarray]], Dict[str, np.ndarray], Dict[str, Dict[str, List[int]]]]
- static remove_nan_events(data, sfreq)[source]
Remove NaN values from event times and generate corresponding sample indices.
Parameters: events_kin (Dict[str, Any]): Dictionary containing event times and labels.
Expected keys are ‘times’ (numpy array of event times) and ‘label’ (list of event labels).
sfreq (float): Sampling frequency.
Returns: Tuple[np.ndarray, np.ndarray]: A tuple containing:
events_kin_times_valid (np.ndarray): Array of valid event times with NaN values removed.
events_kin_samples_valid (np.ndarray): Array of valid event times converted to sample indices.
- static remove_trials_with_short_labels(patients_epochs, subjects_lfp_data_dict=None, subjects_event_idx_dict=None, min_epochs_per_class=1, verbose=True)[source]
Removes trials that contain epochs from only one class (unbalanced trials) from ALL related data structures.
This method ensures consistency across patients_epochs, subjects_lfp_data_dict, and subjects_event_idx_dict by removing the same trials from all data structures simultaneously.
- Return type:
Tuple[Dict[str,EpochsArray],Optional[Dict[str,List[ndarray]]],Optional[Dict[str,List[ndarray]]]]
Parameters:
- patients_epochsDict[str, mne.EpochsArray]
Dictionary of patient EpochsArray objects containing segmented trials
- subjects_lfp_data_dictDict[str, List[np.ndarray]], optional
Dictionary of raw LFP trial data per subject (will be filtered if provided)
- subjects_event_idx_dictDict[str, List[np.ndarray]], optional
Dictionary of event indices per trial per subject (will be filtered if provided)
- min_epochs_per_classint, optional
Minimum number of epochs required for each class within a trial (default: 1)
- verbosebool, optional
Whether to print filtering details (default: True)
Returns:
: Tuple containing:
Dict[str, mne.EpochsArray]: Filtered patients_epochs
Dict[str, List[np.ndarray]]: Filtered subjects_lfp_data_dict (or original if None provided)
Dict[str, List[np.ndarray]]: Filtered subjects_event_idx_dict (or original if None provided)
- static rename_keys(data, key_map)[source]
Recursively renames keys in a dictionary or list in place.
- static rename_lfp_channels(labels)[source]
Convert long-form LFP channels to a more descriptive format and ensure uniqueness.
- Parameters:
labels (np.ndarray) – Array of long-form LFP labels.
- Returns:
- List of converted labels in the format ‘LFP_SIDE_DIGITDIGIT’,
where SIDE is ‘L’ or ‘R’, and DIGIT represents the numeric part extracted from each original channel.
- Return type:
- static segment_and_label_trials(trials, subjects_event_idx_dict, ch_names, sfreq=250, window_size=0.5, overlap=0.5, expand_transition=0.0, discard_ambiguous=False, mod_start_idx=2, mod_end_idx=6, event_dict=None)[source]
Segments LFP trials into overlapping windows, assigns labels (0: normal, 1: modulation), and stores the results in an MNE Epochs object.
- Parameters:
trials (list of np.ndarray) – List of LFP trials, each with shape (n_channels, n_samples).
subjects_event_idx_dict (list of np.ndarray) – List of event indices for each trial.
sfreq (int) – Sampling frequency of the signals.
window_size (float) – Size of each segment in seconds.
overlap (float) – Overlap fraction between consecutive windows.
expand_transition (float) – Amount of time (seconds) to expand mod_start/mod_end.
discard_ambiguous (bool) – Whether to remove windows that overlap both states.
mod_start_idx (int) – Index for modulation start event.
mod_end_idx (int) – Index for modulation end event.
event_dict (dict[str, int]) – Dictionary mapping class names to labels. Must contain exactly 2 classes for normal walking and modulation phases.
- Returns:
MNE Epochs object containing the segmented data.
- Return type:
- static sort_and_filter_events(data, new_order)[source]
Reorder events in the ‘events_KIN’ key of the provided data based on a new order. Can also be used to filter out events by excluding those not in the new order.
- Parameters:
- Raises:
KeyError – If ‘events_KIN’ is not found in any data_type.
- Return type:
- trim_data(events, sfreq, threshold=1e-06)[source]
Trims the beginning of the LFP data by removing leading segments where the signal contains only zero or NaN values, and adjusts the events’ onsets accordingly. Only trims at the beginning if it contains no signal data.
- Parameters:
lfp_data (np.ndarray) – 2D array of LFP data with shape (n_channels, n_samples).
events (np.ndarray) – 2D array of events with shape (n_events, 3), where the second column represents onsets.
sfreq (float) – Sampling frequency of the LFP data.
threshold (float, optional) – Value below which the data is considered as “no recorded data”. Defaults to 1e-6.
- Returns:
- Tuple containing:
Trimmed LFP data (2D array with shape (n_channels, trimmed_n_samples)).
Adjusted events (2D array with shape (n_events, 3)).
- Return type:
Tuple[np.ndarray, np.ndarray]
- Prints:
Number of samples removed.
Number of seconds removed.
Number of samples shifted for the onsets.
- class gaitmod.FeatureExtractor[source]
Bases:
object- static extract_band_power(epochs, freq_bands)[source]
Extracts band power features from each channel of the given MNE epochs without averaging across channels.
Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq).
Returns: - band_power_dict: A dictionary where each key corresponds to a frequency band, and values are lists of features
(1 for each epoch and channel) for that specific band.
- static extract_band_psd(epochs, freq_bands)[source]
Extracts PSD features for specified frequency bands from the given MNE epochs without averaging across channels or frequencies.
Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq).
Returns: - psd_dict: A dictionary where each key corresponds to a frequency band, and values are arrays of shape (n_epochs, n_channels, n_frequencies) representing the raw PSD values for each band.
- static extract_psd_and_band_power(epochs, freq_bands, fmin, fmax)[source]
Extracts power spectral density (PSD) and band power features from each channel of the given MNE epochs (without averaging across channels or epochs).
Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq).
Returns: - psd_array: A NumPy array of shape (n_epochs, n_channels, n_frequencies) containing the raw PSD values. - band_power_array: A NumPy array of shape (n_epochs, n_channels, n_bands) containing the mean band power
features for each frequency band across epochs and channels.
- static extract_windowed_stat_features(trials, methods=['mean'], window_size=100, step_size=50, verbose=False)[source]
- static flatten_features(psds, band_power)[source]
Flatten the features and combine PSD and band power.
Parameters: - psds: 3D array of PSD values - band_power: 3D array of band power values
Returns: - Combined flattened features as a 2D array
- static reshape_lfp_data(lfp_data, mode='flat_time')[source]
Reform the LFP data based on the specified mode.
Parameters:
- lfp_datanp.ndarray
The input LFP data with shape (trials, channels, times).
- modestr
The reshaping mode. Options: - “flat_time”: Reshape to (trials * times, channels). - “flat_channel”: Reshape to (trials, times * channels).
Returns:
: reshaped_data : np.ndarray
The reshaped data based on the selected mode.
- class gaitmod.FeatureExtractor2(sfreq, features_config)[source]
Bases:
object- extract_features_with_labels(epochs, feature_handling='flatten_chs')[source]
Extracts features and labels from an MNE Epochs object.
Parameters: - epochs (mne.Epochs): The epochs object containing LFP data and labels. - feature_handling (str): The strategy to handle multi-channel features.
Returns: - X (np.ndarray): Extracted features of shape (n_epochs, n_features). - y (np.ndarray): Corresponding labels of shape (n_epochs,).
- select_feature(feature_matrix, feature_name, feature_handling='flatten_chs')[source]
Extracts a specific feature slice from the feature matrix using the feature name.
Parameters: - feature_matrix (np.ndarray): The full feature matrix returned by extract_features. - feature_name (str): The name of the feature to extract. - feature_handling (str): How the features are handled across channels.
Options: ‘flatten_chs’, ‘average_chs’, ‘separate_chs’.
Returns: - np.ndarray: The sliced feature matrix for the specified feature name.
- class gaitmod.LSTMClassifier(input_shape, hidden_dims=[50], activations=['tanh'], recurrent_activations=['sigmoid'], dropout=0.2, dense_units=1, dense_activation='sigmoid', optimizer='adam', lr=0.001, patience=5, epochs=10, batch_size=32, threshold=0.5, loss='binary_crossentropy', callbacks=None, mask_vals=(0.0, 2))[source]
Bases:
BaseEstimator,ClassifierMixin
- class gaitmod.LinearRegressionModel(config_path=None, **kwargs)[source]
Bases:
RegressionModels
- class gaitmod.RegressionLSTMModel(config_path=None, **kwargs)[source]
Bases:
RegressionModels
- class gaitmod.Visualise[source]
Bases:
object- static plot_all_patients_trials(subjects_lfp_data_dict, patients_epochs, sfreq, save_path, fig_name, subjects_session_trial_mapping=None, max_display_trials=None, window_size=0.5, overlap=0.5, expand_transition=0.0, sharex=True, sharey=True)[source]
Plots LFP data for all patients and trials with epoch-level temporal coloring. This function creates a grid of subplots where each column represents a patient and each row represents a trial. Each subplot displays the LFP data for a specific trial of a specific patient, with channels offset for clarity. Different time segments (epochs) within each continuous signal are colored based on their individual class labels. Trials belonging to the same session are visually grouped with colored boxes.
Parameters:
- subjects_lfp_data_dictDict[str, List[np.ndarray]]
A dictionary where keys are patient names and values are lists of numpy arrays representing trials. Each numpy array should have shape (num_channels, num_samples).
- patients_epochsDict[str, mne.Epochs]
A dictionary where keys are patient names and values are MNE Epochs objects containing epoch-level data. The epochs.events array contains [start_time_samples, trial_id, class_label] for each epoch.
- sfreqfloat
Sampling frequency of the data.
- save_pathstr
Path to save the figure. If empty, the figure will not be saved.
- fig_namestr
Name of the figure file to be saved.
- subjects_session_trial_mappingDict[str, Dict[str, List[int]]], optional
Dictionary mapping patients to sessions and their corresponding trial indices. Format: {patient_id: {session_name: [trial_indices]}}
- max_display_trialsint, optional
Maximum number of trials to display per patient. If None (default), all trials will be displayed.
- sharexbool, optional
Whether to share the x-axis among subplots. Default is True.
- shareybool, optional
Whether to share the y-axis among subplots. Default is True.
Returns:
: None
- static plot_all_trial_lengths(subjects_lfp_data_dict, lfp_sfreq, save_path, fig_name)[source]
Plots the distribution and boxplot of trial lengths from LFP data.
Parameters: subjects_lfp_data_dict (dict): Dictionary containing LFP data for multiple subjects. Each subject’s data is a list of trials, where each trial is a 2D array. lfp_sfreq (float): Sampling frequency of the LFP data. save_path (str): Directory path where the plot image will be saved. fig_name (str): Name of the figure file to be saved (without extension).
Returns: None
- Return type:
- static plot_epoch_psd_comparison(epochs_orig, epochs_filt, epoch_idx, fix_chs_names=None, save_path=None, fig_name='epoch_psd_comparison', show_fig=True)[source]
Plot PSD comparison for a single epoch (original vs filtered) for all channels.
- static plot_epoch_spectrogram_and_psd(signal_orig, signal_filt, patient_name='Patient', epoch_idx=0, channel_idx=0, fmin=0.1, fmax=50, psd_fmax=100, nperseg=None, noverlap=None, figsize=(16, 12), save_path=None, fig_name=None, show_fig=True)[source]
- static plot_epochs_with_events(patients_epochs, subject_id, window_size, sfreq, event_names, subjects_event_idx_dict, show_fig=True, save_path=None, fig_name='epochs_with_events{subject_id}.png')[source]
Plot the epochs for a given subject with event markers for mod_start and mod_end.
- Parameters:
patients_epochs (Dict[str, mne.EpochsArray]) – A dictionary of patient IDs mapped to MNE EpochsArray objects.
subject_id (str) – The ID of the subject to plot.
window_size (float) – The size of the window in seconds.
sfreq (int) – The sampling frequency of the data.
event_names (list[str]) – The list of event names in the correct order.
subjects_event_idx_dict (Dict[str, Dict[int, List[int]]]) – A dictionary mapping subject IDs to dictionaries of trial IDs mapped to event indices.
show_fig (bool, optional) – Flag to display the plot, by default True.
save_path (Optional[str], optional) – Path to save the plot, by default None.
fig_name (str, optional) – Name of the file to save the plot, by default ‘epochs_with_events{subject_id}.png’.
- Return type:
None
- static plot_event_class_histogram(events, event_dict, n_sessions, show_fig=True, save_fig=True, file_name='event_class_histogram.png')[source]
Creates a histogram to plot the number of onsets for each class of the event array, with event IDs mapped to descriptive labels.
Parameters: events: np.ndarray - Array containing event data with at least three columns: [time, session_id, event_id]. event_dict: Dict[int, str] - A dictionary mapping event IDs to descriptive labels. n_sessions: int - The number of sessions to plot. show_fig: bool - Flag to show the figure or not. save_fig: bool - Flag to save the figure or not. file_name: str - The filename for saving the figure.
- Return type:
- static plot_event_occurrence(events, epoch_sample_length, lfp_sfreq, event_dict, n_sessions, show_fig=True, save_fig=True, file_name='event_occurrence.png')[source]
Creates a horizontal bar plot of event occurrences for each session with different colors for each event type. Maps event IDs to descriptive labels using the provided event_dict.
- Return type:
- static plot_individual_trial_counts(patients_epochs, save_path, fig_name)[source]
Plots the label counts per trial for all patients.
- static plot_label_distribution_boxplot_all_patients(patients_epochs, save_path, fig_name)[source]
Creates a boxplot showing the distribution of the number of labels of each class per trial for all patients.
- static plot_raw_data_with_annotations(lfp_raw_list, scaling=50.0, folder_path='images')[source]
Plot the raw LFP data with annotations for each session.
Parameters: lfp_raw_list : list of mne.io.Raw
List of raw LFP data for each session.
- output_folderstr
Folder where the plots will be saved.
- static plot_session_counts(df_session_counts, save_path, fig_name)[source]
Plots a histogram of the number of recording sessions per patient.
Parameters: df_session_counts (pd.DataFrame): DataFrame containing patient IDs and the number of recording sessions. save_path (str): Directory where the plot will be saved. fig_name (str): Name of the saved plot file.
Returns: None
- Return type:
- static plot_single_patient_trial_psd(subjects_lfp_data_dict, patient_name, trial_idx, sfreq, fix_chs_names=None, save_path=None, fig_name=None, subjects_session_trial_mapping=None, fmin=0.1, fmax=100, figsize=(18, 8), show_fig=True)[source]
Plots Power Spectral Density (PSD) for a single patient’s specific trial continuous LFP signal.
This function creates a comprehensive frequency domain visualization showing the power distribution across different frequencies for all channels of LFP data from a single trial. The PSD is computed using multitaper estimation method which provides robust spectral estimates with good statistical properties and reduced variance compared to traditional periodogram methods.
The visualization includes all channels plotted on the same axes with different colors for easy comparison, frequency band annotations (Delta, Theta, Alpha, Beta, Gamma) to aid in neurophysiological interpretation, and session information when available. This is particularly useful for analyzing brain oscillations relevant to gait and movement disorders in Parkinson’s disease research.
- Return type:
Parameters:
- subjects_lfp_data_dictDict[str, List[np.ndarray]]
A dictionary where keys are patient names and values are lists of numpy arrays representing trials. Each numpy array should have shape (num_channels, num_samples). This is the main data structure containing all LFP recordings.
- patient_namestr
Name of the patient to plot. Must be a valid key in subjects_lfp_data_dict. Typically follows naming convention like ‘PW_HK59’, ‘PW_SN61’, etc.
- trial_idxint
Index of the trial to plot (0-based). Must be within the range of available trials for the specified patient. Each trial represents a separate recording session.
- sfreqfloat
Sampling frequency of the data in Hz. Used for PSD computation and frequency axis scaling. Typically 250 Hz for LFP recordings.
- fix_chs_namesList[str], optional
List of channel names corresponding to the channels in the LFP data.
- save_pathstr, optional
Path to save the figure. If None, the figure will not be saved. Directory will be created if it doesn’t exist. Default is None.
- fig_namestr, optional
Name of the figure file to be saved. If None and save_path is provided, an auto-generated filename will be used following the pattern: ‘{patient_name}_trial{trial_idx}_psd.png’. Default is None.
- subjects_session_trial_mappingDict[str, Dict[str, List[int]]], optional
Dictionary mapping patients to sessions and their corresponding trial indices. Format: {patient_id: {session_name: [trial_indices]}}. Used to display session information in the plot title. Default is None.
- fminfloat, optional
Minimum frequency for PSD computation in Hz. Lower bound of the frequency range to analyze. Default is 0.1 Hz.
- fmaxfloat, optional
Maximum frequency for PSD computation in Hz. Upper bound of the frequency range to analyze. Default is 100 Hz.
- figsizetuple, optional
Figure size (width, height) in inches. Controls the overall plot dimensions. Default is (18, 8).
Returns:
: None
The function displays the plot and optionally saves it to disk. No return value.
Raises:
- ValueError
If patient_name is not found in subjects_lfp_data_dict, or if trial_idx exceeds the number of available trials for the specified patient.
Technical Details:
Uses multitaper method with adaptive weighting for robust spectral estimation
Full normalization ensures proper power spectral density units
Frequency resolution depends on signal length and windowing parameters
dB conversion: PSD_dB = 10 * log10(PSD_linear)
- static plot_single_patient_trial_spectrogram(subjects_lfp_data_dict, patient_name, trial_idx, sfreq, fix_chs_names=None, save_path=None, fig_name=None, subjects_session_trial_mapping=None, fmin=0.1, fmax=100, channel_idx=0, figsize=(16, 10), show_fig=True)[source]
Plots spectrogram for a single patient’s specific trial continuous LFP signal. This function creates a time-frequency representation showing how power changes over time.
- Return type:
Parameters:
- subjects_lfp_data_dictDict[str, List[np.ndarray]]
A dictionary where keys are patient names and values are lists of numpy arrays representing trials. Each numpy array should have shape (num_channels, num_samples).
- patient_namestr
Name of the patient to plot.
- trial_idxint
Index of the trial to plot (0-based).
- sfreqfloat
Sampling frequency of the data.
- fix_chs_namesList[str], optional
List of channel names corresponding to the channels in the LFP data.
- save_pathstr, optional
Path to save the figure. If None, the figure will not be saved.
- fig_namestr, optional
Name of the figure file to be saved. If None, auto-generated.
- subjects_session_trial_mappingDict[str, Dict[str, List[int]]], optional
Dictionary mapping patients to sessions and their corresponding trial indices. Format: {patient_id: {session_name: [trial_indices]}}
- fminfloat, optional
Minimum frequency for spectrogram computation. Default is 0.1 Hz.
- fmaxfloat, optional
Maximum frequency for spectrogram computation. Default is 100 Hz.
- channel_idxint, optional
Index of the channel to plot spectrogram for. Default is 0 (first channel).
- figsizetuple, optional
Figure size (width, height). Default is (16, 10).
Returns:
: None
- static plot_total_label_counts(patients_epochs, save_path, fig_name)[source]
Plots the total label counts across all trials for all patients.
- static plot_trial_counts(df, save_path, fig_name)[source]
Plot histogram for the number of trials per subject.
Parameters: df (pd.DataFrame): DataFrame containing subject IDs and number of trials. save_path (str): Directory path to save the plot. fig_name (str): Name of the file to save the plot.
Returns: None
- Return type:
- static plot_trial_lengths_per_subject_boxplot(subjects_lfp_data_dict, lfp_sfreq, save_path=None, fig_name=None)[source]
Plots a boxplot of trial lengths for each subject with filled colors and a secondary axis for time in seconds.
Parameters: - subjects_lfp_data_dict (dict): Dictionary with subject IDs as keys and lists of 2D NumPy arrays as values. - lfp_sfreq (float): Sampling frequency of the LFP data. Used to convert trial lengths to seconds. - save_path (str, optional): If specified, saves the figure to the given path.
Returns: - None
- Return type:
- static plot_trial_lengths_per_subject_distr(subjects_lfp_data_dict, lfp_sfreq, save_path, fig_name)[source]
Plots the distribution and boxplot of trial lengths from LFP data for each subject.
- Return type:
Parameters: - subjects_lfp_data_dict (dict): Dictionary containing LFP data for multiple subjects.
Each subject’s data is a list of trials, where each trial is a 2D array.
lfp_sfreq (float): Sampling frequency of the LFP data.
save_path (str): Directory path where the plot images will be saved.
Returns: - None