Modules

class gaitmod.BaseModel(config_path=None, **kwargs)[source]

Bases: ABC

class MyMeanAbsoluteError(name='my_mae', **kwargs)[source]

Bases: Metric

reset_state()[source]

Reset all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

result()[source]

Compute the current metric value.

Returns:

A scalar tensor, or a dictionary of scalar tensors.

update_state(y_true, y_pred, sample_weight=None)[source]

Accumulate statistics for the metric.

class MyMeanSquaredError(name='my_mse', **kwargs)[source]

Bases: Metric

reset_state()[source]

Reset all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

result()[source]

Compute the current metric value.

Returns:

A scalar tensor, or a dictionary of scalar tensors.

update_state(y_true, y_pred, sample_weight=None)[source]

Accumulate statistics for the metric.

class R2Score(name='r2_score', **kwargs)[source]

Bases: Metric

reset_state()[source]

Reset all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

result()[source]

Compute the current metric value.

Returns:

A scalar tensor, or a dictionary of scalar tensors.

update_state(y_true, y_pred, sample_weight=None)[source]

Accumulate statistics for the metric.

evaluate(results, fold=None)[source]

Evaluate metrics based on true labels and predictions.

Parameters: - results: Dictionary containing predictions (y_pred) and true values (y_test or y_true) for the fold. # - y_true: Ground truth labels, usually y_test # - y_pred: Predicted labels or values - fold: Fold number to print in the logs. If None, the fold number is not printed.

initialize_metrics()[source]
static load(model_path, model_type, config_path, **kwargs)[source]

Load a Keras model and wrap it in a concrete subclass of BaseModel.

save(model_path)[source]

Save the underlying Keras model.

train(X, y, train_idx, test_idx, callbacks=None)[source]

Trains a model for a specific fold.

Parameters:
  • model – The model to be trained.

  • X – Input features.

  • y – Target values.

  • train_idx – Training indices.

  • test_idx – Testing indices.

  • callbacks – Callbacks for LSTM training (optional).

Returns:

Dictionary containing predictions and true values for the fold.

class gaitmod.ClassificationLSTMModel(model_type='lstm', config_path=None)[source]

Bases: ClassificationModels

build_model(input_shape)[source]
fit(X_train, y_train, callbacks)[source]
predict(X_test)[source]
class gaitmod.ClassificationModels(model_type='logistic', **kwargs)[source]

Bases: BaseModel

class gaitmod.CustomGridSearchCV(estimator, param_grid, *, scoring=None, n_jobs=None, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score=nan, return_train_score=False)[source]

Bases: GridSearchCV

Not used for now

fit(X, y=None, groups=None, **fit_params)[source]

Run fit with all sets of parameters.

Parameters:
  • X (array-like of shape (n_samples, n_features)) – Training vector, where n_samples is the number of samples and n_features is the number of features.

  • y (array-like of shape (n_samples, n_output) or (n_samples,), default=None) – Target relative to X for classification or regression; None for unsupervised learning.

  • groups (array-like of shape (n_samples,), default=None) – Group labels for the samples used while splitting the dataset into train/test set. Only used in conjunction with a “Group” cv instance (e.g., GroupKFold).

  • **fit_params (dict of str -> object) –

    Parameters passed to the fit method of the estimator.

    If a fit parameter is an array-like whose length is equal to num_samples then it will be split across CV groups along with X and y. For example, the sample_weight parameter is split because len(sample_weights) = len(X).

Returns:

self – Instance of fitted estimator.

Return type:

object

class gaitmod.CustomTrainingLogger(fold=0)[source]

Bases: Callback

on_batch_end(batch, logs=None)[source]

A backwards compatibility alias for on_train_batch_end.

on_epoch_begin(epoch, logs=None)[source]

Called at the start of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters:
  • epoch – Integer, index of epoch.

  • logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.

on_epoch_end(epoch, logs=None)[source]

Called at the end of an epoch.

Subclasses should override for any actions to run. This function should only be called during TRAIN mode.

Parameters:
  • epoch – Integer, index of epoch.

  • logs – Dict, metric results for this training epoch, and for the validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.

safe_format(value)[source]
class gaitmod.DataProcessor[source]

Bases: object

static apply_adaptive_filtering_to_epochs(patients_epochs, apply_notch=True, apply_highpass=True, apply_lowpass=False, zero_phase=True, notch_freq=50.0, notch_width=None, highpass_freq=None, lowpass_freq=100.0, filter_order=None, padding_method='reflect', verbose=True)[source]

Apply filtering to STN LFP epochs with explicit separation of frequency and phase filtering.

Now accepts pre-determined parameters or determines them automatically if None. Use determine_adaptive_filter_params() to get optimal parameters first.

Return type:

Dict[str, EpochsArray]

Parameters:

patients_epochsDict[str, mne.EpochsArray]

Dictionary of patient EpochsArray objects to be filtered

apply_notchbool, optional

Whether to apply notch filter for power line noise (default: True)

apply_highpassbool, optional

Whether to apply high-pass filter for drift removal (default: True)

apply_lowpassbool, optional

Whether to apply low-pass filter (default: False, not recommended for HCTSA)

