{ "nbformat": 4, "nbformat_minor": 2, "metadata": { "colab": { "name": "lesson_3-R.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "ir", "display_name": "R" }, "language_info": { "name": "R" }, "coopTranslator": { "original_hash": "5015d65d61ba75a223bfc56c273aa174", "translation_date": "2025-09-03T19:26:15+00:00", "source_file": "2-Regression/3-Linear/solution/R/lesson_3-R.ipynb", "language_code": "zh" } }, "cells": [ { "cell_type": "markdown", "source": [ "# 构建回归模型:线性回归和多项式回归模型\n" ], "metadata": { "id": "EgQw8osnsUV-" } }, { "cell_type": "markdown", "source": [ "## 南瓜定价的线性回归和多项式回归 - 第三课\n", "

\n", " \n", "

信息图作者:Dasani Madipalli
\n", "\n", "\n", "\n", "\n", "#### 介绍\n", "\n", "到目前为止,你已经通过南瓜定价数据集的样本数据了解了什么是回归,并将在整个课程中使用该数据集。你还使用了 `ggplot2` 进行了可视化。💪\n", "\n", "现在你已经准备好深入学习机器学习中的回归。在本课中,你将进一步了解两种回归类型:*基本线性回归* 和 *多项式回归*,以及这些技术背后的一些数学原理。\n", "\n", "> 在整个课程中,我们假设学生的数学知识较少,并努力使内容对来自其他领域的学生更易理解,因此请注意笔记、🧮 数学提示、图表以及其他学习工具,这些都将帮助你更好地理解。\n", "\n", "#### 准备工作\n", "\n", "提醒一下,你正在加载这些数据以便对其进行分析。\n", "\n", "- 什么时候是购买南瓜的最佳时间?\n", "\n", "- 一箱迷你南瓜的价格大概是多少?\n", "\n", "- 我应该选择半蒲式耳篮子还是 1 1/9 蒲式耳箱来购买?让我们继续深入挖掘这些数据。\n", "\n", "在上一课中,你创建了一个 `tibble`(数据框的一种现代化形式),并用原始数据集的一部分填充它,同时将价格标准化为以蒲式耳为单位。然而,通过这种方式,你只能收集到大约 400 个数据点,并且仅限于秋季月份。也许通过进一步清理数据,我们可以获得更多细节?我们拭目以待... 🕵️‍♀️\n", "\n", "完成此任务需要以下包:\n", "\n", "- `tidyverse`: [tidyverse](https://www.tidyverse.org/) 是一个 [R 包集合](https://www.tidyverse.org/packages),旨在让数据科学更快、更简单、更有趣!\n", "\n", "- `tidymodels`: [tidymodels](https://www.tidymodels.org/) 框架是一个 [包集合](https://www.tidymodels.org/packages),用于建模和机器学习。\n", "\n", "- `janitor`: [janitor 包](https://github.com/sfirke/janitor) 提供了一些简单的小工具,用于检查和清理脏数据。\n", "\n", "- `corrplot`: [corrplot 包](https://cran.r-project.org/web/packages/corrplot/vignettes/corrplot-intro.html) 提供了一个可视化的相关矩阵探索工具,支持自动变量重新排序,以帮助发现变量之间隐藏的模式。\n", "\n", "你可以通过以下命令安装这些包:\n", "\n", "`install.packages(c(\"tidyverse\", \"tidymodels\", \"janitor\", \"corrplot\"))`\n", "\n", "下面的脚本会检查你是否安装了完成本模块所需的包,并在缺少时为你安装它们。\n" ], "metadata": { "id": "WqQPS1OAsg3H" } }, { "cell_type": "code", "execution_count": null, "source": [ "suppressWarnings(if (!require(\"pacman\")) install.packages(\"pacman\"))\n", "\n", "pacman::p_load(tidyverse, tidymodels, janitor, corrplot)" ], "outputs": [], "metadata": { "id": "tA4C2WN3skCf", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "c06cd805-5534-4edc-f72b-d0d1dab96ac0" } }, { "cell_type": "markdown", "source": [ "我们稍后会加载这些很棒的包,并将它们在当前的 R 会话中可用。(这只是为了说明,`pacman::p_load()` 已经帮你完成了这一步)\n", "\n", "## 1. 线性回归线\n", "\n", "正如你在第一课中学到的,线性回归的目标是能够绘制一条*最佳拟合线*,以便:\n", "\n", "- **展示变量关系**。展示变量之间的关系\n", "\n", "- **进行预测**。准确预测新数据点在这条线上的位置\n", "\n", "为了绘制这种类型的线,我们使用一种统计技术,称为**最小二乘回归**。`最小二乘`的意思是回归线周围的所有数据点的误差平方后相加。理想情况下,这个最终的总和应该尽可能小,因为我们希望误差数量较低,也就是`最小二乘`。因此,最佳拟合线就是使误差平方和最小的那条线——这就是*最小二乘回归*的名称由来。\n", "\n", "我们这样做是因为我们希望拟合一条与所有数据点的累计距离最小的线。在加总之前,我们会对误差进行平方,因为我们关心的是误差的大小,而不是方向。\n", "\n", "> **🧮 数学公式**\n", ">\n", "> 这条线,称为*最佳拟合线*,可以用[一个公式](https://en.wikipedia.org/wiki/Simple_linear_regression)表示:\n", ">\n", "> Y = a + bX\n", ">\n", "> `X` 是`解释变量`或`预测变量`,`Y` 是`因变量`或`结果变量`。线的斜率是 `b`,而 `a` 是 y 截距,指的是当 `X = 0` 时 `Y` 的值。\n", ">\n", "\n", "> ![](../../../../../../2-Regression/3-Linear/solution/images/slope.png \"slope = $y/x$\")\n", " 信息图由 Jen Looper 制作\n", ">\n", "> 首先,计算斜率 `b`。\n", ">\n", "> 换句话说,参考我们的南瓜数据的原始问题:“按月份预测每蒲式耳南瓜的价格”,`X` 表示价格,`Y` 表示销售月份。\n", ">\n", "> ![](../../../../../../translated_images/calculation.989aa7822020d9d0ba9fc781f1ab5192f3421be86ebb88026528aef33c37b0d8.zh.png)\n", " 信息图由 Jen Looper 制作\n", "> \n", "> 计算 Y 的值。如果你支付大约 \\$4,那一定是四月!\n", ">\n", "> 计算这条线的数学公式必须展示线的斜率,这也取决于截距,即当 `X = 0` 时 `Y` 的位置。\n", ">\n", "> 你可以在 [Math is Fun](https://www.mathsisfun.com/data/least-squares-regression.html) 网站上观察这些值的计算方法。也可以访问[这个最小二乘计算器](https://www.mathsisfun.com/data/least-squares-calculator.html),看看数值如何影响这条线。\n", "\n", "是不是没那么可怕?🤓\n", "\n", "#### 相关性\n", "\n", "还有一个需要理解的术语是给定 X 和 Y 变量之间的**相关系数**。使用散点图,你可以快速可视化这个系数。数据点整齐排列成一条线的图表具有高相关性,而数据点在 X 和 Y 之间随意分布的图表则相关性较低。\n", "\n", "一个好的线性回归模型应该是使用最小二乘回归方法和回归线时,相关系数较高(接近 1 而不是 0)的模型。\n" ], "metadata": { "id": "cdX5FRpvsoP5" } }, { "cell_type": "markdown", "source": [ "## **2. 与数据共舞:创建用于建模的数据框**\n", "\n", "

\n", " \n", "

插画作者:@allison_horst
\n", "\n", "\n", "\n" ], "metadata": { "id": "WdUKXk7Bs8-V" } }, { "cell_type": "markdown", "source": [ "加载所需的库和数据集。将数据转换为包含数据子集的数据框:\n", "\n", "- 仅获取按蒲式耳定价的南瓜数据\n", "\n", "- 将日期转换为月份\n", "\n", "- 计算价格为高价和低价的平均值\n", "\n", "- 将价格转换为反映按蒲式耳数量定价的形式\n", "\n", "> 我们在[上一课](https://github.com/microsoft/ML-For-Beginners/blob/main/2-Regression/2-Data/solution/lesson_2-R.ipynb)中已经涵盖了这些步骤。\n" ], "metadata": { "id": "fMCtu2G2s-p8" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Load the core Tidyverse packages\n", "library(tidyverse)\n", "library(lubridate)\n", "\n", "# Import the pumpkins data\n", "pumpkins <- read_csv(file = \"https://raw.githubusercontent.com/microsoft/ML-For-Beginners/main/2-Regression/data/US-pumpkins.csv\")\n", "\n", "\n", "# Get a glimpse and dimensions of the data\n", "glimpse(pumpkins)\n", "\n", "\n", "# Print the first 50 rows of the data set\n", "pumpkins %>% \n", " slice_head(n = 5)" ], "outputs": [], "metadata": { "id": "ryMVZEEPtERn" } }, { "cell_type": "markdown", "source": [ "出于纯粹冒险的精神,让我们探索 [`janitor package`](../../../../../../2-Regression/3-Linear/solution/R/github.com/sfirke/janitor),它提供了简单的函数来检查和清理脏数据。例如,让我们看看我们数据的列名:\n" ], "metadata": { "id": "xcNxM70EtJjb" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Return column names\n", "pumpkins %>% \n", " names()" ], "outputs": [], "metadata": { "id": "5XtpaIigtPfW" } }, { "cell_type": "markdown", "source": [ "🤔 我们可以做得更好。让我们通过使用 `janitor::clean_names` 将这些列名转换为 [snake_case](https://en.wikipedia.org/wiki/Snake_case) 约定来使它们成为 `friendR`。要了解有关此函数的更多信息:`?clean_names`\n" ], "metadata": { "id": "IbIqrMINtSHe" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Clean names to the snake_case convention\n", "pumpkins <- pumpkins %>% \n", " clean_names(case = \"snake\")\n", "\n", "# Return column names\n", "pumpkins %>% \n", " names()" ], "outputs": [], "metadata": { "id": "a2uYvclYtWvX" } }, { "cell_type": "markdown", "source": [ "非常整洁 🧹!现在,像上一节课一样,用 `dplyr` 来与数据共舞吧!💃\n" ], "metadata": { "id": "HfhnuzDDtaDd" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Select desired columns\n", "pumpkins <- pumpkins %>% \n", " select(variety, city_name, package, low_price, high_price, date)\n", "\n", "\n", "\n", "# Extract the month from the dates to a new column\n", "pumpkins <- pumpkins %>%\n", " mutate(date = mdy(date),\n", " month = month(date)) %>% \n", " select(-date)\n", "\n", "\n", "\n", "# Create a new column for average Price\n", "pumpkins <- pumpkins %>% \n", " mutate(price = (low_price + high_price)/2)\n", "\n", "\n", "# Retain only pumpkins with the string \"bushel\"\n", "new_pumpkins <- pumpkins %>% \n", " filter(str_detect(string = package, pattern = \"bushel\"))\n", "\n", "\n", "# Normalize the pricing so that you show the pricing per bushel, not per 1 1/9 or 1/2 bushel\n", "new_pumpkins <- new_pumpkins %>% \n", " mutate(price = case_when(\n", " str_detect(package, \"1 1/9\") ~ price/(1.1),\n", " str_detect(package, \"1/2\") ~ price*2,\n", " TRUE ~ price))\n", "\n", "# Relocate column positions\n", "new_pumpkins <- new_pumpkins %>% \n", " relocate(month, .before = variety)\n", "\n", "\n", "# Display the first 5 rows\n", "new_pumpkins %>% \n", " slice_head(n = 5)" ], "outputs": [], "metadata": { "id": "X0wU3gQvtd9f" } }, { "cell_type": "markdown", "source": [ "干得好!👌 你现在拥有一个干净整洁的数据集,可以用来构建新的回归模型!\n", "\n", "画个散点图怎么样?\n" ], "metadata": { "id": "UpaIwaxqth82" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Set theme\n", "theme_set(theme_light())\n", "\n", "# Make a scatter plot of month and price\n", "new_pumpkins %>% \n", " ggplot(mapping = aes(x = month, y = price)) +\n", " geom_point(size = 1.6)\n" ], "outputs": [], "metadata": { "id": "DXgU-j37tl5K" } }, { "cell_type": "markdown", "source": [ "散点图提醒我们,我们只有从八月到十二月的月度数据。我们可能需要更多的数据才能以线性方式得出结论。\n", "\n", "让我们再看看我们的建模数据:\n" ], "metadata": { "id": "Ve64wVbwtobI" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Display first 5 rows\n", "new_pumpkins %>% \n", " slice_head(n = 5)" ], "outputs": [], "metadata": { "id": "HFQX2ng1tuSJ" } }, { "cell_type": "markdown", "source": [ "如果我们想根据`city`或`package`列(它们是字符类型)来预测南瓜的`price`,该怎么办?或者更简单地说,我们如何找到`package`和`price`之间的相关性(这要求两个输入都为数值类型)呢?🤷🤷\n", "\n", "机器学习模型在处理数值特征时效果最佳,而不是文本值,因此通常需要将分类特征转换为数值表示。\n", "\n", "这意味着我们需要找到一种方法来重新格式化我们的预测变量,使其更容易被模型有效利用,这个过程被称为`特征工程`。\n" ], "metadata": { "id": "7hsHoxsStyjJ" } }, { "cell_type": "markdown", "source": [ "## 3. 为建模预处理数据,使用 recipes 👩‍🍳👨‍🍳\n", "\n", "将预测变量重新格式化以便模型更有效使用的活动被称为`特征工程`。\n", "\n", "不同的模型对数据预处理有不同的要求。例如,最小二乘法需要对`分类变量`(如月份、品种和城市名称)进行`编码`。这通常涉及将包含`分类值`的列`转换`为一个或多个`数值列`,以替代原始列。\n", "\n", "例如,假设你的数据包含以下分类特征:\n", "\n", "| city |\n", "|:-------:|\n", "| Denver |\n", "| Nairobi |\n", "| Tokyo |\n", "\n", "你可以应用*序数编码*,为每个类别替换一个唯一的整数值,如下所示:\n", "\n", "| city |\n", "|:----:|\n", "| 0 |\n", "| 1 |\n", "| 2 |\n", "\n", "这就是我们将对数据进行的操作!\n", "\n", "在本节中,我们将探索另一个令人惊叹的 Tidymodels 包:[recipes](https://tidymodels.github.io/recipes/) - 它专为在训练模型**之前**帮助你预处理数据而设计。本质上,recipe 是一个对象,用于定义对数据集应用哪些步骤,以使其为建模做好准备。\n", "\n", "现在,让我们创建一个 recipe,通过为预测变量列中的所有观测值替换唯一整数,来为建模准备数据:\n" ], "metadata": { "id": "AD5kQbcvt3Xl" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Specify a recipe\n", "pumpkins_recipe <- recipe(price ~ ., data = new_pumpkins) %>% \n", " step_integer(all_predictors(), zero_based = TRUE)\n", "\n", "\n", "# Print out the recipe\n", "pumpkins_recipe" ], "outputs": [], "metadata": { "id": "BNaFKXfRt9TU" } }, { "cell_type": "markdown", "source": [ "太棒了!👏 我们刚刚创建了第一个配方,它指定了一个结果(价格)及其对应的预测变量,并且所有预测变量列都被编码为一组整数 🙌!让我们快速分解一下:\n", "\n", "- 调用 `recipe()` 并使用公式告诉配方变量的*角色*,以 `new_pumpkins` 数据作为参考。例如,`price` 列被分配了 `outcome` 角色,而其余列被分配了 `predictor` 角色。\n", "\n", "- `step_integer(all_predictors(), zero_based = TRUE)` 指定所有预测变量都应转换为一组整数,编号从 0 开始。\n", "\n", "我们相信你可能会有这样的想法:“这太酷了!!但如果我需要确认这些配方确实按照我的预期在工作怎么办?🤔”\n", "\n", "这是一个很棒的想法!你看,一旦定义了配方,你可以估算实际预处理数据所需的参数,然后提取处理后的数据。通常在使用 Tidymodels 时不需要这样做(我们稍后会看到常规方法 -> `workflows`),但当你想进行某种合理性检查以确认配方是否按预期工作时,这会非常有用。\n", "\n", "为此,你需要两个额外的动词:`prep()` 和 `bake()`。一如既往,我们的小 R 朋友由 [`Allison Horst`](https://github.com/allisonhorst/stats-illustrations) 创作,帮助你更好地理解这一点!\n", "\n", "

\n", " \n", "

插图作者:@allison_horst
\n" ], "metadata": { "id": "KEiO0v7kuC9O" } }, { "cell_type": "markdown", "source": [ "[`prep()`](https://recipes.tidymodels.org/reference/prep.html):从训练集估算所需参数,这些参数可以稍后应用于其他数据集。例如,对于给定的预测变量列,哪个观测值会被分配为整数 0、1、2 等。\n", "\n", "[`bake()`](https://recipes.tidymodels.org/reference/bake.html):使用已准备好的配方并将操作应用于任何数据集。\n", "\n", "话虽如此,让我们准备并应用配方,真正确认在底层,预测变量列会先被编码,然后再拟合模型。\n" ], "metadata": { "id": "Q1xtzebuuTCP" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Prep the recipe\n", "pumpkins_prep <- prep(pumpkins_recipe)\n", "\n", "# Bake the recipe to extract a preprocessed new_pumpkins data\n", "baked_pumpkins <- bake(pumpkins_prep, new_data = NULL)\n", "\n", "# Print out the baked data set\n", "baked_pumpkins %>% \n", " slice_head(n = 10)" ], "outputs": [], "metadata": { "id": "FGBbJbP_uUUn" } }, { "cell_type": "markdown", "source": [ "哇哦!🥳 处理后的数据 `baked_pumpkins` 的所有预测变量都已编码,这确认了我们定义的预处理步骤(作为我们的配方)确实可以如预期般工作。这虽然让数据更难阅读,但对 Tidymodels 来说却更加易于理解!花点时间找出哪些观测值已被映射到对应的整数。\n", "\n", "另外值得一提的是,`baked_pumpkins` 是一个数据框,我们可以在其上进行计算。\n", "\n", "例如,我们可以尝试在数据中的两个点之间找到一个良好的相关性,以便可能构建一个优秀的预测模型。我们将使用函数 `cor()` 来完成此操作。输入 `?cor()` 以了解更多关于该函数的信息。\n" ], "metadata": { "id": "1dvP0LBUueAW" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Find the correlation between the city_name and the price\n", "cor(baked_pumpkins$city_name, baked_pumpkins$price)\n", "\n", "# Find the correlation between the package and the price\n", "cor(baked_pumpkins$package, baked_pumpkins$price)\n" ], "outputs": [], "metadata": { "id": "3bQzXCjFuiSV" } }, { "cell_type": "markdown", "source": [ "事实证明,城市和价格之间的相关性较弱。然而,套餐和价格之间的相关性稍强一些。这很合理,对吧?通常来说,生产箱越大,价格越高。\n", "\n", "既然如此,我们也可以尝试使用 `corrplot` 包来可视化所有列的相关性矩阵。\n" ], "metadata": { "id": "BToPWbgjuoZw" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Load the corrplot package\n", "library(corrplot)\n", "\n", "# Obtain correlation matrix\n", "corr_mat <- cor(baked_pumpkins %>% \n", " # Drop columns that are not really informative\n", " select(-c(low_price, high_price)))\n", "\n", "# Make a correlation plot between the variables\n", "corrplot(corr_mat, method = \"shade\", shade.col = NA, tl.col = \"black\", tl.srt = 45, addCoef.col = \"black\", cl.pos = \"n\", order = \"original\")" ], "outputs": [], "metadata": { "id": "ZwAL3ksmutVR" } }, { "cell_type": "markdown", "source": [ "🤩🤩 好得多。\n", "\n", "现在可以问这个数据的一个好问题是:'`给定一个南瓜包,我可以预期它的价格是多少?`' 让我们直接开始吧!\n", "\n", "> 注意:当你使用 **`new_data = NULL`** 对预处理过的配方 **`pumpkins_prep`** 进行 **`bake()`** 时,你会提取处理过的(即编码后的)训练数据。如果你有另一个数据集,例如测试集,并希望查看配方如何对其进行预处理,你只需使用 **`new_data = test_set`** 对 **`pumpkins_prep`** 进行 bake。\n", "\n", "## 4. 构建线性回归模型\n", "\n", "

\n", " \n", "

Dasani Madipalli 制作的信息图
\n", "\n", "\n", "\n" ], "metadata": { "id": "YqXjLuWavNxW" } }, { "cell_type": "markdown", "source": [ "现在我们已经构建了一个配方,并确认数据将被适当预处理,接下来让我们构建一个回归模型来回答这个问题:`我可以预期某个南瓜包装的价格是多少?`\n", "\n", "#### 使用训练集训练线性回归模型\n", "\n", "正如你可能已经猜到的,*price* 列是 `结果` 变量,而 *package* 列是 `预测` 变量。\n", "\n", "为此,我们首先将数据分割为训练集(占80%)和测试集(占20%),然后定义一个配方,将预测变量列编码为一组整数,接着构建一个模型规范。我们不会准备和烘焙配方,因为我们已经知道它会按预期预处理数据。\n" ], "metadata": { "id": "Pq0bSzCevW-h" } }, { "cell_type": "code", "execution_count": null, "source": [ "set.seed(2056)\n", "# Split the data into training and test sets\n", "pumpkins_split <- new_pumpkins %>% \n", " initial_split(prop = 0.8)\n", "\n", "\n", "# Extract training and test data\n", "pumpkins_train <- training(pumpkins_split)\n", "pumpkins_test <- testing(pumpkins_split)\n", "\n", "\n", "\n", "# Create a recipe for preprocessing the data\n", "lm_pumpkins_recipe <- recipe(price ~ package, data = pumpkins_train) %>% \n", " step_integer(all_predictors(), zero_based = TRUE)\n", "\n", "\n", "\n", "# Create a linear model specification\n", "lm_spec <- linear_reg() %>% \n", " set_engine(\"lm\") %>% \n", " set_mode(\"regression\")" ], "outputs": [], "metadata": { "id": "CyoEh_wuvcLv" } }, { "cell_type": "markdown", "source": [ "干得好!现在我们已经有了一个配方和模型规范,我们需要找到一种方法将它们打包成一个对象,该对象将首先对数据进行预处理(幕后完成prep+bake),然后在预处理后的数据上拟合模型,同时还支持潜在的后处理活动。这是不是让你更安心了!🤩\n", "\n", "在Tidymodels中,这个方便的对象叫做[`workflow`](https://workflows.tidymodels.org/),它可以方便地包含你的建模组件!这在*Python*中我们称之为*管道*。\n", "\n", "那么,让我们把所有东西打包到一个workflow中吧!📦\n" ], "metadata": { "id": "G3zF_3DqviFJ" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Hold modelling components in a workflow\n", "lm_wf <- workflow() %>% \n", " add_recipe(lm_pumpkins_recipe) %>% \n", " add_model(lm_spec)\n", "\n", "# Print out the workflow\n", "lm_wf" ], "outputs": [], "metadata": { "id": "T3olroU3v-WX" } }, { "cell_type": "markdown", "source": [ "顺便提一下,工作流程可以像模型一样进行适配或训练。\n" ], "metadata": { "id": "zd1A5tgOwEPX" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Train the model\n", "lm_wf_fit <- lm_wf %>% \n", " fit(data = pumpkins_train)\n", "\n", "# Print the model coefficients learned \n", "lm_wf_fit" ], "outputs": [], "metadata": { "id": "NhJagFumwFHf" } }, { "cell_type": "markdown", "source": [ "从模型输出中,我们可以看到训练过程中学习到的系数。它们表示最佳拟合线的系数,该线使实际变量与预测变量之间的总体误差最小化。\n", "\n", "#### 使用测试集评估模型性能\n", "\n", "是时候看看模型的表现了 📏!我们该怎么做呢?\n", "\n", "现在我们已经训练了模型,可以使用 `parsnip::predict()` 对测试集进行预测。然后,我们可以将这些预测值与实际标签值进行比较,以评估模型的效果(好或不好)。\n", "\n", "让我们从对测试集进行预测开始,然后将预测结果与测试集绑定在一起。\n" ], "metadata": { "id": "_4QkGtBTwItF" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Make predictions for the test set\n", "predictions <- lm_wf_fit %>% \n", " predict(new_data = pumpkins_test)\n", "\n", "\n", "# Bind predictions to the test set\n", "lm_results <- pumpkins_test %>% \n", " select(c(package, price)) %>% \n", " bind_cols(predictions)\n", "\n", "\n", "# Print the first ten rows of the tibble\n", "lm_results %>% \n", " slice_head(n = 10)" ], "outputs": [], "metadata": { "id": "UFZzTG0gwTs9" } }, { "cell_type": "markdown", "source": [ "是的,你刚刚训练了一个模型并用它进行了预测!🔮 它表现如何呢?让我们来评估模型的性能吧!\n", "\n", "在Tidymodels中,我们使用 `yardstick::metrics()` 来完成这一任务!对于线性回归,我们重点关注以下指标:\n", "\n", "- `均方根误差 (RMSE)`:即[均方误差 (MSE)](https://en.wikipedia.org/wiki/Mean_squared_error)的平方根。它提供了一个绝对指标,单位与标签一致(在这个例子中是南瓜的价格)。值越小,模型越好(简单来说,它表示预测值平均偏差的价格范围)。\n", "\n", "- `决定系数(通常称为R平方或R2)`:一个相对指标,值越高,模型拟合效果越好。实际上,这个指标表示模型能够解释预测值与实际标签值之间方差的程度。\n" ], "metadata": { "id": "0A5MjzM7wW9M" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Evaluate performance of linear regression\n", "metrics(data = lm_results,\n", " truth = price,\n", " estimate = .pred)" ], "outputs": [], "metadata": { "id": "reJ0UIhQwcEH" } }, { "cell_type": "markdown", "source": [ "模型性能下降了。让我们通过可视化包裹和价格的散点图来看看是否能获得更好的指示,然后使用预测结果叠加一条最佳拟合线。\n", "\n", "这意味着我们需要准备并处理测试集,以便对包裹列进行编码,然后将其与模型生成的预测结果绑定在一起。\n" ], "metadata": { "id": "fdgjzjkBwfWt" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Encode package column\n", "package_encode <- lm_pumpkins_recipe %>% \n", " prep() %>% \n", " bake(new_data = pumpkins_test) %>% \n", " select(package)\n", "\n", "\n", "# Bind encoded package column to the results\n", "lm_results <- lm_results %>% \n", " bind_cols(package_encode %>% \n", " rename(package_integer = package)) %>% \n", " relocate(package_integer, .after = package)\n", "\n", "\n", "# Print new results data frame\n", "lm_results %>% \n", " slice_head(n = 5)\n", "\n", "\n", "# Make a scatter plot\n", "lm_results %>% \n", " ggplot(mapping = aes(x = package_integer, y = price)) +\n", " geom_point(size = 1.6) +\n", " # Overlay a line of best fit\n", " geom_line(aes(y = .pred), color = \"orange\", size = 1.2) +\n", " xlab(\"package\")\n", " \n" ], "outputs": [], "metadata": { "id": "R0nw719lwkHE" } }, { "cell_type": "markdown", "source": [ "很棒!正如你所看到的,线性回归模型并不能很好地概括包裹与其对应价格之间的关系。\n", "\n", "🎃 恭喜你,你刚刚创建了一个可以帮助预测几种南瓜价格的模型。你的节日南瓜田会非常漂亮。但你可能可以创建一个更好的模型!\n", "\n", "## 5. 构建一个多项式回归模型\n", "\n", "

\n", " \n", "

信息图由 Dasani Madipalli 制作
\n", "\n", "\n", "\n" ], "metadata": { "id": "HOCqJXLTwtWI" } }, { "cell_type": "markdown", "source": [ "有时候,我们的数据可能并不存在线性关系,但我们仍然希望预测结果。这时,多项式回归可以帮助我们对更复杂的非线性关系进行预测。\n", "\n", "以我们的南瓜数据集中的包装和价格关系为例。虽然有时变量之间存在线性关系——比如南瓜的体积越大,价格越高——但有时这些关系无法用一个平面或直线来表示。\n", "\n", "> ✅ 这里有[更多使用多项式回归的数据示例](https://online.stat.psu.edu/stat501/lesson/9/9.8)\n", ">\n", "> 再次看看之前图中品种与价格的关系。这个散点图看起来是否一定应该用一条直线来分析?可能并不是。在这种情况下,你可以尝试使用多项式回归。\n", ">\n", "> ✅ 多项式是可能包含一个或多个变量和系数的数学表达式\n", "\n", "#### 使用训练集训练一个多项式回归模型\n", "\n", "多项式回归会创建一条*曲线*,以更好地拟合非线性数据。\n", "\n", "让我们看看多项式模型是否能在预测中表现得更好。我们将遵循与之前类似的步骤:\n", "\n", "- 创建一个配方,指定对数据进行建模前需要执行的预处理步骤,例如:对预测变量进行编码并计算次数为 *n* 的多项式\n", "\n", "- 构建一个模型规范\n", "\n", "- 将配方和模型规范打包到一个工作流中\n", "\n", "- 通过拟合工作流来创建模型\n", "\n", "- 评估模型在测试数据上的表现\n", "\n", "让我们开始吧!\n" ], "metadata": { "id": "VcEIpRV9wzYr" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Specify a recipe\r\n", "poly_pumpkins_recipe <-\r\n", " recipe(price ~ package, data = pumpkins_train) %>%\r\n", " step_integer(all_predictors(), zero_based = TRUE) %>% \r\n", " step_poly(all_predictors(), degree = 4)\r\n", "\r\n", "\r\n", "# Create a model specification\r\n", "poly_spec <- linear_reg() %>% \r\n", " set_engine(\"lm\") %>% \r\n", " set_mode(\"regression\")\r\n", "\r\n", "\r\n", "# Bundle recipe and model spec into a workflow\r\n", "poly_wf <- workflow() %>% \r\n", " add_recipe(poly_pumpkins_recipe) %>% \r\n", " add_model(poly_spec)\r\n", "\r\n", "\r\n", "# Create a model\r\n", "poly_wf_fit <- poly_wf %>% \r\n", " fit(data = pumpkins_train)\r\n", "\r\n", "\r\n", "# Print learned model coefficients\r\n", "poly_wf_fit\r\n", "\r\n", " " ], "outputs": [], "metadata": { "id": "63n_YyRXw3CC" } }, { "cell_type": "markdown", "source": [ "#### 评估模型性能\n", "\n", "👏👏你已经构建了一个多项式模型,现在让我们在测试集上进行预测吧!\n" ], "metadata": { "id": "-LHZtztSxDP0" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Make price predictions on test data\r\n", "poly_results <- poly_wf_fit %>% predict(new_data = pumpkins_test) %>% \r\n", " bind_cols(pumpkins_test %>% select(c(package, price))) %>% \r\n", " relocate(.pred, .after = last_col())\r\n", "\r\n", "\r\n", "# Print the results\r\n", "poly_results %>% \r\n", " slice_head(n = 10)" ], "outputs": [], "metadata": { "id": "YUFpQ_dKxJGx" } }, { "cell_type": "markdown", "source": [ "Woo-hoo,让我们使用 `yardstick::metrics()` 来评估模型在 test_set 上的表现。\n" ], "metadata": { "id": "qxdyj86bxNGZ" } }, { "cell_type": "code", "execution_count": null, "source": [ "metrics(data = poly_results, truth = price, estimate = .pred)" ], "outputs": [], "metadata": { "id": "8AW5ltkBxXDm" } }, { "cell_type": "markdown", "source": [ "🤩🤩 表现更出色。\n", "\n", "`rmse` 从大约 7 降至大约 3,这表明实际价格与预测价格之间的误差减少了。你可以*粗略地*理解为平均而言,错误预测的误差大约为 \\$3。`rsq` 从大约 0.4 增加到 0.8。\n", "\n", "所有这些指标都表明多项式模型的表现远优于线性模型。干得好!\n", "\n", "让我们看看是否可以将其可视化!\n" ], "metadata": { "id": "6gLHNZDwxYaS" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Bind encoded package column to the results\r\n", "poly_results <- poly_results %>% \r\n", " bind_cols(package_encode %>% \r\n", " rename(package_integer = package)) %>% \r\n", " relocate(package_integer, .after = package)\r\n", "\r\n", "\r\n", "# Print new results data frame\r\n", "poly_results %>% \r\n", " slice_head(n = 5)\r\n", "\r\n", "\r\n", "# Make a scatter plot\r\n", "poly_results %>% \r\n", " ggplot(mapping = aes(x = package_integer, y = price)) +\r\n", " geom_point(size = 1.6) +\r\n", " # Overlay a line of best fit\r\n", " geom_line(aes(y = .pred), color = \"midnightblue\", size = 1.2) +\r\n", " xlab(\"package\")\r\n" ], "outputs": [], "metadata": { "id": "A83U16frxdF1" } }, { "cell_type": "markdown", "source": [ "您可以看到一条更符合您数据的曲线!🤩\n", "\n", "您可以通过向 `geom_smooth` 传递一个多项式公式,使其更加平滑,如下所示:\n" ], "metadata": { "id": "4U-7aHOVxlGU" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Make a scatter plot\r\n", "poly_results %>% \r\n", " ggplot(mapping = aes(x = package_integer, y = price)) +\r\n", " geom_point(size = 1.6) +\r\n", " # Overlay a line of best fit\r\n", " geom_smooth(method = lm, formula = y ~ poly(x, degree = 4), color = \"midnightblue\", size = 1.2, se = FALSE) +\r\n", " xlab(\"package\")" ], "outputs": [], "metadata": { "id": "5vzNT0Uexm-w" } }, { "cell_type": "markdown", "source": [ "就像一条平滑的曲线!🤩\n", "\n", "以下是如何进行新的预测:\n" ], "metadata": { "id": "v9u-wwyLxq4G" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Make a hypothetical data frame\r\n", "hypo_tibble <- tibble(package = \"bushel baskets\")\r\n", "\r\n", "# Make predictions using linear model\r\n", "lm_pred <- lm_wf_fit %>% predict(new_data = hypo_tibble)\r\n", "\r\n", "# Make predictions using polynomial model\r\n", "poly_pred <- poly_wf_fit %>% predict(new_data = hypo_tibble)\r\n", "\r\n", "# Return predictions in a list\r\n", "list(\"linear model prediction\" = lm_pred, \r\n", " \"polynomial model prediction\" = poly_pred)\r\n" ], "outputs": [], "metadata": { "id": "jRPSyfQGxuQv" } }, { "cell_type": "markdown", "source": [ "`多项式模型`的预测是合理的,结合`价格`和`包装`的散点图来看确实如此!而且,如果这个模型比之前的模型更好,那么根据相同的数据,你需要为这些更贵的南瓜做好预算!\n", "\n", "🏆 干得好!你在一节课中创建了两个回归模型。在回归的最后一部分,你将学习逻辑回归以确定类别。\n", "\n", "## **🚀挑战**\n", "\n", "在这个笔记本中测试几个不同的变量,看看相关性如何影响模型的准确性。\n", "\n", "## [**课后测验**](https://gray-sand-07a10f403.1.azurestaticapps.net/quiz/14/)\n", "\n", "## **复习与自学**\n", "\n", "在本课中我们学习了线性回归。还有其他重要的回归类型。阅读关于逐步回归、岭回归、套索回归和弹性网络技术的内容。一个很好的课程是[斯坦福统计学习课程](https://online.stanford.edu/courses/sohs-ystatslearning-statistical-learning)。\n", "\n", "如果你想了解更多关于如何使用出色的Tidymodels框架,请查看以下资源:\n", "\n", "- Tidymodels官网:[Tidymodels入门](https://www.tidymodels.org/start/)\n", "\n", "- Max Kuhn 和 Julia Silge,[*Tidy Modeling with R*](https://www.tmwr.org/)*.*\n", "\n", "###### **特别感谢:**\n", "\n", "[Allison Horst](https://twitter.com/allison_horst?lang=en) 创作了令人惊叹的插图,使R语言更加友好和吸引人。可以在她的[画廊](https://www.google.com/url?q=https://github.com/allisonhorst/stats-illustrations&sa=D&source=editors&ust=1626380772530000&usg=AOvVaw3zcfyCizFQZpkSLzxiiQEM)中找到更多插图。\n" ], "metadata": { "id": "8zOLOWqMxzk5" } }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n---\n\n**免责声明**: \n本文档使用AI翻译服务[Co-op Translator](https://github.com/Azure/co-op-translator)进行翻译。尽管我们努力确保翻译的准确性,但请注意,自动翻译可能包含错误或不准确之处。原始语言的文档应被视为权威来源。对于关键信息,建议使用专业人工翻译。我们对因使用此翻译而产生的任何误解或误读不承担责任。\n" ] } ] }