You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/translations/tw/2-Regression/3-Linear/solution/R/lesson_3-R.ipynb

1089 lines
40 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"colab": {
"name": "lesson_3-R.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "ir",
"display_name": "R"
},
"language_info": {
"name": "R"
},
"coopTranslator": {
"original_hash": "5015d65d61ba75a223bfc56c273aa174",
"translation_date": "2025-09-03T19:29:48+00:00",
"source_file": "2-Regression/3-Linear/solution/R/lesson_3-R.ipynb",
"language_code": "tw"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "EgQw8osnsUV-"
}
},
{
"cell_type": "markdown",
"source": [
"## 南瓜定價的線性回歸與多項式回歸 - 第三課\n",
"<p >\n",
" <img src=\"../../images/linear-polynomial.png\"\n",
" width=\"800\"/>\n",
" <figcaption>Dasani Madipalli 的資訊圖表</figcaption>\n",
"\n",
"\n",
"<!--![Dasani Madipalli 的資訊圖表](../../../../../../translated_images/linear-polynomial.5523c7cb6576ccab0fecbd0e3505986eb2d191d9378e785f82befcf3a578a6e7.tw.png){width=\"800\"}-->\n",
"\n",
"#### 簡介\n",
"\n",
"到目前為止,您已經探索了回歸的概念,並使用我們將在本課中使用的南瓜定價數據集進行了示例分析。您還使用 `ggplot2` 將其可視化。💪\n",
"\n",
"現在,您準備深入了解機器學習中的回歸。在本課中,您將學習兩種回歸方法:*基本線性回歸* 和 *多項式回歸*,以及這些技術背後的一些數學原理。\n",
"\n",
"> 在整個課程中,我們假設學生的數學知識有限,並努力使內容對來自其他領域的學生更易於理解。因此,請留意筆記、🧮 提示、圖表以及其他學習工具,幫助您掌握內容。\n",
"\n",
"#### 準備工作\n",
"\n",
"提醒一下,您正在載入這些數據以便提出問題。\n",
"\n",
"- 什麼時候是購買南瓜的最佳時機?\n",
"\n",
"- 一箱迷你南瓜的價格大概是多少?\n",
"\n",
"- 我應該選擇半蒲式耳籃子還是 1 1/9 蒲式耳箱來購買?讓我們繼續深入挖掘這些數據。\n",
"\n",
"在上一課中,您創建了一個 `tibble`(一種現代化的數據框架),並用部分原始數據集填充它,將價格標準化為以蒲式耳計算。然而,通過這樣做,您只能收集到大約 400 個數據點,而且僅限於秋季月份。也許通過進一步清理數據,我們可以獲得更多細節?我們拭目以待... 🕵️‍♀️\n",
"\n",
"完成此任務,我們需要以下套件:\n",
"\n",
"- `tidyverse`: [tidyverse](https://www.tidyverse.org/) 是一個 [R 套件集合](https://www.tidyverse.org/packages),旨在使數據科學更快速、更簡單、更有趣!\n",
"\n",
"- `tidymodels`: [tidymodels](https://www.tidymodels.org/) 框架是一個 [套件集合](https://www.tidymodels.org/packages),用於建模和機器學習。\n",
"\n",
"- `janitor`: [janitor 套件](https://github.com/sfirke/janitor) 提供了一些簡單的工具,用於檢查和清理髒數據。\n",
"\n",
"- `corrplot`: [corrplot 套件](https://cran.r-project.org/web/packages/corrplot/vignettes/corrplot-intro.html) 提供了一種可視化探索相關矩陣的工具,支持自動變量重新排序,以幫助檢測變量之間的隱藏模式。\n",
"\n",
"您可以通過以下方式安裝它們:\n",
"\n",
"`install.packages(c(\"tidyverse\", \"tidymodels\", \"janitor\", \"corrplot\"))`\n",
"\n",
"下面的腳本會檢查您是否已安裝完成此模組所需的套件,並在缺少時為您安裝。\n"
],
"metadata": {
"id": "WqQPS1OAsg3H"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"suppressWarnings(if (!require(\"pacman\")) install.packages(\"pacman\"))\n",
"\n",
"pacman::p_load(tidyverse, tidymodels, janitor, corrplot)"
],
"outputs": [],
"metadata": {
"id": "tA4C2WN3skCf",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c06cd805-5534-4edc-f72b-d0d1dab96ac0"
}
},
{
"cell_type": "markdown",
"source": [
"我們稍後會載入這些超棒的套件,並在我們目前的 R 工作環境中使用它們。(這只是為了說明,`pacman::p_load()` 已經幫你完成了這件事)\n",
"\n",
"## 1. 線性回歸線\n",
"\n",
"正如你在第一課中學到的,線性回歸的目標是能夠繪製一條*最佳擬合線*,以便:\n",
"\n",
"- **展示變數之間的關係**。顯示變數之間的關聯性\n",
"\n",
"- **進行預測**。準確預測新數據點在這條線上的位置\n",
"\n",
"為了繪製這種類型的線,我們使用一種統計技術,稱為**最小平方回歸**。`最小平方`的意思是所有圍繞回歸線的數據點的誤差值都被平方後加總。理想情況下,最終的總和應該越小越好,因為我們希望誤差值越低越好,也就是`最小平方`。因此,最佳擬合線就是使平方誤差總和最小的那條線,這也就是*最小平方回歸*的名稱由來。\n",
"\n",
"我們這樣做是因為我們希望建模出一條與所有數據點的累積距離最小的線。我們在加總之前將誤差值平方,因為我們關注的是誤差的大小,而不是方向。\n",
"\n",
"> **🧮 數學公式**\n",
">\n",
"> 這條線,稱為*最佳擬合線*,可以用[一個公式](https://en.wikipedia.org/wiki/Simple_linear_regression)表示:\n",
">\n",
"> Y = a + bX\n",
">\n",
"> `X` 是`解釋變數`或`預測變數``Y` 是`依變數`或`結果變數`。線的斜率是 `b`,而 `a` 是 y 截距,指的是當 `X = 0` 時 `Y` 的值。\n",
">\n",
"\n",
"> ![](../../../../../../2-Regression/3-Linear/solution/images/slope.png \"slope = $y/x$\")\n",
" 圖解由 Jen Looper 提供\n",
">\n",
"> 首先,計算斜率 `b`。\n",
">\n",
"> 換句話說,參考我們南瓜數據的原始問題:「根據月份預測每蒲式耳南瓜的價格」,`X` 代表價格,`Y` 代表銷售月份。\n",
">\n",
"> ![](../../../../../../translated_images/calculation.989aa7822020d9d0ba9fc781f1ab5192f3421be86ebb88026528aef33c37b0d8.tw.png)\n",
" 圖解由 Jen Looper 提供\n",
"> \n",
"> 計算 Y 的值。如果你支付大約 \\$4那一定是四月\n",
">\n",
"> 計算這條線的數學公式必須展示線的斜率,這也取決於截距,或者當 `X = 0` 時 `Y` 的位置。\n",
">\n",
"> 你可以在 [Math is Fun](https://www.mathsisfun.com/data/least-squares-regression.html) 網站上觀察這些值的計算方法。也可以訪問[最小平方計算器](https://www.mathsisfun.com/data/least-squares-calculator.html),看看數值如何影響這條線。\n",
"\n",
"不那麼可怕吧?🤓\n",
"\n",
"#### 相關性\n",
"\n",
"還有一個需要理解的術語是給定 X 和 Y 變數之間的**相關係數**。使用散佈圖,你可以快速地可視化這個係數。數據點整齊排列成一條線的圖表具有高相關性,而數據點在 X 和 Y 之間隨意分佈的圖表則具有低相關性。\n",
"\n",
"一個好的線性回歸模型應該是使用最小平方回歸方法和回歸線,並且具有高(接近 1 而非 0的相關係數。\n"
],
"metadata": {
"id": "cdX5FRpvsoP5"
}
},
{
"cell_type": "markdown",
"source": [
"## **2. 與數據共舞:建立用於建模的數據框**\n",
"\n",
"<p >\n",
" <img src=\"../../images/janitor.jpg\"\n",
" width=\"700\"/>\n",
" <figcaption>插圖作者:@allison_horst</figcaption>\n",
"\n",
"\n",
"<!--![插圖作者:\\@allison_horst](../../../../../../translated_images/janitor.e4a77dd3d3e6a32e25327090b8a9c00dc7cf459c44fa9f184c5ecb0d48ce3794.tw.jpg){width=\"700\"}-->\n"
],
"metadata": {
"id": "WdUKXk7Bs8-V"
}
},
{
"cell_type": "markdown",
"source": [
"載入所需的庫和數據集。將數據轉換為包含數據子集的數據框:\n",
"\n",
"- 僅選擇以蒲式耳定價的南瓜\n",
"\n",
"- 將日期轉換為月份\n",
"\n",
"- 計算價格為高價和低價的平均值\n",
"\n",
"- 將價格轉換為反映按蒲式耳數量定價\n",
"\n",
"> 我們在[上一課](https://github.com/microsoft/ML-For-Beginners/blob/main/2-Regression/2-Data/solution/lesson_2-R.ipynb)中已經涵蓋了這些步驟。\n"
],
"metadata": {
"id": "fMCtu2G2s-p8"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Load the core Tidyverse packages\n",
"library(tidyverse)\n",
"library(lubridate)\n",
"\n",
"# Import the pumpkins data\n",
"pumpkins <- read_csv(file = \"https://raw.githubusercontent.com/microsoft/ML-For-Beginners/main/2-Regression/data/US-pumpkins.csv\")\n",
"\n",
"\n",
"# Get a glimpse and dimensions of the data\n",
"glimpse(pumpkins)\n",
"\n",
"\n",
"# Print the first 50 rows of the data set\n",
"pumpkins %>% \n",
" slice_head(n = 5)"
],
"outputs": [],
"metadata": {
"id": "ryMVZEEPtERn"
}
},
{
"cell_type": "markdown",
"source": [
"在純粹冒險的精神下,讓我們探索 [`janitor package`](../../../../../../2-Regression/3-Linear/solution/R/github.com/sfirke/janitor),它提供了簡單的函數來檢查和清理髒數據。例如,讓我們看看我們數據的列名稱:\n"
],
"metadata": {
"id": "xcNxM70EtJjb"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Return column names\n",
"pumpkins %>% \n",
" names()"
],
"outputs": [],
"metadata": {
"id": "5XtpaIigtPfW"
}
},
{
"cell_type": "markdown",
"source": [
"🤔 我們可以做得更好。讓我們通過使用 `janitor::clean_names` 將這些列名稱轉換為 [snake_case](https://en.wikipedia.org/wiki/Snake_case) 規範,並將它們命名為 `friendR`。要了解有關此函數的更多信息:`?clean_names`\n"
],
"metadata": {
"id": "IbIqrMINtSHe"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Clean names to the snake_case convention\n",
"pumpkins <- pumpkins %>% \n",
" clean_names(case = \"snake\")\n",
"\n",
"# Return column names\n",
"pumpkins %>% \n",
" names()"
],
"outputs": [],
"metadata": {
"id": "a2uYvclYtWvX"
}
},
{
"cell_type": "markdown",
"source": [
"非常整潔R 🧹!現在,像上一節課一樣,使用 `dplyr` 與數據共舞吧!💃\n"
],
"metadata": {
"id": "HfhnuzDDtaDd"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Select desired columns\n",
"pumpkins <- pumpkins %>% \n",
" select(variety, city_name, package, low_price, high_price, date)\n",
"\n",
"\n",
"\n",
"# Extract the month from the dates to a new column\n",
"pumpkins <- pumpkins %>%\n",
" mutate(date = mdy(date),\n",
" month = month(date)) %>% \n",
" select(-date)\n",
"\n",
"\n",
"\n",
"# Create a new column for average Price\n",
"pumpkins <- pumpkins %>% \n",
" mutate(price = (low_price + high_price)/2)\n",
"\n",
"\n",
"# Retain only pumpkins with the string \"bushel\"\n",
"new_pumpkins <- pumpkins %>% \n",
" filter(str_detect(string = package, pattern = \"bushel\"))\n",
"\n",
"\n",
"# Normalize the pricing so that you show the pricing per bushel, not per 1 1/9 or 1/2 bushel\n",
"new_pumpkins <- new_pumpkins %>% \n",
" mutate(price = case_when(\n",
" str_detect(package, \"1 1/9\") ~ price/(1.1),\n",
" str_detect(package, \"1/2\") ~ price*2,\n",
" TRUE ~ price))\n",
"\n",
"# Relocate column positions\n",
"new_pumpkins <- new_pumpkins %>% \n",
" relocate(month, .before = variety)\n",
"\n",
"\n",
"# Display the first 5 rows\n",
"new_pumpkins %>% \n",
" slice_head(n = 5)"
],
"outputs": [],
"metadata": {
"id": "X0wU3gQvtd9f"
}
},
{
"cell_type": "markdown",
"source": [
"做得好!👌 現在你擁有一個乾淨、整潔的數據集,可以用來建立新的回歸模型!\n",
"\n",
"要來個散佈圖嗎?\n"
],
"metadata": {
"id": "UpaIwaxqth82"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Set theme\n",
"theme_set(theme_light())\n",
"\n",
"# Make a scatter plot of month and price\n",
"new_pumpkins %>% \n",
" ggplot(mapping = aes(x = month, y = price)) +\n",
" geom_point(size = 1.6)\n"
],
"outputs": [],
"metadata": {
"id": "DXgU-j37tl5K"
}
},
{
"cell_type": "markdown",
"source": [
"散佈圖提醒我們,目前僅有八月至十二月的月份數據。我們可能需要更多數據,才能以線性方式得出結論。\n",
"\n",
"讓我們再次查看我們的建模數據:\n"
],
"metadata": {
"id": "Ve64wVbwtobI"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Display first 5 rows\n",
"new_pumpkins %>% \n",
" slice_head(n = 5)"
],
"outputs": [],
"metadata": {
"id": "HFQX2ng1tuSJ"
}
},
{
"cell_type": "markdown",
"source": [
"如果我們想根據屬於字符類型的 `city` 或 `package` 欄位來預測南瓜的 `price`,該怎麼做呢?或者更簡單地說,我們如何找到例如 `package` 和 `price` 之間的相關性(這需要兩個輸入都是數值型)?🤷🤷\n",
"\n",
"機器學習模型在處理數值特徵時效果最佳,而非文字值,因此通常需要將分類特徵轉換為數值表示。\n",
"\n",
"這意味著我們必須找到一種方法來重新格式化我們的預測變數,使其更容易被模型有效使用,這個過程被稱為 `特徵工程`。\n"
],
"metadata": {
"id": "7hsHoxsStyjJ"
}
},
{
"cell_type": "markdown",
"source": [
"## 3. 使用 recipes 👩‍🍳👨‍🍳 進行建模資料的預處理\n",
"\n",
"將預測變數重新格式化,使其更容易被模型有效使用的活動被稱為 `特徵工程`。\n",
"\n",
"不同的模型有不同的預處理需求。例如,最小平方法需要對 `類別型變數`(如月份、品種和城市名稱)進行 `編碼`。這通常涉及將包含 `類別值` 的欄位 `轉換` 為一個或多個 `數值欄位`,以取代原始欄位。\n",
"\n",
"舉例來說,假設你的資料包含以下類別型特徵:\n",
"\n",
"| 城市 |\n",
"|:-------:|\n",
"| 丹佛 |\n",
"| 奈洛比 |\n",
"| 東京 |\n",
"\n",
"你可以應用 *序列編碼*,為每個類別替換一個唯一的整數值,如下所示:\n",
"\n",
"| 城市 |\n",
"|:----:|\n",
"| 0 |\n",
"| 1 |\n",
"| 2 |\n",
"\n",
"這就是我們將對資料進行的操作!\n",
"\n",
"在本節中,我們將探索另一個令人驚嘆的 Tidymodels 套件:[recipes](https://tidymodels.github.io/recipes/) - 它專門用於在訓練模型之前幫助你預處理資料。從本質上來說recipe 是一個物件,用於定義應該對資料集應用哪些步驟,以使其準備好進行建模。\n",
"\n",
"現在,讓我們建立一個 recipe通過為預測欄位中的所有觀測值替換唯一的整數來準備資料進行建模\n"
],
"metadata": {
"id": "AD5kQbcvt3Xl"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Specify a recipe\n",
"pumpkins_recipe <- recipe(price ~ ., data = new_pumpkins) %>% \n",
" step_integer(all_predictors(), zero_based = TRUE)\n",
"\n",
"\n",
"# Print out the recipe\n",
"pumpkins_recipe"
],
"outputs": [],
"metadata": {
"id": "BNaFKXfRt9TU"
}
},
{
"cell_type": "markdown",
"source": [
"太棒了!👏 我們剛剛創建了第一個食譜,指定了一個結果(價格)及其對應的預測變數,並將所有的預測變數列編碼為一組整數 🙌!讓我們快速解析一下:\n",
"\n",
"- 使用公式調用 `recipe()`,告訴食譜變數的*角色*,並以 `new_pumpkins` 數據作為參考。例如,`price` 列被分配為 `outcome`(結果)角色,而其餘的列則被分配為 `predictor`(預測變數)角色。\n",
"\n",
"- `step_integer(all_predictors(), zero_based = TRUE)` 指定所有的預測變數應該被轉換為一組整數,且編號從 0 開始。\n",
"\n",
"我們相信你可能會有這樣的想法:「這太酷了!!但如果我需要確認這些食譜確實按照我的預期運作怎麼辦?🤔」\n",
"\n",
"這是一個很棒的想法!你看,一旦你的食譜被定義好,你可以估算實際預處理數據所需的參數,然後提取處理後的數據。當你使用 Tidymodels 時,通常不需要這樣做(我們稍後會看到常規用法 -\\> `workflows`),但當你想進行某種檢查以確認食譜是否按照預期運作時,這會非常有用。\n",
"\n",
"為此,你需要另外兩個動詞:`prep()` 和 `bake()`。一如既往,我們的小 R 朋友們由 [`Allison Horst`](https://github.com/allisonhorst/stats-illustrations) 創作,幫助你更好地理解這一點!\n",
"\n",
"<p >\n",
" <img src=\"../../images/recipes.png\"\n",
" width=\"550\"/>\n",
" <figcaption>插圖由 @allison_horst 提供</figcaption>\n",
"\n",
"\n",
"<!--![插圖由 \\@allison_horst 提供](../../../../../../translated_images/recipes.9ad10d8a4056bf89413fc33644924e0bd29d7c12fb2154e03a1ca3d2d6ea9323.tw.png){width=\"550\"}-->\n"
],
"metadata": {
"id": "KEiO0v7kuC9O"
}
},
{
"cell_type": "markdown",
"source": [
"[`prep()`](https://recipes.tidymodels.org/reference/prep.html):從訓練集估算所需的參數,這些參數之後可以應用到其他數據集。例如,對於給定的預測變數列,哪個觀測值會被分配為整數 0、1、2 等。\n",
"\n",
"[`bake()`](https://recipes.tidymodels.org/reference/bake.html):使用已準備好的 recipe並將操作應用到任何數據集上。\n",
"\n",
"話雖如此,讓我們準備並執行 recipe來真正確認在底層預測變數列會在模型擬合之前先被編碼。\n"
],
"metadata": {
"id": "Q1xtzebuuTCP"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Prep the recipe\n",
"pumpkins_prep <- prep(pumpkins_recipe)\n",
"\n",
"# Bake the recipe to extract a preprocessed new_pumpkins data\n",
"baked_pumpkins <- bake(pumpkins_prep, new_data = NULL)\n",
"\n",
"# Print out the baked data set\n",
"baked_pumpkins %>% \n",
" slice_head(n = 10)"
],
"outputs": [],
"metadata": {
"id": "FGBbJbP_uUUn"
}
},
{
"cell_type": "markdown",
"source": [
"哇哦!🥳 處理過的數據 `baked_pumpkins` 的所有預測變數都已編碼,這確認了我們所定義的預處理步驟(作為我們的配方)確實能如預期運作。雖然這讓你閱讀起來更困難,但對 Tidymodels 來說卻更加易於理解!花點時間找出哪個觀測值已被映射到相應的整數。\n",
"\n",
"另外值得一提的是,`baked_pumpkins` 是一個可以進行計算的數據框。\n",
"\n",
"例如,我們可以嘗試在你的數據中找到兩個點之間的良好相關性,以便可能構建一個好的預測模型。我們將使用函數 `cor()` 來完成這項工作。輸入 `?cor()` 來了解更多關於該函數的信息。\n"
],
"metadata": {
"id": "1dvP0LBUueAW"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Find the correlation between the city_name and the price\n",
"cor(baked_pumpkins$city_name, baked_pumpkins$price)\n",
"\n",
"# Find the correlation between the package and the price\n",
"cor(baked_pumpkins$package, baked_pumpkins$price)\n"
],
"outputs": [],
"metadata": {
"id": "3bQzXCjFuiSV"
}
},
{
"cell_type": "markdown",
"source": [
"事實證明,城市與價格之間的相關性相當弱。然而,包裝與其價格之間的相關性稍微好一些。這很合理,對吧?通常,生鮮箱越大,價格就越高。\n",
"\n",
"既然如此,我們也可以嘗試使用 `corrplot` 套件來視覺化所有欄位的相關性矩陣。\n"
],
"metadata": {
"id": "BToPWbgjuoZw"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Load the corrplot package\n",
"library(corrplot)\n",
"\n",
"# Obtain correlation matrix\n",
"corr_mat <- cor(baked_pumpkins %>% \n",
" # Drop columns that are not really informative\n",
" select(-c(low_price, high_price)))\n",
"\n",
"# Make a correlation plot between the variables\n",
"corrplot(corr_mat, method = \"shade\", shade.col = NA, tl.col = \"black\", tl.srt = 45, addCoef.col = \"black\", cl.pos = \"n\", order = \"original\")"
],
"outputs": [],
"metadata": {
"id": "ZwAL3ksmutVR"
}
},
{
"cell_type": "markdown",
"source": [
"🤩🤩 更棒了。\n",
"\n",
"現在可以問這些數據的一個好問題是:'`我可以預期某個南瓜包裝的價格是多少?`' 讓我們直接開始吧!\n",
"\n",
"> 注意:當你使用 **`new_data = NULL`** 烘焙(**`bake()`**)已準備好的配方 **`pumpkins_prep`** 時,你會提取處理過的(即編碼過的)訓練數據。如果你有另一個數據集,例如測試集,並希望查看配方如何預處理它,你只需使用 **`new_data = test_set`** 烘焙 **`pumpkins_prep`** 即可。\n",
"\n",
"## 4. 建立線性回歸模型\n",
"\n",
"<p >\n",
" <img src=\"../../images/linear-polynomial.png\"\n",
" width=\"800\"/>\n",
" <figcaption>由 Dasani Madipalli 製作的資訊圖表</figcaption>\n",
"\n",
"\n",
"<!--![由 Dasani Madipalli 製作的資訊圖表](../../../../../../translated_images/linear-polynomial.5523c7cb6576ccab0fecbd0e3505986eb2d191d9378e785f82befcf3a578a6e7.tw.png){width=\"800\"}-->\n"
],
"metadata": {
"id": "YqXjLuWavNxW"
}
},
{
"cell_type": "markdown",
"source": [
"現在我們已經建立了一個配方,並確認數據將被適當地預處理,接下來讓我們建立一個回歸模型來回答這個問題:`給定一個南瓜包裝,我可以預期它的價格是多少?`\n",
"\n",
"#### 使用訓練集訓練一個線性回歸模型\n",
"\n",
"如你可能已經猜到的,*price* 列是 `結果` 變數,而 *package* 列是 `預測` 變數。\n",
"\n",
"為了做到這一點我們首先將數據分割為訓練集80%和測試集20%),然後定義一個配方,將預測變數列編碼為一組整數,接著建立一個模型規範。我們不會準備和執行配方,因為我們已經知道它會按預期對數據進行預處理。\n"
],
"metadata": {
"id": "Pq0bSzCevW-h"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"set.seed(2056)\n",
"# Split the data into training and test sets\n",
"pumpkins_split <- new_pumpkins %>% \n",
" initial_split(prop = 0.8)\n",
"\n",
"\n",
"# Extract training and test data\n",
"pumpkins_train <- training(pumpkins_split)\n",
"pumpkins_test <- testing(pumpkins_split)\n",
"\n",
"\n",
"\n",
"# Create a recipe for preprocessing the data\n",
"lm_pumpkins_recipe <- recipe(price ~ package, data = pumpkins_train) %>% \n",
" step_integer(all_predictors(), zero_based = TRUE)\n",
"\n",
"\n",
"\n",
"# Create a linear model specification\n",
"lm_spec <- linear_reg() %>% \n",
" set_engine(\"lm\") %>% \n",
" set_mode(\"regression\")"
],
"outputs": [],
"metadata": {
"id": "CyoEh_wuvcLv"
}
},
{
"cell_type": "markdown",
"source": [
"做得好!現在我們已經有了一個食譜和一個模型規範,接下來需要找到一種方法,將它們打包成一個物件。這個物件將首先對數據進行預處理(在背後完成 prep+bake 的步驟),然後基於預處理後的數據擬合模型,並且還能支持潛在的後處理操作。這樣是不是讓你更安心了呢!🤩\n",
"\n",
"在 Tidymodels 中,這個方便的物件叫做 [`workflow`](https://workflows.tidymodels.org/),它能輕鬆地容納你的建模組件!這就像我們在 *Python* 中所說的 *pipelines*。\n",
"\n",
"那麼,讓我們把所有東西打包到一個 workflow 中吧!📦\n"
],
"metadata": {
"id": "G3zF_3DqviFJ"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Hold modelling components in a workflow\n",
"lm_wf <- workflow() %>% \n",
" add_recipe(lm_pumpkins_recipe) %>% \n",
" add_model(lm_spec)\n",
"\n",
"# Print out the workflow\n",
"lm_wf"
],
"outputs": [],
"metadata": {
"id": "T3olroU3v-WX"
}
},
{
"cell_type": "markdown",
"source": [
"順帶一提,工作流程可以像模型一樣進行適配或訓練。\n"
],
"metadata": {
"id": "zd1A5tgOwEPX"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Train the model\n",
"lm_wf_fit <- lm_wf %>% \n",
" fit(data = pumpkins_train)\n",
"\n",
"# Print the model coefficients learned \n",
"lm_wf_fit"
],
"outputs": [],
"metadata": {
"id": "NhJagFumwFHf"
}
},
{
"cell_type": "markdown",
"source": [
"從模型輸出中,我們可以看到訓練過程中學到的係數。這些係數代表最佳擬合直線的係數,該直線使實際變數與預測變數之間的整體誤差最小化。\n",
"\n",
"#### 使用測試集評估模型表現\n",
"\n",
"現在是時候看看模型的表現如何了 📏!我們該怎麼做呢?\n",
"\n",
"既然我們已經訓練了模型,就可以使用 `parsnip::predict()` 為測試集進行預測。接著,我們可以將這些預測值與實際的標籤值進行比較,以評估模型的效果(或者效果不佳!)。\n",
"\n",
"讓我們先從為測試集進行預測開始,然後將這些預測結果與測試集的欄位結合起來。\n"
],
"metadata": {
"id": "_4QkGtBTwItF"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Make predictions for the test set\n",
"predictions <- lm_wf_fit %>% \n",
" predict(new_data = pumpkins_test)\n",
"\n",
"\n",
"# Bind predictions to the test set\n",
"lm_results <- pumpkins_test %>% \n",
" select(c(package, price)) %>% \n",
" bind_cols(predictions)\n",
"\n",
"\n",
"# Print the first ten rows of the tibble\n",
"lm_results %>% \n",
" slice_head(n = 10)"
],
"outputs": [],
"metadata": {
"id": "UFZzTG0gwTs9"
}
},
{
"cell_type": "markdown",
"source": [
"是的,您剛剛訓練了一個模型並使用它進行了預測!🔮 它表現如何呢?讓我們來評估模型的性能吧!\n",
"\n",
"在 Tidymodels 中,我們使用 `yardstick::metrics()` 來完成這項工作!對於線性回歸,我們主要關注以下指標:\n",
"\n",
"- `均方根誤差 (RMSE)`: [MSE](https://en.wikipedia.org/wiki/Mean_squared_error) 的平方根。這是一個絕對指標,單位與標籤相同(在這個例子中是南瓜的價格)。值越小,模型越好(簡單來說,它代表了預測錯誤的平均價格!)\n",
"\n",
"- `決定係數 (通常稱為 R-squared 或 R2)`: 一個相對指標,值越高,模型的擬合效果越好。從本質上講,這個指標表示模型能解釋的預測值與實際標籤值之間的方差比例。\n"
],
"metadata": {
"id": "0A5MjzM7wW9M"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Evaluate performance of linear regression\n",
"metrics(data = lm_results,\n",
" truth = price,\n",
" estimate = .pred)"
],
"outputs": [],
"metadata": {
"id": "reJ0UIhQwcEH"
}
},
{
"cell_type": "markdown",
"source": [
"模型效能下降了。我們來看看是否可以透過視覺化包裝與價格的散佈圖,並使用模型的預測結果疊加最佳擬合線,來獲得更好的指標。\n",
"\n",
"這表示我們需要準備並處理測試集,以便對包裝欄位進行編碼,然後將其與模型的預測結果結合起來。\n"
],
"metadata": {
"id": "fdgjzjkBwfWt"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Encode package column\n",
"package_encode <- lm_pumpkins_recipe %>% \n",
" prep() %>% \n",
" bake(new_data = pumpkins_test) %>% \n",
" select(package)\n",
"\n",
"\n",
"# Bind encoded package column to the results\n",
"lm_results <- lm_results %>% \n",
" bind_cols(package_encode %>% \n",
" rename(package_integer = package)) %>% \n",
" relocate(package_integer, .after = package)\n",
"\n",
"\n",
"# Print new results data frame\n",
"lm_results %>% \n",
" slice_head(n = 5)\n",
"\n",
"\n",
"# Make a scatter plot\n",
"lm_results %>% \n",
" ggplot(mapping = aes(x = package_integer, y = price)) +\n",
" geom_point(size = 1.6) +\n",
" # Overlay a line of best fit\n",
" geom_line(aes(y = .pred), color = \"orange\", size = 1.2) +\n",
" xlab(\"package\")\n",
" \n"
],
"outputs": [],
"metadata": {
"id": "R0nw719lwkHE"
}
},
{
"cell_type": "markdown",
"source": [
"很棒!如您所見,線性回歸模型並未很好地概括包裝與其對應價格之間的關係。\n",
"\n",
"🎃 恭喜您,您剛剛創建了一個可以幫助預測幾種南瓜價格的模型。您的節日南瓜田將會非常美麗。但您可能可以創建一個更好的模型!\n",
"\n",
"## 5. 建立多項式回歸模型\n",
"\n",
"<p >\n",
" <img src=\"../../images/linear-polynomial.png\"\n",
" width=\"800\"/>\n",
" <figcaption>資訊圖表由 Dasani Madipalli 提供</figcaption>\n",
"\n",
"\n",
"<!--![資訊圖表由 Dasani Madipalli 提供](../../../../../../translated_images/linear-polynomial.5523c7cb6576ccab0fecbd0e3505986eb2d191d9378e785f82befcf3a578a6e7.tw.png){width=\"800\"}-->\n"
],
"metadata": {
"id": "HOCqJXLTwtWI"
}
},
{
"cell_type": "markdown",
"source": [
"有時候,我們的數據可能並非呈線性關係,但我們仍然希望能夠預測結果。多項式回歸可以幫助我們針對更複雜的非線性關係進行預測。\n",
"\n",
"以我們的南瓜數據集中的包裝和價格之間的關係為例。雖然有時候變數之間存在線性關係——例如南瓜的體積越大,價格越高——但有時候這些關係無法用平面或直線來表示。\n",
"\n",
"> ✅ 這裡有[更多例子](https://online.stat.psu.edu/stat501/lesson/9/9.8),展示了可以使用多項式回歸的數據\n",
">\n",
"> 再次看看之前圖表中品種與價格之間的關係。這個散點圖看起來是否一定應該用直線來分析?可能不是。在這種情況下,你可以嘗試使用多項式回歸。\n",
">\n",
"> ✅ 多項式是可能包含一個或多個變數和係數的數學表達式\n",
"\n",
"#### 使用訓練集訓練多項式回歸模型\n",
"\n",
"多項式回歸會創建一條*曲線*,以更好地擬合非線性數據。\n",
"\n",
"讓我們看看多項式模型是否能在預測中表現得更好。我們將遵循與之前類似的步驟:\n",
"\n",
"- 創建一個配方,指定應對數據進行的預處理步驟,以使其準備好進行建模,例如:編碼預測變數並計算*n*次多項式\n",
"\n",
"- 建立模型規範\n",
"\n",
"- 將配方和模型規範捆綁到工作流程中\n",
"\n",
"- 通過擬合工作流程來創建模型\n",
"\n",
"- 評估模型在測試數據上的表現\n",
"\n",
"讓我們開始吧!\n"
],
"metadata": {
"id": "VcEIpRV9wzYr"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Specify a recipe\r\n",
"poly_pumpkins_recipe <-\r\n",
" recipe(price ~ package, data = pumpkins_train) %>%\r\n",
" step_integer(all_predictors(), zero_based = TRUE) %>% \r\n",
" step_poly(all_predictors(), degree = 4)\r\n",
"\r\n",
"\r\n",
"# Create a model specification\r\n",
"poly_spec <- linear_reg() %>% \r\n",
" set_engine(\"lm\") %>% \r\n",
" set_mode(\"regression\")\r\n",
"\r\n",
"\r\n",
"# Bundle recipe and model spec into a workflow\r\n",
"poly_wf <- workflow() %>% \r\n",
" add_recipe(poly_pumpkins_recipe) %>% \r\n",
" add_model(poly_spec)\r\n",
"\r\n",
"\r\n",
"# Create a model\r\n",
"poly_wf_fit <- poly_wf %>% \r\n",
" fit(data = pumpkins_train)\r\n",
"\r\n",
"\r\n",
"# Print learned model coefficients\r\n",
"poly_wf_fit\r\n",
"\r\n",
" "
],
"outputs": [],
"metadata": {
"id": "63n_YyRXw3CC"
}
},
{
"cell_type": "markdown",
"source": [
"#### 評估模型表現\n",
"\n",
"👏👏你已經建立了一個多項式模型,現在讓我們在測試集上進行預測吧!\n"
],
"metadata": {
"id": "-LHZtztSxDP0"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Make price predictions on test data\r\n",
"poly_results <- poly_wf_fit %>% predict(new_data = pumpkins_test) %>% \r\n",
" bind_cols(pumpkins_test %>% select(c(package, price))) %>% \r\n",
" relocate(.pred, .after = last_col())\r\n",
"\r\n",
"\r\n",
"# Print the results\r\n",
"poly_results %>% \r\n",
" slice_head(n = 10)"
],
"outputs": [],
"metadata": {
"id": "YUFpQ_dKxJGx"
}
},
{
"cell_type": "markdown",
"source": [
"Woo-hoo讓我們使用 `yardstick::metrics()` 評估模型在 test_set 上的表現。\n"
],
"metadata": {
"id": "qxdyj86bxNGZ"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"metrics(data = poly_results, truth = price, estimate = .pred)"
],
"outputs": [],
"metadata": {
"id": "8AW5ltkBxXDm"
}
},
{
"cell_type": "markdown",
"source": [
"🤩🤩 表現大幅提升。\n",
"\n",
"`rmse` 從大約 7 降至大約 3這表示實際價格與預測價格之間的誤差減少。你可以*大致*理解為平均來說,錯誤的預測大約相差 \\$3。`rsq` 從大約 0.4 提升至 0.8。\n",
"\n",
"所有這些指標都顯示多項式模型的表現遠優於線性模型。做得好!\n",
"\n",
"讓我們看看是否可以將這些視覺化!\n"
],
"metadata": {
"id": "6gLHNZDwxYaS"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Bind encoded package column to the results\r\n",
"poly_results <- poly_results %>% \r\n",
" bind_cols(package_encode %>% \r\n",
" rename(package_integer = package)) %>% \r\n",
" relocate(package_integer, .after = package)\r\n",
"\r\n",
"\r\n",
"# Print new results data frame\r\n",
"poly_results %>% \r\n",
" slice_head(n = 5)\r\n",
"\r\n",
"\r\n",
"# Make a scatter plot\r\n",
"poly_results %>% \r\n",
" ggplot(mapping = aes(x = package_integer, y = price)) +\r\n",
" geom_point(size = 1.6) +\r\n",
" # Overlay a line of best fit\r\n",
" geom_line(aes(y = .pred), color = \"midnightblue\", size = 1.2) +\r\n",
" xlab(\"package\")\r\n"
],
"outputs": [],
"metadata": {
"id": "A83U16frxdF1"
}
},
{
"cell_type": "markdown",
"source": [
"您可以看到一條更符合您數據的曲線!🤩\n",
"\n",
"您可以通過向 `geom_smooth` 傳遞多項式公式來使其更加平滑,例如:\n"
],
"metadata": {
"id": "4U-7aHOVxlGU"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Make a scatter plot\r\n",
"poly_results %>% \r\n",
" ggplot(mapping = aes(x = package_integer, y = price)) +\r\n",
" geom_point(size = 1.6) +\r\n",
" # Overlay a line of best fit\r\n",
" geom_smooth(method = lm, formula = y ~ poly(x, degree = 4), color = \"midnightblue\", size = 1.2, se = FALSE) +\r\n",
" xlab(\"package\")"
],
"outputs": [],
"metadata": {
"id": "5vzNT0Uexm-w"
}
},
{
"cell_type": "markdown",
"source": [
"就像一條平滑的曲線!🤩\n",
"\n",
"以下是如何進行新的預測:\n"
],
"metadata": {
"id": "v9u-wwyLxq4G"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Make a hypothetical data frame\r\n",
"hypo_tibble <- tibble(package = \"bushel baskets\")\r\n",
"\r\n",
"# Make predictions using linear model\r\n",
"lm_pred <- lm_wf_fit %>% predict(new_data = hypo_tibble)\r\n",
"\r\n",
"# Make predictions using polynomial model\r\n",
"poly_pred <- poly_wf_fit %>% predict(new_data = hypo_tibble)\r\n",
"\r\n",
"# Return predictions in a list\r\n",
"list(\"linear model prediction\" = lm_pred, \r\n",
" \"polynomial model prediction\" = poly_pred)\r\n"
],
"outputs": [],
"metadata": {
"id": "jRPSyfQGxuQv"
}
},
{
"cell_type": "markdown",
"source": [
"`多項式模型`的預測是合理的,根據`價格`和`包裝`的散點圖!而且,如果這個模型比之前的模型更好,基於相同的數據,你需要為這些更昂貴的南瓜預算做好準備!\n",
"\n",
"🏆 做得好!你在一節課中建立了兩個回歸模型。在回歸的最後一部分,你將學習邏輯回歸以確定分類。\n",
"\n",
"## **🚀挑戰**\n",
"\n",
"在此筆記本中測試幾個不同的變數,看看相關性如何影響模型的準確性。\n",
"\n",
"## [**課後測驗**](https://gray-sand-07a10f403.1.azurestaticapps.net/quiz/14/)\n",
"\n",
"## **複習與自學**\n",
"\n",
"在這節課中,我們學習了線性回歸。還有其他重要的回歸類型。閱讀有關逐步回歸、嶺回歸、套索回歸和彈性網技術的資料。一門很好的課程是[斯坦福統計學習課程](https://online.stanford.edu/courses/sohs-ystatslearning-statistical-learning)。\n",
"\n",
"如果你想了解更多有關使用出色的Tidymodels框架的資訊請查看以下資源\n",
"\n",
"- Tidymodels網站[開始使用Tidymodels](https://www.tidymodels.org/start/)\n",
"\n",
"- Max Kuhn 和 Julia Silge[*Tidy Modeling with R*](https://www.tmwr.org/)*.*\n",
"\n",
"###### **感謝:**\n",
"\n",
"[Allison Horst](https://twitter.com/allison_horst?lang=en) 創作了令人驚嘆的插圖使R更加友好和吸引人。可以在她的[畫廊](https://www.google.com/url?q=https://github.com/allisonhorst/stats-illustrations&sa=D&source=editors&ust=1626380772530000&usg=AOvVaw3zcfyCizFQZpkSLzxiiQEM)找到更多插圖。\n"
],
"metadata": {
"id": "8zOLOWqMxzk5"
}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n---\n\n**免責聲明** \n本文件已使用 AI 翻譯服務 [Co-op Translator](https://github.com/Azure/co-op-translator) 進行翻譯。儘管我們致力於提供準確的翻譯,請注意自動翻譯可能包含錯誤或不準確之處。原始文件的母語版本應被視為權威來源。對於關鍵資訊,建議使用專業人工翻譯。我們對因使用此翻譯而引起的任何誤解或錯誤解釋不承擔責任。\n"
]
}
]
}