ToB企服应用市场:ToB评测及商务社交产业平台

标题: 第100+17步 ChatGPT学习:R实现Catboost分类 [打印本页]

作者: 泉缘泉    时间: 2024-7-27 15:03
标题: 第100+17步 ChatGPT学习:R实现Catboost分类
基于R 4.2.2版本演示


一、写在前面

有不少大佬问做机器学习分类能不能用R语言,不想学Python咯。

答曰:可!用GPT大概Kimi转一下就得了呗。

加上最近也没啥内容写了,就帮各位搬运一下吧。


二、R代码实现Catboost分类

(1)导入数据
我习惯用RStudio自带的导入功能:



(2)创建Catboost模型(默认参数)
  1. # Load necessary libraries
  2. library(caret)
  3. library(pROC)
  4. library(ggplot2)
  5. library(catboost)
  6. # Assume 'data' is your dataframe containing the data
  7. # Set seed to ensure reproducibility
  8. set.seed(123)
  9. # Split data into training and validation sets (80% training, 20% validation)
  10. trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
  11. trainData <- data[trainIndex, ]
  12. validData <- data[-trainIndex, ]
  13. # Prepare pools for CatBoost
  14. trainPool <- catboost.load_pool(data = trainData[, -which(names(trainData) == "X")], label = trainData$X)
  15. validPool <- catboost.load_pool(data = validData[, -which(names(validData) == "X")], label = validData$X)
  16. # Define parameters for CatBoost
  17. params <- list(
  18.   iterations = 250,
  19.   depth = 6,
  20.   learning_rate = 0.1,
  21.   l2_leaf_reg = 10,
  22.   loss_function = "Logloss",
  23.   eval_metric = "AUC"
  24. )
  25. # Train the CatBoost model
  26. model <- catboost.train(learn_pool = trainPool, params = params)
  27. # Predict on the training and validation sets using the correct parameter
  28. trainPredict <- catboost.predict(model, trainPool, prediction_type = "Probability")
  29. validPredict <- catboost.predict(model, validPool, prediction_type = "Probability")
  30. # Convert predictions to binary using 0.5 as threshold
  31. trainPredictBinary <- ifelse(trainPredict > 0.5, 1, 0)
  32. validPredictBinary <- ifelse(validPredict > 0.5, 1, 0)
  33. # 计算ROC对象
  34. trainRoc <- roc(response = as.numeric(trainData$X) - 1, predictor = trainPredict)
  35. validRoc <- roc(response = as.numeric(validData$X) - 1, predictor = validPredict)
  36. # 使用ggplot绘制ROC曲线
  37. trainRocPlot <- ggplot(data = data.frame(fpr = 1 - trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = fpr, y = tpr)) +
  38.   geom_line(color = "blue") +
  39.   geom_area(aes(ifelse(fpr <= 1, fpr, NA)), fill = "blue", alpha = 0.2) + # 使用条件表达式确保不超出坐标范围
  40.   geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  41.   ggtitle("Training ROC Curve") +
  42.   xlab("False Positive Rate") +
  43.   ylab("True Positive Rate") +
  44.   annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")
  45. validRocPlot <- ggplot(data = data.frame(fpr = 1 - validRoc$specificities, tpr = validRoc$sensitivities), aes(x = fpr, y = tpr)) +
  46.   geom_line(color = "red") +
  47.   geom_area(aes(ifelse(fpr <= 1, fpr, NA)), fill = "red", alpha = 0.2) + # 使用条件表达式确保不超出坐标范围
  48.   geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  49.   ggtitle("Validation ROC Curve") +
  50.   xlab("False Positive Rate") +
  51.   ylab("True Positive Rate") +
  52.   annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")
  53. # 显示绘图
  54. print(trainRocPlot)
  55. print(validRocPlot)
  56. # Calculate confusion matrices based on 0.5 cutoff for probability
  57. confMatTrain <- table(trainData$X, trainPredict >= 0.5)
  58. confMatValid <- table(validData$X, validPredict >= 0.5)
  59. # Plot and display confusion matrices
  60. plot_confusion_matrix <- function(pred, actual, dataset_name) {
  61.   conf_mat <- table(Predicted = pred >= 0.5, Actual = actual)
  62.   conf_mat_df <- as.data.frame(as.table(conf_mat))
  63.   colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")
  64.   
  65.   p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +
  66.     geom_tile(color = "white") +
  67.     geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +
  68.     scale_fill_gradient(low = "white", high = "steelblue") +
  69.     labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +
  70.     theme_minimal() +
  71.     theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))
  72.   
  73.   print(p)
  74. }
  75. # Call the function to plot confusion matrices for both training and validation sets
  76. plot_confusion_matrix(trainPredict, trainData$X, "Training")
  77. plot_confusion_matrix(validPredict, validData$X, "Validation")
  78. # Extract values for calculations
  79. a_train <- confMatTrain[1, 1]
  80. b_train <- confMatTrain[1, 2]
  81. c_train <- confMatTrain[2, 1]
  82. d_train <- confMatTrain[2, 2]
  83. a_valid <- confMatValid[1, 1]
  84. b_valid <- confMatValid[1, 2]
  85. c_valid <- confMatValid[2, 1]
  86. d_valid <- confMatValid[2, 2]
  87. # Training Set Metrics
  88. acc_train <- (a_train + d_train) / sum(confMatTrain)
  89. error_rate_train <- 1 - acc_train
  90. sen_train <- d_train / (d_train + c_train)
  91. sep_train <- a_train / (a_train + b_train)
  92. precision_train <- d_train / (b_train + d_train)
  93. F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
  94. MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
  95. auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc
  96. # Validation Set Metrics
  97. acc_valid <- (a_valid + d_valid) / sum(confMatValid)
  98. error_rate_valid <- 1 - acc_valid
  99. sen_valid <- d_valid / (d_valid + c_valid)
  100. sep_valid <- a_valid / (a_valid + b_valid)
  101. precision_valid <- d_valid / (b_valid + d_valid)
  102. F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
  103. MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
  104. auc_valid <- roc(response = validData$X, predictor = validPredict)$auc
  105. # Print Metrics
  106. cat("Training Metrics\n")
  107. cat("Accuracy:", acc_train, "\n")
  108. cat("Error Rate:", error_rate_train, "\n")
  109. cat("Sensitivity:", sen_train, "\n")
  110. cat("Specificity:", sep_train, "\n")
  111. cat("Precision:", precision_train, "\n")
  112. cat("F1 Score:", F1_train, "\n")
  113. cat("MCC:", MCC_train, "\n")
  114. cat("AUC:", auc_train, "\n\n")
  115. cat("Validation Metrics\n")
  116. cat("Accuracy:", acc_valid, "\n")
  117. cat("Error Rate:", error_rate_valid, "\n")
  118. cat("Sensitivity:", sen_valid, "\n")
  119. cat("Specificity:", sep_valid, "\n")
  120. cat("Precision:", precision_valid, "\n")
  121. cat("F1 Score:", F1_valid, "\n")
  122. cat("MCC:", MCC_valid, "\n")
  123. cat("AUC:", auc_valid, "\n")
