马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
项目拆分
src/data/load_data.py
- # -*- coding: utf-8 -*-
- import sys
- import io
- # 设置标准输出为 UTF-8 编码
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
- import pandas as pd
-
- def load_data(file_path: str) -> pd.DataFrame:
- """加载数据文件
- Args:
- file_path: 数据文件路径
- Returns:
- 加载的数据框
- """
- return pd.read_csv(file_path)
-
- if __name__ == "__main__":
- # 测试代码
- data = load_data("testDay31/data/raw/heart.csv")
- print("数据读取完成!")
复制代码 src/data/preprocessing.py
- # -*- coding: utf-8 -*-
- import sys
- import io
- import os
- # 设置标准输出为 UTF-8 编码
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
-
- import pandas as pd
- import numpy as np
- from typing import Tuple, Dict
-
-
- def load_data(file_path: str) -> pd.DataFrame:
- """加载数据文件
- Args:
- file_path: 数据文件路径
- Returns:
- 加载的数据框
- """
- return pd.read_csv(file_path)
-
- # 仅以处理缺失值为例
- def handle_missing_values(data: pd.DataFrame) -> pd.DataFrame:
- """处理缺失值
- Args:
- data: 包含缺失值的数据框
- Returns:
- 处理后的数据框
- """
- data_clean = data.copy()
- continuous_features = data.select_dtypes(include=['int64', 'float64']).columns.tolist()
-
- for feature in continuous_features:
- mode_value = data[feature].mode()[0]
- data_clean[feature].fillna(mode_value, inplace=True)
-
- return data_clean
-
- if __name__ == "__main__":
- # 测试代码
- data = load_data("testDay31/data/raw/heart.csv")
- # data_encoded, mappings = encode_categorical_features(data)
- data_clean = handle_missing_values(data)
- print("数据预处理完成!")
复制代码 models/train.py
- # -*- coding: utf-8 -*-
- import sys
- import os
- import io
- # 设置标准输出为 UTF-8 编码
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
-
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
- from sklearn.ensemble import RandomForestClassifier
- from sklearn.model_selection import train_test_split
- from sklearn.metrics import classification_report, confusion_matrix
- import time
- import joblib # 用于保存模型
- from typing import Tuple # 用于类型注解
-
- from data.preprocessing import load_data,handle_missing_values
- # from data.load_data import load_data
-
- def prepare_data() -> Tuple:
- """准备训练数据
- Returns:
- 训练集和测试集的特征和标签
- """
- # 加载和预处理数据
- data = load_data("testDay31/data/raw/heart.csv")
- data_clean = handle_missing_values(data)
-
- # 分离特征和标签
- X = data_clean.drop(['target'], axis=1)
- y = data_clean['target']
-
- # 划分训练集和测试集
- X_train, X_test, y_train, y_test = train_test_split(
- X, y, test_size=0.2, random_state=42
- )
-
- return X_train, X_test, y_train, y_test
-
- def train_model(X_train, y_train, model_params=None) -> RandomForestClassifier:
- """训练随机森林模型
- Args:
- X_train: 训练特征
- y_train: 训练标签
- model_params: 模型参数字典
- Returns:
- 训练好的模型
- """
- if model_params is None:
- model_params = {'random_state': 42}
-
- model = RandomForestClassifier(**model_params)
- model.fit(X_train, y_train)
- return model
-
- def evaluate_model(model, X_test, y_test) -> None:
- """评估模型性能
- Args:
- model: 训练好的模型
- X_test: 测试特征
- y_test: 测试标签
- """
- y_pred = model.predict(X_test)
- print("\n分类报告:")
- print(classification_report(y_test, y_pred))
- print("\n混淆矩阵:")
- print(confusion_matrix(y_test, y_pred))
-
- def save_model(model, model_path: str) -> None:
- """保存模型
- Args:
- model: 训练好的模型
- model_path: 模型保存路径
- """
- os.makedirs(os.path.dirname(model_path), exist_ok=True)
- joblib.dump(model, model_path)
- print(f"\n模型已保存至: {model_path}")
-
- if __name__ == "__main__":
- # 准备数据
- X_train, X_test, y_train, y_test = prepare_data()
-
- # 记录开始时间
- start_time = time.time()
-
- # 训练模型
- model = train_model(X_train, y_train)
-
- # 记录结束时间
- end_time = time.time()
- print(f"\n训练耗时: {end_time - start_time:.4f} 秒")
-
- # 评估模型
- evaluate_model(model, X_test, y_test)
-
- # 保存模型
- save_model(model, "testDay31/models/random_forest_model.joblib")
复制代码 @浙大疏锦行
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |