PYTHON训练营DAY31

打印 上一主题 下一主题

主题 1509|帖子 1509|积分 4527

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
项目拆分

src/data/load_data.py

  1. # -*- coding: utf-8 -*-
  2. import sys
  3. import io
  4. # 设置标准输出为 UTF-8 编码
  5. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
  6. import pandas as pd
  7. def load_data(file_path: str) -> pd.DataFrame:
  8.     """加载数据文件
  9.     Args:
  10.         file_path: 数据文件路径
  11.     Returns:
  12.         加载的数据框
  13.     """
  14.     return pd.read_csv(file_path)
  15. if __name__ == "__main__":
  16.     # 测试代码
  17.     data = load_data("testDay31/data/raw/heart.csv")
  18.     print("数据读取完成!")
复制代码
src/data/preprocessing.py 

  1. # -*- coding: utf-8 -*-
  2. import sys
  3. import io
  4. import os
  5. # 设置标准输出为 UTF-8 编码
  6. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
  7. import pandas as pd
  8. import numpy as np
  9. from typing import Tuple, Dict
  10. def load_data(file_path: str) -> pd.DataFrame:
  11.     """加载数据文件
  12.     Args:
  13.         file_path: 数据文件路径
  14.     Returns:
  15.         加载的数据框
  16.     """
  17.     return pd.read_csv(file_path)
  18. # 仅以处理缺失值为例
  19. def handle_missing_values(data: pd.DataFrame) -> pd.DataFrame:
  20.     """处理缺失值
  21.     Args:
  22.         data: 包含缺失值的数据框
  23.     Returns:
  24.         处理后的数据框
  25.     """
  26.     data_clean = data.copy()
  27.     continuous_features = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
  28.    
  29.     for feature in continuous_features:
  30.         mode_value = data[feature].mode()[0]
  31.         data_clean[feature].fillna(mode_value, inplace=True)
  32.    
  33.     return data_clean
  34. if __name__ == "__main__":
  35.     # 测试代码
  36.     data = load_data("testDay31/data/raw/heart.csv")
  37.     # data_encoded, mappings = encode_categorical_features(data)
  38.     data_clean = handle_missing_values(data)
  39.     print("数据预处理完成!")
复制代码
models/train.py

  1. # -*- coding: utf-8 -*-
  2. import sys
  3. import os
  4. import io
  5. # 设置标准输出为 UTF-8 编码
  6. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
  7. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  8. from sklearn.ensemble import RandomForestClassifier
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.metrics import classification_report, confusion_matrix
  11. import time
  12. import joblib # 用于保存模型
  13. from typing import Tuple # 用于类型注解
  14. from data.preprocessing import  load_data,handle_missing_values
  15. # from data.load_data import load_data
  16. def prepare_data() -> Tuple:
  17.     """准备训练数据
  18.     Returns:
  19.         训练集和测试集的特征和标签
  20.     """
  21.     # 加载和预处理数据
  22.     data = load_data("testDay31/data/raw/heart.csv")
  23.     data_clean = handle_missing_values(data)
  24.    
  25.     # 分离特征和标签
  26.     X = data_clean.drop(['target'], axis=1)
  27.     y = data_clean['target']
  28.    
  29.     # 划分训练集和测试集
  30.     X_train, X_test, y_train, y_test = train_test_split(
  31.         X, y, test_size=0.2, random_state=42
  32.     )
  33.    
  34.     return X_train, X_test, y_train, y_test
  35. def train_model(X_train, y_train, model_params=None) -> RandomForestClassifier:
  36.     """训练随机森林模型
  37.     Args:
  38.         X_train: 训练特征
  39.         y_train: 训练标签
  40.         model_params: 模型参数字典
  41.     Returns:
  42.         训练好的模型
  43.     """
  44.     if model_params is None:
  45.         model_params = {'random_state': 42}
  46.    
  47.     model = RandomForestClassifier(**model_params)
  48.     model.fit(X_train, y_train)
  49.     return model
  50. def evaluate_model(model, X_test, y_test) -> None:
  51.     """评估模型性能
  52.     Args:
  53.         model: 训练好的模型
  54.         X_test: 测试特征
  55.         y_test: 测试标签
  56.     """
  57.     y_pred = model.predict(X_test)
  58.     print("\n分类报告:")
  59.     print(classification_report(y_test, y_pred))
  60.     print("\n混淆矩阵:")
  61.     print(confusion_matrix(y_test, y_pred))
  62. def save_model(model, model_path: str) -> None:
  63.     """保存模型
  64.     Args:
  65.         model: 训练好的模型
  66.         model_path: 模型保存路径
  67.     """
  68.     os.makedirs(os.path.dirname(model_path), exist_ok=True)
  69.     joblib.dump(model, model_path)
  70.     print(f"\n模型已保存至: {model_path}")
  71. if __name__ == "__main__":
  72.     # 准备数据
  73.     X_train, X_test, y_train, y_test = prepare_data()
  74.    
  75.     # 记录开始时间
  76.     start_time = time.time()
  77.    
  78.     # 训练模型
  79.     model = train_model(X_train, y_train)
  80.    
  81.     # 记录结束时间
  82.     end_time = time.time()
  83.     print(f"\n训练耗时: {end_time - start_time:.4f} 秒")
  84.    
  85.     # 评估模型
  86.     evaluate_model(model, X_test, y_test)
  87.    
  88.     # 保存模型
  89.     save_model(model, "testDay31/models/random_forest_model.joblib")
复制代码
@浙大疏锦行

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

光之使者

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表