zero_phasebool, optional

Whether to apply zero-phase filtering (forward-backward) vs standard filtering. True: Uses sosfiltfilt() - preserves timing, no phase distortion (default: True) False: Uses sosfilt() - faster but introduces phase delays

notch_freqfloat, optional

Notch filter center frequency in Hz (default: 50.0)

notch_widthfloat, optional

Notch filter width in Hz. If None, will be determined automatically

highpass_freqfloat, optional

High-pass filter frequency in Hz. If None, will be determined automatically

lowpass_freqfloat, optional

Low-pass filter frequency in Hz (default: 100.0)

filter_orderint, optional

Filter order. If None, will be determined automatically

padding_methodstr, optional

Method for edge padding (‘reflect’, ‘zero’, ‘constant’) (default: ‘reflect’)

verbosebool, optional

Whether to print filtering information (default: True)

Returns:

: Dict[str, mne.EpochsArray]

Dictionary of filtered EpochsArray objects

static check_and_fix_lfp_chs_name(data, fix_chs_name=False)[source]

Check the LFP channel names in the data and optionally fix them.

Parameters:
  • data (Dict[str, Any]) – The dataset containing LFP channel information.

  • fix_chs_name (Optional[Union[bool, List[str]]]) – If True, fix the channel names to a default list. If a list is provided, use it to fix the channel names. Defaults to False.

Return type:

None

static check_event_order(events_kin_samples, events_kin_labels, trial_idx, exclude_label_names=[])[source]

Check if the values in each column of events_kin_samples are in ascending order for the given trial index.

Parameters:
  • events_kin_samples (np.ndarray) – A 2D array of shape (n_events, n_trials) containing the sample indices of each event.

  • events_kin_labels (np.ndarray | List) – A 1D array of shape (n_events,) or a list containing the event names.

  • trial_idx (int) – The index of the trial to check.

Raises:

ValueError – If the values in the column are not in ascending order.

static clean_data(data, data_type, key_map, new_events_order, sfreq, fix_chs_names=False, verbose=False)[source]

Cleans data by applying renaming, sorting/filtering, and removing NaNs.

Parameters:
  • data (dict) – The dataset to clean.

  • data_type (str) – The type of data being processed.

  • key_map (dict) – Mapping of old keys to new keys.

  • new_events_order (list) – Ordered list of events to keep.

  • sfreq (float) – Sampling frequency for processing.

  • verbose (bool, optional) – Whether to print debug information. Defaults to False.

static convert_lfp_label(labels)[source]

Convert long-form LFP labels to a more descriptive format.

Parameters:

labels (np.ndarray) – Array of long-form LFP labels.

Returns:

List of converted labels in the format ‘LFP_SIDE_DIGITDIGIT’,

where SIDE is ‘L’ or ‘R’, and DIGIT represents the numeric part extracted from each original label.

Return type:

list

create_epochs_with_events(events_mod_start, normal_walking_events, gait_modulation_event_id, normal_walking_event_id, epoch_tmin, epoch_tmax, event_dict)[source]

Combines modulation and normal walking events, sorts them by onset time, and creates MNE epochs.

Parameters:
  • lfp_raw (mne.io.Raw) – The raw LFP signal data.

  • events_mod_start (np.ndarray) – Array of modulation event start times.

  • normal_walking_events (np.ndarray) – Array of normal walking events.

  • gait_modulation_event_id (int) – Event ID for gait modulation events.

  • normal_walking_event_id (int) – Event ID for normal walking events.

  • epoch_tmin (float) – Start time for epochs (in seconds).

  • epoch_tmax (float) – End time for epochs (in seconds).

  • event_dict (Dict[str, int]) – Dictionary mapping event class names to event IDs.

Returns:

  • Array containing the combined and sorted events.

  • The MNE Epochs object containing the epoched data for both modulation and normal walking events.

Return type:

Tuple[np.ndarray, mne.Epochs]

static create_events_array(events_kin, sfreq)[source]

Create an MNE-compatible events array from events_KIN data, handling NaN values.

Return type:

Tuple[ndarray, Dict[str, int]]

Parameters:

events_kinDict[str, Any]

Dictionary containing events_KIN data with ‘times’ and ‘label’ keys.

sfreqfloat

Sampling frequency of the LFP data.

Returns:

: mne_events : np.ndarray

MNE-compatible events array (shape: [n_events, 3]). Each row contains [sample_index, 0, event_id].

event_idDict[str, int]

Dictionary mapping event labels to unique event IDs.

static create_trials(events_kin_samples, data, sfreq, config, subject_id, session_id)[source]

Extracts trial data based on event kinematic samples. :type events_kin_samples: ndarray :param events_kin_samples: A 2D numpy array containing kinematic event samples with shape (n_events, n_trials). :type events_kin_samples: np.ndarray :type data: ndarray :param data: A 2D numpy array containing measurement data (LFP, EEG, IMU, etc.) with shape (n_channels, n_times). :type data: np.ndarray :type sfreq: float :param sfreq: Sampling frequency of the LFP data (Hz). :type sfreq: float :type config: Dict[str, Any] :param config: Configuration settings for padding and truncating the data. :type config: Dict[str, Any] :type subject_id: str :param subject_id: Subject identifier for debugging messages. :type subject_id: str :type session_id: str :param session_id: Session identifier for debugging messages. :type session_id: str

