You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
686 lines
30 KiB
686 lines
30 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Vytvoření modelu logistické regrese - Lekce 4\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"#### **[Kvíz před lekcí](https://gray-sand-07a10f403.1.azurestaticapps.net/quiz/15/)**\n",
|
|
"\n",
|
|
"#### Úvod\n",
|
|
"\n",
|
|
"V této závěrečné lekci o regresi, jednom ze základních *klasických* technik strojového učení, se podíváme na logistickou regresi. Tuto techniku byste použili k objevování vzorců pro predikci binárních kategorií. Je tato cukrovinka čokoláda, nebo ne? Je tato nemoc nakažlivá, nebo ne? Vybere si tento zákazník tento produkt, nebo ne?\n",
|
|
"\n",
|
|
"V této lekci se naučíte:\n",
|
|
"\n",
|
|
"- Techniky logistické regrese\n",
|
|
"\n",
|
|
"✅ Prohlubte své znalosti práce s tímto typem regrese v tomto [výukovém modulu](https://learn.microsoft.com/training/modules/introduction-classification-models/?WT.mc_id=academic-77952-leestott)\n",
|
|
"\n",
|
|
"## Předpoklady\n",
|
|
"\n",
|
|
"Po práci s daty o dýních jsme nyní dostatečně obeznámeni s tím, že existuje jedna binární kategorie, se kterou můžeme pracovat: `Color`.\n",
|
|
"\n",
|
|
"Postavme model logistické regrese, který předpoví, na základě některých proměnných, *jakou barvu bude mít daná dýně* (oranžová 🎃 nebo bílá 👻).\n",
|
|
"\n",
|
|
"> Proč mluvíme o binární klasifikaci v lekci zaměřené na regresi? Pouze z jazykového pohodlí, protože logistická regrese je [ve skutečnosti klasifikační metoda](https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression), i když založená na lineární regresi. O dalších způsobech klasifikace dat se dozvíte v další skupině lekcí.\n",
|
|
"\n",
|
|
"Pro tuto lekci budeme potřebovat následující balíčky:\n",
|
|
"\n",
|
|
"- `tidyverse`: [tidyverse](https://www.tidyverse.org/) je [kolekce balíčků pro R](https://www.tidyverse.org/packages), která usnadňuje, zrychluje a zpříjemňuje práci s datovou vědou!\n",
|
|
"\n",
|
|
"- `tidymodels`: [tidymodels](https://www.tidymodels.org/) je rámec [kolekce balíčků](https://www.tidymodels.org/packages/) pro modelování a strojové učení.\n",
|
|
"\n",
|
|
"- `janitor`: [janitor balíček](https://github.com/sfirke/janitor) poskytuje jednoduché nástroje pro zkoumání a čištění špinavých dat.\n",
|
|
"\n",
|
|
"- `ggbeeswarm`: [ggbeeswarm balíček](https://github.com/eclarke/ggbeeswarm) nabízí metody pro vytváření grafů ve stylu \"beeswarm\" pomocí ggplot2.\n",
|
|
"\n",
|
|
"Můžete je nainstalovat pomocí:\n",
|
|
"\n",
|
|
"`install.packages(c(\"tidyverse\", \"tidymodels\", \"janitor\", \"ggbeeswarm\"))`\n",
|
|
"\n",
|
|
"Alternativně níže uvedený skript zkontroluje, zda máte potřebné balíčky pro dokončení tohoto modulu, a v případě jejich absence je nainstaluje.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"suppressWarnings(if (!require(\"pacman\"))install.packages(\"pacman\"))\n",
|
|
"\n",
|
|
"pacman::p_load(tidyverse, tidymodels, janitor, ggbeeswarm)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## **Definujte otázku**\n",
|
|
"\n",
|
|
"Pro naše účely ji vyjádříme jako binární: 'Bílá' nebo 'Ne bílá'. V našem datasetu je také kategorie 'pruhovaná', ale má málo záznamů, takže ji nebudeme používat. Stejně zmizí, jakmile odstraníme nulové hodnoty z datasetu.\n",
|
|
"\n",
|
|
"> 🎃 Zajímavost: bílé dýně někdy nazýváme 'duchové' dýně. Nejsou příliš snadné na vyřezávání, takže nejsou tak populární jako oranžové, ale vypadají zajímavě! Mohli bychom tedy naši otázku přeformulovat na: 'Duch' nebo 'Ne duch'. 👻\n",
|
|
"\n",
|
|
"## **O logistické regresi**\n",
|
|
"\n",
|
|
"Logistická regrese se liší od lineární regrese, kterou jste se naučili dříve, v několika důležitých ohledech.\n",
|
|
"\n",
|
|
"#### **Binární klasifikace**\n",
|
|
"\n",
|
|
"Logistická regrese nenabízí stejné funkce jako lineární regrese. První z nich poskytuje předpověď o `binární kategorii` (\"oranžová nebo ne oranžová\"), zatímco druhá je schopna předpovídat `kontinuální hodnoty`, například na základě původu dýně a času sklizně, *o kolik se zvýší její cena*.\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"### Další klasifikace\n",
|
|
"\n",
|
|
"Existují i jiné typy logistické regrese, včetně multinomiální a ordinální:\n",
|
|
"\n",
|
|
"- **Multinomiální**, která zahrnuje více než jednu kategorii - \"Oranžová, Bílá a Pruhovaná\".\n",
|
|
"\n",
|
|
"- **Ordinální**, která zahrnuje uspořádané kategorie, užitečné, pokud bychom chtěli logicky uspořádat naše výsledky, například naše dýně, které jsou uspořádány podle konečného počtu velikostí (mini,sm,med,lg,xl,xxl).\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"#### **Proměnné NEMUSÍ být korelované**\n",
|
|
"\n",
|
|
"Pamatujete si, jak lineární regrese fungovala lépe s více korelovanými proměnnými? Logistická regrese je opakem - proměnné nemusí být v souladu. To funguje pro tato data, která mají poměrně slabé korelace.\n",
|
|
"\n",
|
|
"#### **Potřebujete hodně čistých dat**\n",
|
|
"\n",
|
|
"Logistická regrese poskytne přesnější výsledky, pokud použijete více dat; náš malý dataset není pro tento úkol optimální, takže na to pamatujte.\n",
|
|
"\n",
|
|
"✅ Zamyslete se nad typy dat, které by se dobře hodily pro logistickou regresi.\n",
|
|
"\n",
|
|
"## Cvičení - úprava dat\n",
|
|
"\n",
|
|
"Nejprve data trochu vyčistěte, odstraňte nulové hodnoty a vyberte pouze některé sloupce:\n",
|
|
"\n",
|
|
"1. Přidejte následující kód:\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load the core tidyverse packages\n",
|
|
"library(tidyverse)\n",
|
|
"\n",
|
|
"# Import the data and clean column names\n",
|
|
"pumpkins <- read_csv(file = \"https://raw.githubusercontent.com/microsoft/ML-For-Beginners/main/2-Regression/data/US-pumpkins.csv\") %>% \n",
|
|
" clean_names()\n",
|
|
"\n",
|
|
"# Select desired columns\n",
|
|
"pumpkins_select <- pumpkins %>% \n",
|
|
" select(c(city_name, package, variety, origin, item_size, color)) \n",
|
|
"\n",
|
|
"# Drop rows containing missing values and encode color as factor (category)\n",
|
|
"pumpkins_select <- pumpkins_select %>% \n",
|
|
" drop_na() %>% \n",
|
|
" mutate(color = factor(color))\n",
|
|
"\n",
|
|
"# View the first few rows\n",
|
|
"pumpkins_select %>% \n",
|
|
" slice_head(n = 5)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Vždy se můžete podívat na svůj nový dataframe pomocí funkce [*glimpse()*](https://pillar.r-lib.org/reference/glimpse.html), jak je uvedeno níže:\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"pumpkins_select %>% \n",
|
|
" glimpse()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Pojďme si potvrdit, že se skutečně budeme zabývat problémem binární klasifikace:\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Subset distinct observations in outcome column\n",
|
|
"pumpkins_select %>% \n",
|
|
" distinct(color)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Vizualizace - kategorický graf\n",
|
|
"Nyní jste znovu načetli data o dýních a vyčistili je tak, aby obsahovala dataset s několika proměnnými, včetně barvy. Pojďme vizualizovat dataframe v notebooku pomocí knihovny ggplot.\n",
|
|
"\n",
|
|
"Knihovna ggplot nabízí několik skvělých způsobů, jak vizualizovat vaše data. Například můžete porovnat rozložení dat pro každou odrůdu a barvu v kategorickém grafu.\n",
|
|
"\n",
|
|
"1. Vytvořte takový graf pomocí funkce geombar, použijte naše data o dýních a specifikujte barevné mapování pro každou kategorii dýní (oranžová nebo bílá):\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "python"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Specify colors for each value of the hue variable\n",
|
|
"palette <- c(ORANGE = \"orange\", WHITE = \"wheat\")\n",
|
|
"\n",
|
|
"# Create the bar plot\n",
|
|
"ggplot(pumpkins_select, aes(y = variety, fill = color)) +\n",
|
|
" geom_bar(position = \"dodge\") +\n",
|
|
" scale_fill_manual(values = palette) +\n",
|
|
" labs(y = \"Variety\", fill = \"Color\") +\n",
|
|
" theme_minimal()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Pozorováním dat můžete vidět, jak se údaje o barvě vztahují k odrůdě.\n",
|
|
"\n",
|
|
"✅ Na základě tohoto kategorického grafu, jaké zajímavé průzkumy si dokážete představit?\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Předzpracování dat: kódování atributů\n",
|
|
"\n",
|
|
"Náš dataset dýní obsahuje textové hodnoty ve všech svých sloupcích. Práce s kategoriálními daty je pro lidi intuitivní, ale pro stroje nikoliv. Algoritmy strojového učení pracují dobře s čísly. Proto je kódování velmi důležitým krokem ve fázi předzpracování dat, protože nám umožňuje převést kategoriální data na numerická, aniž bychom ztratili jakékoliv informace. Dobré kódování vede k vytvoření kvalitního modelu.\n",
|
|
"\n",
|
|
"Pro kódování atributů existují dva hlavní typy kódovačů:\n",
|
|
"\n",
|
|
"1. Ordinalní kódovač: hodí se pro ordinalní proměnné, což jsou kategoriální proměnné, jejichž data mají logické pořadí, jako například sloupec `item_size` v našem datasetu. Vytváří mapování, kde každá kategorie je reprezentována číslem, které odpovídá pořadí kategorie ve sloupci.\n",
|
|
"\n",
|
|
"2. Kategoriální kódovač: hodí se pro nominální proměnné, což jsou kategoriální proměnné, jejichž data nemají logické pořadí, jako všechny atributy kromě `item_size` v našem datasetu. Jedná se o one-hot kódování, což znamená, že každá kategorie je reprezentována binárním sloupcem: kódovaná proměnná je rovna 1, pokud dýně patří do dané Variety, a 0 v opačném případě.\n",
|
|
"\n",
|
|
"Tidymodels poskytuje další šikovný balíček: [recipes](https://recipes.tidymodels.org/) - balíček pro předzpracování dat. Definujeme `recipe`, který specifikuje, že všechny sloupce prediktorů by měly být kódovány na sadu celých čísel, `prep` pro odhad potřebných množství a statistik potřebných pro jakékoliv operace, a nakonec `bake` pro aplikaci výpočtů na nová data.\n",
|
|
"\n",
|
|
"> Obvykle se recipes používá jako předzpracovatel pro modelování, kde definuje, jaké kroky by měly být aplikovány na dataset, aby byl připraven pro modelování. V takovém případě je **vysoce doporučeno**, abyste použili `workflow()` místo ručního odhadu recipe pomocí prep a bake. Vše si ukážeme za chvíli.\n",
|
|
">\n",
|
|
"> Prozatím však používáme recipes + prep + bake k tomu, abychom specifikovali, jaké kroky by měly být aplikovány na dataset, aby byl připraven pro analýzu dat, a následně extrahovali předzpracovaná data s aplikovanými kroky.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Preprocess and extract data to allow some data analysis\n",
|
|
"baked_pumpkins <- recipe(color ~ ., data = pumpkins_select) %>%\n",
|
|
" # Define ordering for item_size column\n",
|
|
" step_mutate(item_size = ordered(item_size, levels = c('sml', 'med', 'med-lge', 'lge', 'xlge', 'jbo', 'exjbo'))) %>%\n",
|
|
" # Convert factors to numbers using the order defined above (Ordinal encoding)\n",
|
|
" step_integer(item_size, zero_based = F) %>%\n",
|
|
" # Encode all other predictors using one hot encoding\n",
|
|
" step_dummy(all_nominal(), -all_outcomes(), one_hot = TRUE) %>%\n",
|
|
" prep(data = pumpkin_select) %>%\n",
|
|
" bake(new_data = NULL)\n",
|
|
"\n",
|
|
"# Display the first few rows of preprocessed data\n",
|
|
"baked_pumpkins %>% \n",
|
|
" slice_head(n = 5)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"✅ Jaké jsou výhody použití ordinálního enkodéru pro sloupec Item Size?\n",
|
|
"\n",
|
|
"### Analýza vztahů mezi proměnnými\n",
|
|
"\n",
|
|
"Nyní, když jsme předzpracovali naše data, můžeme analyzovat vztahy mezi vlastnostmi a štítkem, abychom získali představu o tom, jak dobře bude model schopen předpovědět štítek na základě vlastností. Nejlepší způsob, jak provést tento typ analýzy, je vizualizace dat. \n",
|
|
"Opět použijeme funkci ggplot geom_boxplot_, abychom zobrazili vztahy mezi Item Size, Variety a Color v kategorickém grafu. Pro lepší vizualizaci dat použijeme zakódovaný sloupec Item Size a nekódovaný sloupec Variety.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Define the color palette\n",
|
|
"palette <- c(ORANGE = \"orange\", WHITE = \"wheat\")\n",
|
|
"\n",
|
|
"# We need the encoded Item Size column to use it as the x-axis values in the plot\n",
|
|
"pumpkins_select_plot<-pumpkins_select\n",
|
|
"pumpkins_select_plot$item_size <- baked_pumpkins$item_size\n",
|
|
"\n",
|
|
"# Create the grouped box plot\n",
|
|
"ggplot(pumpkins_select_plot, aes(x = `item_size`, y = color, fill = color)) +\n",
|
|
" geom_boxplot() +\n",
|
|
" facet_grid(variety ~ ., scales = \"free_x\") +\n",
|
|
" scale_fill_manual(values = palette) +\n",
|
|
" labs(x = \"Item Size\", y = \"\") +\n",
|
|
" theme_minimal() +\n",
|
|
" theme(strip.text = element_text(size = 12)) +\n",
|
|
" theme(axis.text.x = element_text(size = 10)) +\n",
|
|
" theme(axis.title.x = element_text(size = 12)) +\n",
|
|
" theme(axis.title.y = element_blank()) +\n",
|
|
" theme(legend.position = \"bottom\") +\n",
|
|
" guides(fill = guide_legend(title = \"Color\")) +\n",
|
|
" theme(panel.spacing = unit(0.5, \"lines\"))+\n",
|
|
" theme(strip.text.y = element_text(size = 4, hjust = 0)) \n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### Použijte swarm plot\n",
|
|
"\n",
|
|
"Protože Barva je binární kategorie (Bílá nebo Ne), vyžaduje '[specializovaný přístup](https://github.com/rstudio/cheatsheets/blob/main/data-visualization.pdf) k vizualizaci'.\n",
|
|
"\n",
|
|
"Vyzkoušejte `swarm plot`, abyste zobrazili rozložení barvy vzhledem k velikosti položky.\n",
|
|
"\n",
|
|
"Použijeme balíček [ggbeeswarm](https://github.com/eclarke/ggbeeswarm), který poskytuje metody pro vytváření grafů ve stylu beeswarm pomocí ggplot2. Beeswarm grafy jsou způsobem vykreslování bodů, které by se za normálních okolností překrývaly, tak, aby se zobrazily vedle sebe.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create beeswarm plots of color and item_size\n",
|
|
"baked_pumpkins %>% \n",
|
|
" mutate(color = factor(color)) %>% \n",
|
|
" ggplot(mapping = aes(x = color, y = item_size, color = color)) +\n",
|
|
" geom_quasirandom() +\n",
|
|
" scale_color_brewer(palette = \"Dark2\", direction = -1) +\n",
|
|
" theme(legend.position = \"none\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Nyní, když máme představu o vztahu mezi binárními kategoriemi barvy a širší skupinou velikostí, pojďme prozkoumat logistickou regresi, abychom určili pravděpodobnou barvu dané dýně.\n",
|
|
"\n",
|
|
"## Vytvořte svůj model\n",
|
|
"\n",
|
|
"Vyberte proměnné, které chcete použít ve svém klasifikačním modelu, a rozdělte data na trénovací a testovací sady. [rsample](https://rsample.tidymodels.org/), balíček v Tidymodels, poskytuje infrastrukturu pro efektivní rozdělování dat a resampling:\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Split data into 80% for training and 20% for testing\n",
|
|
"set.seed(2056)\n",
|
|
"pumpkins_split <- pumpkins_select %>% \n",
|
|
" initial_split(prop = 0.8)\n",
|
|
"\n",
|
|
"# Extract the data in each split\n",
|
|
"pumpkins_train <- training(pumpkins_split)\n",
|
|
"pumpkins_test <- testing(pumpkins_split)\n",
|
|
"\n",
|
|
"# Print out the first 5 rows of the training set\n",
|
|
"pumpkins_train %>% \n",
|
|
" slice_head(n = 5)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"🙌 Nyní jsme připraveni trénovat model tím, že přizpůsobíme tréninkové vlastnosti tréninkovému štítku (barvě).\n",
|
|
"\n",
|
|
"Začneme vytvořením receptu, který určuje kroky předzpracování, jež by měly být provedeny na našich datech, aby byla připravena pro modelování, tj. kódování kategoriálních proměnných do sady celých čísel. Stejně jako `baked_pumpkins` vytvoříme `pumpkins_recipe`, ale nebudeme používat `prep` a `bake`, protože to bude zahrnuto do pracovního postupu, který uvidíte za pár kroků.\n",
|
|
"\n",
|
|
"Existuje poměrně mnoho způsobů, jak specifikovat logistický regresní model v Tidymodels. Podívejte se na `?logistic_reg()`. Prozatím specifikujeme logistický regresní model pomocí výchozího enginu `stats::glm()`.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Create a recipe that specifies preprocessing steps for modelling\n",
|
|
"pumpkins_recipe <- recipe(color ~ ., data = pumpkins_train) %>% \n",
|
|
" step_mutate(item_size = ordered(item_size, levels = c('sml', 'med', 'med-lge', 'lge', 'xlge', 'jbo', 'exjbo'))) %>%\n",
|
|
" step_integer(item_size, zero_based = F) %>% \n",
|
|
" step_dummy(all_nominal(), -all_outcomes(), one_hot = TRUE)\n",
|
|
"\n",
|
|
"# Create a logistic model specification\n",
|
|
"log_reg <- logistic_reg() %>% \n",
|
|
" set_engine(\"glm\") %>% \n",
|
|
" set_mode(\"classification\")\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Nyní, když máme recept a specifikaci modelu, musíme najít způsob, jak je spojit dohromady do objektu, který nejprve předzpracuje data (příprava + pečení na pozadí), přizpůsobí model na předzpracovaná data a také umožní případné aktivity po zpracování.\n",
|
|
"\n",
|
|
"V Tidymodels se tento praktický objekt nazývá [`workflow`](https://workflows.tidymodels.org/) a pohodlně uchovává vaše modelovací komponenty.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Bundle modelling components in a workflow\n",
|
|
"log_reg_wf <- workflow() %>% \n",
|
|
" add_recipe(pumpkins_recipe) %>% \n",
|
|
" add_model(log_reg)\n",
|
|
"\n",
|
|
"# Print out the workflow\n",
|
|
"log_reg_wf\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Poté, co byl pracovní postup *definován*, může být model `natrénován` pomocí funkce [`fit()`](https://tidymodels.github.io/parsnip/reference/fit.html). Pracovní postup odhadne recept a předzpracuje data před tréninkem, takže to nebudeme muset dělat ručně pomocí funkcí prep a bake.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Train the model\n",
|
|
"wf_fit <- log_reg_wf %>% \n",
|
|
" fit(data = pumpkins_train)\n",
|
|
"\n",
|
|
"# Print the trained workflow\n",
|
|
"wf_fit\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Model vytiskne koeficienty naučené během trénování.\n",
|
|
"\n",
|
|
"Nyní, když jsme model natrénovali pomocí trénovacích dat, můžeme provádět predikce na testovacích datech pomocí [parsnip::predict()](https://parsnip.tidymodels.org/reference/predict.model_fit.html). Začněme tím, že použijeme model k predikci štítků pro naši testovací sadu a pravděpodobností pro každý štítek. Pokud je pravděpodobnost vyšší než 0,5, predikovaná třída je `WHITE`, jinak `ORANGE`.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Make predictions for color and corresponding probabilities\n",
|
|
"results <- pumpkins_test %>% select(color) %>% \n",
|
|
" bind_cols(wf_fit %>% \n",
|
|
" predict(new_data = pumpkins_test)) %>%\n",
|
|
" bind_cols(wf_fit %>%\n",
|
|
" predict(new_data = pumpkins_test, type = \"prob\"))\n",
|
|
"\n",
|
|
"# Compare predictions\n",
|
|
"results %>% \n",
|
|
" slice_head(n = 10)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Velmi pěkné! To poskytuje další vhledy do toho, jak funguje logistická regrese.\n",
|
|
"\n",
|
|
"### Lepší pochopení pomocí matice záměn\n",
|
|
"\n",
|
|
"Porovnávání každé předpovědi s odpovídající „skutečnou hodnotou“ není příliš efektivní způsob, jak zjistit, jak dobře model předpovídá. Naštěstí má Tidymodels ještě několik triků v rukávu: [`yardstick`](https://yardstick.tidymodels.org/) - balíček používaný k měření účinnosti modelů pomocí metrik výkonu.\n",
|
|
"\n",
|
|
"Jednou z metrik výkonu spojených s klasifikačními problémy je [`matice záměn`](https://wikipedia.org/wiki/Confusion_matrix). Matice záměn popisuje, jak dobře klasifikační model funguje. Matice záměn zaznamenává, kolik příkladů v každé třídě bylo modelem správně klasifikováno. V našem případě vám ukáže, kolik oranžových dýní bylo klasifikováno jako oranžové a kolik bílých dýní bylo klasifikováno jako bílé; matice záměn také ukazuje, kolik bylo klasifikováno do **nesprávných** kategorií.\n",
|
|
"\n",
|
|
"Funkce [**`conf_mat()`**](https://tidymodels.github.io/yardstick/reference/conf_mat.html) z yardsticku vypočítává tuto křížovou tabulku pozorovaných a předpovězených tříd.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Confusion matrix for prediction results\n",
|
|
"conf_mat(data = results, truth = color, estimate = .pred_class)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Pojďme si vysvětlit matici záměn. Náš model má za úkol klasifikovat dýně do dvou binárních kategorií, kategorie `bílá` a kategorie `ne-bílá`.\n",
|
|
"\n",
|
|
"- Pokud váš model předpoví dýni jako bílou a ve skutečnosti patří do kategorie 'bílá', nazýváme to `pravý pozitivní`, což je znázorněno číslem v levém horním rohu.\n",
|
|
"\n",
|
|
"- Pokud váš model předpoví dýni jako ne-bílou a ve skutečnosti patří do kategorie 'bílá', nazýváme to `falešný negativní`, což je znázorněno číslem v levém dolním rohu.\n",
|
|
"\n",
|
|
"- Pokud váš model předpoví dýni jako bílou a ve skutečnosti patří do kategorie 'ne-bílá', nazýváme to `falešný pozitivní`, což je znázorněno číslem v pravém horním rohu.\n",
|
|
"\n",
|
|
"- Pokud váš model předpoví dýni jako ne-bílou a ve skutečnosti patří do kategorie 'ne-bílá', nazýváme to `pravý negativní`, což je znázorněno číslem v pravém dolním rohu.\n",
|
|
"\n",
|
|
"| Skutečnost |\n",
|
|
"|:----------:|\n",
|
|
"\n",
|
|
"\n",
|
|
"| | | |\n",
|
|
"|---------------|--------|-------|\n",
|
|
"| **Předpověď** | BÍLÁ | ORANŽOVÁ |\n",
|
|
"| BÍLÁ | TP | FP |\n",
|
|
"| ORANŽOVÁ | FN | TN |\n",
|
|
"\n",
|
|
"Jak jste možná uhodli, je žádoucí mít větší počet pravých pozitivních a pravých negativních a nižší počet falešných pozitivních a falešných negativních, což naznačuje, že model funguje lépe.\n",
|
|
"\n",
|
|
"Matice záměn je užitečná, protože z ní vycházejí další metriky, které nám mohou pomoci lépe vyhodnotit výkon klasifikačního modelu. Pojďme si je projít:\n",
|
|
"\n",
|
|
"🎓 Přesnost: `TP/(TP + FP)` definovaná jako podíl předpovězených pozitivních, které jsou skutečně pozitivní. Také se nazývá [pozitivní prediktivní hodnota](https://en.wikipedia.org/wiki/Positive_predictive_value \"Positive predictive value\").\n",
|
|
"\n",
|
|
"🎓 Záchyt: `TP/(TP + FN)` definovaný jako podíl pozitivních výsledků z počtu vzorků, které byly skutečně pozitivní. Také známý jako `citlivost`.\n",
|
|
"\n",
|
|
"🎓 Specifičnost: `TN/(TN + FP)` definovaná jako podíl negativních výsledků z počtu vzorků, které byly skutečně negativní.\n",
|
|
"\n",
|
|
"🎓 Přesnost modelu: `TP + TN/(TP + TN + FP + FN)` Procento štítků správně předpovězených pro vzorek.\n",
|
|
"\n",
|
|
"🎓 F míra: Vážený průměr přesnosti a záchytu, přičemž nejlepší hodnota je 1 a nejhorší je 0.\n",
|
|
"\n",
|
|
"Pojďme tyto metriky vypočítat!\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Combine metric functions and calculate them all at once\n",
|
|
"eval_metrics <- metric_set(ppv, recall, spec, f_meas, accuracy)\n",
|
|
"eval_metrics(data = results, truth = color, estimate = .pred_class)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Vizualizace ROC křivky tohoto modelu\n",
|
|
"\n",
|
|
"Provedeme ještě jednu vizualizaci, abychom si prohlédli tzv. [`ROC křivku`](https://en.wikipedia.org/wiki/Receiver_operating_characteristic):\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Make a roc_curve\n",
|
|
"results %>% \n",
|
|
" roc_curve(color, .pred_ORANGE) %>% \n",
|
|
" autoplot()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"ROC křivky se často používají k zobrazení výstupu klasifikátoru z hlediska jeho skutečných vs. falešných pozitivních výsledků. ROC křivky obvykle zobrazují `True Positive Rate`/citlivost na ose Y a `False Positive Rate`/1-specificitu na ose X. Strmost křivky a prostor mezi středovou čarou a křivkou jsou důležité: chcete křivku, která rychle stoupá a překračuje čáru. V našem případě se na začátku objevují falešné pozitivní výsledky, poté čára správně stoupá a překračuje.\n",
|
|
"\n",
|
|
"Nakonec použijme `yardstick::roc_auc()` k výpočtu skutečné plochy pod křivkou. Jedním ze způsobů interpretace AUC je pravděpodobnost, že model zařadí náhodný pozitivní příklad výše než náhodný negativní příklad.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "r"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Calculate area under curve\n",
|
|
"results %>% \n",
|
|
" roc_auc(color, .pred_ORANGE)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Výsledek je přibližně `0.975`. Vzhledem k tomu, že AUC se pohybuje v rozmezí od 0 do 1, chcete dosáhnout vysokého skóre, protože model, který je ve svých předpovědích 100% přesný, bude mít AUC rovno 1; v tomto případě je model *docela dobrý*.\n",
|
|
"\n",
|
|
"V budoucích lekcích o klasifikacích se naučíte, jak zlepšit skóre svého modelu (například jak se vypořádat s nevyváženými daty v tomto případě).\n",
|
|
"\n",
|
|
"## 🚀Výzva\n",
|
|
"\n",
|
|
"Logistická regrese nabízí mnohem více k prozkoumání! Nejlepší způsob, jak se učit, je experimentovat. Najděte dataset, který se hodí k tomuto typu analýzy, a vytvořte s ním model. Co jste se naučili? tip: zkuste [Kaggle](https://www.kaggle.com/search?q=logistic+regression+datasets) pro zajímavé datové sady.\n",
|
|
"\n",
|
|
"## Přehled & Samostudium\n",
|
|
"\n",
|
|
"Přečtěte si první několik stran [tohoto článku ze Stanfordu](https://web.stanford.edu/~jurafsky/slp3/5.pdf) o praktickém využití logistické regrese. Zamyslete se nad úkoly, které jsou lépe vhodné pro jeden nebo druhý typ regresních úkolů, které jsme dosud studovali. Co by fungovalo nejlépe?\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"\n---\n\n**Prohlášení**: \nTento dokument byl přeložen pomocí služby pro automatický překlad [Co-op Translator](https://github.com/Azure/co-op-translator). I když se snažíme o přesnost, mějte prosím na paměti, že automatické překlady mohou obsahovat chyby nebo nepřesnosti. Původní dokument v jeho původním jazyce by měl být považován za autoritativní zdroj. Pro důležité informace se doporučuje profesionální lidský překlad. Neodpovídáme za žádná nedorozumění nebo nesprávné interpretace vyplývající z použití tohoto překladu.\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"anaconda-cloud": "",
|
|
"kernelspec": {
|
|
"display_name": "R",
|
|
"langauge": "R",
|
|
"name": "ir"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": "r",
|
|
"file_extension": ".r",
|
|
"mimetype": "text/x-r-source",
|
|
"name": "R",
|
|
"pygments_lexer": "r",
|
|
"version": "3.4.1"
|
|
},
|
|
"coopTranslator": {
|
|
"original_hash": "feaf125f481a89c468fa115bf2aed580",
|
|
"translation_date": "2025-09-04T06:48:07+00:00",
|
|
"source_file": "2-Regression/4-Logistic/solution/R/lesson_4-R.ipynb",
|
|
"language_code": "cs"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
} |