You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ML-For-Beginners/translations/hk/2-Regression/3-Linear/solution/R/lesson_3-R.ipynb

1086 lines
40 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"nbformat": 4,
"nbformat_minor": 2,
"metadata": {
"colab": {
"name": "lesson_3-R.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "ir",
"display_name": "R"
},
"language_info": {
"name": "R"
},
"coopTranslator": {
"original_hash": "5015d65d61ba75a223bfc56c273aa174",
"translation_date": "2025-09-03T19:27:58+00:00",
"source_file": "2-Regression/3-Linear/solution/R/lesson_3-R.ipynb",
"language_code": "hk"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "EgQw8osnsUV-"
}
},
{
"cell_type": "markdown",
"source": [
"## 南瓜定價的線性回歸與多項式回歸 - 第三課\n",
"<p >\n",
" <img src=\"../../images/linear-polynomial.png\"\n",
" width=\"800\"/>\n",
" <figcaption>資訊圖表由 Dasani Madipalli 製作</figcaption>\n",
"\n",
"\n",
"<!--![資訊圖表由 Dasani Madipalli 製作](../../../../../../translated_images/linear-polynomial.5523c7cb6576ccab0fecbd0e3505986eb2d191d9378e785f82befcf3a578a6e7.hk.png){width=\"800\"}-->\n",
"\n",
"#### 簡介\n",
"\n",
"到目前為止,你已經透過我們將在本課中使用的南瓜定價數據集,探索了回歸分析的概念,並使用 `ggplot2` 將其可視化。💪\n",
"\n",
"現在,你已準備好更深入地了解機器學習中的回歸分析。在本課中,你將學習兩種類型的回歸分析:*基本線性回歸* 和 *多項式回歸*,以及這些技術背後的一些數學原理。\n",
"\n",
"> 在整個課程中,我們假設學生的數學知識有限,並致力於讓來自其他領域的學生也能輕鬆理解。因此,請留意筆記、🧮 提示、圖表和其他學習工具,幫助你更好地掌握內容。\n",
"\n",
"#### 準備工作\n",
"\n",
"提醒一下,你正在載入這些數據以便對其進行分析並提出問題。\n",
"\n",
"- 什麼時候是購買南瓜的最佳時機?\n",
"\n",
"- 一箱迷你南瓜的價格大約是多少?\n",
"\n",
"- 我應該購買半蒲式耳籃裝的還是 1 1/9 蒲式耳箱裝的?讓我們繼續深入挖掘這些數據。\n",
"\n",
"在上一課中,你創建了一個 `tibble`(一種現代化的數據框架),並用部分原始數據集填充它,將價格標準化為以蒲式耳為單位。然而,這樣做後,你只能收集到大約 400 個數據點,並且僅限於秋季月份。也許通過進一步清理數據,我們可以獲得更多細節?我們拭目以待吧... 🕵️‍♀️\n",
"\n",
"為了完成這項任務,我們需要以下套件:\n",
"\n",
"- `tidyverse` [tidyverse](https://www.tidyverse.org/) 是一個 [R 套件集合](https://www.tidyverse.org/packages),旨在讓數據科學更快速、更簡單、更有趣!\n",
"\n",
"- `tidymodels` [tidymodels](https://www.tidymodels.org/) 框架是一個 [套件集合](https://www.tidymodels.org/packages/),用於建模和機器學習。\n",
"\n",
"- `janitor` [janitor 套件](https://github.com/sfirke/janitor) 提供了一些簡單的小工具,用於檢查和清理髒數據。\n",
"\n",
"- `corrplot` [corrplot 套件](https://cran.r-project.org/web/packages/corrplot/vignettes/corrplot-intro.html) 提供了一種可視化相關矩陣的工具,支持自動變量重排序,以幫助發現變量之間隱藏的模式。\n",
"\n",
"你可以通過以下方式安裝這些套件:\n",
"\n",
"`install.packages(c(\"tidyverse\", \"tidymodels\", \"janitor\", \"corrplot\"))`\n",
"\n",
"以下腳本將檢查你是否已安裝完成本模組所需的套件,並在缺少時自動為你安裝。\n"
],
"metadata": {
"id": "WqQPS1OAsg3H"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"suppressWarnings(if (!require(\"pacman\")) install.packages(\"pacman\"))\n",
"\n",
"pacman::p_load(tidyverse, tidymodels, janitor, corrplot)"
],
"outputs": [],
"metadata": {
"id": "tA4C2WN3skCf",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c06cd805-5534-4edc-f72b-d0d1dab96ac0"
}
},
{
"cell_type": "markdown",
"source": [
"我們稍後會載入這些很棒的套件,並將它們在我們當前的 R 工作環境中啟用。(這只是為了說明,`pacman::p_load()` 已經幫你完成了這件事)\n",
"\n",
"## 1. 線性回歸線\n",
"\n",
"正如你在第一課中學到的,線性回歸的目標是能夠繪製一條*最佳擬合線*,以便:\n",
"\n",
"- **展示變數之間的關係**。顯示變數之間的關聯性。\n",
"\n",
"- **進行預測**。準確預測新數據點在這條線上的位置。\n",
"\n",
"為了繪製這種類型的線,我們使用一種稱為**最小平方法**的統計技術。`最小平方法`的意思是,將回歸線周圍的所有數據點的誤差平方後相加。理想情況下,這個最終的總和應該越小越好,因為我們希望誤差數值越低越好,也就是`最小平方`。因此,最佳擬合線就是使平方誤差總和最小的那條線——這就是*最小平方法回歸*名稱的由來。\n",
"\n",
"我們這樣做是因為我們希望建模出一條與所有數據點的累積距離最小的線。我們在相加之前對誤差進行平方,因為我們關注的是誤差的大小,而不是方向。\n",
"\n",
"> **🧮 數學公式展示**\n",
">\n",
"> 這條線,稱為*最佳擬合線*,可以用[一個方程式](https://en.wikipedia.org/wiki/Simple_linear_regression)來表示:\n",
">\n",
"> Y = a + bX\n",
">\n",
"> `X` 是「`解釋變數`或`預測變數`」。`Y` 是「`應變數`或`結果變數`」。`b` 是線的斜率,而 `a` 是 y 截距,表示當 `X = 0` 時 `Y` 的值。\n",
">\n",
"\n",
"> ![](../../../../../../2-Regression/3-Linear/solution/images/slope.png \"slope = $y/x$\")\n",
" 圖解由 Jen Looper 提供\n",
">\n",
"> 首先,計算斜率 `b`。\n",
">\n",
"> 換句話說,參考我們南瓜數據的原始問題:「根據月份預測每蒲式耳南瓜的價格」,`X` 代表價格,`Y` 代表銷售月份。\n",
">\n",
"> ![](../../../../../../translated_images/calculation.989aa7822020d9d0ba9fc781f1ab5192f3421be86ebb88026528aef33c37b0d8.hk.png)\n",
" 圖解由 Jen Looper 提供\n",
"> \n",
"> 計算 Y 的值。如果你支付大約 4 美元,那應該是四月!\n",
">\n",
"> 計算這條線的數學公式必須展示線的斜率,而這也取決於截距,或者說當 `X = 0` 時 `Y` 的位置。\n",
">\n",
"> 你可以在 [Math is Fun](https://www.mathsisfun.com/data/least-squares-regression.html) 網站上觀察這些值的計算方法。還可以訪問[這個最小平方法計算器](https://www.mathsisfun.com/data/least-squares-calculator.html),觀察數值如何影響這條線。\n",
"\n",
"不那麼可怕吧?🤓\n",
"\n",
"#### 相關性\n",
"\n",
"還有一個需要理解的術語是給定 X 和 Y 變數之間的**相關係數**。使用散點圖,你可以快速地可視化這個係數。數據點整齊排列成一條線的圖表具有高相關性,而數據點在 X 和 Y 之間隨意分佈的圖表則具有低相關性。\n",
"\n",
"一個好的線性回歸模型應該是使用最小平方法回歸線,並且具有高(接近 1 而非 0的相關係數的模型。\n"
],
"metadata": {
"id": "cdX5FRpvsoP5"
}
},
{
"cell_type": "markdown",
"source": [
"## **2. 與數據共舞:建立用於建模的數據框**\n",
"\n",
"<p >\n",
" <img src=\"../../images/janitor.jpg\"\n",
" width=\"700\"/>\n",
" <figcaption>插圖由 @allison_horst 提供</figcaption>\n",
"\n",
"\n",
"<!--![插圖由 \\@allison_horst 提供](../../../../../../translated_images/janitor.e4a77dd3d3e6a32e25327090b8a9c00dc7cf459c44fa9f184c5ecb0d48ce3794.hk.jpg){width=\"700\"}-->\n"
],
"metadata": {
"id": "WdUKXk7Bs8-V"
}
},
{
"cell_type": "markdown",
"source": [
"載入所需的庫和數據集。將數據轉換為包含部分數據的數據框:\n",
"\n",
"- 只選擇以蒲式耳定價的南瓜\n",
"\n",
"- 將日期轉換為月份\n",
"\n",
"- 計算價格為高價和低價的平均值\n",
"\n",
"- 將價格轉換為反映按蒲式耳數量定價\n",
"\n",
"> 我們在[上一課](https://github.com/microsoft/ML-For-Beginners/blob/main/2-Regression/2-Data/solution/lesson_2-R.ipynb)中已經涵蓋了這些步驟。\n"
],
"metadata": {
"id": "fMCtu2G2s-p8"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Load the core Tidyverse packages\n",
"library(tidyverse)\n",
"library(lubridate)\n",
"\n",
"# Import the pumpkins data\n",
"pumpkins <- read_csv(file = \"https://raw.githubusercontent.com/microsoft/ML-For-Beginners/main/2-Regression/data/US-pumpkins.csv\")\n",
"\n",
"\n",
"# Get a glimpse and dimensions of the data\n",
"glimpse(pumpkins)\n",
"\n",
"\n",
"# Print the first 50 rows of the data set\n",
"pumpkins %>% \n",
" slice_head(n = 5)"
],
"outputs": [],
"metadata": {
"id": "ryMVZEEPtERn"
}
},
{
"cell_type": "markdown",
"source": [
"在純粹冒險的精神下,讓我們探索 [`janitor package`](../../../../../../2-Regression/3-Linear/solution/R/github.com/sfirke/janitor),它提供了簡單的函數來檢查和清理髒數據。例如,讓我們看看我們數據的列名稱:\n"
],
"metadata": {
"id": "xcNxM70EtJjb"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Return column names\n",
"pumpkins %>% \n",
" names()"
],
"outputs": [],
"metadata": {
"id": "5XtpaIigtPfW"
}
},
{
"cell_type": "markdown",
"source": [
"🤔 我們可以做得更好。讓我們通過使用 `janitor::clean_names` 將這些列名稱轉換為 [snake_case](https://en.wikipedia.org/wiki/Snake_case) 規範,並將它們命名為 `friendR`。想了解更多關於此函數的信息:`?clean_names`\n"
],
"metadata": {
"id": "IbIqrMINtSHe"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Clean names to the snake_case convention\n",
"pumpkins <- pumpkins %>% \n",
" clean_names(case = \"snake\")\n",
"\n",
"# Return column names\n",
"pumpkins %>% \n",
" names()"
],
"outputs": [],
"metadata": {
"id": "a2uYvclYtWvX"
}
},
{
"cell_type": "markdown",
"source": [
"非常整潔 🧹!現在,像上一課一樣,用 `dplyr` 與數據共舞吧!💃\n"
],
"metadata": {
"id": "HfhnuzDDtaDd"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Select desired columns\n",
"pumpkins <- pumpkins %>% \n",
" select(variety, city_name, package, low_price, high_price, date)\n",
"\n",
"\n",
"\n",
"# Extract the month from the dates to a new column\n",
"pumpkins <- pumpkins %>%\n",
" mutate(date = mdy(date),\n",
" month = month(date)) %>% \n",
" select(-date)\n",
"\n",
"\n",
"\n",
"# Create a new column for average Price\n",
"pumpkins <- pumpkins %>% \n",
" mutate(price = (low_price + high_price)/2)\n",
"\n",
"\n",
"# Retain only pumpkins with the string \"bushel\"\n",
"new_pumpkins <- pumpkins %>% \n",
" filter(str_detect(string = package, pattern = \"bushel\"))\n",
"\n",
"\n",
"# Normalize the pricing so that you show the pricing per bushel, not per 1 1/9 or 1/2 bushel\n",
"new_pumpkins <- new_pumpkins %>% \n",
" mutate(price = case_when(\n",
" str_detect(package, \"1 1/9\") ~ price/(1.1),\n",
" str_detect(package, \"1/2\") ~ price*2,\n",
" TRUE ~ price))\n",
"\n",
"# Relocate column positions\n",
"new_pumpkins <- new_pumpkins %>% \n",
" relocate(month, .before = variety)\n",
"\n",
"\n",
"# Display the first 5 rows\n",
"new_pumpkins %>% \n",
" slice_head(n = 5)"
],
"outputs": [],
"metadata": {
"id": "X0wU3gQvtd9f"
}
},
{
"cell_type": "markdown",
"source": [
"做得好!👌 現在你擁有一個乾淨整潔的數據集,可以用來建立新的回歸模型!\n",
"\n",
"要來個散點圖嗎?\n"
],
"metadata": {
"id": "UpaIwaxqth82"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Set theme\n",
"theme_set(theme_light())\n",
"\n",
"# Make a scatter plot of month and price\n",
"new_pumpkins %>% \n",
" ggplot(mapping = aes(x = month, y = price)) +\n",
" geom_point(size = 1.6)\n"
],
"outputs": [],
"metadata": {
"id": "DXgU-j37tl5K"
}
},
{
"cell_type": "markdown",
"source": [
"散佈圖提醒我們,目前只有八月至十二月的月份數據。可能需要更多數據才能以線性方式得出結論。\n",
"\n",
"讓我們再次查看我們的建模數據:\n"
],
"metadata": {
"id": "Ve64wVbwtobI"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Display first 5 rows\n",
"new_pumpkins %>% \n",
" slice_head(n = 5)"
],
"outputs": [],
"metadata": {
"id": "HFQX2ng1tuSJ"
}
},
{
"cell_type": "markdown",
"source": [
"如果我們想根據屬於字符類型的 `city` 或 `package` 欄位來預測南瓜的 `price`,該怎麼辦?或者更簡單地說,我們如何找到例如 `package` 和 `price` 之間的相關性(這需要兩個輸入都是數值型)?🤷🤷\n",
"\n",
"機器學習模型在處理數值特徵時效果最佳,而不是文本值,因此通常需要將分類特徵轉換為數值表示。\n",
"\n",
"這意味著我們需要找到一種方法來重新格式化我們的預測變數,使其更容易被模型有效使用,這個過程被稱為 `特徵工程`。\n"
],
"metadata": {
"id": "7hsHoxsStyjJ"
}
},
{
"cell_type": "markdown",
"source": [
"## 3. 使用 recipes 👩‍🍳👨‍🍳 預處理建模數據\n",
"\n",
"將預測值重新格式化,使其更容易被模型有效使用的活動被稱為 `特徵工程`。\n",
"\n",
"不同的模型有不同的預處理需求。例如,最小二乘法需要 `編碼分類變數`,例如月份、品種和城市名稱。這簡單地涉及將包含 `分類值` 的列 `轉換` 為一個或多個 `數值列`,以取代原始列。\n",
"\n",
"例如,假設你的數據包含以下分類特徵:\n",
"\n",
"| 城市 |\n",
"|:-------:|\n",
"| 丹佛 |\n",
"| 奈洛比 |\n",
"| 東京 |\n",
"\n",
"你可以應用 *序數編碼*,為每個類別替換一個唯一的整數值,如下所示:\n",
"\n",
"| 城市 |\n",
"|:----:|\n",
"| 0 |\n",
"| 1 |\n",
"| 2 |\n",
"\n",
"這就是我們將對數據進行的操作!\n",
"\n",
"在本節中,我們將探索另一個令人驚嘆的 Tidymodels 套件:[recipes](https://tidymodels.github.io/recipes/) - 它旨在幫助你在訓練模型之前預處理數據。其核心是一個 recipe 對象,定義了應該對數據集應用哪些步驟,以使其準備好進行建模。\n",
"\n",
"現在,讓我們創建一個 recipe通過為預測列中的所有觀測值替換唯一的整數來準備數據進行建模\n"
],
"metadata": {
"id": "AD5kQbcvt3Xl"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Specify a recipe\n",
"pumpkins_recipe <- recipe(price ~ ., data = new_pumpkins) %>% \n",
" step_integer(all_predictors(), zero_based = TRUE)\n",
"\n",
"\n",
"# Print out the recipe\n",
"pumpkins_recipe"
],
"outputs": [],
"metadata": {
"id": "BNaFKXfRt9TU"
}
},
{
"cell_type": "markdown",
"source": [
"太棒了!👏 我們剛剛創建了第一個配方,指定了一個結果(價格)及其相應的預測變數,並且所有的預測變數列都應該被編碼為一組整數 🙌!讓我們快速拆解一下:\n",
"\n",
"- 使用公式調用 `recipe()` 告訴配方變數的*角色*,並以 `new_pumpkins` 數據作為參考。例如,`price` 列被分配為 `outcome` 角色,而其餘的列則被分配為 `predictor` 角色。\n",
"\n",
"- `step_integer(all_predictors(), zero_based = TRUE)` 指定所有的預測變數應該被轉換為一組整數,編號從 0 開始。\n",
"\n",
"我們相信你可能會有這樣的想法:「這太酷了!!但如果我需要確認這些配方是否真的按照我的預期運作呢?🤔」\n",
"\n",
"這是一個很棒的想法!你看,一旦你的配方被定義,你可以估算實際預處理數據所需的參數,然後提取處理後的數據。通常使用 Tidymodels 時不需要這樣做(我們稍後會看到常規方法 -> `workflows`),但當你想進行某種檢查以確認配方是否符合預期時,這會非常有用。\n",
"\n",
"為此,你需要另外兩個動詞:`prep()` 和 `bake()`。一如既往,我們的小 R 朋友由 [`Allison Horst`](https://github.com/allisonhorst/stats-illustrations) 創作的插圖會幫助你更好地理解!\n",
"\n",
"<p >\n",
" <img src=\"../../images/recipes.png\"\n",
" width=\"550\"/>\n",
" <figcaption>插圖由 @allison_horst 創作</figcaption>\n"
],
"metadata": {
"id": "KEiO0v7kuC9O"
}
},
{
"cell_type": "markdown",
"source": [
"[`prep()`](https://recipes.tidymodels.org/reference/prep.html):從訓練集估算所需的參數,這些參數可以稍後應用到其他數據集。例如,對於某個預測變數列,哪個觀察值會被分配為整數 0、1、2 等。\n",
"\n",
"[`bake()`](https://recipes.tidymodels.org/reference/bake.html):使用已準備好的配方,並將操作應用到任何數據集。\n",
"\n",
"話雖如此,讓我們準備並應用配方,真正確認在底層,預測變數列會先被編碼,然後再進行模型擬合。\n"
],
"metadata": {
"id": "Q1xtzebuuTCP"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Prep the recipe\n",
"pumpkins_prep <- prep(pumpkins_recipe)\n",
"\n",
"# Bake the recipe to extract a preprocessed new_pumpkins data\n",
"baked_pumpkins <- bake(pumpkins_prep, new_data = NULL)\n",
"\n",
"# Print out the baked data set\n",
"baked_pumpkins %>% \n",
" slice_head(n = 10)"
],
"outputs": [],
"metadata": {
"id": "FGBbJbP_uUUn"
}
},
{
"cell_type": "markdown",
"source": [
"Woo-hoo!🥳 處理過的數據 `baked_pumpkins` 的所有預測變數都已編碼,這確認了我們所定義的預處理步驟(即我們的配方)確實能如預期般運作。雖然這讓你閱讀起來更困難,但對於 Tidymodels 來說卻更加容易理解!花點時間找出哪個觀察值已被映射到對應的整數。\n",
"\n",
"另外值得一提的是,`baked_pumpkins` 是一個可以進行計算的數據框。\n",
"\n",
"例如,我們可以嘗試在你的數據中找到兩個點之間的良好相關性,以便可能構建一個好的預測模型。我們將使用函數 `cor()` 來完成這項工作。輸入 `?cor()` 來了解更多有關該函數的信息。\n"
],
"metadata": {
"id": "1dvP0LBUueAW"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Find the correlation between the city_name and the price\n",
"cor(baked_pumpkins$city_name, baked_pumpkins$price)\n",
"\n",
"# Find the correlation between the package and the price\n",
"cor(baked_pumpkins$package, baked_pumpkins$price)\n"
],
"outputs": [],
"metadata": {
"id": "3bQzXCjFuiSV"
}
},
{
"cell_type": "markdown",
"source": [
"事實證明,城市與價格之間的相關性相當弱。然而,包裝與其價格之間的相關性稍微好一些。這是合理的,對吧?通常來說,生產箱越大,價格就越高。\n",
"\n",
"既然如此,我們也可以嘗試使用 `corrplot` 套件來視覺化所有欄位的相關性矩陣。\n"
],
"metadata": {
"id": "BToPWbgjuoZw"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Load the corrplot package\n",
"library(corrplot)\n",
"\n",
"# Obtain correlation matrix\n",
"corr_mat <- cor(baked_pumpkins %>% \n",
" # Drop columns that are not really informative\n",
" select(-c(low_price, high_price)))\n",
"\n",
"# Make a correlation plot between the variables\n",
"corrplot(corr_mat, method = \"shade\", shade.col = NA, tl.col = \"black\", tl.srt = 45, addCoef.col = \"black\", cl.pos = \"n\", order = \"original\")"
],
"outputs": [],
"metadata": {
"id": "ZwAL3ksmutVR"
}
},
{
"cell_type": "markdown",
"source": [
"🤩🤩 好得多。\n",
"\n",
"現在可以問這些數據的一個好問題是:'`我可以預期某個南瓜包裝的價格是多少?`' 讓我們直接開始吧!\n",
"\n",
"> 注意:當你使用 **`new_data = NULL`** 烘焙(**`bake()`**)已準備好的配方 **`pumpkins_prep`** 時,你會提取處理過的(即編碼過的)訓練數據。如果你有另一個數據集,例如測試集,並希望查看配方如何預處理它,你只需使用 **`new_data = test_set`** 烘焙 **`pumpkins_prep`** 即可。\n",
"\n",
"## 4. 建立線性回歸模型\n",
"\n",
"<p >\n",
" <img src=\"../../images/linear-polynomial.png\"\n",
" width=\"800\"/>\n",
" <figcaption>Dasani Madipalli 的資訊圖表</figcaption>\n",
"\n",
"\n",
"<!--![Dasani Madipalli 的資訊圖表](../../../../../../translated_images/linear-polynomial.5523c7cb6576ccab0fecbd0e3505986eb2d191d9378e785f82befcf3a578a6e7.hk.png){width=\"800\"}-->\n"
],
"metadata": {
"id": "YqXjLuWavNxW"
}
},
{
"cell_type": "markdown",
"source": [
"現在我們已經建立了一個配方,並確認數據會被適當地預處理,接下來讓我們建立一個回歸模型來回答這個問題:`我可以預期某個南瓜包裝的價格是多少?`\n",
"\n",
"#### 使用訓練集訓練線性回歸模型\n",
"\n",
"如你可能已經猜到,*price* 列是 `結果` 變數,而 *package* 列是 `預測` 變數。\n",
"\n",
"為了完成這項工作,我們首先將數據分割成 80% 的訓練集和 20% 的測試集,然後定義一個配方,將預測變數列編碼為一組整數,接著建立模型規範。我們不會準備和烘焙配方,因為我們已經知道它會按預期預處理數據。\n"
],
"metadata": {
"id": "Pq0bSzCevW-h"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"set.seed(2056)\n",
"# Split the data into training and test sets\n",
"pumpkins_split <- new_pumpkins %>% \n",
" initial_split(prop = 0.8)\n",
"\n",
"\n",
"# Extract training and test data\n",
"pumpkins_train <- training(pumpkins_split)\n",
"pumpkins_test <- testing(pumpkins_split)\n",
"\n",
"\n",
"\n",
"# Create a recipe for preprocessing the data\n",
"lm_pumpkins_recipe <- recipe(price ~ package, data = pumpkins_train) %>% \n",
" step_integer(all_predictors(), zero_based = TRUE)\n",
"\n",
"\n",
"\n",
"# Create a linear model specification\n",
"lm_spec <- linear_reg() %>% \n",
" set_engine(\"lm\") %>% \n",
" set_mode(\"regression\")"
],
"outputs": [],
"metadata": {
"id": "CyoEh_wuvcLv"
}
},
{
"cell_type": "markdown",
"source": [
"做得好!現在我們已經有了一個食譜和一個模型規範,接下來我們需要找到一種方法,將它們打包成一個物件。這個物件將會先對數據進行預處理(在背後執行 prep + bake然後基於預處理後的數據來訓練模型並且還能支持潛在的後處理操作。這樣是不是讓你更安心了呢🤩\n",
"\n",
"在 Tidymodels 中,這個方便的物件叫做 [`workflow`](https://workflows.tidymodels.org/),它能輕鬆地保存你的建模組件!這就像我們在 *Python* 中所說的 *pipelines*。\n",
"\n",
"那麼,讓我們把所有東西打包到一個 workflow 裡吧!📦\n"
],
"metadata": {
"id": "G3zF_3DqviFJ"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Hold modelling components in a workflow\n",
"lm_wf <- workflow() %>% \n",
" add_recipe(lm_pumpkins_recipe) %>% \n",
" add_model(lm_spec)\n",
"\n",
"# Print out the workflow\n",
"lm_wf"
],
"outputs": [],
"metadata": {
"id": "T3olroU3v-WX"
}
},
{
"cell_type": "markdown",
"source": [
"順帶一提,工作流程可以像模型一樣進行適配或訓練。\n"
],
"metadata": {
"id": "zd1A5tgOwEPX"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Train the model\n",
"lm_wf_fit <- lm_wf %>% \n",
" fit(data = pumpkins_train)\n",
"\n",
"# Print the model coefficients learned \n",
"lm_wf_fit"
],
"outputs": [],
"metadata": {
"id": "NhJagFumwFHf"
}
},
{
"cell_type": "markdown",
"source": [
"從模型輸出中,我們可以看到訓練過程中學到的係數。這些係數代表最佳擬合線的係數,該線能讓實際值與預測值之間的整體誤差降到最低。\n",
"\n",
"#### 使用測試集評估模型表現\n",
"\n",
"是時候看看模型的表現如何了 📏!我們該怎麼做呢?\n",
"\n",
"現在模型已經完成訓練,我們可以使用 `parsnip::predict()` 為測試集進行預測。接著,我們可以將這些預測值與實際標籤值進行比較,以評估模型的表現是否良好(或者不良!)。\n",
"\n",
"首先,讓我們為測試集進行預測,然後將這些預測結果與測試集的欄位結合起來。\n"
],
"metadata": {
"id": "_4QkGtBTwItF"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Make predictions for the test set\n",
"predictions <- lm_wf_fit %>% \n",
" predict(new_data = pumpkins_test)\n",
"\n",
"\n",
"# Bind predictions to the test set\n",
"lm_results <- pumpkins_test %>% \n",
" select(c(package, price)) %>% \n",
" bind_cols(predictions)\n",
"\n",
"\n",
"# Print the first ten rows of the tibble\n",
"lm_results %>% \n",
" slice_head(n = 10)"
],
"outputs": [],
"metadata": {
"id": "UFZzTG0gwTs9"
}
},
{
"cell_type": "markdown",
"source": [
"你剛剛訓練了一個模型並用它進行了預測!🔮 它表現如何呢?讓我們來評估模型的表現吧!\n",
"\n",
"在 Tidymodels 中,我們使用 `yardstick::metrics()` 來完成這項工作!對於線性回歸,我們主要關注以下指標:\n",
"\n",
"- `均方根誤差 (RMSE)`: [MSE](https://en.wikipedia.org/wiki/Mean_squared_error) 的平方根。這是一個絕對指標,單位與標籤一致(在這個例子中是南瓜的價格)。數值越小,模型越好(簡單來說,它代表了預測錯誤的平均價格!)\n",
"\n",
"- `決定係數 (通常稱為 R-squared 或 R2)`: 一個相對指標,數值越高,模型的擬合效果越好。基本上,這個指標表示模型能解釋多少預測值與實際標籤值之間的差異。\n"
],
"metadata": {
"id": "0A5MjzM7wW9M"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Evaluate performance of linear regression\n",
"metrics(data = lm_results,\n",
" truth = price,\n",
" estimate = .pred)"
],
"outputs": [],
"metadata": {
"id": "reJ0UIhQwcEH"
}
},
{
"cell_type": "markdown",
"source": [
"模型表現下降了。讓我們試試看能否透過視覺化包裝與價格的散點圖,並使用模型的預測結果疊加最佳擬合線,來獲得更好的指標。\n",
"\n",
"這意味著我們需要準備並處理測試集,以便對包裝欄位進行編碼,然後將其與模型的預測結果結合起來。\n"
],
"metadata": {
"id": "fdgjzjkBwfWt"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Encode package column\n",
"package_encode <- lm_pumpkins_recipe %>% \n",
" prep() %>% \n",
" bake(new_data = pumpkins_test) %>% \n",
" select(package)\n",
"\n",
"\n",
"# Bind encoded package column to the results\n",
"lm_results <- lm_results %>% \n",
" bind_cols(package_encode %>% \n",
" rename(package_integer = package)) %>% \n",
" relocate(package_integer, .after = package)\n",
"\n",
"\n",
"# Print new results data frame\n",
"lm_results %>% \n",
" slice_head(n = 5)\n",
"\n",
"\n",
"# Make a scatter plot\n",
"lm_results %>% \n",
" ggplot(mapping = aes(x = package_integer, y = price)) +\n",
" geom_point(size = 1.6) +\n",
" # Overlay a line of best fit\n",
" geom_line(aes(y = .pred), color = \"orange\", size = 1.2) +\n",
" xlab(\"package\")\n",
" \n"
],
"outputs": [],
"metadata": {
"id": "R0nw719lwkHE"
}
},
{
"cell_type": "markdown",
"source": [
"很好!如你所見,線性回歸模型並未能很好地概括包裝與其對應價格之間的關係。\n",
"\n",
"🎃 恭喜你!你剛剛創建了一個可以幫助預測幾種南瓜價格的模型。你的節日南瓜田將會非常美麗。但你可能可以創建一個更好的模型!\n",
"\n",
"## 5. 建立多項式回歸模型\n",
"\n",
"<p >\n",
" <img src=\"../../images/linear-polynomial.png\"\n",
" width=\"800\"/>\n",
" <figcaption>資訊圖表由 Dasani Madipalli 製作</figcaption>\n",
"\n",
"\n",
"<!--![資訊圖表由 Dasani Madipalli 製作](../../../../../../translated_images/linear-polynomial.5523c7cb6576ccab0fecbd0e3505986eb2d191d9378e785f82befcf3a578a6e7.hk.png){width=\"800\"}-->\n"
],
"metadata": {
"id": "HOCqJXLTwtWI"
}
},
{
"cell_type": "markdown",
"source": [
"有時候,我們的數據可能並非呈線性關係,但我們仍然希望能夠預測結果。多項式回歸可以幫助我們針對更複雜的非線性關係進行預測。\n",
"\n",
"舉例來說,南瓜數據集中包裝與價格之間的關係。有時候,變量之間可能存在線性關係——例如南瓜的體積越大,價格越高——但有時候這些關係無法用平面或直線來表示。\n",
"\n",
"> ✅ 這裡有[更多例子](https://online.stat.psu.edu/stat501/lesson/9/9.8),展示了哪些數據可以使用多項式回歸\n",
">\n",
"> 再看看之前圖表中品種與價格之間的關係。這個散點圖看起來是否一定要用直線來分析?可能不是。在這種情況下,你可以嘗試使用多項式回歸。\n",
">\n",
"> ✅ 多項式是可能包含一個或多個變量和係數的數學表達式\n",
"\n",
"#### 使用訓練集訓練多項式回歸模型\n",
"\n",
"多項式回歸會創建一條*曲線*,以更好地擬合非線性數據。\n",
"\n",
"讓我們看看多項式模型是否能在預測中表現得更好。我們將遵循與之前類似的步驟:\n",
"\n",
"- 創建一個配方,指定數據預處理步驟,使其準備好進行建模,例如:編碼預測變量並計算*n*次多項式\n",
"\n",
"- 建立模型規範\n",
"\n",
"- 將配方和模型規範打包到工作流程中\n",
"\n",
"- 通過擬合工作流程來創建模型\n",
"\n",
"- 評估模型在測試數據上的表現\n",
"\n",
"讓我們開始吧!\n"
],
"metadata": {
"id": "VcEIpRV9wzYr"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Specify a recipe\r\n",
"poly_pumpkins_recipe <-\r\n",
" recipe(price ~ package, data = pumpkins_train) %>%\r\n",
" step_integer(all_predictors(), zero_based = TRUE) %>% \r\n",
" step_poly(all_predictors(), degree = 4)\r\n",
"\r\n",
"\r\n",
"# Create a model specification\r\n",
"poly_spec <- linear_reg() %>% \r\n",
" set_engine(\"lm\") %>% \r\n",
" set_mode(\"regression\")\r\n",
"\r\n",
"\r\n",
"# Bundle recipe and model spec into a workflow\r\n",
"poly_wf <- workflow() %>% \r\n",
" add_recipe(poly_pumpkins_recipe) %>% \r\n",
" add_model(poly_spec)\r\n",
"\r\n",
"\r\n",
"# Create a model\r\n",
"poly_wf_fit <- poly_wf %>% \r\n",
" fit(data = pumpkins_train)\r\n",
"\r\n",
"\r\n",
"# Print learned model coefficients\r\n",
"poly_wf_fit\r\n",
"\r\n",
" "
],
"outputs": [],
"metadata": {
"id": "63n_YyRXw3CC"
}
},
{
"cell_type": "markdown",
"source": [
"#### 評估模型表現\n",
"\n",
"👏👏你已經建立了一個多項式模型,現在讓我們在測試集上進行預測吧!\n"
],
"metadata": {
"id": "-LHZtztSxDP0"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Make price predictions on test data\r\n",
"poly_results <- poly_wf_fit %>% predict(new_data = pumpkins_test) %>% \r\n",
" bind_cols(pumpkins_test %>% select(c(package, price))) %>% \r\n",
" relocate(.pred, .after = last_col())\r\n",
"\r\n",
"\r\n",
"# Print the results\r\n",
"poly_results %>% \r\n",
" slice_head(n = 10)"
],
"outputs": [],
"metadata": {
"id": "YUFpQ_dKxJGx"
}
},
{
"cell_type": "markdown",
"source": [
"Woo-hoo讓我們使用 `yardstick::metrics()` 評估模型在 test_set 上的表現。\n"
],
"metadata": {
"id": "qxdyj86bxNGZ"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"metrics(data = poly_results, truth = price, estimate = .pred)"
],
"outputs": [],
"metadata": {
"id": "8AW5ltkBxXDm"
}
},
{
"cell_type": "markdown",
"source": [
"🤩🤩 表現大幅提升。\n",
"\n",
"`rmse` 從大約 7 降至大約 3這表示實際價格與預測價格之間的誤差減少。你可以*大致*理解為平均來說,錯誤的預測大約相差 \\$3。`rsq` 從大約 0.4 增加到 0.8。\n",
"\n",
"所有這些指標都顯示多項式模型的表現遠勝於線性模型。幹得好!\n",
"\n",
"讓我們看看能否將這些視覺化!\n"
],
"metadata": {
"id": "6gLHNZDwxYaS"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Bind encoded package column to the results\r\n",
"poly_results <- poly_results %>% \r\n",
" bind_cols(package_encode %>% \r\n",
" rename(package_integer = package)) %>% \r\n",
" relocate(package_integer, .after = package)\r\n",
"\r\n",
"\r\n",
"# Print new results data frame\r\n",
"poly_results %>% \r\n",
" slice_head(n = 5)\r\n",
"\r\n",
"\r\n",
"# Make a scatter plot\r\n",
"poly_results %>% \r\n",
" ggplot(mapping = aes(x = package_integer, y = price)) +\r\n",
" geom_point(size = 1.6) +\r\n",
" # Overlay a line of best fit\r\n",
" geom_line(aes(y = .pred), color = \"midnightblue\", size = 1.2) +\r\n",
" xlab(\"package\")\r\n"
],
"outputs": [],
"metadata": {
"id": "A83U16frxdF1"
}
},
{
"cell_type": "markdown",
"source": [
"你可以看到一條更符合你數據的曲線!🤩\n",
"\n",
"你可以通過向 `geom_smooth` 傳遞一個多項式公式來使曲線更加平滑,例如:\n"
],
"metadata": {
"id": "4U-7aHOVxlGU"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Make a scatter plot\r\n",
"poly_results %>% \r\n",
" ggplot(mapping = aes(x = package_integer, y = price)) +\r\n",
" geom_point(size = 1.6) +\r\n",
" # Overlay a line of best fit\r\n",
" geom_smooth(method = lm, formula = y ~ poly(x, degree = 4), color = \"midnightblue\", size = 1.2, se = FALSE) +\r\n",
" xlab(\"package\")"
],
"outputs": [],
"metadata": {
"id": "5vzNT0Uexm-w"
}
},
{
"cell_type": "markdown",
"source": [
"就像一條流暢的曲線!🤩\n",
"\n",
"以下是如何進行新的預測:\n"
],
"metadata": {
"id": "v9u-wwyLxq4G"
}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# Make a hypothetical data frame\r\n",
"hypo_tibble <- tibble(package = \"bushel baskets\")\r\n",
"\r\n",
"# Make predictions using linear model\r\n",
"lm_pred <- lm_wf_fit %>% predict(new_data = hypo_tibble)\r\n",
"\r\n",
"# Make predictions using polynomial model\r\n",
"poly_pred <- poly_wf_fit %>% predict(new_data = hypo_tibble)\r\n",
"\r\n",
"# Return predictions in a list\r\n",
"list(\"linear model prediction\" = lm_pred, \r\n",
" \"polynomial model prediction\" = poly_pred)\r\n"
],
"outputs": [],
"metadata": {
"id": "jRPSyfQGxuQv"
}
},
{
"cell_type": "markdown",
"source": [
"`polynomial model` 的預測確實合理,根據 `price` 和 `package` 的散點圖來看!而且,如果這個模型比之前的模型更好,基於相同的數據,你需要為這些更昂貴的南瓜預算做好準備!\n",
"\n",
"🏆 做得好!你在一節課中建立了兩個回歸模型。在回歸的最後一部分,你將學習如何使用邏輯回歸來判定分類。\n",
"\n",
"## **🚀挑戰**\n",
"\n",
"在這個筆記本中測試幾個不同的變數,看看相關性如何影響模型的準確性。\n",
"\n",
"## [**課後測驗**](https://gray-sand-07a10f403.1.azurestaticapps.net/quiz/14/)\n",
"\n",
"## **回顧與自學**\n",
"\n",
"在這節課中,我們學習了線性回歸。還有其他重要的回歸類型。閱讀有關逐步回歸、嶺回歸、套索回歸和彈性網技術的資料。一門很好的課程是 [Stanford Statistical Learning course](https://online.stanford.edu/courses/sohs-ystatslearning-statistical-learning)。\n",
"\n",
"如果你想了解更多如何使用出色的 Tidymodels 框架,請查看以下資源:\n",
"\n",
"- Tidymodels 網站:[開始使用 Tidymodels](https://www.tidymodels.org/start/)\n",
"\n",
"- Max Kuhn 和 Julia Silge[*Tidy Modeling with R*](https://www.tmwr.org/)*.*\n",
"\n",
"###### **感謝:**\n",
"\n",
"[Allison Horst](https://twitter.com/allison_horst?lang=en) 創作了令人驚嘆的插圖,使 R 更加友好和吸引人。可以在她的 [畫廊](https://www.google.com/url?q=https://github.com/allisonhorst/stats-illustrations&sa=D&source=editors&ust=1626380772530000&usg=AOvVaw3zcfyCizFQZpkSLzxiiQEM) 中找到更多插圖。\n"
],
"metadata": {
"id": "8zOLOWqMxzk5"
}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n---\n\n**免責聲明** \n此文件已使用人工智能翻譯服務 [Co-op Translator](https://github.com/Azure/co-op-translator) 進行翻譯。我們致力於提供準確的翻譯,但請注意,自動翻譯可能包含錯誤或不準確之處。應以原始語言的文件作為權威來源。對於關鍵資訊,建議尋求專業人工翻譯。我們對因使用此翻譯而引起的任何誤解或誤釋不承擔責任。\n"
]
}
]
}