Returns:

A tuple containing:
  • trials_data: List of 2D numpy arrays containing trial data with shape (n_channels, n_times).

  • valid_trials_mask: List of booleans indicating which trials were successfully processed (True) vs skipped (False).

Return type:

Tuple[List[np.ndarray], List[bool]]

static crop_lfp_to_event_times(lfp_data, events_KIN, lfp_sfreq)[source]

Crop the LFP data to match the time span of the provided event times.

Parameters: lfp_data (np.ndarray): 2D array of LFP data with shape (channels, samples). events_KIN (dict): Dictionary containing event times. lfp_sfreq (float): Sampling frequency of the LFP data.

Returns: tuple: Tuple containing:

  • first_non_zero_indices (np.ndarray): Indices of the first non-zero values in each row.

  • cropped_lfp_data (np.ndarray): Cropped LFP data array.

define_normal_walking_events(events_mod_start, gap_sample_length, epoch_sample_length, n_samples)[source]

Defines normal walking events by creating intervals between modulation events and constructing an array of event onsets for normal walking periods.

Parameters:
  • normal_walking_event_id (int) – The event ID assigned to normal walking events.

  • events_mod_start (np.ndarray) – Array of modulation event start times, where each row contains the sample onset of a modulation event.

  • gap_sample_length (int) – Length of the gap in samples to create before and after each modulation event.

  • epoch_sample_length (int) – The length of the epochs in samples for normal walking events.

  • n_samples (int) – Total number of samples in the signal.

Returns:

Array containing the normal walking events. Each row contains three values:
  • The onset of the normal walking event in samples.

  • A dummy value (always zero).

  • The event ID (as provided in normal_walking_event_id).

Return type:

np.ndarray

static determine_adaptive_filter_params(patients_epochs, notch_freq=50.0, verbose=True)[source]

Determine optimal filtering parameters based on epoch characteristics.

This function analyzes the first patient’s epochs to determine optimal filter parameters that will be applied consistently across all patients.

Return type:

Dict[str, any]

Parameters:

patients_epochsDict[str, mne.EpochsArray]

Dictionary of patient EpochsArray objects to analyze

notch_freqfloat, optional

Notch filter center frequency in Hz (default: 50.0)

verbosebool, optional

Whether to print parameter determination details (default: True)

Returns:

: Dict[str, any]

Dictionary containing determined filter parameters: - ‘filter_order’: int - ‘highpass_freq’: float - ‘notch_width’: float - ‘epoch_duration’: float - ‘freq_resolution’: float - ‘sfreq’: float - ‘nyquist’: float - ‘warnings’: List[str]

static generate_ch_locs(ch_names)[source]
static get_trial_event_indices(events_kin_samples, events_kin_labels, trial_idx)[source]

Computes the sample indices of event occurrences relative to the start of a given trial.

Parameters:
  • events_kin_samples (np.ndarray) – A 2D array of shape (n_events, n_trials) containing the sample indices of each event.

  • events_kin_labels (np.ndarray | List) – A 1D array of shape (n_events,) or a list containing the event names.

  • trial_idx (int) – The index of the trial for which to compute event indices.

Returns:

Event indices relative to the start of the trial.

Return type:

np.ndarray (int)

Notes

  • If an event occurs outside the trial boundaries, a random valid index within the trial range is assigned.

  • The returned indices are relative to the trial’s start index.

np_to_dict()[source]

Converts a numpy array-based data structure to a dictionary-like structure.

Return type:

Dict[str, Any]

Parameters:

data_structurenp.ndarray

Numpy array-based data structure from which to extract metadata and events.

Returns:

: dict

A dictionary-like structure containing the extracted data from the input numpy array-based structure.

static pad_data(trials, max_length, padding_value=0, position='end')[source]

Pad all trials to the specified max_length with the given padding value.

static pad_or_truncate(trials, config, target_length=None)[source]

Pads or truncates the given trials to the target length based on the provided configuration.

Parameters:
  • trials (List[np.ndarray]) – List of trial data arrays to be padded or truncated.

  • target_length (Optional[Union[int, float, str]]) – The target length to pad or truncate the trials to. If None, the target length will be determined from the config.

  • config (Dict[str, Any]) – Configuration dictionary containing padding and truncation settings.

Returns:

List of trial data arrays after padding or truncation.

Return type:

List[np.ndarray]

static process_events_kin(events_kin)[source]

Process kinesthetic event data.

Parameters:

events_kin (Dict[str, Any]) – Dictionary containing kinesthetic event data.

Return type:

None

static process_events_steps(events_steps)[source]

Process step event data.

Parameters:

events_steps (Dict[str, Any]) – Dictionary containing data

Return type:

None

static process_lfp_data(data, n_sessions, lfp_sfreq, event_of_interest, mod_start_event_id, normal_walking_event_id, gap_sample_length, epoch_sample_length, epoch_tmin, epoch_tmax, epoch_duration, event_dict, info, reject_criteria, config, verbose=False)[source]
static process_trials_and_events(data, data_type, sfreq, config, verbose=False)[source]

