import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import mne
import os
from typing import Dict, List, Optional
from mne.time_frequency import psd_array_multitaper
[docs]
class Visualise:
[docs]
@staticmethod
def plot_session_counts(df_session_counts: pd.DataFrame, save_path: str, fig_name: str) -> None:
"""
Plots a histogram of the number of recording sessions per patient.
Parameters:
df_session_counts (pd.DataFrame): DataFrame containing patient IDs and the number of recording sessions.
save_path (str): Directory where the plot will be saved.
fig_name (str): Name of the saved plot file.
Returns:
None
"""
# Plot histogram of recording sessions
plt.figure(figsize=(12, 8))
plt.bar(
df_session_counts["patient_id"],
df_session_counts["n_essions"],
color="skyblue",
edgecolor="black",
linewidth=1.2
)
# Add labels, title, and grid
plt.xlabel("Patient ID", fontsize=14)
plt.ylabel("Number of Recording Sessions", fontsize=14)
plt.title("Number of Recording Sessions per Patient", fontsize=16)
plt.xticks(rotation=0, fontsize=12)
plt.yticks(fontsize=12)
plt.ylim(bottom=0, top=df_session_counts["n_essions"].max() + 1)
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
plt.show()
plt.close()
[docs]
@staticmethod
def plot_trial_counts(df: pd.DataFrame, save_path: str, fig_name: str) -> None:
"""
Plot histogram for the number of trials per subject.
Parameters:
df (pd.DataFrame): DataFrame containing subject IDs and number of trials.
save_path (str): Directory path to save the plot.
fig_name (str): Name of the file to save the plot.
Returns:
None
"""
plt.figure(figsize=(15, 9))
plt.bar(df['subject_id'], df['n_trials'], color='skyblue', edgecolor='black')
plt.xlabel("Subject ID", fontsize=12)
plt.ylabel("Number of Trials", fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.yticks(range(0, df['n_trials'].max() + 1), fontsize=8, rotation=0)
plt.xticks(fontsize=11, rotation=0)
plt.title("Number of Trials per Subject", fontsize=14)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
plt.show()
plt.close()
[docs]
@staticmethod
def plot_trial_lengths_per_subject_distr(subjects_lfp_data_dict: Dict[str, List[np.ndarray]],
lfp_sfreq: float,
save_path: str,
fig_name: str) -> None:
"""
Plots the distribution and boxplot of trial lengths from LFP data for each subject.
Parameters:
- subjects_lfp_data_dict (dict): Dictionary containing LFP data for multiple subjects.
Each subject's data is a list of trials, where each trial is a 2D array.
- lfp_sfreq (float): Sampling frequency of the LFP data.
- save_path (str): Directory path where the plot images will be saved.
Returns:
- None
"""
num_subjects = len(subjects_lfp_data_dict)
fig, axes = plt.subplots(2, num_subjects, figsize=(5 * num_subjects, 10),
sharex=True, sharey='row', constrained_layout=True)
axes = np.atleast_2d(axes)
for idx, (subject, trials) in enumerate(subjects_lfp_data_dict.items()):
trial_lengths = [trial.shape[1] for trial in trials]
trial_lengths_sec = [length / lfp_sfreq for length in trial_lengths]
# --- Histogram (Top Row) ---
ax_hist = axes[0, idx] if num_subjects > 1 else axes[0]
ax_hist.hist(trial_lengths_sec, bins=30, color='skyblue', edgecolor='black', alpha=0.75)
ax_hist.set_ylabel('Frequency', fontsize=16)
ax_hist.set_title(f'{subject}', fontsize=18, fontweight='bold')
ax_hist.grid(True, which='major', linestyle='--', linewidth=0.7, alpha=0.6)
ax_hist.grid(True, which='minor', linestyle=':', linewidth=0.5, alpha=0.4)
ax_hist.minorticks_on()
# --- Boxplot (Bottom Row) ---
ax_box = axes[1, idx] if num_subjects > 1 else axes[1]
ax_box.boxplot(trial_lengths_sec, vert=False, patch_artist=True,
boxprops=dict(facecolor='skyblue', color='black', linewidth=1.5),
medianprops=dict(color='red', linewidth=2),
whiskerprops=dict(color='black', linewidth=1.5, linestyle='--'),
capprops=dict(color='black', linewidth=1.5))
ax_box.set_xlabel('Trial Length (seconds)', fontsize=16)
ax_box.grid(True, which='major', linestyle='--', linewidth=0.7, alpha=0.6)
ax_box.grid(True, which='minor', linestyle=':', linewidth=0.5, alpha=0.4)
ax_box.minorticks_on()
fig.suptitle('Distribution and Boxplot of Trial Lengths per Subject', fontsize=20, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.96])
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
plt.show()
plt.close()
[docs]
@staticmethod
def plot_trial_lengths_per_subject_boxplot(subjects_lfp_data_dict: Dict[str, List[np.ndarray]],
lfp_sfreq: float,
save_path: Optional[str] = None,
fig_name: Optional[str] = None) -> None:
"""
Plots a boxplot of trial lengths for each subject with filled colors and a secondary axis for time in seconds.
Parameters:
- subjects_lfp_data_dict (dict): Dictionary with subject IDs as keys and lists of 2D NumPy arrays as values.
- lfp_sfreq (float): Sampling frequency of the LFP data. Used to convert trial lengths to seconds.
- save_path (str, optional): If specified, saves the figure to the given path.
Returns:
- None
"""
trial_lengths_dict = {
subject: [trial.shape[1] for trial in trials] for subject, trials in subjects_lfp_data_dict.items()
}
trial_lengths_seconds_dict = {
subject: [length / lfp_sfreq for length in lengths] for subject, lengths in trial_lengths_dict.items()
}
df_trial_lengths = pd.DataFrame.from_dict(trial_lengths_dict, orient="index").T
df_trial_lengths_seconds = pd.DataFrame.from_dict(trial_lengths_seconds_dict, orient="index").T
fig, ax1 = plt.subplots(figsize=(12, 8))
boxplot = df_trial_lengths.boxplot(
patch_artist=True, # Allows filling boxes with color
medianprops=dict(color='red', linewidth=2), # Median line color
whiskerprops=dict(color='black', linewidth=1.5), # Whisker color
capprops=dict(color='black', linewidth=1.5), # Cap line color
flierprops=dict(marker='o', color='red', alpha=0.6, markersize=6), # Outliers
ax=ax1 # Attach to primary axis
)
colors = plt.cm.Paired.colors # Get colors from colormap
for box, color in zip(ax1.artists, colors):
box.set_facecolor(color) # Set box color
box.set_edgecolor("black") # Keep black edges for contrast
ax1.set_xlabel("Patient ID", fontsize=14, fontweight="bold")
ax1.set_ylabel("Trial Length (samples)", fontsize=14, fontweight="bold")
ax1.set_title("Boxplot of Trial Lengths for Each Patient", fontsize=16, fontweight="bold")
ax2 = ax1.twinx()
ax2.set_ylabel("Trial Length (seconds)", fontsize=14, fontweight="bold")
max_samples = df_trial_lengths.max().max()
ax2.set_ylim(ax1.get_ylim()[0] / lfp_sfreq, ax1.get_ylim()[1] / lfp_sfreq)
ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, fontsize=12)
ax1.tick_params(axis='y', labelsize=12)
ax2.tick_params(axis='y', labelsize=12)
ax1.grid(axis='y', linestyle="--", alpha=0.7)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
plt.show()
plt.close()
[docs]
@staticmethod
def plot_all_trial_lengths(subjects_lfp_data_dict: Dict[str, List[np.ndarray]], lfp_sfreq: float, save_path: str, fig_name: str) -> None:
"""
Plots the distribution and boxplot of trial lengths from LFP data.
Parameters:
subjects_lfp_data_dict (dict): Dictionary containing LFP data for multiple subjects. Each subject's data is a list of trials, where each trial is a 2D array.
lfp_sfreq (float): Sampling frequency of the LFP data.
save_path (str): Directory path where the plot image will be saved.
fig_name (str): Name of the figure file to be saved (without extension).
Returns:
None
"""
trial_lengths = [
trial.shape[1]
for subjects in subjects_lfp_data_dict.values()
for trial in subjects]
trial_lengths_ms = [length / lfp_sfreq for length in trial_lengths]
# Create a figure with subplots
fig, axes = plt.subplots(1, 2, figsize=(20, 6))
# Plot the histogram of trial lengths
axes[0].hist(trial_lengths_ms, bins=30, color='skyblue', edgecolor='black')
axes[0].set_xlabel('Trial Length (seconds)', fontsize=12)
axes[0].set_ylabel('Frequency (Number of Trials)', fontsize=12)
axes[0].set_title('Distribution of Trial Lengths Across Entire Dataset', fontsize=14)
axes[0].grid(axis='y', linestyle='--', alpha=0.7)
# Plot the boxplot of trial lengths
axes[1].boxplot(trial_lengths_ms, vert=False, patch_artist=True, boxprops=dict(facecolor='skyblue', color='black'), medianprops=dict(color='red'))
axes[1].set_xlabel('Trial Length (seconds)', fontsize=12)
axes[1].set_title('Boxplot of Trial Lengths Across Entire Dataset', fontsize=14)
axes[1].grid(axis='x', linestyle='--', alpha=0.7)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
plt.show()
plt.close()
[docs]
@staticmethod
def plot_individual_trial_counts(patients_epochs: Dict, save_path: str, fig_name: str):
"""
Plots the label counts per trial for all patients.
Args:
patients_epochs (dict): Dictionary containing patient names as keys and MNE Epochs objects as values.
save_path (str): Directory path where the plot image will be saved.
fig_name (str): Name of the figure file to be saved (without extension).
"""
num_patients = len(patients_epochs)
fig, axes = plt.subplots(num_patients, 1, figsize=(20, 5 * num_patients), sharex=True, sharey=True)
for i, (patient, epochs) in enumerate(patients_epochs.items()):
# Extract trial indices and labels from events_array
trial_indices = epochs.events[:, 1]
labels = epochs.events[:, 2]
# Create a dictionary to store label counts for each trial
trial_label_counts = {}
for trial_idx, label in zip(trial_indices, labels):
if trial_idx not in trial_label_counts:
trial_label_counts[trial_idx] = [0, 0] # Initialize counts for both labels
trial_label_counts[trial_idx][label] += 1
# Prepare data for plotting
trial_indices = sorted(trial_label_counts.keys())
normal_counts = [trial_label_counts[idx][0] for idx in trial_indices]
modulation_counts = [trial_label_counts[idx][1] for idx in trial_indices]
# Plot histogram of labels for each trial
bar_width = 0.35
index = np.arange(len(trial_indices))
ax = axes[i] if num_patients > 1 else axes
bar1 = ax.bar(index, normal_counts, bar_width, label='Normal walking', color='#1f77b4')
bar2 = ax.bar(index + bar_width, modulation_counts, bar_width, label='Modulation', color='#ff7f0e')
ax.set_xlabel('Trial Index', fontsize=16)
ax.set_ylabel('Count', fontsize=16)
ax.set_title(f'Label Counts per Trial ({patient})', fontsize=17)
ax.legend(fontsize=12)
ax.grid(axis='y', linestyle='--')
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
ax.xaxis.set_major_locator(plt.FixedLocator(index + bar_width / 2))
ax.set_xticks(index + bar_width / 2) # Set all ticks
ax.set_xticklabels(trial_indices, rotation=45, ha="right", fontsize=12) # Label all ticks
ax.xaxis.set_major_locator(plt.FixedLocator(index + bar_width / 2)) # Ensure all ticks are shown
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
plt.show()
plt.close()
[docs]
@staticmethod
def plot_total_label_counts(patients_epochs: Dict, save_path: str, fig_name: str):
"""
Plots the total label counts across all trials for all patients.
Args:
patients_epochs (dict): Dictionary containing MNE Epochs objects for each patient.
save_path (str): Directory path where the plot image will be saved.
fig_name (str): Name of the figure file to be saved (without extension).
"""
num_patients = len(patients_epochs)
fig, axes = plt.subplots(1, num_patients, figsize=(6 * num_patients, 10), sharey=True)
for i, (patient, epochs) in enumerate(patients_epochs.items()):
# Extract trial indices and labels from events_array
trial_indices = epochs.events[:, 1]
labels = epochs.events[:, 2]
# Create a dictionary to store label counts for each trial
trial_label_counts = {}
for trial_idx, label in zip(trial_indices, labels):
if trial_idx not in trial_label_counts:
trial_label_counts[trial_idx] = [0, 0] # Initialize counts for both labels
trial_label_counts[trial_idx][label] += 1
# Prepare data for plotting
normal_counts = sum([counts[0] for counts in trial_label_counts.values()])
modulation_counts = sum([counts[1] for counts in trial_label_counts.values()])
ax = axes[i] if num_patients > 1 else axes
# Plot total counts
bars = ax.bar(['Normal walking', 'Modulation'], [normal_counts, modulation_counts], color=['#1f77b4', '#ff7f0e'])
ax.set_xlabel('Label', fontsize=18)
ax.set_ylabel('Total Count', fontsize=18)
ax.set_title(f'{patient}', fontsize=20)
ax.grid(axis='y', linestyle='--')
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
ax.set_ylim(0, max(normal_counts, modulation_counts) * 1.1)
# Add legend
ax.legend(bars, ['Normal walking', 'Modulation'], fontsize=14)
fig.suptitle('Total Label Counts Across All Trials', fontsize=22)
plt.tight_layout(rect=[0, 0, 1, 0.95])
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
plt.show()
plt.close()
[docs]
@staticmethod
def plot_label_distribution_boxplot_all_patients(patients_epochs: Dict, save_path: str, fig_name:str):
"""
Creates a boxplot showing the distribution of the number of labels of each class per trial for all patients.
Args:
patients_epochs (dict): Dictionary containing MNE Epochs objects for each patient.
fig_save_path (str): Path to save the figure.
"""
fig, ax = plt.subplots(figsize=(15, 10))
all_normal_counts = []
all_modulation_counts = []
patient_labels = []
for patient, epochs in patients_epochs.items():
# Extract trial indices and labels from events_array
trial_indices = epochs.events[:, 1]
labels = epochs.events[:, 2]
# Create a dictionary to store label counts for each trial
trial_label_counts = {}
for trial_idx, label in zip(trial_indices, labels):
if trial_idx not in trial_label_counts:
trial_label_counts[trial_idx] = [0, 0] # Initialize counts for both labels
trial_label_counts[trial_idx][label] += 1 # Safely update count
# Prepare data for plotting
sorted_trials = sorted(trial_label_counts.keys())
normal_counts = [trial_label_counts[idx][0] for idx in sorted_trials]
modulation_counts = [trial_label_counts[idx][1] for idx in sorted_trials]
all_normal_counts.append(normal_counts)
all_modulation_counts.append(modulation_counts)
patient_labels.append(patient)
# Plot boxplot
box_data = [counts for pair in zip(all_normal_counts, all_modulation_counts) for counts in pair]
box_labels = [f"{patient} Normal" for patient in patient_labels] + [f"{patient} Modulation" for patient in patient_labels]
# Adjust positions to decrease the distance between boxplots of the same trial
positions = []
for i in range(len(patient_labels)):
positions.extend([i * 2 + 1, i * 2 + 1.5])
box = ax.boxplot(box_data, patch_artist=True, positions=positions)
# Customize boxplot colors
colors = ['lightblue', 'salmon']
for patch, color in zip(box['boxes'], colors * len(patient_labels)):
patch.set_facecolor(color)
# Customize the plot
ax.set_ylabel("Label Count", fontsize=14)
ax.set_title("Label Distribution across Trials for All Patients", fontsize=16)
ax.grid(axis='y', linestyle='--')
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
# Set x-ticks to be centered between the pairs of boxplots
xticks = np.arange(1.25, 2 * len(patient_labels), 2)
ax.set_xticks(xticks)
ax.set_xticklabels(patient_labels, rotation=45)
ax.set_xlabel("Patient", fontsize=14)
ax.legend([box["boxes"][0], box["boxes"][1]], ["Normal", "Modulation"], loc="upper right", fontsize=12)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
plt.show()
plt.close()
[docs]
@staticmethod
def plot_all_patients_trials(subjects_lfp_data_dict: Dict[str, List[np.ndarray]],
patients_epochs: Dict[str, mne.Epochs],
sfreq: float,
save_path: str,
fig_name: str,
subjects_session_trial_mapping: Dict[str, Dict[str, List[int]]] = None,
max_display_trials: int = None,
window_size: float = 0.5,
overlap: float = 0.5,
expand_transition: float = 0.0,
sharex: bool=True,
sharey: bool=True):
"""
Plots LFP data for all patients and trials with epoch-level temporal coloring.
This function creates a grid of subplots where each column represents a patient and each row represents a trial.
Each subplot displays the LFP data for a specific trial of a specific patient, with channels offset for clarity.
Different time segments (epochs) within each continuous signal are colored based on their individual class labels.
Trials belonging to the same session are visually grouped with colored boxes.
Parameters:
-----------
subjects_lfp_data_dict : Dict[str, List[np.ndarray]]
A dictionary where keys are patient names and values are lists of numpy arrays representing trials.
Each numpy array should have shape (num_channels, num_samples).
patients_epochs : Dict[str, mne.Epochs]
A dictionary where keys are patient names and values are MNE Epochs objects containing epoch-level data.
The epochs.events array contains [start_time_samples, trial_id, class_label] for each epoch.
sfreq : float
Sampling frequency of the data.
save_path : str
Path to save the figure. If empty, the figure will not be saved.
fig_name : str
Name of the figure file to be saved.
subjects_session_trial_mapping : Dict[str, Dict[str, List[int]]], optional
Dictionary mapping patients to sessions and their corresponding trial indices.
Format: {patient_id: {session_name: [trial_indices]}}
max_display_trials : int, optional
Maximum number of trials to display per patient. If None (default), all trials will be displayed.
sharex : bool, optional
Whether to share the x-axis among subplots. Default is True.
sharey : bool, optional
Whether to share the y-axis among subplots. Default is True.
Returns:
--------
None
"""
# Extract epoch timing information for temporal coloring
epoch_info = {}
for patient_name, epochs in patients_epochs.items():
# Get epochs information: [start_time_samples, trial_id, class_label]
events = epochs.events
trial_indices = events[:, 1] # Column 1 contains trial indices
epoch_times = events[:, 0] # Column 0 contains epoch start times (in samples)
epoch_labels = events[:, -1] # Last column contains class labels (0=normal, 1=modulation)
# Group epochs by trial
unique_trials = np.unique(trial_indices)
trial_epochs = {}
for trial_idx in unique_trials:
# Get all epochs for this trial
mask = trial_indices == trial_idx
trial_epoch_times = epoch_times[mask]
trial_epoch_labels = epoch_labels[mask]
# Store epoch information for this trial
trial_epochs[int(trial_idx)] = {
'times': trial_epoch_times,
'labels': trial_epoch_labels
}
epoch_info[patient_name] = trial_epochs
print(f"Patient {patient_name}: {len(unique_trials)} trials with epoch-level timing extracted")
num_patients = len(subjects_lfp_data_dict)
max_trials = max(len(trials) for trials in subjects_lfp_data_dict.values())
# Determine number of trials to display
if max_display_trials is None:
actual_display_trials = max_trials # Display all trials
print(f"Displaying all {max_trials} trials per patient")
else:
actual_display_trials = min(max_trials, max_display_trials) # Show specified max trials
print(f"Displaying first {actual_display_trials} trials per patient (out of {max_trials} total)")
# Define colors for each class
class_colors = {0: 'blue', 1: 'red'} # 0: normal walking (blue), 1: gait modulation (red)
class_names = {0: 'Normal Walking', 1: 'Gait Modulation'}
# Get window size from epochs (assuming all patients have same parameters)
window_size_samples = int(0.5 * sfreq) # Default to 0.5 seconds
# Extract channel names once (major performance optimization)
first_patient = list(patients_epochs.keys())[0]
channel_names = patients_epochs[first_patient].ch_names
num_channels = len(channel_names)
# Pre-calculate y-positions for channels (performance optimization)
y_positions = [ch_idx * 40 for ch_idx in range(num_channels)]
# Adaptive figure sizing based on number of trials to display
if max_display_trials is None:
# For all trials, use more conservative sizing to prevent extremely large figures
fig_width = min(num_patients * 3, 20) # Smaller width when showing all trials
fig_height = min(actual_display_trials * 2, 100) # Smaller height per trial, but allow more total height
else:
# For limited trials, use the original sizing
fig_width = min(num_patients * 4, 16) # Larger width for limited trials
fig_height = min(actual_display_trials * 3, 15) # Larger height per trial
fig, axes = plt.subplots(actual_display_trials, num_patients, figsize=(fig_width, fig_height),
sharex=sharex, sharey=sharey)
# Ensure axes is always a 2D array
if actual_display_trials == 1:
axes = np.expand_dims(axes, axis=0)
if num_patients == 1:
axes = np.expand_dims(axes, axis=1)
# Define colors for session grouping
session_colors = ['steelblue', 'forestgreen', 'crimson', 'goldenrod', 'mediumvioletred',
'darkslategray', 'darkorange', 'darkseagreen', 'midnightblue', 'darkviolet']
for col_idx, (patient_name, trials) in enumerate(subjects_lfp_data_dict.items()):
patient_epoch_info = epoch_info.get(patient_name, {})
# Get session mapping for this patient if available
patient_session_mapping = subjects_session_trial_mapping.get(patient_name, {}) if subjects_session_trial_mapping else {}
# Create trial-to-session mapping for quick lookup
trial_to_session = {}
session_to_color = {}
for session_idx, (session_name, trial_indices) in enumerate(patient_session_mapping.items()):
color = session_colors[session_idx % len(session_colors)]
session_to_color[session_name] = color
for trial_idx in trial_indices:
trial_to_session[trial_idx] = (session_name, color)
# Plot the specified number of trials (all trials if max_display_trials is None)
for row_idx in range(min(len(trials), actual_display_trials)):
trial_data = trials[row_idx]
ax = axes[row_idx, col_idx] # Select the correct subplot
num_samples = trial_data.shape[1]
# Get epoch information for this trial
trial_epochs = patient_epoch_info.get(row_idx, {'times': [], 'labels': []})
epoch_times = trial_epochs['times']
epoch_labels = trial_epochs['labels']
# Pre-calculate epoch counts for legend (performance optimization)
if len(epoch_times) > 0:
normal_count = np.sum(epoch_labels == 0)
modulation_count = np.sum(epoch_labels == 1)
else:
normal_count = modulation_count = 0
# Plot each channel with optimized epoch-level coloring
for channel_idx in range(num_channels):
channel_signal = trial_data[channel_idx, :] + channel_idx * 40
if len(epoch_times) > 0:
# Group consecutive epochs with same label to reduce plot calls
current_label = None
segment_start = None
segment_indices = []
for epoch_start, epoch_label in zip(epoch_times, epoch_labels):
epoch_end = min(epoch_start + window_size_samples, num_samples)
time_indices = np.arange(max(0, epoch_start), epoch_end)
if len(time_indices) > 0:
if epoch_label == current_label:
# Extend current segment
segment_indices.extend(time_indices)
else:
# Plot previous segment if exists
if segment_indices:
ax.plot(segment_indices, channel_signal[segment_indices],
color=class_colors[current_label], alpha=0.7, linewidth=0.5)
# Start new segment
current_label = epoch_label
segment_indices = list(time_indices)
# Plot final segment
if segment_indices:
ax.plot(segment_indices, channel_signal[segment_indices],
color=class_colors[current_label], alpha=0.7, linewidth=0.5)
else:
# Fallback: plot entire signal in default color if no epoch info
ax.plot(channel_signal, color='gray', alpha=0.7, linewidth=0.5)
# Add session grouping box around the subplot
if row_idx in trial_to_session:
session_name, session_color = trial_to_session[row_idx]
# Add a colored background to indicate session grouping
ax.patch.set_facecolor(session_color)
ax.patch.set_alpha(0.1) # Make it subtle
# Add a border around the subplot
for spine in ax.spines.values():
spine.set_edgecolor(session_color)
spine.set_linewidth(2)
if row_idx == 0:
ax.set_title(f"Patient {patient_name}", fontsize=10, pad=8)
ax.set_xlabel('Samples', fontsize=7)
ax.tick_params(axis='x', labelbottom=True, labelsize=7)
# Set channel names on y-axis (using pre-calculated values)
ax.set_yticks(y_positions)
ax.set_yticklabels(channel_names, fontsize=7)
ax.tick_params(axis='y', labelsize=7)
# Add optimized legend for each subplot
if normal_count > 0 or modulation_count > 0:
legend_handles = []
legend_labels = []
if normal_count > 0:
legend_handles.append(plt.Line2D([0], [0], color=class_colors[0], alpha=0.7))
legend_labels.append(f'{class_names[0]} ({normal_count})')
if modulation_count > 0:
legend_handles.append(plt.Line2D([0], [0], color=class_colors[1], alpha=0.7))
legend_labels.append(f'{class_names[1]} ({modulation_count})')
ax.legend(legend_handles, legend_labels, loc='upper right', fontsize=6,
framealpha=0.8, borderpad=0.2, columnspacing=0.3)
else:
# Show "No epochs" if no epoch information available
ax.text(0.95, 0.95, 'No epochs', transform=ax.transAxes,
verticalalignment='top', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8),
fontsize=6)
# Set trial labels with appropriate font size
for i in range(actual_display_trials):
axes[i, 0].set_ylabel(f"Trial {i + 1}", fontsize=7)
# Add secondary x-axis for time
for ax in axes.flat:
secax = ax.secondary_xaxis('top', functions=(lambda x: x /sfreq, lambda x: x * sfreq))
secax.set_xlabel('Time (s)', fontsize=7)
secax.tick_params(labelsize=6)
# Add session legend if session mapping is provided
if subjects_session_trial_mapping:
legend_elements = []
all_sessions = set()
for patient_sessions in subjects_session_trial_mapping.values():
all_sessions.update(patient_sessions.keys())
for session_idx, session_name in enumerate(sorted(all_sessions)):
color = session_colors[session_idx % len(session_colors)]
legend_elements.append(plt.Rectangle((0, 0), 1, 1, facecolor=color, alpha=0.3,
edgecolor=color, linewidth=2, label=session_name))
# Add legend outside the plot area
fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.99, 0.92),
fontsize=8, title="Sessions", title_fontsize=9)
# Improve layout with better spacing
fig.suptitle(f"LFP Data for All Patients and Trials (window_size={window_size}, overlap={overlap}, expand_transition={expand_transition})", fontsize=12, y=0.93)
# Adjusted subplot spacing for better readability
plt.subplots_adjust(
left=0.08, # Left margin
right=0.90, # Right margin (leave space for legend)
bottom=0.08, # Bottom margin
top=0.92, # Top margin (space for title)
hspace=0.9, # Vertical space between subplots
wspace=0.2 # Horizontal space between subplots
)
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.pdf'), dpi=300, bbox_inches='tight')
plt.show()
# Clear the current figure
plt.clf()
[docs]
@staticmethod
def plot_event_class_histogram(events: np.ndarray,
event_dict: Dict[int, str],
n_sessions: int,
show_fig: bool = True,
save_fig: bool = True,
file_name: str = 'event_class_histogram.png') -> None:
"""
Creates a histogram to plot the number of onsets for each class of the event array,
with event IDs mapped to descriptive labels.
Parameters:
events: np.ndarray - Array containing event data with at least three columns: [time, session_id, event_id].
event_dict: Dict[int, str] - A dictionary mapping event IDs to descriptive labels.
n_sessions: int - The number of sessions to plot.
show_fig: bool - Flag to show the figure or not.
save_fig: bool - Flag to save the figure or not.
file_name: str - The filename for saving the figure.
"""
n_cols = 4
n_rows = math.ceil(n_sessions / n_cols)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3 * n_rows))
# Ensure axes is always a 2D array (even if n_sessions < 4)
axes = np.atleast_2d(axes)
# Find the maximum value of occurrences across all sessions
max_count = 0
for s in range(n_sessions):
session_data = events[events[:, 1] == s]
if len(session_data) > 0:
_, counts = np.unique(session_data[:, 2], return_counts=True)
max_count = max(max_count, max(counts))
# Loop through each session and plot with consistent y-axis limits
for s, ax in zip(range(n_sessions), axes.ravel()):
# Filter events by session (assuming second column is session ID)
session_data = events[events[:, 1] == s]
if len(session_data) == 0:
ax.set_title(f'Session {s}')
ax.axis('off')
continue
# Count occurrences of each event class in this session
unique_classes, counts = np.unique(session_data[:, 2], return_counts=True)
# Map numeric classes to their descriptive labels using the event_dict
class_labels = [event_dict.get(cls, str(cls)) for cls in unique_classes]
# Plot the histogram
bars = ax.bar(class_labels, counts, color=['blue', 'orange'], edgecolor='black')
# Annotate bars with counts
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width() / 2.0, height, f'{int(height)}',
ha='center', va='bottom', fontsize=10, color='black')
ax.set_title(f'Session {s}', fontsize=12)
ax.grid(axis='y', linestyle='--', alpha=0.7)
# Set consistent y-axis limit
ax.set_ylim(0, max_count)
# Add common labels for the entire figure
fig.supxlabel('Event Class', fontsize=12)
fig.supylabel('Occurrences', fontsize=12)
# Turn off axes for unused subplots
for ax in axes.ravel()[n_sessions:]:
ax.axis('off')
plt.tight_layout()
if save_fig:
plt.savefig(file_name)
print(f"Plot saved as {file_name}")
if show_fig:
plt.show()
plt.close(fig)
[docs]
@staticmethod
def plot_epochs_with_events(patients_epochs: Dict[str, mne.EpochsArray],
subject_id: str,
window_size: float,
sfreq: int,
event_names: list[str],
subjects_event_idx_dict: Dict[str, Dict[int, List[int]]],
show_fig: bool = True,
save_path: Optional[str] = None,
fig_name: str = 'epochs_with_events{subject_id}.png') -> None:
"""
Plot the epochs for a given subject with event markers for mod_start and mod_end.
Parameters
----------
patients_epochs : Dict[str, mne.EpochsArray]
A dictionary of patient IDs mapped to MNE EpochsArray objects.
subject_id : str
The ID of the subject to plot.
window_size : float
The size of the window in seconds.
sfreq : int
The sampling frequency of the data.
event_names : list[str]
The list of event names in the correct order.
subjects_event_idx_dict : Dict[str, Dict[int, List[int]]]
A dictionary mapping subject IDs to dictionaries of trial IDs mapped to event indices.
show_fig : bool, optional
Flag to display the plot, by default True.
save_path : Optional[str], optional
Path to save the plot, by default None.
fig_name : str, optional
Name of the file to save the plot, by default 'epochs_with_events{subject_id}.png'.
Returns
-------
None
"""
mod_start_index = event_names.index('mod_start')
mod_end_index = event_names.index('mod_end')
# Convert window size from seconds to samples
window_size_time = int(window_size * sfreq)
# Get unique trial IDs for the subject
unique_trial_ids = np.unique(patients_epochs[subject_id].events[:, 1])
# Find the maximum end time across all trials
max_end_time = max(event[0] + window_size_time for event in patients_epochs[subject_id].events)
# Find the maximum number of epochs across all trials
max_num_epochs = max(len(patients_epochs[subject_id].events[patients_epochs[subject_id].events[:, 1] == trial_id]) for trial_id in unique_trial_ids)
# Calculate grid layout for subplots
n_trials = len(unique_trial_ids)
n_cols = min(3, n_trials) # Maximum 3 columns
n_rows = int(np.ceil(n_trials / n_cols))
# Calculate figure size based on grid layout
fig_width = n_cols * 6 # 6 inches per column
fig_height = n_rows * 4 # 4 inches per row
# Create subplots in grid layout
fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), sharex=False, sharey=True)
# Ensure axes is always 2D array for consistent indexing
if n_rows == 1 and n_cols == 1:
axes = np.array([[axes]])
elif n_rows == 1:
axes = axes.reshape(1, -1)
elif n_cols == 1:
axes = axes.reshape(-1, 1)
for idx, trial_id in enumerate(unique_trial_ids):
# Calculate row and column position
row = idx // n_cols
col = idx % n_cols
ax = axes[row, col]
# Get events for the current trial
event_subset = patients_epochs[subject_id].events[patients_epochs[subject_id].events[:, 1] == trial_id]
num_epochs = len(event_subset) # Total number of epochs in this trial
# Count normal walking and modulation epochs for this trial
normal_count = np.sum(event_subset[:, 2] == 0)
modulation_count = np.sum(event_subset[:, 2] == 1)
labels_added = set()
for j, event in enumerate(event_subset):
onset = event[0]
label = event[2]
# Adjust the start time based on the trial index
start = onset - idx
end = onset + window_size_time - idx
# Draw vertical lines for mod_start and mod_end
mod_start = subjects_event_idx_dict[subject_id][trial_id][mod_start_index]
mod_end = subjects_event_idx_dict[subject_id][trial_id][mod_end_index]
ax.axvline(x=mod_start, color='r', linestyle='--', label='Mod Start' if j == 0 else "")
ax.axvline(x=mod_end, color='b', linestyle='--', label='Mod End' if j == 0 else "")
# Fill the area between mod_start and mod_end with gray color and alpha value
ax.axvspan(mod_start, mod_end, color='gray', alpha=0.008, zorder=0)
# Calculate vertical position for each "box"
ymin = j / max_num_epochs
ymax = (j + 1) / max_num_epochs
if label not in labels_added:
# Plot the epoch span with label including counts
ax.axvspan(
start, end,
ymin=ymin, ymax=ymax,
color=f'C{label}', alpha=0.7,
label=f'Normal walking ({normal_count})' if label == 0 else f'Modulation ({modulation_count})'
)
labels_added.add(label)
else:
# Plot the epoch span without label
ax.axvspan(
start, end,
ymin=ymin, ymax=ymax,
color=f'C{label}', alpha=0.7
)
# Add window index text in the middle of each span
ax.text((start + end) / 2, (ymin + ymax) / 2, f'{j}', ha='center', va='center', fontsize=6, color='black')
# Add individual legend for each subplot in the top left corner
ax.legend(loc='upper left', fontsize=8, framealpha=0.9)
# Set title and labels for the subplot
ax.set_title(f'Trial {trial_id} ({num_epochs} epochs)')
ax.set_xlim(0, max_num_epochs)
ax.set_xticks(np.arange(0, max_end_time, step=sfreq))
ax.set_xticklabels(np.arange(0, max_end_time / sfreq, step=1))
ax.set_yticks(np.linspace(0.5 / max_num_epochs, 1 - 0.5 / max_num_epochs, max_num_epochs))
ax.set_yticklabels([f'{idx}' for idx in range(max_num_epochs)])
ax.set_xlabel('Time (s)')
ax.set_ylabel('Epochs')
ax.tick_params(axis='x', labelbottom=True) # Ensure x-axis labels are shown
ax.grid(which='both', linestyle='--', linewidth=0.5)
# Hide empty subplots if number of trials doesn't fill the grid
for idx in range(len(unique_trial_ids), n_rows * n_cols):
row = idx // n_cols
col = idx % n_cols
axes[row, col].axis('off')
# Set the overall plot labels and layout with improved spacing
fig.suptitle(f'Epochs with Events for Subject {subject_id} ({len(unique_trial_ids)} trials)', fontsize=16, y=0.97)
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjusted layout for individual legends
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
if show_fig:
plt.show()
plt.close(fig)
# TODO: enhance this function `plot_event_occurrence`
[docs]
@staticmethod
def plot_event_occurrence(events: np.ndarray,
epoch_sample_length: int,
lfp_sfreq: float,
event_dict: Dict[str, int],
n_sessions: int,
show_fig: bool = True,
save_fig: bool = True,
file_name: str = 'event_occurrence.png') -> None:
"""
Creates a horizontal bar plot of event occurrences for each session with different colors for each event type.
Maps event IDs to descriptive labels using the provided event_dict.
"""
n_cols = 4
n_rows = math.ceil(n_sessions / n_cols)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 4 * n_rows))
# Ensure axes is always a 2D array (even if n_sessions < 4)
axes = np.atleast_2d(axes)
# Extract event IDs from the dictionary for clarity
mod_start_event_id = event_dict.get('mod_start', 1)
normal_walking_event_id = event_dict.get('normal_walking', -1)
# Create an inverted dictionary for mapping IDs back to labels
inv_event_dict = {v: k for k, v in event_dict.items()}
for s, ax in zip(range(n_sessions), axes.ravel()):
session_data = events[events[:, 1] == s]
if len(session_data) == 0:
ax.set_title(f'Session {s}')
ax.axis('off')
continue
session_data = session_data[np.argsort(session_data[:, 2])]
events_time = session_data[:, 0] / lfp_sfreq
event_ids = session_data[:, 2]
# Count occurrences of each event type
unique_event_ids, counts = np.unique(event_ids, return_counts=True)
event_counts = dict(zip(unique_event_ids, counts))
# Plot events with colors based on type
for onset, event_id in zip(events_time, event_ids):
start = onset - epoch_sample_length / lfp_sfreq
end = onset
color = 'orange' if event_id == mod_start_event_id else 'blue' if event_id == normal_walking_event_id else 'gray'
bar = ax.barh(inv_event_dict.get(event_id, event_id), width=(end - start), left=start - 0.7, color=color, edgecolor='black')
for onset in events_time:
ax.axvline(x=onset, color='black', linestyle='--', linewidth=1, alpha=0.2)
ax.set_title(f'Session {s}', fontsize=13)
# Set y-ticks with counts in parentheses
y_labels = [f"{inv_event_dict.get(event_id, event_id)} ({event_counts.get(event_id, 0)})"
for event_id in inv_event_dict.keys()]
ax.set_yticks(list(inv_event_dict.values()))
ax.set_yticklabels(y_labels, va='center', rotation=90, fontsize=10)
fig.supxlabel('Time (s)', fontsize=15)
fig.supylabel('Event class', fontsize=15)
for ax in axes.ravel()[n_sessions:]:
ax.axis('off')
plt.subplots_adjust(left=0.05, bottom=0.05) # Adjust margins as needed
if save_fig:
plt.savefig(file_name)
print(f"Plot saved as {file_name}")
if show_fig:
plt.show()
plt.close(fig)
[docs]
@staticmethod
def plot_raw_data_with_annotations(lfp_raw_list, scaling=5e1, folder_path='images'):
"""
Plot the raw LFP data with annotations for each session.
Parameters:
lfp_raw_list : list of mne.io.Raw
List of raw LFP data for each session.
output_folder : str
Folder where the plots will be saved.
"""
for s, lfp_raw in enumerate(lfp_raw_list):
fig = lfp_raw.plot(start=0, duration=np.inf, scalings=dict(dbs=scaling) ,show=False) # lfp_duration
fig.suptitle(f'Session {s}', fontsize=16)
plt.tight_layout()
plt.savefig(f'{folder_path}/session{s}.png')
plt.close(fig)
[docs]
@staticmethod
def plot_single_patient_trial_psd(
subjects_lfp_data_dict: Dict[str, List[np.ndarray]],
patient_name: str,
trial_idx: int,
sfreq: float,
fix_chs_names: List[str] = None,
save_path: str = None,
fig_name: str = None,
subjects_session_trial_mapping: Dict[str, Dict[str, List[int]]] = None,
fmin: float = 0.1,
fmax: float = 100,
figsize: tuple = (18, 8),
show_fig: bool = True
) -> None:
"""
Plots Power Spectral Density (PSD) for a single patient's specific trial continuous LFP signal.
This function creates a comprehensive frequency domain visualization showing the power distribution
across different frequencies for all channels of LFP data from a single trial. The PSD is computed
using multitaper estimation method which provides robust spectral estimates with good statistical
properties and reduced variance compared to traditional periodogram methods.
The visualization includes all channels plotted on the same axes with different colors for easy
comparison, frequency band annotations (Delta, Theta, Alpha, Beta, Gamma) to aid in neurophysiological
interpretation, and session information when available. This is particularly useful for analyzing
brain oscillations relevant to gait and movement disorders in Parkinson's disease research.
Parameters:
-----------
subjects_lfp_data_dict : Dict[str, List[np.ndarray]]
A dictionary where keys are patient names and values are lists of numpy arrays
representing trials. Each numpy array should have shape (num_channels, num_samples).
This is the main data structure containing all LFP recordings.
patient_name : str
Name of the patient to plot. Must be a valid key in subjects_lfp_data_dict.
Typically follows naming convention like 'PW_HK59', 'PW_SN61', etc.
trial_idx : int
Index of the trial to plot (0-based). Must be within the range of available
trials for the specified patient. Each trial represents a separate recording session.
sfreq : float
Sampling frequency of the data in Hz. Used for PSD computation and frequency
axis scaling. Typically 250 Hz for LFP recordings.
fix_chs_names : List[str], optional
List of channel names corresponding to the channels in the LFP data.
save_path : str, optional
Path to save the figure. If None, the figure will not be saved.
Directory will be created if it doesn't exist. Default is None.
fig_name : str, optional
Name of the figure file to be saved. If None and save_path is provided,
an auto-generated filename will be used following the pattern:
'{patient_name}_trial{trial_idx}_psd.png'. Default is None.
subjects_session_trial_mapping : Dict[str, Dict[str, List[int]]], optional
Dictionary mapping patients to sessions and their corresponding trial indices.
Format: {patient_id: {session_name: [trial_indices]}}. Used to display
session information in the plot title. Default is None.
fmin : float, optional
Minimum frequency for PSD computation in Hz. Lower bound of the frequency
range to analyze. Default is 0.1 Hz.
fmax : float, optional
Maximum frequency for PSD computation in Hz. Upper bound of the frequency
range to analyze. Default is 100 Hz.
figsize : tuple, optional
Figure size (width, height) in inches. Controls the overall plot dimensions.
Default is (18, 8).
Returns:
--------
None
The function displays the plot and optionally saves it to disk. No return value.
Raises:
-------
ValueError
If patient_name is not found in subjects_lfp_data_dict, or if trial_idx
exceeds the number of available trials for the specified patient.
Technical Details:
------------------
- Uses multitaper method with adaptive weighting for robust spectral estimation
- Full normalization ensures proper power spectral density units
- Frequency resolution depends on signal length and windowing parameters
- dB conversion: PSD_dB = 10 * log10(PSD_linear)
"""
# Validate inputs
if patient_name not in subjects_lfp_data_dict:
raise ValueError(f"Patient '{patient_name}' not found in data dictionary")
patient_trials = subjects_lfp_data_dict[patient_name]
if trial_idx >= len(patient_trials):
raise ValueError(f"Trial index {trial_idx} exceeds available trials ({len(patient_trials)}) for patient {patient_name}")
# Get the specific trial data
trial_data = patient_trials[trial_idx] # Shape: (num_channels, num_samples)
n_channels, n_samples = trial_data.shape
# Channel colors for differentiation
channel_colors = plt.cm.tab10(np.linspace(0, 1, min(10, n_channels)))
# Determine session info if available
session_info = ""
session_name = "Unknown"
if subjects_session_trial_mapping and patient_name in subjects_session_trial_mapping:
for sess_name, trial_indices in subjects_session_trial_mapping[patient_name].items():
if trial_idx in trial_indices:
session_info = f" (Session: {sess_name})"
session_name = sess_name
break
# Create single figure for all channels
fig, ax = plt.subplots(1, 1, figsize=figsize)
# IMPROVED LAYOUT ADJUSTMENTS
plt.subplots_adjust(left=0.08, right=0.82, top=0.90, bottom=0.12)
# Compute PSD for all channels at once
print(f"Computing PSD for {patient_name}, Trial {trial_idx}...")
psd, freqs = psd_array_multitaper(trial_data, sfreq, fmin=fmin, fmax=fmax,
adaptive=True, normalization='full', verbose=False)
# Convert to dB
psd_db = 10 * np.log10(psd)
# Plot each channel on the same axes
for ch_idx in range(n_channels):
# Plot PSD for this channel
color = channel_colors[ch_idx % len(channel_colors)]
ch_name = fix_chs_names[ch_idx] if ch_idx < len(fix_chs_names) else f'Ch{ch_idx}'
ax.plot(freqs, psd_db[ch_idx], color=color, linewidth=1.5, alpha=0.8, label=ch_name)
# Formatting
ax.set_xlabel('Frequency (Hz)', fontsize=12)
ax.set_ylabel('PSD (dB)', fontsize=12)
ax.grid(True, alpha=0.3)
ax.set_xlim(fmin, fmax)
# Add frequency band markers
freq_bands = [(0.1, 3, 'Delta'), (4, 7, 'Theta'), (8, 12, 'Alpha'),
(12, 16, 'Low Beta'), (16, 30, 'High Beta'), (30, 100, 'Gamma')]
# Add subtle vertical lines for frequency bands
for fmin_band, fmax_band, band_name in freq_bands:
if fmin_band >= fmin and fmax_band <= fmax:
ax.axvspan(fmin_band, fmax_band, alpha=0.1, color='gray')
# Add band labels at the top
ax.text((fmin_band + fmax_band) / 2, ax.get_ylim()[1] * 0.95,
band_name, ha='center', va='top', fontsize=9,
alpha=0.7, rotation=90)
# Add legend for channels
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=10)
# Main title
duration = n_samples / sfreq
title = f'{patient_name} - Trial {trial_idx}{session_info}\n'
title += f'Continuous Signal PSD (Duration: {duration:.1f}s, Fs: {sfreq}Hz, Samples: {n_samples:,})'
fig.suptitle(title, fontsize=14, fontweight='bold')
# Save figure if requested
if save_path and fig_name:
full_path = os.path.join(save_path, fig_name)
plt.savefig(full_path, dpi=300, bbox_inches='tight')
print(f"Figure saved to: {full_path}")
elif save_path:
# Auto-generate filename
safe_patient = patient_name.replace('/', '_').replace(' ', '_')
filename = f'{safe_patient}_trial{trial_idx}_psd.png'
full_path = os.path.join(save_path, filename)
plt.savefig(full_path, dpi=300, bbox_inches='tight')
print(f"Figure saved to: {full_path}")
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
if show_fig:
plt.show()
[docs]
@staticmethod
def plot_single_patient_trial_spectrogram(
subjects_lfp_data_dict: Dict[str, List[np.ndarray]],
patient_name: str,
trial_idx: int,
sfreq: float,
fix_chs_names: List[str] = None,
save_path: str = None,
fig_name: str = None,
subjects_session_trial_mapping: Dict[str, Dict[str, List[int]]] = None,
fmin: float = 0.1,
fmax: float = 100,
channel_idx: int = 0,
figsize: tuple = (16, 10),
show_fig: bool = True
) -> None:
"""
Plots spectrogram for a single patient's specific trial continuous LFP signal.
This function creates a time-frequency representation showing how power changes over time.
Parameters:
-----------
subjects_lfp_data_dict : Dict[str, List[np.ndarray]]
A dictionary where keys are patient names and values are lists of numpy arrays representing trials.
Each numpy array should have shape (num_channels, num_samples).
patient_name : str
Name of the patient to plot.
trial_idx : int
Index of the trial to plot (0-based).
sfreq : float
Sampling frequency of the data.
fix_chs_names : List[str], optional
List of channel names corresponding to the channels in the LFP data.
save_path : str, optional
Path to save the figure. If None, the figure will not be saved.
fig_name : str, optional
Name of the figure file to be saved. If None, auto-generated.
subjects_session_trial_mapping : Dict[str, Dict[str, List[int]]], optional
Dictionary mapping patients to sessions and their corresponding trial indices.
Format: {patient_id: {session_name: [trial_indices]}}
fmin : float, optional
Minimum frequency for spectrogram computation. Default is 0.1 Hz.
fmax : float, optional
Maximum frequency for spectrogram computation. Default is 100 Hz.
channel_idx : int, optional
Index of the channel to plot spectrogram for. Default is 0 (first channel).
figsize : tuple, optional
Figure size (width, height). Default is (16, 10).
Returns:
--------
None
"""
# Validate inputs
if patient_name not in subjects_lfp_data_dict:
raise ValueError(f"Patient '{patient_name}' not found in data dictionary")
patient_trials = subjects_lfp_data_dict[patient_name]
if trial_idx >= len(patient_trials):
raise ValueError(f"Trial index {trial_idx} exceeds available trials ({len(patient_trials)}) for patient {patient_name}")
# Get the specific trial data
trial_data = patient_trials[trial_idx] # Shape: (num_channels, num_samples)
n_channels, n_samples = trial_data.shape
if channel_idx >= n_channels:
raise ValueError(f"Channel index {channel_idx} exceeds available channels ({n_channels})")
# Get channel data
channel_data = trial_data[channel_idx, :]
# Determine session info if available
session_info = ""
session_name = "Unknown"
if subjects_session_trial_mapping and patient_name in subjects_session_trial_mapping:
for sess_name, trial_indices in subjects_session_trial_mapping[patient_name].items():
if trial_idx in trial_indices:
session_info = f" (Session: {sess_name})"
session_name = sess_name
break
# Create figure with subplots: raw signal on top, spectrogram below
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, height_ratios=[1, 3])
plt.subplots_adjust(left=0.08, right=0.95, top=0.92, bottom=0.08, hspace=0.3)
print(f"Computing spectrogram for {patient_name}, Trial {trial_idx}, Channel {channel_idx}...")
# Plot raw signal on top subplot
time_vector = np.arange(n_samples) / sfreq
ch_name = fix_chs_names[channel_idx] if channel_idx < len(fix_chs_names) else f'Ch{channel_idx}'
ax1.plot(time_vector, channel_data, color='steelblue', linewidth=0.8, alpha=0.8)
ax1.set_ylabel(f'{ch_name}\nAmplitude (µV)', fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.set_xlim(0, time_vector[-1])
ax1.set_title(f'Raw Signal - {ch_name}', fontsize=12, pad=10)
# Compute spectrogram using scipy
from scipy import signal
# Calculate appropriate window size (e.g., 1 second window)
nperseg = int(min(sfreq, n_samples // 4)) # Window size
noverlap = nperseg // 2 # 50% overlap
frequencies, times, Sxx = signal.spectrogram(
channel_data,
fs=sfreq,
nperseg=nperseg,
noverlap=noverlap,
scaling='density'
)
# Filter frequencies to desired range
freq_mask = (frequencies >= fmin) & (frequencies <= fmax)
frequencies_filtered = frequencies[freq_mask]
Sxx_filtered = Sxx[freq_mask, :]
# Convert to dB
Sxx_db = 10 * np.log10(Sxx_filtered + 1e-10) # Add small value to avoid log(0)
# Plot spectrogram
im = ax2.pcolormesh(times, frequencies_filtered, Sxx_db,
shading='gouraud', cmap='viridis')
# Add colorbar
cbar = plt.colorbar(im, ax=ax2, shrink=0.8, aspect=30)
cbar.set_label('Power Spectral Density (dB/Hz)', fontsize=11)
# Format spectrogram subplot
ax2.set_ylabel('Frequency (Hz)', fontsize=12)
ax2.set_xlabel('Time (s)', fontsize=12)
ax2.set_ylim(fmin, fmax)
ax2.set_xlim(0, times[-1])
ax2.grid(True, alpha=0.3)
# Add frequency band markers as horizontal lines
freq_bands = [(0.1, 3, 'Delta'), (4, 7, 'Theta'), (8, 12, 'Alpha'),
(12, 16, 'Low Beta'), (16, 30, 'High Beta'), (30, 100, 'Gamma')]
for fmin_band, fmax_band, band_name in freq_bands:
if fmin_band >= fmin and fmax_band <= fmax:
# Add horizontal lines at band boundaries
ax2.axhline(y=fmin_band, color='white', alpha=0.6, linewidth=0.8, linestyle='--')
ax2.axhline(y=fmax_band, color='white', alpha=0.6, linewidth=0.8, linestyle='--')
# Add band labels on the right
y_center = (fmin_band + fmax_band) / 2
ax2.text(times[-1] * 1.02, y_center, band_name,
rotation=90, ha='left', va='center',
fontsize=9, alpha=0.8, color='white',
bbox=dict(boxstyle='round,pad=0.2', facecolor='black', alpha=0.7))
# Main title
duration = n_samples / sfreq
title = f'{patient_name} - Trial {trial_idx}{session_info}\n'
title += f'Spectrogram: {ch_name} (Duration: {duration:.1f}s, Fs: {sfreq}Hz)'
fig.suptitle(title, fontsize=14, fontweight='bold')
# Add info text box
info_text = f'Channel: {ch_name}\nWindow: {nperseg/sfreq:.2f}s\nOverlap: {noverlap/nperseg*100:.0f}%\nFreq Range: {fmin}-{fmax} Hz'
fig.text(0.02, 0.02, info_text, fontsize=9,
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8))
# Save figure if requested
if save_path and fig_name:
full_path = os.path.join(save_path, fig_name)
plt.savefig(full_path, dpi=300, bbox_inches='tight')
print(f"Figure saved to: {full_path}")
elif save_path:
# Auto-generate filename
safe_patient = patient_name.replace('/', '_').replace(' ', '_')
filename = f'{safe_patient}_trial{trial_idx}_ch{channel_idx}_spectrogram.png'
full_path = os.path.join(save_path, filename)
plt.savefig(full_path, dpi=300, bbox_inches='tight')
print(f"Figure saved to: {full_path}")
plt.tight_layout()
# Print summary
print(f"\nSpectrogram Summary for {patient_name}, Trial {trial_idx}, Channel {ch_name}:")
print(f"Session: {session_name}")
print(f"Duration: {duration:.1f} seconds")
print(f"Frequency range: {fmin:.1f} - {fmax:.1f} Hz")
print(f"Time resolution: {times[1] - times[0]:.3f} seconds")
print(f"Frequency resolution: {frequencies_filtered[1] - frequencies_filtered[0]:.3f} Hz")
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
if show_fig:
plt.show()
[docs]
@staticmethod
def plot_epoch_spectrogram_and_psd(
signal_orig,
signal_filt,
patient_name='Patient',
epoch_idx=0,
channel_idx=0,
fmin=0.1,
fmax=50,
psd_fmax=100,
nperseg=None,
noverlap=None,
figsize=(16, 12),
save_path=None,
fig_name=None,
show_fig=True
):
from scipy import signal as sp_signal
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
# Get sampling frequency BEFORE extracting data
sfreq = signal_orig.info['sfreq']
# Now extract the data arrays
signal_orig = signal_orig.get_data()[epoch_idx, channel_idx, :]
signal_filt = signal_filt.get_data()[epoch_idx, channel_idx, :]
time_vector = np.arange(0, len(signal_orig)) / sfreq # Convert to actual time in seconds
signal_diff = signal_orig - signal_filt
# Calculate RMS of the difference
rms_diff = np.sqrt(np.mean(signal_diff ** 2))
if nperseg is None:
nperseg = int(min(sfreq, signal_orig.size // 4))
if noverlap is None:
noverlap = nperseg // 2
# Spectrograms
frequencies, times, Sxx_orig = sp_signal.spectrogram(
signal_orig,
fs=sfreq,
nperseg=nperseg,
noverlap=noverlap,
scaling='density'
)
_, _, Sxx_filt = sp_signal.spectrogram(
signal_filt,
fs=sfreq,
nperseg=nperseg,
noverlap=noverlap,
scaling='density'
)
Sxx_diff = Sxx_orig - Sxx_filt
freq_mask = (frequencies >= fmin) & (frequencies <= fmax)
frequencies_filtered = frequencies[freq_mask]
Sxx_orig_db = 10 * np.log10(Sxx_orig[freq_mask, :] + 1e-10)
Sxx_filt_db = 10 * np.log10(Sxx_filt[freq_mask, :] + 1e-10)
Sxx_diff_db = Sxx_orig_db - Sxx_filt_db
# PSD calculation for one channel
psd_orig, freqs = psd_array_multitaper(signal_orig[np.newaxis, :], sfreq, fmin=0.1, fmax=psd_fmax, adaptive=True, normalization='full', verbose=False)
psd_filt, _ = psd_array_multitaper(signal_filt[np.newaxis, :], sfreq, fmin=0.1, fmax=psd_fmax, adaptive=True, normalization='full', verbose=False)
psd_db_orig = 10 * np.log10(psd_orig[0])
psd_db_filt = 10 * np.log10(psd_filt[0])
psd_db_diff = psd_db_orig - psd_db_filt
fig = plt.figure(figsize=figsize)
gs = gridspec.GridSpec(3, 4, width_ratios=[1, 1, 1, 0.05], height_ratios=[1, 1, 1])
# Row 1: Raw signals and difference
ax_raw = fig.add_subplot(gs[0, 0:2])
ax_raw.plot(time_vector, signal_orig, color='steelblue', label='Original')
ax_raw.plot(time_vector, signal_filt, color='darkred', label='Filtered', alpha=0.7)
ax_raw.set_title(f'Raw Signal - {patient_name} Epoch {epoch_idx+1} Ch{channel_idx+1}')
ax_raw.set_xlabel('Time (s)')
ax_raw.set_ylabel('Amplitude (µV)')
ax_raw.grid(True, alpha=0.3)
ax_raw.legend()
ax_diff = fig.add_subplot(gs[0, 2])
ax_diff.plot(time_vector, signal_diff, color='purple')
ax_diff.axhline(rms_diff, color='orange', linestyle='--', label=f'RMS = {rms_diff:.2f}')
ax_diff.axhline(-rms_diff, color='orange', linestyle='--')
ax_diff.set_title('Signal Difference\n(Original - Filtered)')
ax_diff.set_xlabel('Time (s)')
ax_diff.set_ylabel('Amplitude Diff (µV)')
ax_diff.grid(True, alpha=0.3)
ax_diff.legend()
# Row 2: Spectrograms
ax1 = fig.add_subplot(gs[1, 0])
ax2 = fig.add_subplot(gs[1, 1], sharey=ax1)
ax3 = fig.add_subplot(gs[1, 2], sharey=ax1)
pcm1 = ax1.pcolormesh(times, frequencies_filtered, Sxx_orig_db, shading='gouraud', cmap='viridis', vmin=np.min(Sxx_orig_db), vmax=np.max(Sxx_orig_db))
pcm2 = ax2.pcolormesh(times, frequencies_filtered, Sxx_filt_db, shading='gouraud', cmap='viridis', vmin=np.min(Sxx_orig_db), vmax=np.max(Sxx_orig_db))
pcm3 = ax3.pcolormesh(times, frequencies_filtered, Sxx_diff_db, shading='gouraud', cmap='bwr', vmin=-np.max(np.abs(Sxx_diff_db)), vmax=np.max(np.abs(Sxx_diff_db)))
ax1.set_ylabel('Frequency (Hz)')
ax1.set_xlabel('Time (s)')
ax1.set_ylim(fmin, fmax)
ax1.set_title('Spectrogram (Original)')
ax2.set_xlabel('Time (s)')
ax2.set_title('Spectrogram (Filtered)')
plt.setp(ax2.get_yticklabels(), visible=False)
ax3.set_xlabel('Time (s)')
ax3.set_title('Spectrogram Difference (dB)')
plt.setp(ax3.get_yticklabels(), visible=False)
# Shared colorbar for original/filtered
cax = fig.add_subplot(gs[1, 3])
cbar = fig.colorbar(pcm1, cax=cax)
cbar.set_label('PSD (dB/Hz)')
# Row 3: PSD plots
ax_psd = fig.add_subplot(gs[2, 0:2])
ax_psd.plot(freqs, psd_db_orig, color='steelblue', label='Original')
ax_psd.plot(freqs, psd_db_filt, color='darkred', label='Filtered', alpha=0.7)
ax_psd.set_title(f'PSD - Original & Filtered\nCh{channel_idx+1}')
ax_psd.set_xlabel('Frequency (Hz)')
ax_psd.set_ylabel('PSD (dB)')
ax_psd.grid(True, alpha=0.3)
ax_psd.set_xlim(0, psd_fmax)
ax_psd.legend()
ax_psd_diff = fig.add_subplot(gs[2, 2])
ax_psd_diff.plot(freqs, psd_db_diff, color='purple')
ax_psd_diff.set_title('PSD Diff (Orig - Filt)')
ax_psd_diff.set_xlabel('Frequency (Hz)')
ax_psd_diff.set_ylabel('PSD Diff (dB)')
ax_psd_diff.grid(True, alpha=0.3)
ax_psd_diff.set_xlim(0, psd_fmax)
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
if show_fig:
plt.show()
[docs]
@staticmethod
def plot_epoch_psd_comparison(epochs_orig, epochs_filt, epoch_idx, fix_chs_names=None, save_path=None, fig_name='epoch_psd_comparison', show_fig=True):
"""
Plot PSD comparison for a single epoch (original vs filtered) for all channels.
"""
sfreq = epochs_orig.info['sfreq']
epoch_data_orig = epochs_orig.get_data()[epoch_idx]
epoch_data_filt = epochs_filt.get_data()[epoch_idx]
psd_orig, freqs = psd_array_multitaper(epoch_data_orig, sfreq, fmin=0.1, fmax=100, adaptive=True, normalization='full', verbose=False)
psd_filt, _ = psd_array_multitaper(epoch_data_filt, sfreq, fmin=0.1, fmax=100, adaptive=True, normalization='full', verbose=False)
psd_db_orig = 10 * np.log10(psd_orig)
psd_db_filt = 10 * np.log10(psd_filt)
fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
for ch in range(epoch_data_orig.shape[0]):
ch_name = fix_chs_names[ch] if fix_chs_names and ch < len(fix_chs_names) else f'Ch{ch+1}'
axes[0].plot(freqs, psd_db_orig[ch], label=ch_name)
axes[1].plot(freqs, psd_db_filt[ch], label=ch_name)
axes[0].set_title(f'Original PSD\nEpoch {epoch_idx+1}')
axes[1].set_title(f'Filtered PSD\nEpoch {epoch_idx+1}')
for ax in axes:
ax.set_xlabel('Frequency (Hz)')
ax.set_ylabel('PSD (dB)')
ax.set_xlim(0, 100)
ax.grid(True, alpha=0.3)
axes[1].legend()
plt.tight_layout()
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(os.path.join(save_path, fig_name + '.png'), dpi=300, bbox_inches='tight')
if show_fig:
plt.show()