复制代码
在R语言中,Catboost模型得单独安装,下面是一些可以调整的关键参数:

①学习率 (learning_rate):控制每步模型更新的幅度。较小的学习率可以进步模型的训练稳定性和准确性,但可能必要更多的时间和更多的树来收敛。

②树的深度 (depth):决定了每棵树的最大深度。较深的树可以更好地捕捉数据中的复杂关系,但也可能导致过拟合。

③树的数量 (iterations):模型中树的总数。更多的树可以增长模型的复杂度和能力,但同样可能导致过拟合。

④L2 正则化系数 (l2_leaf_reg):在模型的丧失函数中增长一个正则项,以减少模型复杂度和过拟合风险。

⑤边界计数 (border_count):用于数值特征分箱的边界数量,影响模型在连续特征上的决策边界。

⑥种别特征组合深度 (cat_features):CatBoost 优化了对种别特征的处置处罚,可以指定在模型中使用的种别特征。

⑦子采样 (subsample):指定每棵树训练时从训练数据集中随机抽取的比例,有助于防止模型过拟合。

⑧列采样 (colsample_bylevel,colsample_bytree):控制每棵树或每个级别使用的特征的比例,可以增长模型的多样性,低落过拟合风险。

⑨最小数据在叶节点 (min_data_in_leaf):叶节点必需的最小样本数量,增长这个参数的值可以防止模型学习过于具体的模式,从而低落过拟合风险。

⑩评估指标 (eval_metric):用于训练过程中模型评估的性能指标。

结果输出(随便挑的):






从AUC来看,Catboost随便一跑,就跑出过拟合了,跟Xgboost差不多。


三、Catboost调参