Prepares data and event indices for each subject across multiple sessions.

This function processes data such as LFP, IMU, EEG, etc by extracting trials from kinematic events and storing them in three dictionaries:
  • subjects_data_dict: A dictionary where each subject’s trials (as 2D NumPy arrays) are stored. Each trial array has the shape (n_channels x n_times), where n_channels is the number of channels or sensors.

  • subjects_event_sample_idx_dict: A dictionary where each subject’s event indices across trials are stored. The array shape is (n_trials x n_events). Each row corresponds to a specifc trial index and each column to a specific event label’s index.

  • subjects_session_trial_mapping: A dictionary where each subject’s trials are mapped to their originating sessions. Format: {patient_id: {session_name: [trial_indices]}}

Parameters:
  • data (Dict[str, Dict[str, dict]]) – Nested dictionary where the outer key is the subject name, and the inner dictionary contains session data for each subject.

  • data_type (str) – Type of data to be processed. Must be one of the following: ‘data_acc’, ‘data_EEG’, ‘data_EMG’, ‘data_giro’, ‘data_LFP’.

  • sfreq (float) – Sampling frequency of the specified data (in Hz).

  • config (dict) – Configuration dictionary used in the data processing.

  • verbose (bool, optional) – Whether to print processing details for each subject and session. Defaults to True.

Returns:

  • subjects_data_dict: Dictionary of subjects with trials as 2D NumPy arrays.

  • subjects_event_sample_idx_dict: Dictionary of subjects with event indices across trials.

  • subjects_session_trial_mapping: Dictionary mapping each subject’s trials to their originating sessions.

Return type:

Tuple[Dict[str, List[np.ndarray]], Dict[str, np.ndarray], Dict[str, Dict[str, List[int]]]]

static remove_nan_events(data, sfreq)[source]

Remove NaN values from event times and generate corresponding sample indices.

Parameters: events_kin (Dict[str, Any]): Dictionary containing event times and labels.

Expected keys are ‘times’ (numpy array of event times) and ‘label’ (list of event labels).

sfreq (float): Sampling frequency.

Returns: Tuple[np.ndarray, np.ndarray]: A tuple containing:

  • events_kin_times_valid (np.ndarray): Array of valid event times with NaN values removed.

  • events_kin_samples_valid (np.ndarray): Array of valid event times converted to sample indices.

static remove_trials_with_short_labels(patients_epochs, subjects_lfp_data_dict=None, subjects_event_idx_dict=None, min_epochs_per_class=1, verbose=True)[source]

Removes trials that contain epochs from only one class (unbalanced trials) from ALL related data structures.

This method ensures consistency across patients_epochs, subjects_lfp_data_dict, and subjects_event_idx_dict by removing the same trials from all data structures simultaneously.

Return type:

Tuple[Dict[str, EpochsArray], Optional[Dict[str, List[ndarray]]], Optional[Dict[str, List[ndarray]]]]

Parameters:

patients_epochsDict[str, mne.EpochsArray]

Dictionary of patient EpochsArray objects containing segmented trials

subjects_lfp_data_dictDict[str, List[np.ndarray]], optional

Dictionary of raw LFP trial data per subject (will be filtered if provided)

subjects_event_idx_dictDict[str, List[np.ndarray]], optional

Dictionary of event indices per trial per subject (will be filtered if provided)

min_epochs_per_classint, optional

Minimum number of epochs required for each class within a trial (default: 1)

verbosebool, optional

Whether to print filtering details (default: True)

Returns:

: Tuple containing:

  • Dict[str, mne.EpochsArray]: Filtered patients_epochs

  • Dict[str, List[np.ndarray]]: Filtered subjects_lfp_data_dict (or original if None provided)

  • Dict[str, List[np.ndarray]]: Filtered subjects_event_idx_dict (or original if None provided)

static rename_keys(data, key_map)[source]

Recursively renames keys in a dictionary or list in place.

Parameters:
  • data (dict, list, or any data structure) – The data to process.

  • key_map (dict) – A dictionary mapping old key names to new key names.

Returns:

Modifies the input data in place.

Return type:

None

static rename_lfp_channels(labels)[source]

Convert long-form LFP channels to a more descriptive format and ensure uniqueness.

Parameters:

labels (np.ndarray) – Array of long-form LFP labels.

Returns:

List of converted labels in the format ‘LFP_SIDE_DIGITDIGIT’,

where SIDE is ‘L’ or ‘R’, and DIGIT represents the numeric part extracted from each original channel.

Return type:

list

static segment_and_label_trials(trials, subjects_event_idx_dict, ch_names, sfreq=250, window_size=0.5, overlap=0.5, expand_transition=0.0, discard_ambiguous=False, mod_start_idx=2, mod_end_idx=6, event_dict=None)[source]

Segments LFP trials into overlapping windows, assigns labels (0: normal, 1: modulation), and stores the results in an MNE Epochs object.

Parameters:
  • trials (list of np.ndarray) – List of LFP trials, each with shape (n_channels, n_samples).

  • subjects_event_idx_dict (list of np.ndarray) – List of event indices for each trial.

  • ch_names (list of str) – List of channel names.

  • sfreq (int) – Sampling frequency of the signals.

  • window_size (float) – Size of each segment in seconds.

  • overlap (float) – Overlap fraction between consecutive windows.

  • expand_transition (float) – Amount of time (seconds) to expand mod_start/mod_end.

  • discard_ambiguous (bool) – Whether to remove windows that overlap both states.

  • mod_start_idx (int) – Index for modulation start event.

  • mod_end_idx (int) – Index for modulation end event.

  • event_dict (dict[str, int]) – Dictionary mapping class names to labels. Must contain exactly 2 classes for normal walking and modulation phases.

Returns:

MNE Epochs object containing the segmented data.

Return type:

mne.epochs.EpochsArray

static sort_and_filter_events(data, new_order)[source]

Reorder events in the ‘events_KIN’ key of the provided data based on a new order. Can also be used to filter out events by excluding those not in the new order.

Parameters:
  • data (Dict[str, Dict[str, dict]]) – The input data containing events to be reordered and filtered.

  • new_order (List[str]) – The desired order of event labels. Events not in this list will be filtered out.

Raises:

KeyError – If ‘events_KIN’ is not found in any data_type.

Return type:

Dict[str, Any]

trim_data(events, sfreq, threshold=1e-06)[source]

Trims the beginning of the LFP data by removing leading segments where the signal contains only zero or NaN values, and adjusts the events’ onsets accordingly. Only trims at the beginning if it contains no signal data.

Parameters:
  • lfp_data (np.ndarray) – 2D array of LFP data with shape (n_channels, n_samples).

  • events (np.ndarray) – 2D array of events with shape (n_events, 3), where the second column represents onsets.

  • sfreq (float) – Sampling frequency of the LFP data.

  • threshold (float, optional) – Value below which the data is considered as “no recorded data”. Defaults to 1e-6.

Returns:

Tuple containing:
  • Trimmed LFP data (2D array with shape (n_channels, trimmed_n_samples)).

  • Adjusted events (2D array with shape (n_events, 3)).

Return type:

Tuple[np.ndarray, np.ndarray]

Prints:
  • Number of samples removed.

  • Number of seconds removed.

  • Number of samples shifted for the onsets.

static truncate_data(trials, target_length, position='end')[source]

Truncate all trials to the specified target_length.

class gaitmod.FeatureExtractor[source]

Bases: object

static extract_band_power(epochs, freq_bands)[source]

Extracts band power features from each channel of the given MNE epochs without averaging across channels.

Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq).

Returns: - band_power_dict: A dictionary where each key corresponds to a frequency band, and values are lists of features

(1 for each epoch and channel) for that specific band.

static extract_band_psd(epochs, freq_bands)[source]

Extracts PSD features for specified frequency bands from the given MNE epochs without averaging across channels or frequencies.

Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq).

Returns: - psd_dict: A dictionary where each key corresponds to a frequency band, and values are arrays of shape (n_epochs, n_channels, n_frequencies) representing the raw PSD values for each band.

static extract_epoched_stat_features(epochs, methods=['mean', 'std', 'median'])[source]
static extract_psd_and_band_power(epochs, freq_bands, fmin, fmax)[source]

Extracts power spectral density (PSD) and band power features from each channel of the given MNE epochs (without averaging across channels or epochs).

Parameters: - epochs: mne.Epochs object containing the LFP data. - freq_bands: Dictionary where keys are the band names, and values are tuples with (low_freq, high_freq).

Returns: - psd_array: A NumPy array of shape (n_epochs, n_channels, n_frequencies) containing the raw PSD values. - band_power_array: A NumPy array of shape (n_epochs, n_channels, n_bands) containing the mean band power

features for each frequency band across epochs and channels.

static extract_windowed_stat_features(trials, methods=['mean'], window_size=100, step_size=50, verbose=False)[source]
static flatten_features(psds, band_power)[source]

Flatten the features and combine PSD and band power.

Parameters: - psds: 3D array of PSD values - band_power: 3D array of band power values

Returns: - Combined flattened features as a 2D array

static reshape_lfp_data(lfp_data, mode='flat_time')[source]

Reform the LFP data based on the specified mode.

Parameters:

lfp_datanp.ndarray

The input LFP data with shape (trials, channels, times).

modestr

The reshaping mode. Options: - “flat_time”: Reshape to (trials * times, channels). - “flat_channel”: Reshape to (trials, times * channels).

Returns:

: reshaped_data : np.ndarray

The reshaped data based on the selected mode.

class gaitmod.FeatureExtractor2(sfreq, features_config)[source]

Bases: object

extract_features(epochs, feature_handling='flatten_chs')[source]
extract_features_with_labels(epochs, feature_handling='flatten_chs')[source]

Extracts features and labels from an MNE Epochs object.

Parameters: - epochs (mne.Epochs): The epochs object containing LFP data and labels. - feature_handling (str): The strategy to handle multi-channel features.

Returns: - X (np.ndarray): Extracted features of shape (n_epochs, n_features). - y (np.ndarray): Corresponding labels of shape (n_epochs,).

select_feature(feature_matrix, feature_name, feature_handling='flatten_chs')[source]

