{ "nbformat": 4, "nbformat_minor": 2, "metadata": { "colab": { "name": "lesson_3-R.ipynb", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "ir", "display_name": "R" }, "language_info": { "name": "R" }, "coopTranslator": { "original_hash": "5015d65d61ba75a223bfc56c273aa174", "translation_date": "2025-08-29T22:53:44+00:00", "source_file": "2-Regression/3-Linear/solution/R/lesson_3-R.ipynb", "language_code": "mo" } }, "cells": [ { "cell_type": "markdown", "source": [], "metadata": { "id": "EgQw8osnsUV-" } }, { "cell_type": "markdown", "source": [ "## 南瓜定價的線性與多項式回歸 - 第三課\n", "

\n", " \n", "

資訊圖表由 Dasani Madipalli 製作
\n", "\n", "\n", "\n", "\n", "#### 簡介\n", "\n", "到目前為止,你已經透過我們將在本課中使用的南瓜定價數據集,探索了回歸的概念,並使用 `ggplot2` 進行了可視化。💪\n", "\n", "現在,你準備更深入地了解機器學習中的回歸。在本課中,你將學習兩種類型的回歸:*基本線性回歸* 和 *多項式回歸*,以及這些技術背後的一些數學原理。\n", "\n", "> 在整個課程中,我們假設學生的數學知識有限,並努力使其對來自其他領域的學生更易於理解,因此請注意筆記、🧮 提示、圖表和其他學習工具,這些都將幫助你更好地理解內容。\n", "\n", "#### 準備工作\n", "\n", "提醒一下,你正在載入這些數據以便對其進行分析。\n", "\n", "- 什麼時候是購買南瓜的最佳時機?\n", "\n", "- 一箱迷你南瓜的價格大約是多少?\n", "\n", "- 我應該購買半蒲式耳籃還是 1 1/9 蒲式耳箱?讓我們繼續深入挖掘這些數據。\n", "\n", "在上一課中,你創建了一個 `tibble`(數據框架的現代化版本),並用部分原始數據集填充它,將價格標準化為以蒲式耳為單位。然而,這樣做的結果是,你只能收集到大約 400 個數據點,而且僅限於秋季月份。也許通過進一步清理數據,我們可以更詳細地了解數據的性質?我們拭目以待... 🕵️‍♀️\n", "\n", "為了完成這項任務,我們需要以下套件:\n", "\n", "- `tidyverse`: [tidyverse](https://www.tidyverse.org/) 是一個 [R 套件集合](https://www.tidyverse.org/packages),旨在讓數據科學更快速、更簡單、更有趣!\n", "\n", "- `tidymodels`: [tidymodels](https://www.tidymodels.org/) 框架是一個 [套件集合](https://www.tidymodels.org/packages/),用於建模和機器學習。\n", "\n", "- `janitor`: [janitor 套件](https://github.com/sfirke/janitor) 提供了一些簡單的小工具,用於檢查和清理髒數據。\n", "\n", "- `corrplot`: [corrplot 套件](https://cran.r-project.org/web/packages/corrplot/vignettes/corrplot-intro.html) 提供了一個可視化的相關矩陣探索工具,支持自動變量重新排序,以幫助檢測變量之間隱藏的模式。\n", "\n", "你可以通過以下方式安裝它們:\n", "\n", "`install.packages(c(\"tidyverse\", \"tidymodels\", \"janitor\", \"corrplot\"))`\n", "\n", "以下腳本將檢查你是否已安裝完成本模組所需的套件,並在缺少時為你安裝。\n" ], "metadata": { "id": "WqQPS1OAsg3H" } }, { "cell_type": "code", "execution_count": null, "source": [ "suppressWarnings(if (!require(\"pacman\")) install.packages(\"pacman\"))\n", "\n", "pacman::p_load(tidyverse, tidymodels, janitor, corrplot)" ], "outputs": [], "metadata": { "id": "tA4C2WN3skCf", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "c06cd805-5534-4edc-f72b-d0d1dab96ac0" } }, { "cell_type": "markdown", "source": [ "我們稍後會載入這些超棒的套件,並在目前的 R 工作環境中使用它們。(這只是示範,`pacman::p_load()` 已經幫你完成了這件事)\n", "\n", "## 1. 線性回歸線\n", "\n", "如同你在第一課中學到的,線性回歸的目標是能夠繪製一條*最佳擬合線*,以便:\n", "\n", "- **展示變數之間的關係**。顯示變數之間的關聯性。\n", "\n", "- **進行預測**。準確預測新數據點在該線上的位置。\n", "\n", "為了繪製這種類型的線,我們使用一種統計技術,稱為**最小平方回歸**。`最小平方`的意思是,所有圍繞回歸線的數據點的距離都被平方後加總。理想情況下,最終的加總值應該越小越好,因為我們希望誤差值越低越好,也就是`最小平方`。因此,最佳擬合線就是使平方誤差加總值最低的那條線——這就是*最小平方回歸*的由來。\n", "\n", "我們這樣做是因為希望建模出一條與所有數據點的累積距離最小的線。我們在加總之前先將距離平方,因為我們關注的是距離的大小,而不是方向。\n", "\n", "> **🧮 數學公式**\n", ">\n", "> 這條線,稱為*最佳擬合線*,可以用[一個公式](https://en.wikipedia.org/wiki/Simple_linear_regression)表示:\n", ">\n", "> Y = a + bX\n", ">\n", "> `X` 是`解釋變數`或`預測因子`,`Y` 是`依變數`或`結果`。線的斜率是 `b`,而 `a` 是 y 截距,表示當 `X = 0` 時 `Y` 的值。\n", ">\n", "\n", "> ![](../../../../../../2-Regression/3-Linear/solution/images/slope.png \"slope = $y/x$\")\n", " 圖解由 Jen Looper 提供\n", ">\n", "> 首先,計算斜率 `b`。\n", ">\n", "> 換句話說,參考我們南瓜數據的原始問題:「根據月份預測每蒲式耳南瓜的價格」,`X` 代表價格,`Y` 代表銷售月份。\n", ">\n", "> ![](../../../../../../translated_images/calculation.989aa7822020d9d0ba9fc781f1ab5192f3421be86ebb88026528aef33c37b0d8.mo.png)\n", " 圖解由 Jen Looper 提供\n", "> \n", "> 計算 Y 的值。如果你支付大約 \\$4,那一定是四月!\n", ">\n", "> 計算這條線的數學公式必須展示線的斜率,這也取決於截距,或者說當 `X = 0` 時 `Y` 的位置。\n", ">\n", "> 你可以在 [Math is Fun](https://www.mathsisfun.com/data/least-squares-regression.html) 網站上觀察這些值的計算方法。也可以訪問[最小平方計算器](https://www.mathsisfun.com/data/least-squares-calculator.html),看看數值如何影響這條線。\n", "\n", "看起來不那麼可怕吧?🤓\n", "\n", "#### 相關性\n", "\n", "還有一個需要理解的術語是給定 X 和 Y 變數之間的**相關係數**。使用散佈圖,你可以快速地可視化這個係數。數據點整齊排列成一條線的圖表具有高相關性,而數據點在 X 和 Y 之間隨意分佈的圖表則具有低相關性。\n", "\n", "一個好的線性回歸模型應該是使用最小平方回歸方法並具有高相關係數(接近 1 而非 0)的回歸線。\n" ], "metadata": { "id": "cdX5FRpvsoP5" } }, { "cell_type": "markdown", "source": [ "## **2. 與數據共舞:建立用於建模的數據框**\n", "\n", "

\n", " \n", "

插圖由 @allison_horst 提供
\n" ], "metadata": { "id": "WdUKXk7Bs8-V" } }, { "cell_type": "markdown", "source": [ "載入所需的函式庫和數據集。將數據轉換為包含數據子集的數據框:\n", "\n", "- 僅選取以蒲式耳定價的南瓜\n", "\n", "- 將日期轉換為月份\n", "\n", "- 計算價格為高價和低價的平均值\n", "\n", "- 將價格轉換為反映按蒲式耳數量定價的形式\n", "\n", "> 我們在[上一課](https://github.com/microsoft/ML-For-Beginners/blob/main/2-Regression/2-Data/solution/lesson_2-R.ipynb)中已經涵蓋了這些步驟。\n" ], "metadata": { "id": "fMCtu2G2s-p8" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Load the core Tidyverse packages\n", "library(tidyverse)\n", "library(lubridate)\n", "\n", "# Import the pumpkins data\n", "pumpkins <- read_csv(file = \"https://raw.githubusercontent.com/microsoft/ML-For-Beginners/main/2-Regression/data/US-pumpkins.csv\")\n", "\n", "\n", "# Get a glimpse and dimensions of the data\n", "glimpse(pumpkins)\n", "\n", "\n", "# Print the first 50 rows of the data set\n", "pumpkins %>% \n", " slice_head(n = 5)" ], "outputs": [], "metadata": { "id": "ryMVZEEPtERn" } }, { "cell_type": "markdown", "source": [ "出於純粹冒險的精神,讓我們探索 [`janitor package`](../../../../../../2-Regression/3-Linear/solution/R/github.com/sfirke/janitor),它提供了簡單的函數來檢查和清理髒數據。例如,讓我們看看我們數據的列名稱:\n" ], "metadata": { "id": "xcNxM70EtJjb" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Return column names\n", "pumpkins %>% \n", " names()" ], "outputs": [], "metadata": { "id": "5XtpaIigtPfW" } }, { "cell_type": "markdown", "source": [ "🤔 我們可以做得更好。讓我們通過使用 `janitor::clean_names` 將這些列名稱轉換為 [snake_case](https://en.wikipedia.org/wiki/Snake_case) 規範,並將它們命名為 `friendR`。要了解有關此函數的更多信息:`?clean_names`\n" ], "metadata": { "id": "IbIqrMINtSHe" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Clean names to the snake_case convention\n", "pumpkins <- pumpkins %>% \n", " clean_names(case = \"snake\")\n", "\n", "# Return column names\n", "pumpkins %>% \n", " names()" ], "outputs": [], "metadata": { "id": "a2uYvclYtWvX" } }, { "cell_type": "markdown", "source": [ "非常整潔 🧹!現在,像上一課一樣,使用 `dplyr` 與數據共舞吧!💃\n" ], "metadata": { "id": "HfhnuzDDtaDd" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Select desired columns\n", "pumpkins <- pumpkins %>% \n", " select(variety, city_name, package, low_price, high_price, date)\n", "\n", "\n", "\n", "# Extract the month from the dates to a new column\n", "pumpkins <- pumpkins %>%\n", " mutate(date = mdy(date),\n", " month = month(date)) %>% \n", " select(-date)\n", "\n", "\n", "\n", "# Create a new column for average Price\n", "pumpkins <- pumpkins %>% \n", " mutate(price = (low_price + high_price)/2)\n", "\n", "\n", "# Retain only pumpkins with the string \"bushel\"\n", "new_pumpkins <- pumpkins %>% \n", " filter(str_detect(string = package, pattern = \"bushel\"))\n", "\n", "\n", "# Normalize the pricing so that you show the pricing per bushel, not per 1 1/9 or 1/2 bushel\n", "new_pumpkins <- new_pumpkins %>% \n", " mutate(price = case_when(\n", " str_detect(package, \"1 1/9\") ~ price/(1.1),\n", " str_detect(package, \"1/2\") ~ price*2,\n", " TRUE ~ price))\n", "\n", "# Relocate column positions\n", "new_pumpkins <- new_pumpkins %>% \n", " relocate(month, .before = variety)\n", "\n", "\n", "# Display the first 5 rows\n", "new_pumpkins %>% \n", " slice_head(n = 5)" ], "outputs": [], "metadata": { "id": "X0wU3gQvtd9f" } }, { "cell_type": "markdown", "source": [ "做得好!👌 現在你擁有一個乾淨整潔的數據集,可以用來建立新的回歸模型!\n", "\n", "要來個散佈圖嗎?\n" ], "metadata": { "id": "UpaIwaxqth82" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Set theme\n", "theme_set(theme_light())\n", "\n", "# Make a scatter plot of month and price\n", "new_pumpkins %>% \n", " ggplot(mapping = aes(x = month, y = price)) +\n", " geom_point(size = 1.6)\n" ], "outputs": [], "metadata": { "id": "DXgU-j37tl5K" } }, { "cell_type": "markdown", "source": [ "散佈圖提醒我們,目前僅有從八月到十二月的月度數據。我們可能需要更多數據,才能以線性方式得出結論。\n", "\n", "讓我們再看看我們的建模數據:\n" ], "metadata": { "id": "Ve64wVbwtobI" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Display first 5 rows\n", "new_pumpkins %>% \n", " slice_head(n = 5)" ], "outputs": [], "metadata": { "id": "HFQX2ng1tuSJ" } }, { "cell_type": "markdown", "source": [ "如果我們想根據屬於字符類型的 `city` 或 `package` 欄位來預測南瓜的 `price`,該怎麼做呢?或者更簡單地說,我們如何找到例如 `package` 和 `price` 之間的相關性(這需要兩個輸入都是數值型)?🤷🤷\n", "\n", "機器學習模型在處理數值特徵時效果最佳,而不是文字值,因此通常需要將分類特徵轉換為數值表示。\n", "\n", "這意味著我們必須找到一種方法來重新格式化我們的預測變數,使其更容易被模型有效使用,這個過程被稱為 `特徵工程`。\n" ], "metadata": { "id": "7hsHoxsStyjJ" } }, { "cell_type": "markdown", "source": [ "## 3. 使用 recipes 預處理建模數據 👩‍🍳👨‍🍳\n", "\n", "將預測變數重新格式化以便模型更有效使用的活動被稱為 `特徵工程`。\n", "\n", "不同的模型有不同的預處理需求。例如,最小平方法需要對 `分類變數`(如月份、品種和城市名稱)進行 `編碼`。這通常涉及將包含 `分類值` 的列 `轉換` 為一個或多個 `數值列`,以取代原始列。\n", "\n", "舉例來說,假設你的數據包含以下分類特徵:\n", "\n", "| 城市 |\n", "|:-------:|\n", "| 丹佛 |\n", "| 奈洛比 |\n", "| 東京 |\n", "\n", "你可以使用 *序數編碼* 為每個類別替換一個唯一的整數值,如下所示:\n", "\n", "| 城市 |\n", "|:----:|\n", "| 0 |\n", "| 1 |\n", "| 2 |\n", "\n", "這就是我們將對數據進行的操作!\n", "\n", "在本節中,我們將探索另一個令人驚嘆的 Tidymodels 套件:[recipes](https://tidymodels.github.io/recipes/) - 它專為幫助你在訓練模型之前預處理數據而設計。從本質上來說,recipe 是一個物件,用於定義應該對數據集應用哪些步驟,以使其準備好進行建模。\n", "\n", "現在,讓我們創建一個 recipe,通過為預測列中的所有觀測值替換唯一的整數來準備數據進行建模:\n" ], "metadata": { "id": "AD5kQbcvt3Xl" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Specify a recipe\n", "pumpkins_recipe <- recipe(price ~ ., data = new_pumpkins) %>% \n", " step_integer(all_predictors(), zero_based = TRUE)\n", "\n", "\n", "# Print out the recipe\n", "pumpkins_recipe" ], "outputs": [], "metadata": { "id": "BNaFKXfRt9TU" } }, { "cell_type": "markdown", "source": [ "太棒了!👏 我們剛剛創建了第一個配方,指定了一個結果(價格)及其相應的預測變數,並且所有的預測變數列都被編碼為一組整數 🙌!讓我們快速拆解一下:\n", "\n", "- 使用公式調用 `recipe()` 告訴配方變數的*角色*,並以 `new_pumpkins` 數據作為參考。例如,`price` 列被分配為 `outcome` 角色,而其餘的列則被分配為 `predictor` 角色。\n", "\n", "- `step_integer(all_predictors(), zero_based = TRUE)` 指定所有的預測變數應轉換為一組整數,編號從 0 開始。\n", "\n", "我們相信你可能會有這樣的想法:「這太酷了!!但如果我需要確認這些配方是否真的按照我的預期運作呢?🤔」\n", "\n", "這是一個很棒的想法!你看,一旦你的配方被定義,你可以估算實際預處理數據所需的參數,然後提取處理後的數據。通常使用 Tidymodels 時不需要這樣做(我們稍後會看到常規方法 -> `workflows`),但當你想進行某種檢查以確認配方是否符合預期時,這會非常有用。\n", "\n", "為此,你需要另外兩個動詞:`prep()` 和 `bake()`。一如既往,我們的小 R 朋友由 [`Allison Horst`](https://github.com/allisonhorst/stats-illustrations) 創作的插圖能幫助你更好地理解!\n", "\n", "

\n", " \n", "

插圖由 @allison_horst 創作
\n" ], "metadata": { "id": "KEiO0v7kuC9O" } }, { "cell_type": "markdown", "source": [ "[`prep()`](https://recipes.tidymodels.org/reference/prep.html):從訓練集估算所需的參數,這些參數之後可以應用到其他數據集。例如,對於某個預測變數列,哪個觀測值會被分配為整數 0、1、2 等。\n", "\n", "[`bake()`](https://recipes.tidymodels.org/reference/bake.html):使用已準備好的 recipe,並將操作應用到任何數據集。\n", "\n", "話雖如此,讓我們準備並執行 recipe,來真正確認在背後,預測變數列會在模型擬合之前先被編碼。\n" ], "metadata": { "id": "Q1xtzebuuTCP" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Prep the recipe\n", "pumpkins_prep <- prep(pumpkins_recipe)\n", "\n", "# Bake the recipe to extract a preprocessed new_pumpkins data\n", "baked_pumpkins <- bake(pumpkins_prep, new_data = NULL)\n", "\n", "# Print out the baked data set\n", "baked_pumpkins %>% \n", " slice_head(n = 10)" ], "outputs": [], "metadata": { "id": "FGBbJbP_uUUn" } }, { "cell_type": "markdown", "source": [ "哇哦!🥳 處理後的數據 `baked_pumpkins` 的所有預測變數都已經編碼,這證實了我們定義的預處理步驟(作為我們的配方)確實能如預期運作。雖然這讓數據對你來說更難閱讀,但對 Tidymodels 來說卻更加易於理解!花點時間找出哪個觀測值被映射到對應的整數。\n", "\n", "另外值得一提的是,`baked_pumpkins` 是一個數據框,我們可以在其上執行計算。\n", "\n", "例如,我們可以嘗試在數據中的兩個點之間找到一個良好的相關性,從而潛在地構建一個好的預測模型。我們將使用函數 `cor()` 來完成這個操作。輸入 `?cor()` 來了解更多關於該函數的信息。\n" ], "metadata": { "id": "1dvP0LBUueAW" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Find the correlation between the city_name and the price\n", "cor(baked_pumpkins$city_name, baked_pumpkins$price)\n", "\n", "# Find the correlation between the package and the price\n", "cor(baked_pumpkins$package, baked_pumpkins$price)\n" ], "outputs": [], "metadata": { "id": "3bQzXCjFuiSV" } }, { "cell_type": "markdown", "source": [ "事實證明,城市與價格之間的相關性相當弱。不過,包裝與價格之間的相關性稍微好一些。這是合理的,對吧?通常來說,生鮮箱越大,價格就越高。\n", "\n", "既然如此,我們也來試著使用 `corrplot` 套件,將所有欄位的相關性矩陣視覺化一下吧。\n" ], "metadata": { "id": "BToPWbgjuoZw" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Load the corrplot package\n", "library(corrplot)\n", "\n", "# Obtain correlation matrix\n", "corr_mat <- cor(baked_pumpkins %>% \n", " # Drop columns that are not really informative\n", " select(-c(low_price, high_price)))\n", "\n", "# Make a correlation plot between the variables\n", "corrplot(corr_mat, method = \"shade\", shade.col = NA, tl.col = \"black\", tl.srt = 45, addCoef.col = \"black\", cl.pos = \"n\", order = \"original\")" ], "outputs": [], "metadata": { "id": "ZwAL3ksmutVR" } }, { "cell_type": "markdown", "source": [ "🤩🤩 更棒了。\n", "\n", "現在可以針對這些數據提出一個好問題:'`對於一個特定的南瓜包裝,我可以預期的價格是多少?`' 讓我們直接開始吧!\n", "\n", "> 注意:當你使用 **`bake()`** 並將預處理好的配方 **`pumpkins_prep`** 設定為 **`new_data = NULL`** 時,你會提取出已處理(即已編碼)的訓練數據。如果你有另一組數據,例如測試集,並希望查看配方如何對其進行預處理,你只需將 **`pumpkins_prep`** 與 **`new_data = test_set`** 一起使用即可。\n", "\n", "## 4. 建立線性回歸模型\n", "\n", "

\n", " \n", "

圖解由 Dasani Madipalli 提供
\n" ], "metadata": { "id": "YqXjLuWavNxW" } }, { "cell_type": "markdown", "source": [ "現在我們已經建立了一個配方,並確認數據將被適當地預處理,接下來我們來建立一個回歸模型來回答這個問題:`給定的南瓜包裝,我可以預期的價格是多少?`\n", "\n", "#### 使用訓練集訓練線性回歸模型\n", "\n", "正如你可能已經猜到的,*price* 列是 `結果` 變量,而 *package* 列是 `預測` 變量。\n", "\n", "為了做到這一點,我們首先將數據分割為 80% 的訓練集和 20% 的測試集,然後定義一個配方,將預測變量列編碼為一組整數,接著建立一個模型規範。我們不會準備和執行配方,因為我們已經知道它會按預期對數據進行預處理。\n" ], "metadata": { "id": "Pq0bSzCevW-h" } }, { "cell_type": "code", "execution_count": null, "source": [ "set.seed(2056)\n", "# Split the data into training and test sets\n", "pumpkins_split <- new_pumpkins %>% \n", " initial_split(prop = 0.8)\n", "\n", "\n", "# Extract training and test data\n", "pumpkins_train <- training(pumpkins_split)\n", "pumpkins_test <- testing(pumpkins_split)\n", "\n", "\n", "\n", "# Create a recipe for preprocessing the data\n", "lm_pumpkins_recipe <- recipe(price ~ package, data = pumpkins_train) %>% \n", " step_integer(all_predictors(), zero_based = TRUE)\n", "\n", "\n", "\n", "# Create a linear model specification\n", "lm_spec <- linear_reg() %>% \n", " set_engine(\"lm\") %>% \n", " set_mode(\"regression\")" ], "outputs": [], "metadata": { "id": "CyoEh_wuvcLv" } }, { "cell_type": "markdown", "source": [ "做得好!現在我們已經有了配方和模型規格,我們需要找到一種方法將它們打包成一個物件,這個物件首先會對數據進行預處理(在幕後完成 prep+bake),然後基於預處理後的數據擬合模型,並且還能支持潛在的後處理活動。這樣是不是讓你更安心了呢!🤩\n", "\n", "在 Tidymodels 中,這個方便的物件叫做 [`workflow`](https://workflows.tidymodels.org/),它能輕鬆地保存你的建模組件!這就像我們在 *Python* 中所說的 *pipelines*。\n", "\n", "那麼,讓我們把所有東西打包成一個 workflow 吧!📦\n" ], "metadata": { "id": "G3zF_3DqviFJ" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Hold modelling components in a workflow\n", "lm_wf <- workflow() %>% \n", " add_recipe(lm_pumpkins_recipe) %>% \n", " add_model(lm_spec)\n", "\n", "# Print out the workflow\n", "lm_wf" ], "outputs": [], "metadata": { "id": "T3olroU3v-WX" } }, { "cell_type": "markdown", "source": [ "順帶一提,工作流程可以像模型一樣進行適配或訓練。\n" ], "metadata": { "id": "zd1A5tgOwEPX" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Train the model\n", "lm_wf_fit <- lm_wf %>% \n", " fit(data = pumpkins_train)\n", "\n", "# Print the model coefficients learned \n", "lm_wf_fit" ], "outputs": [], "metadata": { "id": "NhJagFumwFHf" } }, { "cell_type": "markdown", "source": [ "從模型輸出中,我們可以看到訓練期間學到的係數。這些係數代表最佳擬合線的係數,該線使實際變數與預測變數之間的整體誤差最小化。\n", "\n", "#### 使用測試集評估模型表現\n", "\n", "現在是時候來看看模型的表現如何了 📏!我們該怎麼做呢?\n", "\n", "既然我們已經訓練了模型,就可以使用 `parsnip::predict()` 對測試集進行預測。然後,我們可以將這些預測值與實際的標籤值進行比較,以評估模型的表現是否良好(或者不良!)。\n", "\n", "讓我們從對測試集進行預測開始,然後將這些預測結果與測試集的欄位結合起來。\n" ], "metadata": { "id": "_4QkGtBTwItF" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Make predictions for the test set\n", "predictions <- lm_wf_fit %>% \n", " predict(new_data = pumpkins_test)\n", "\n", "\n", "# Bind predictions to the test set\n", "lm_results <- pumpkins_test %>% \n", " select(c(package, price)) %>% \n", " bind_cols(predictions)\n", "\n", "\n", "# Print the first ten rows of the tibble\n", "lm_results %>% \n", " slice_head(n = 10)" ], "outputs": [], "metadata": { "id": "UFZzTG0gwTs9" } }, { "cell_type": "markdown", "source": [ "是的,你剛剛訓練了一個模型並用它進行了預測!🔮 它表現如何呢?讓我們來評估模型的性能吧!\n", "\n", "在 Tidymodels 中,我們使用 `yardstick::metrics()` 來完成這項工作!對於線性回歸,我們主要關注以下幾個指標:\n", "\n", "- `均方根誤差 (RMSE)`: [MSE](https://en.wikipedia.org/wiki/Mean_squared_error) 的平方根。這是一個絕對指標,單位與標籤一致(在這個例子中是南瓜的價格)。數值越小,模型越好(簡單來說,它代表了預測錯誤的平均價格!)\n", "\n", "- `決定係數 (通常稱為 R-squared 或 R2)`: 一個相對指標,數值越高,模型的擬合效果越好。從本質上講,這個指標表示模型能解釋的預測值與實際標籤值之間的方差比例。\n" ], "metadata": { "id": "0A5MjzM7wW9M" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Evaluate performance of linear regression\n", "metrics(data = lm_results,\n", " truth = price,\n", " estimate = .pred)" ], "outputs": [], "metadata": { "id": "reJ0UIhQwcEH" } }, { "cell_type": "markdown", "source": [ "模型效能下降了。我們來看看是否可以透過視覺化包裝與價格的散佈圖,並使用預測結果疊加一條最佳擬合線,來獲得更好的指標。\n", "\n", "這意味著我們需要準備並處理測試集,以便對包裝欄位進行編碼,然後將其與模型的預測結果結合起來。\n" ], "metadata": { "id": "fdgjzjkBwfWt" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Encode package column\n", "package_encode <- lm_pumpkins_recipe %>% \n", " prep() %>% \n", " bake(new_data = pumpkins_test) %>% \n", " select(package)\n", "\n", "\n", "# Bind encoded package column to the results\n", "lm_results <- lm_results %>% \n", " bind_cols(package_encode %>% \n", " rename(package_integer = package)) %>% \n", " relocate(package_integer, .after = package)\n", "\n", "\n", "# Print new results data frame\n", "lm_results %>% \n", " slice_head(n = 5)\n", "\n", "\n", "# Make a scatter plot\n", "lm_results %>% \n", " ggplot(mapping = aes(x = package_integer, y = price)) +\n", " geom_point(size = 1.6) +\n", " # Overlay a line of best fit\n", " geom_line(aes(y = .pred), color = \"orange\", size = 1.2) +\n", " xlab(\"package\")\n", " \n" ], "outputs": [], "metadata": { "id": "R0nw719lwkHE" } }, { "cell_type": "markdown", "source": [ "太棒了!如你所見,線性回歸模型並未很好地概括包裝與其對應價格之間的關係。\n", "\n", "🎃 恭喜你!你剛剛創建了一個可以幫助預測幾種南瓜價格的模型。你的節日南瓜園將會非常美麗。不過,你可能可以創建一個更好的模型!\n", "\n", "## 5. 建立多項式回歸模型\n", "\n", "

\n", " \n", "

圖表由 Dasani Madipalli 製作
\n", "\n", "\n", "\n" ], "metadata": { "id": "HOCqJXLTwtWI" } }, { "cell_type": "markdown", "source": [ "有時候,我們的數據可能並非呈線性關係,但我們仍然希望能夠預測結果。多項式回歸可以幫助我們對更複雜的非線性關係進行預測。\n", "\n", "以我們的南瓜數據集中的包裝和價格之間的關係為例。有時候,變量之間可能存在線性關係——例如,南瓜的體積越大,價格越高——但有時候這些關係無法用平面或直線來表示。\n", "\n", "> ✅ 這裡有[更多範例](https://online.stat.psu.edu/stat501/lesson/9/9.8),展示了可以使用多項式回歸的數據\n", ">\n", "> 再次看看之前圖表中品種與價格之間的關係。這個散點圖看起來是否一定應該用直線來分析?或許不應該。在這種情況下,你可以嘗試使用多項式回歸。\n", ">\n", "> ✅ 多項式是由一個或多個變量和係數組成的數學表達式\n", "\n", "#### 使用訓練集訓練多項式回歸模型\n", "\n", "多項式回歸會創建一條*曲線*,以更好地擬合非線性數據。\n", "\n", "讓我們看看多項式模型是否能在預測中表現得更好。我們將遵循與之前類似的步驟:\n", "\n", "- 創建一個配方,指定應對數據進行的預處理步驟,以便為建模做好準備,例如:對預測變量進行編碼並計算次數為 *n* 的多項式\n", "\n", "- 建立模型規範\n", "\n", "- 將配方和模型規範捆綁到一個工作流程中\n", "\n", "- 通過擬合工作流程來創建模型\n", "\n", "- 評估模型在測試數據上的表現\n", "\n", "讓我們開始吧!\n" ], "metadata": { "id": "VcEIpRV9wzYr" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Specify a recipe\r\n", "poly_pumpkins_recipe <-\r\n", " recipe(price ~ package, data = pumpkins_train) %>%\r\n", " step_integer(all_predictors(), zero_based = TRUE) %>% \r\n", " step_poly(all_predictors(), degree = 4)\r\n", "\r\n", "\r\n", "# Create a model specification\r\n", "poly_spec <- linear_reg() %>% \r\n", " set_engine(\"lm\") %>% \r\n", " set_mode(\"regression\")\r\n", "\r\n", "\r\n", "# Bundle recipe and model spec into a workflow\r\n", "poly_wf <- workflow() %>% \r\n", " add_recipe(poly_pumpkins_recipe) %>% \r\n", " add_model(poly_spec)\r\n", "\r\n", "\r\n", "# Create a model\r\n", "poly_wf_fit <- poly_wf %>% \r\n", " fit(data = pumpkins_train)\r\n", "\r\n", "\r\n", "# Print learned model coefficients\r\n", "poly_wf_fit\r\n", "\r\n", " " ], "outputs": [], "metadata": { "id": "63n_YyRXw3CC" } }, { "cell_type": "markdown", "source": [ "#### 評估模型表現\n", "\n", "👏👏你已經建立了一個多項式模型,現在讓我們在測試集上進行預測吧!\n" ], "metadata": { "id": "-LHZtztSxDP0" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Make price predictions on test data\r\n", "poly_results <- poly_wf_fit %>% predict(new_data = pumpkins_test) %>% \r\n", " bind_cols(pumpkins_test %>% select(c(package, price))) %>% \r\n", " relocate(.pred, .after = last_col())\r\n", "\r\n", "\r\n", "# Print the results\r\n", "poly_results %>% \r\n", " slice_head(n = 10)" ], "outputs": [], "metadata": { "id": "YUFpQ_dKxJGx" } }, { "cell_type": "markdown", "source": [ "Woo-hoo,讓我們使用 `yardstick::metrics()` 評估模型在 test_set 上的表現。\n" ], "metadata": { "id": "qxdyj86bxNGZ" } }, { "cell_type": "code", "execution_count": null, "source": [ "metrics(data = poly_results, truth = price, estimate = .pred)" ], "outputs": [], "metadata": { "id": "8AW5ltkBxXDm" } }, { "cell_type": "markdown", "source": [ "🤩🤩 表現大幅提升。\n", "\n", "`rmse` 從大約 7 降至大約 3,這表示實際價格與預測價格之間的誤差減少。你可以*大致*理解為平均來說,錯誤的預測大約相差 \\$3。`rsq` 從大約 0.4 增加到 0.8。\n", "\n", "所有這些指標都顯示多項式模型的表現遠優於線性模型。幹得好!\n", "\n", "讓我們看看是否可以將這些視覺化!\n" ], "metadata": { "id": "6gLHNZDwxYaS" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Bind encoded package column to the results\r\n", "poly_results <- poly_results %>% \r\n", " bind_cols(package_encode %>% \r\n", " rename(package_integer = package)) %>% \r\n", " relocate(package_integer, .after = package)\r\n", "\r\n", "\r\n", "# Print new results data frame\r\n", "poly_results %>% \r\n", " slice_head(n = 5)\r\n", "\r\n", "\r\n", "# Make a scatter plot\r\n", "poly_results %>% \r\n", " ggplot(mapping = aes(x = package_integer, y = price)) +\r\n", " geom_point(size = 1.6) +\r\n", " # Overlay a line of best fit\r\n", " geom_line(aes(y = .pred), color = \"midnightblue\", size = 1.2) +\r\n", " xlab(\"package\")\r\n" ], "outputs": [], "metadata": { "id": "A83U16frxdF1" } }, { "cell_type": "markdown", "source": [ "你可以看到一條更符合你數據的曲線!🤩\n", "\n", "你可以通過向 `geom_smooth` 傳遞一個多項式公式來使其更加平滑,例如:\n" ], "metadata": { "id": "4U-7aHOVxlGU" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Make a scatter plot\r\n", "poly_results %>% \r\n", " ggplot(mapping = aes(x = package_integer, y = price)) +\r\n", " geom_point(size = 1.6) +\r\n", " # Overlay a line of best fit\r\n", " geom_smooth(method = lm, formula = y ~ poly(x, degree = 4), color = \"midnightblue\", size = 1.2, se = FALSE) +\r\n", " xlab(\"package\")" ], "outputs": [], "metadata": { "id": "5vzNT0Uexm-w" } }, { "cell_type": "markdown", "source": [ "就像一條流暢的曲線!🤩\n", "\n", "以下是如何進行新的預測:\n" ], "metadata": { "id": "v9u-wwyLxq4G" } }, { "cell_type": "code", "execution_count": null, "source": [ "# Make a hypothetical data frame\r\n", "hypo_tibble <- tibble(package = \"bushel baskets\")\r\n", "\r\n", "# Make predictions using linear model\r\n", "lm_pred <- lm_wf_fit %>% predict(new_data = hypo_tibble)\r\n", "\r\n", "# Make predictions using polynomial model\r\n", "poly_pred <- poly_wf_fit %>% predict(new_data = hypo_tibble)\r\n", "\r\n", "# Return predictions in a list\r\n", "list(\"linear model prediction\" = lm_pred, \r\n", " \"polynomial model prediction\" = poly_pred)\r\n" ], "outputs": [], "metadata": { "id": "jRPSyfQGxuQv" } }, { "cell_type": "markdown", "source": [ "`多項式模型` 的預測確實有道理,根據 `price` 和 `package` 的散點圖來看!而且,如果這比之前的模型更好,從相同的數據來看,你需要為這些更昂貴的南瓜預算!\n", "\n", "🏆 做得好!你在一節課中建立了兩個回歸模型。在回歸的最後一部分,你將學習邏輯回歸來判定分類。\n", "\n", "## **🚀挑戰**\n", "\n", "在這個筆記本中測試幾個不同的變數,看看相關性如何影響模型的準確性。\n", "\n", "## [**課後測驗**](https://gray-sand-07a10f403.1.azurestaticapps.net/quiz/14/)\n", "\n", "## **複習與自學**\n", "\n", "在這節課中,我們學習了線性回歸。還有其他重要的回歸類型。閱讀有關逐步回歸、Ridge 回歸、Lasso 回歸和 Elasticnet 技術的資料。一個很好的進階課程是 [Stanford Statistical Learning course](https://online.stanford.edu/courses/sohs-ystatslearning-statistical-learning)。\n", "\n", "如果你想進一步了解如何使用出色的 Tidymodels 框架,請查看以下資源:\n", "\n", "- Tidymodels 官方網站:[開始使用 Tidymodels](https://www.tidymodels.org/start/)\n", "\n", "- Max Kuhn 和 Julia Silge 的著作 [*Tidy Modeling with R*](https://www.tmwr.org/)*.*\n", "\n", "###### **特別感謝:**\n", "\n", "[Allison Horst](https://twitter.com/allison_horst?lang=en) 創作了這些令人驚嘆的插圖,使 R 更加親切和有趣。可以在她的 [畫廊](https://www.google.com/url?q=https://github.com/allisonhorst/stats-illustrations&sa=D&source=editors&ust=1626380772530000&usg=AOvVaw3zcfyCizFQZpkSLzxiiQEM) 中找到更多插圖。\n" ], "metadata": { "id": "8zOLOWqMxzk5" } }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n---\n\n**免責聲明**: \n本文件已使用 AI 翻譯服務 [Co-op Translator](https://github.com/Azure/co-op-translator) 進行翻譯。儘管我們努力確保翻譯的準確性,但請注意,自動翻譯可能包含錯誤或不準確之處。原始文件的母語版本應被視為權威來源。對於關鍵信息,建議使用專業人工翻譯。我們對因使用此翻譯而引起的任何誤解或誤釋不承擔責任。\n" ] } ] }