随便设置了一下,效果不明显,给各位自行嗨皮:

  1. # Load necessary libraries
  2. library(caret)
  3. library(pROC)
  4. library(ggplot2)
  5. library(catboost)
  6. # Assume 'data' is your dataframe containing the data
  7. # Set seed to ensure reproducibility
  8. set.seed(123)
  9. # Convert the target variable to factor if not already
  10. data$X <- as.factor(data$X)
  11. data$X <- as.numeric(data$X) - 1
  12. # Split data into training and validation sets (80% training, 20% validation)
  13. trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
  14. trainData <- data[trainIndex, ]
  15. validData <- data[-trainIndex, ]
  16. # Prepare CatBoost pools
  17. trainPool <- catboost.load_pool(data = trainData[, -which(names(trainData) == "X")], label = trainData$X)
  18. validPool <- catboost.load_pool(data = validData[, -which(names(validData) == "X")], label = validData$X)
  19. # Define parameter grid
  20. depths <- c(2, 4, 6)  # Reduced maximum depth
  21. l2_leaf_regs <- c(1, 3, 5, 10, 20, 25)  # Increased maximum regularization
  22. iterations <- c(500, 1000)  # Added higher iteration count for lower learning rates
  23. learning_rates <- c(0.05, 0.1)  # Lower maximum learning rate
  24. subsample <- 1.0  # Use 80% of data for each tree to prevent overfitting
  25. best_auc <- 0
  26. best_params <- list()
  27. # Loop through parameter grid
  28. for (depth in depths) {
  29.   for (l2_leaf_reg in l2_leaf_regs) {
  30.     for (iter in iterations) {
  31.       for (learning_rate in learning_rates) {
  32.         # Set parameters for this iteration
  33.         params <- list(
  34.           iterations = iter,
  35.           depth = depth,
  36.           learning_rate = learning_rate,
  37.           l2_leaf_reg = l2_leaf_reg,
  38.           loss_function = 'Logloss',
  39.           eval_metric = 'AUC'
  40.         )
  41.         
  42.         # Train the model
  43.         model <- catboost.train(learn_pool = trainPool, test_pool = validPool, params = params)
  44.         
  45.         # Predict on the validation set
  46.         validPredict <- catboost.predict(model, validPool)
  47.         if (is.vector(validPredict)) {
  48.           validPredictBinary <- ifelse(validPredict > 0.5, 1, 0)
  49.         } else {
  50.           # Assuming the second column is the probability of the positive class
  51.           validPredictBinary <- ifelse(validPredict[, 2] > 0.5, 1, 0)
  52.         }
  53.         
  54.         # Calculate AUC
  55.         validRoc <- roc(response = as.numeric(validData$X) - 1, predictor = validPredictBinary)
  56.         auc_score <- auc(validRoc)
  57.         
  58.         # Update best model if current AUC is better
  59.         if (auc_score > best_auc) {
  60.           best_auc <- auc_score
  61.           best_params <- params
  62.         }
  63.       }
  64.     }
  65.   }
  66. }
  67. # Print the best AUC and corresponding parameters
  68. print(paste("Best AUC:", best_auc))
  69. print("Best Parameters:")
  70. print(best_params)
  71. # After parameter tuning, train the model with best parameters
  72. model <- catboost.train(learn_pool = trainPool, params = best_params)
  73. # Predict on the training and validation sets using the correct parameter
  74. trainPredict <- catboost.predict(model, trainPool, prediction_type = "Probability")
  75. validPredict <- catboost.predict(model, validPool, prediction_type = "Probability")
  76. # Convert predictions to binary using 0.5 as threshold
  77. trainPredictBinary <- ifelse(trainPredict > 0.5, 1, 0)
  78. validPredictBinary <- ifelse(validPredict > 0.5, 1, 0)
  79. # 计算ROC对象
  80. trainRoc <- roc(response = as.numeric(trainData$X) - 1, predictor = trainPredict)
  81. validRoc <- roc(response = as.numeric(validData$X) - 1, predictor = validPredict)
  82. # 使用ggplot绘制ROC曲线
  83. trainRocPlot <- ggplot(data = data.frame(fpr = 1 - trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = fpr, y = tpr)) +
  84.   geom_line(color = "blue") +
  85.   geom_area(aes(ifelse(fpr <= 1, fpr, NA)), fill = "blue", alpha = 0.2) + # 使用条件表达式确保不超出坐标范围
  86.   geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  87.   ggtitle("Training ROC Curve") +
  88.   xlab("False Positive Rate") +
  89.   ylab("True Positive Rate") +
  90.   annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")
  91. validRocPlot <- ggplot(data = data.frame(fpr = 1 - validRoc$specificities, tpr = validRoc$sensitivities), aes(x = fpr, y = tpr)) +
  92.   geom_line(color = "red") +
  93.   geom_area(aes(ifelse(fpr <= 1, fpr, NA)), fill = "red", alpha = 0.2) + # 使用条件表达式确保不超出坐标范围
  94.   geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  95.   ggtitle("Validation ROC Curve") +
  96.   xlab("False Positive Rate") +
  97.   ylab("True Positive Rate") +
  98.   annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")
  99. # 显示绘图
  100. print(trainRocPlot)
  101. print(validRocPlot)
  102. # Calculate confusion matrices based on 0.5 cutoff for probability
  103. confMatTrain <- table(trainData$X, trainPredict >= 0.5)
  104. confMatValid <- table(validData$X, validPredict >= 0.5)
  105. # Function to plot confusion matrix using ggplot2
  106. plot_confusion_matrix <- function(conf_mat, dataset_name) {
  107.   conf_mat_df <- as.data.frame(as.table(conf_mat))
  108.   colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")
  109.   
  110.   p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +
  111.     geom_tile(color = "white") +
  112.     geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +
  113.     scale_fill_gradient(low = "white", high = "steelblue") +
  114.     labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +
  115.     theme_minimal() +
  116.     theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))
  117.   
  118.   print(p)
  119. }
  120. # Now call the function to plot and display the confusion matrices
  121. plot_confusion_matrix(confMatTrain, "Training")
  122. plot_confusion_matrix(confMatValid, "Validation")
  123. # Extract values for calculations
  124. a_train <- confMatTrain[1, 1]
  125. b_train <- confMatTrain[1, 2]
  126. c_train <- confMatTrain[2, 1]
  127. d_train <- confMatTrain[2, 2]
  128. a_valid <- confMatValid[1, 1]
  129. b_valid <- confMatValid[1, 2]
  130. c_valid <- confMatValid[2, 1]
  131. d_valid <- confMatValid[2, 2]
  132. # Training Set Metrics
  133. acc_train <- (a_train + d_train) / sum(confMatTrain)
  134. error_rate_train <- 1 - acc_train
  135. sen_train <- d_train / (d_train + c_train)
  136. sep_train <- a_train / (a_train + b_train)
  137. precision_train <- d_train / (b_train + d_train)
  138. F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
  139. MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
  140. auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc
  141. # Validation Set Metrics
  142. acc_valid <- (a_valid + d_valid) / sum(confMatValid)
  143. error_rate_valid <- 1 - acc_valid
  144. sen_valid <- d_valid / (d_valid + c_valid)
  145. sep_valid <- a_valid / (a_valid + b_valid)
  146. precision_valid <- d_valid / (b_valid + d_valid)
  147. F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
  148. MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
  149. auc_valid <- roc(response = validData$X, predictor = validPredict)$auc
  150. # Print Metrics
  151. cat("Training Metrics\n")
  152. cat("Accuracy:", acc_train, "\n")
  153. cat("Error Rate:", error_rate_train, "\n")
  154. cat("Sensitivity:", sen_train, "\n")
  155. cat("Specificity:", sep_train, "\n")
  156. cat("Precision:", precision_train, "\n")
  157. cat("F1 Score:", F1_train, "\n")
  158. cat("MCC:", MCC_train, "\n")
  159. cat("AUC:", auc_train, "\n\n")
  160. cat("Validation Metrics\n")
  161. cat("Accuracy:", acc_valid, "\n")
  162. cat("Error Rate:", error_rate_valid, "\n")
  163. cat("Sensitivity:", sen_valid, "\n")
  164. cat("Specificity:", sep_valid, "\n")
  165. cat("Precision:", precision_valid, "\n")
  166. cat("F1 Score:", F1_valid, "\n")
  167. cat("MCC:", MCC_valid, "\n")
  168. cat("AUC:", auc_valid, "\n")
复制代码
结果输出:


提供个样本代码吧,我不调了。


五、最后

至于怎么安装,自学了哈。
数据嘛:
链接:https://pan.baidu.com/s/1rEf6JZyzA1ia5exoq5OF7g?pwd=x8xm
提取码:x8xm

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4