Extracts a specific feature slice from the feature matrix using the feature name.

Parameters: - feature_matrix (np.ndarray): The full feature matrix returned by extract_features. - feature_name (str): The name of the feature to extract. - feature_handling (str): How the features are handled across channels.

Options: ‘flatten_chs’, ‘average_chs’, ‘separate_chs’.

Returns: - np.ndarray: The sliced feature matrix for the specified feature name.

class gaitmod.LSTMClassifier(input_shape, hidden_dims=[50], activations=['tanh'], recurrent_activations=['sigmoid'], dropout=0.2, dense_units=1, dense_activation='sigmoid', optimizer='adam', lr=0.001, patience=5, epochs=10, batch_size=32, threshold=0.5, loss='binary_crossentropy', callbacks=None, mask_vals=(0.0, 2))[source]

Bases: BaseEstimator, ClassifierMixin

build_model()[source]
calculate_class_weights(y)[source]
fit(X, y)[source]
static lr_schedule(epoch, lr)[source]
static masked_accuracy_score(y_true, y_pred)[source]
static masked_classification_report(y_true, y_pred, target_names=None, digits=4)[source]
static masked_confusion_matrix(y_true, y_pred)[source]
static masked_f1_score(y_true, y_pred)[source]
masked_loss_binary_crossentropy(y_true, y_pred)[source]
static masked_precision_score(y_true, y_pred)[source]
static masked_recall_score(y_true, y_pred)[source]
static masked_roc_auc_score(y_true, y_pred)[source]
predict(X)[source]
predict_proba(X)[source]
summary()[source]
class gaitmod.LinearRegressionModel(config_path=None, **kwargs)[source]

Bases: RegressionModels

fit(X_train, y_train)[source]
get_coefficients()[source]
get_intercept()[source]
predict(X_test)[source]
class gaitmod.MatFileReader(directory)[source]

Bases: object

read_data()[source]

Reads data from all .mat files in the directory.

Returns:

List of dictionaries, each containing the data from a .mat file.

Return type:

List[Dict[str, Any]]

class gaitmod.RegressionLSTMModel(config_path=None, **kwargs)[source]

Bases: RegressionModels

build_model(input_shape)[source]
data_generator(X, y, batch_size=32)[source]
fit(X_train, y_train, callbacks)[source]
predict(X_test)[source]
class gaitmod.RegressionModels(config_path=None, **kwargs)[source]

Bases: BaseModel

class gaitmod.Visualise[source]

Bases: object

static plot_all_patients_trials(subjects_lfp_data_dict, patients_epochs, sfreq, save_path, fig_name, subjects_session_trial_mapping=None, max_display_trials=None, window_size=0.5, overlap=0.5, expand_transition=0.0, sharex=True, sharey=True)[source]

Plots LFP data for all patients and trials with epoch-level temporal coloring. This function creates a grid of subplots where each column represents a patient and each row represents a trial. Each subplot displays the LFP data for a specific trial of a specific patient, with channels offset for clarity. Different time segments (epochs) within each continuous signal are colored based on their individual class labels. Trials belonging to the same session are visually grouped with colored boxes.

Parameters:

subjects_lfp_data_dictDict[str, List[np.ndarray]]

A dictionary where keys are patient names and values are lists of numpy arrays representing trials. Each numpy array should have shape (num_channels, num_samples).

patients_epochsDict[str, mne.Epochs]

A dictionary where keys are patient names and values are MNE Epochs objects containing epoch-level data. The epochs.events array contains [start_time_samples, trial_id, class_label] for each epoch.

sfreqfloat

Sampling frequency of the data.

save_pathstr

Path to save the figure. If empty, the figure will not be saved.

fig_namestr

Name of the figure file to be saved.

subjects_session_trial_mappingDict[str, Dict[str, List[int]]], optional

Dictionary mapping patients to sessions and their corresponding trial indices. Format: {patient_id: {session_name: [trial_indices]}}

max_display_trialsint, optional

Maximum number of trials to display per patient. If None (default), all trials will be displayed.

sharexbool, optional

Whether to share the x-axis among subplots. Default is True.

shareybool, optional

Whether to share the y-axis among subplots. Default is True.

Returns:

: None

static plot_all_trial_lengths(subjects_lfp_data_dict, lfp_sfreq, save_path, fig_name)[source]

Plots the distribution and boxplot of trial lengths from LFP data.

Parameters: subjects_lfp_data_dict (dict): Dictionary containing LFP data for multiple subjects. Each subject’s data is a list of trials, where each trial is a 2D array. lfp_sfreq (float): Sampling frequency of the LFP data. save_path (str): Directory path where the plot image will be saved. fig_name (str): Name of the figure file to be saved (without extension).

Returns: None

Return type:

None

static plot_epoch_psd_comparison(epochs_orig, epochs_filt, epoch_idx, fix_chs_names=None, save_path=None, fig_name='epoch_psd_comparison', show_fig=True)[source]

Plot PSD comparison for a single epoch (original vs filtered) for all channels.

