实现基于深度学习ECG心电信号分类,用多个数据集(MIT-BIH心率不齐数据库、 ...

诗林  论坛元老 | 2025-3-12 20:57:25 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1331|帖子 1331|积分 3993

基于深度学习的ECG心信号分类

对人体的心电信号进行分类,判断出被测试者心跳是否正常,或患有什么样的心脏疾病,终极实现心电数据的分类。此中包罗CNN,LSTM,GRU等模子对比。

数据集利用的是以下四个数据集的合并:

  • MIT-BIH心率不齐数据库
  • MIT-BIH ST变革数据库
  • 欧盟ST-T心电数据库
  • 心脏性猝死动态心电数据库

    实现基于深度学习的ECG心电信号分类,我们可以利用多个数据集(MIT-BIH心率不齐数据库、MIT-BIH ST变革数据库、欧盟ST-T心电数据库和心脏性猝死动态心电数据库)来练习和评估模子。我们将对比不同的模子,如CNN、LSTM和GRU,以确定哪种模子在心电信号分类任务上体现最佳。
项目概述

以下是项目标详细步调:

  • 数据网络与预处理

    • 下载并合并四个数据集。
    • 对数据进行清洗和预处理,包罗去噪、归一化等。

  • 特性提取

    • 将原始心电信号转换为得当模子输入的形式。

  • 模子构建

    • 构建CNN、LSTM和GRU模子。
    • 练习并评估每个模子的体现。

  • 结果分析

    • 比力不同模子的性能指标,如准确率、精确率、召回率、F1分数等。

  • 可视化

    • 可视化练习过程中的损失和准确率曲线。
    • 可视化肴杂矩阵。

  • 摆设

    • 创建一个简朴的GUI界面来进行及时猜测。

数据集下载与合并

首先,我们需要下载并合并四个数据集。这里假设你已经下载了这些数据集,并将它们存储在一个文件夹中。
数据集路径配置

  1. % Configuration
  2. data_folder = 'path/to/data'; % Path to the folder containing datasets
  3. output_folder = 'path/to/output'; % Path to save preprocessed data and models
复制代码
数据预处理

加载和预处理数据

  1. [<title="Data Preprocessing for ECG Classification">]
  2. function [X_train, y_train, X_val, y_val, X_test, y_test] = preprocess_ecg_data(data_folder)
  3.     % Load datasets
  4.     mitbih_arrhythmia = load(fullfile(data_folder, 'mitbih_arrhythmia.mat'));
  5.     mitbih_st_change = load(fullfile(data_folder, 'mitbih_st_change.mat'));
  6.     eu_stt = load(fullfile(data_folder, 'eu_stt.mat'));
  7.     sudden_cardiac_death = load(fullfile(data_folder, 'sudden_cardiac_death.mat'));
  8.     % Extract signals and labels
  9.     signals = {};
  10.     labels = {};
  11.     % MIT-BIH Arrhythmia Database
  12.     if isfield(mitbih_arrhythmia, 'signals') && isfield(mitbih_arrhythmia, 'labels')
  13.         signals{end+1} = mitbih_arrhythmia.signals;
  14.         labels{end+1} = mitbih_arrhythmia.labels;
  15.     end
  16.     % MIT-BIH ST Change Database
  17.     if isfield(mitbih_st_change, 'signals') && isfield(mitbih_st_change, 'labels')
  18.         signals{end+1} = mitbih_st_change.signals;
  19.         labels{end+1} = mitbih_st_change.labels;
  20.     end
  21.     % EU ST-T Database
  22.     if isfield(eu_stt, 'signals') && isfield(eu_stt, 'labels')
  23.         signals{end+1} = eu_stt.signals;
  24.         labels{end+1} = eu_stt.labels;
  25.     end
  26.     % Sudden Cardiac Death Database
  27.     if isfield(sudden_cardiac_death, 'signals') && isfield(sudden_cardiac_death, 'labels')
  28.         signals{end+1} = sudden_cardiac_death.signals;
  29.         labels{end+1} = sudden_cardiac_death.labels;
  30.     end
  31.     % Concatenate all signals and labels
  32.     all_signals = vertcat(signals{:});
  33.     all_labels = vertcat(labels{:});
  34.     % Normalize signals
  35.     all_signals = zscore(all_signals);
  36.     % Split data into train, validation, and test sets
  37.     cv = cvpartition(size(all_signals, 1), 'HoldOut', 0.2);
  38.     idx_train = training(cv);
  39.     idx_test = test(cv);
  40.     X_train = all_signals(idx_train, :);
  41.     y_train = all_labels(idx_train);
  42.     X_test = all_signals(idx_test, :);
  43.     y_test = all_labels(idx_test);
  44.     % Further split training set into training and validation sets
  45.     cv_inner = cvpartition(sum(idx_train), 'HoldOut', 0.2);
  46.     idx_train_inner = training(cv_inner);
  47.     idx_val_inner = test(cv_inner);
  48.     X_val = X_train(idx_val_inner, :);
  49.     y_val = y_train(idx_val_inner);
  50.     X_train = X_train(idx_train_inner, :);
  51.     y_train = y_train(idx_train_inner);
  52. end
复制代码
模子构建与练习

我们将构建CNN、LSTM和GRU模子,并比力它们的性能。
CNN模子

  1. [<title="CNN Model for ECG Classification">]
  2. function model_cnn = build_cnn_model(input_shape, num_classes)
  3.     layers = [
  4.         inputLayer(input_shape)
  5.         convolution2dLayer([1 16], 16, 'Padding', 'same')
  6.         batchNormalizationLayer
  7.         reluLayer
  8.         maxPooling2dLayer(2, 'Stride', 2)
  9.         
  10.         convolution2dLayer([1 32], 32, 'Padding', 'same')
  11.         batchNormalizationLayer
  12.         reluLayer
  13.         maxPooling2dLayer(2, 'Stride', 2)
  14.         
  15.         fullyConnectedLayer(num_classes)
  16.         softmaxLayer
  17.         classificationLayer];
  18.    
  19.     options = trainingOptions('adam', ...
  20.         'MaxEpochs', 20, ...
  21.         'MiniBatchSize', 128, ...
  22.         'InitialLearnRate', 0.001, ...
  23.         'Plots', 'training-progress', ...
  24.         'Verbose', false);
  25.    
  26.     model_cnn = trainNetwork(X_train, categorical(y_train), layers, options);
  27. end
复制代码
LSTM模子

  1. [<title="LSTM Model for ECG Classification">]
  2. function model_lstm = build_lstm_model(input_shape, num_classes)
  3.     layers = [
  4.         sequenceInputLayer(input_shape(2))
  5.         lstmLayer(128, 'OutputMode', 'last')
  6.         dropoutLayer(0.5)
  7.         fullyConnectedLayer(num_classes)
  8.         softmaxLayer
  9.         classificationLayer];
  10.    
  11.     options = trainingOptions('adam', ...
  12.         'MaxEpochs', 20, ...
  13.         'GradientThreshold', 1, ...
  14.         'InitialLearnRate', 0.001, ...
  15.         'SequenceLength', 'longest', ...
  16.         'Plots', 'training-progress', ...
  17.         'Verbose', false);
  18.    
  19.     model_lstm = trainNetwork(X_train, categorical(y_train), layers, options);
  20. end
复制代码
GRU模子

  1. [<title="GRU Model for ECG Classification">]
  2. function model_gru = build_gru_model(input_shape, num_classes)
  3.     layers = [
  4.         sequenceInputLayer(input_shape(2))
  5.         gruLayer(128, 'OutputMode', 'last')
  6.         dropoutLayer(0.5)
  7.         fullyConnectedLayer(num_classes)
  8.         softmaxLayer
  9.         classificationLayer];
  10.    
  11.     options = trainingOptions('adam', ...
  12.         'MaxEpochs', 20, ...
  13.         'GradientThreshold', 1, ...
  14.         'InitialLearnRate', 0.001, ...
  15.         'SequenceLength', 'longest', ...
  16.         'Plots', 'training-progress', ...
  17.         'Verbose', false);
  18.    
  19.     model_gru = trainNetwork(X_train, categorical(y_train), layers, options);
  20. end
复制代码
模子评估与结果分析

评估每个模子并在图表中展示的结果。
评估函数

  1. [<title="Model Evaluation Function">]
  2. function evaluate_models(model_cnn, model_lstm, model_gru, X_val, y_val)
  3.     % Evaluate CNN model
  4.     YPred_cnn = classify(model_cnn, X_val);
  5.     accuracy_cnn = sum(YPred_cnn == y_val) / numel(y_val);
  6.     disp(['CNN Accuracy: ', num2str(accuracy_cnn)]);
  7.    
  8.     % Evaluate LSTM model
  9.     YPred_lstm = classify(model_lstm, X_val);
  10.     accuracy_lstm = sum(YPred_lstm == y_val) / numel(y_val);
  11.     disp(['LSTM Accuracy: ', num2str(accuracy_lstm)]);
  12.    
  13.     % Evaluate GRU model
  14.     YPred_gru = classify(model_gru, X_val);
  15.     accuracy_gru = sum(YPred_gru == y_val) / numel(y_val);
  16.     disp(['GRU Accuracy: ', num2str(accuracy_gru)]);
  17.    
  18.     % Plot confusion matrices
  19.     figure;
  20.     subplot(1, 3, 1);
  21.     cm_cnn = confusionchart(categorical(y_val), YPred_cnn);
  22.     title('Confusion Matrix (CNN)');
  23.    
  24.     subplot(1, 3, 2);
  25.     cm_lstm = confusionchart(categorical(y_val), YPred_lstm);
  26.     title('Confusion Matrix (LSTM)');
  27.    
  28.     subplot(1, 3, 3);
  29.     cm_gru = confusionchart(categorical(y_val), YPred_gru);
  30.     title('Confusion Matrix (GRU)');
  31. end
复制代码
主脚本 main_script.m

将全部步调整合到主脚本中。
  1. [<title="Main Script for ECG Classification">]% Main Script for ECG Classification% This script preprocesses the ECG data, builds and trains CNN, LSTM, and GRU models,% evaluates their performance, and visualizes the results.clear;clc;% Configuration
  2. data_folder = 'path/to/data'; % Path to the folder containing datasets
  3. output_folder = 'path/to/output'; % Path to save preprocessed data and models
  4. % Preprocess data[X_train, y_train, X_val, y_val, X_test, y_test] = preprocess_ecg_data(data_folder);% Reshape data for CNNinput_shape_cnn = [1, size(X_train, 2)];X_train_cnn = permute(X_train, [2, 1, 3]);X_val_cnn = permute(X_val, [2, 1, 3]);% Build and train CNN modelmodel_cnn = build_cnn_model(input_shape_cnn, length(unique(y_train)));% Build and train LSTM modelinput_shape_rnn = size(X_train, 2);model_lstm = build_lstm_model(input_shape_rnn, length(unique(y_train)));% Build and train GRU modelmodel_gru = build_gru_model(input_shape_rnn, length(unique(y_train)));% Evaluate modelsevaluate_models(model_cnn, model_lstm, model_gru, X_val_cnn, y_val);
复制代码
利用说明


  • 配置路径

    • 将 data_folder 设置为存放数据集的目录路径。
    • 将 output_folder 设置为生存预处理数据和模子的目标目录路径。

  • 运行脚本

    • 在 MATLAB 命令窗口中运行 main_script.m。
    • 脚本会自动读取 data_folder 中的数据集,对数据进行预处理,构建并练习CNN、LSTM和GRU模子,并评估其性能。

  • 注意事项

    • 确保全部须要的工具箱已安装,特别是 Deep Learning Toolbox 和 Signal Processing Toolbox。
    • 根据需要调整参数,如 MaxEpochs 和 MiniBatchSize。

示例

假设你的数据文件夹结构如下:
  1. data/
  2. ├── mitbih_arrhythmia.mat
  3. ├── mitbih_st_change.mat
  4. ├── eu_stt.mat
  5. └── sudden_cardiac_death.mat
复制代码
而且每个 .mat 文件中都有 signals 和 labels 变量。运行 main_script.m 后,MATLAB 将体现每个模子的准确性,并天生肴杂矩阵图表。
总结

通过上述 MATLAB 代码,你可以轻松地对心电信号进行分类,并对比不同模子的性能。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

诗林

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表