static plot_epoch_spectrogram_and_psd(signal_orig, signal_filt, patient_name='Patient', epoch_idx=0, channel_idx=0, fmin=0.1, fmax=50, psd_fmax=100, nperseg=None, noverlap=None, figsize=(16, 12), save_path=None, fig_name=None, show_fig=True)[source]
static plot_epochs_with_events(patients_epochs, subject_id, window_size, sfreq, event_names, subjects_event_idx_dict, show_fig=True, save_path=None, fig_name='epochs_with_events{subject_id}.png')[source]

Plot the epochs for a given subject with event markers for mod_start and mod_end.

Parameters:
  • patients_epochs (Dict[str, mne.EpochsArray]) – A dictionary of patient IDs mapped to MNE EpochsArray objects.

  • subject_id (str) – The ID of the subject to plot.

  • window_size (float) – The size of the window in seconds.

  • sfreq (int) – The sampling frequency of the data.

  • event_names (list[str]) – The list of event names in the correct order.

  • subjects_event_idx_dict (Dict[str, Dict[int, List[int]]]) – A dictionary mapping subject IDs to dictionaries of trial IDs mapped to event indices.

  • show_fig (bool, optional) – Flag to display the plot, by default True.

  • save_path (Optional[str], optional) – Path to save the plot, by default None.

  • fig_name (str, optional) – Name of the file to save the plot, by default ‘epochs_with_events{subject_id}.png’.

Return type:

None

static plot_event_class_histogram(events, event_dict, n_sessions, show_fig=True, save_fig=True, file_name='event_class_histogram.png')[source]

Creates a histogram to plot the number of onsets for each class of the event array, with event IDs mapped to descriptive labels.

Parameters: events: np.ndarray - Array containing event data with at least three columns: [time, session_id, event_id]. event_dict: Dict[int, str] - A dictionary mapping event IDs to descriptive labels. n_sessions: int - The number of sessions to plot. show_fig: bool - Flag to show the figure or not. save_fig: bool - Flag to save the figure or not. file_name: str - The filename for saving the figure.

Return type:

None

static plot_event_occurrence(events, epoch_sample_length, lfp_sfreq, event_dict, n_sessions, show_fig=True, save_fig=True, file_name='event_occurrence.png')[source]

Creates a horizontal bar plot of event occurrences for each session with different colors for each event type. Maps event IDs to descriptive labels using the provided event_dict.

Return type:

None

static plot_individual_trial_counts(patients_epochs, save_path, fig_name)[source]

Plots the label counts per trial for all patients.

Parameters:
  • patients_epochs (dict) – Dictionary containing patient names as keys and MNE Epochs objects as values.

  • save_path (str) – Directory path where the plot image will be saved.

  • fig_name (str) – Name of the figure file to be saved (without extension).

static plot_label_distribution_boxplot_all_patients(patients_epochs, save_path, fig_name)[source]

Creates a boxplot showing the distribution of the number of labels of each class per trial for all patients.

Parameters:
  • patients_epochs (dict) – Dictionary containing MNE Epochs objects for each patient.

  • fig_save_path (str) – Path to save the figure.

static plot_raw_data_with_annotations(lfp_raw_list, scaling=50.0, folder_path='images')[source]

Plot the raw LFP data with annotations for each session.

Parameters: lfp_raw_list : list of mne.io.Raw

List of raw LFP data for each session.

output_folderstr

Folder where the plots will be saved.

static plot_session_counts(df_session_counts, save_path, fig_name)[source]

Plots a histogram of the number of recording sessions per patient.

Parameters: df_session_counts (pd.DataFrame): DataFrame containing patient IDs and the number of recording sessions. save_path (str): Directory where the plot will be saved. fig_name (str): Name of the saved plot file.

Returns: None

Return type:

None

static plot_single_patient_trial_psd(subjects_lfp_data_dict, patient_name, trial_idx, sfreq, fix_chs_names=None, save_path=None, fig_name=None, subjects_session_trial_mapping=None, fmin=0.1, fmax=100, figsize=(18, 8), show_fig=True)[source]

Plots Power Spectral Density (PSD) for a single patient’s specific trial continuous LFP signal.

This function creates a comprehensive frequency domain visualization showing the power distribution across different frequencies for all channels of LFP data from a single trial. The PSD is computed using multitaper estimation method which provides robust spectral estimates with good statistical properties and reduced variance compared to traditional periodogram methods.

The visualization includes all channels plotted on the same axes with different colors for easy comparison, frequency band annotations (Delta, Theta, Alpha, Beta, Gamma) to aid in neurophysiological interpretation, and session information when available. This is particularly useful for analyzing brain oscillations relevant to gait and movement disorders in Parkinson’s disease research.

Return type:

None

Parameters:

subjects_lfp_data_dictDict[str, List[np.ndarray]]

A dictionary where keys are patient names and values are lists of numpy arrays representing trials. Each numpy array should have shape (num_channels, num_samples). This is the main data structure containing all LFP recordings.

patient_namestr

Name of the patient to plot. Must be a valid key in subjects_lfp_data_dict. Typically follows naming convention like ‘PW_HK59’, ‘PW_SN61’, etc.

trial_idxint

Index of the trial to plot (0-based). Must be within the range of available trials for the specified patient. Each trial represents a separate recording session.

sfreqfloat

Sampling frequency of the data in Hz. Used for PSD computation and frequency axis scaling. Typically 250 Hz for LFP recordings.

fix_chs_namesList[str], optional

List of channel names corresponding to the channels in the LFP data.

save_pathstr, optional

Path to save the figure. If None, the figure will not be saved. Directory will be created if it doesn’t exist. Default is None.

fig_namestr, optional

Name of the figure file to be saved. If None and save_path is provided, an auto-generated filename will be used following the pattern: ‘{patient_name}_trial{trial_idx}_psd.png’. Default is None.

subjects_session_trial_mappingDict[str, Dict[str, List[int]]], optional

Dictionary mapping patients to sessions and their corresponding trial indices. Format: {patient_id: {session_name: [trial_indices]}}. Used to display session information in the plot title. Default is None.

fminfloat, optional

Minimum frequency for PSD computation in Hz. Lower bound of the frequency range to analyze. Default is 0.1 Hz.

fmaxfloat, optional

Maximum frequency for PSD computation in Hz. Upper bound of the frequency range to analyze. Default is 100 Hz.

figsizetuple, optional

Figure size (width, height) in inches. Controls the overall plot dimensions. Default is (18, 8).

Returns:

: None

The function displays the plot and optionally saves it to disk. No return value.

Raises:

ValueError

If patient_name is not found in subjects_lfp_data_dict, or if trial_idx exceeds the number of available trials for the specified patient.

Technical Details:

  • Uses multitaper method with adaptive weighting for robust spectral estimation

  • Full normalization ensures proper power spectral density units

  • Frequency resolution depends on signal length and windowing parameters

  • dB conversion: PSD_dB = 10 * log10(PSD_linear)

static plot_single_patient_trial_spectrogram(subjects_lfp_data_dict, patient_name, trial_idx, sfreq, fix_chs_names=None, save_path=None, fig_name=None, subjects_session_trial_mapping=None, fmin=0.1, fmax=100, channel_idx=0, figsize=(16, 10), show_fig=True)[source]

Plots spectrogram for a single patient’s specific trial continuous LFP signal. This function creates a time-frequency representation showing how power changes over time.

Return type:

None

Parameters:

subjects_lfp_data_dictDict[str, List[np.ndarray]]

A dictionary where keys are patient names and values are lists of numpy arrays representing trials. Each numpy array should have shape (num_channels, num_samples).

patient_namestr

Name of the patient to plot.

trial_idxint

Index of the trial to plot (0-based).

sfreqfloat

Sampling frequency of the data.

fix_chs_namesList[str], optional

List of channel names corresponding to the channels in the LFP data.

save_pathstr, optional

Path to save the figure. If None, the figure will not be saved.

fig_namestr, optional

Name of the figure file to be saved. If None, auto-generated.

subjects_session_trial_mappingDict[str, Dict[str, List[int]]], optional

Dictionary mapping patients to sessions and their corresponding trial indices. Format: {patient_id: {session_name: [trial_indices]}}

fminfloat, optional

Minimum frequency for spectrogram computation. Default is 0.1 Hz.

fmaxfloat, optional

Maximum frequency for spectrogram computation. Default is 100 Hz.

channel_idxint, optional

Index of the channel to plot spectrogram for. Default is 0 (first channel).

figsizetuple, optional

Figure size (width, height). Default is (16, 10).

Returns:

: None

static plot_total_label_counts(patients_epochs, save_path, fig_name)[source]

Plots the total label counts across all trials for all patients.

Parameters:
  • patients_epochs (dict) – Dictionary containing MNE Epochs objects for each patient.

  • save_path (str) – Directory path where the plot image will be saved.

  • fig_name (str) – Name of the figure file to be saved (without extension).

static plot_trial_counts(df, save_path, fig_name)[source]

Plot histogram for the number of trials per subject.

Parameters: df (pd.DataFrame): DataFrame containing subject IDs and number of trials. save_path (str): Directory path to save the plot. fig_name (str): Name of the file to save the plot.

Returns: None

Return type:

None

static plot_trial_lengths_per_subject_boxplot(subjects_lfp_data_dict, lfp_sfreq, save_path=None, fig_name=None)[source]

Plots a boxplot of trial lengths for each subject with filled colors and a secondary axis for time in seconds.

Parameters: - subjects_lfp_data_dict (dict): Dictionary with subject IDs as keys and lists of 2D NumPy arrays as values. - lfp_sfreq (float): Sampling frequency of the LFP data. Used to convert trial lengths to seconds. - save_path (str, optional): If specified, saves the figure to the given path.

Returns: - None

Return type:

None

static plot_trial_lengths_per_subject_distr(subjects_lfp_data_dict, lfp_sfreq, save_path, fig_name)[source]

Plots the distribution and boxplot of trial lengths from LFP data for each subject.

Return type:

None

Parameters: - subjects_lfp_data_dict (dict): Dictionary containing LFP data for multiple subjects.

Each subject’s data is a list of trials, where each trial is a 2D array.

  • lfp_sfreq (float): Sampling frequency of the LFP data.

  • save_path (str): Directory path where the plot images will be saved.

Returns: - None