Untuk membuat sebuah model machine learning diperlukan beberapa proses seperti praproses data, modeling, sampai model evaluation. R memiliki banyak sekali packages yang dapat digunakan dalam proses pemodelan. Untuk mempermudah seorang data scientist dalam membuat model, Rstudio mengembangkan kumpulan packages yang diberi nama Tidymodels.
Tidymodels merupakan framework yang berisi kumpulan packages untuk membuat model machine learning menggunakan konsep Tidyverse. Beberapa packages seperti rsample
, recipes
, parsnip
, dials
, tune
, yardstick
, workflows
sudah saling terintegrasi untuk membuat satu proses pemodelan yang utuh.
Artikel ini berfokus membahas bagaimana cara menggunakan tidymodels
dalam proses pemodelan machine learning. Artikel ini mengasumsikan pembaca sudah mengetahui konsep dasar dari tidyverse
(piping) dan proses pembuatan model machine learning seperti praproses data, cross validation, modeling, model evaluation dll.
Jika anda belum menginstall packages tidymodels, silahkan install terlebih dahulu.
install.packages("tidymodels")
Data Exploration
library(tidyverse) # data wrangling
library(inspectdf) # data exploration
library(tidymodels) # modeling
library(themis)
Read Data
Data yang digunakan pada artikel ini adalah data Telco Customer Churn yang diperoleh dari kaggle. Pada kasus ini model yang dibuat untuk memprediksi apakah seorang pelanggan akan churn dari produk yang digunakan atau tidak (klasifikasi biner).
churn <- read_csv("data_input/tidymodels/watson-churn.csv")
head(churn)
#> # A tibble: 6 x 19
#> customerID gender SeniorCitizen Partner Dependents tenure MultipleLines
#> <chr> <chr> <dbl> <chr> <chr> <dbl> <chr>
#> 1 5575-GNVDE Male 0 No No 34 No
#> 2 3668-QPYBK Male 0 No No 2 No
#> 3 9237-HQITU Female 0 No No 2 No
#> 4 9305-CDSKC Female 0 No No 8 Yes
#> 5 1452-KIOVK Male 0 No Yes 22 Yes
#> 6 7892-POOKP Female 0 Yes No 28 Yes
#> # ... with 12 more variables: OnlineSecurity <chr>, OnlineBackup <chr>,
#> # DeviceProtection <chr>, TechSupport <chr>, StreamingTV <chr>,
#> # StreamingMovies <chr>, Contract <chr>, PaperlessBilling <chr>,
#> # PaymentMethod <chr>, MonthlyCharges <dbl>, TotalCharges <dbl>, Churn <chr>
Data Structure
Dilihat secara sekilas, data yang digunakan terdiri dari 19 kolom, dimana mayoritas kolom adalah kategorikal. Untuk melihat lebih detail struktur data bisa menggunakan fungsi glimpse()
.
glimpse(churn)
#> Rows: 4,835
#> Columns: 19
#> $ customerID <chr> "5575-GNVDE", "3668-QPYBK", "9237-HQITU", "9305-CD...
#> $ gender <chr> "Male", "Male", "Female", "Female", "Male", "Femal...
#> $ SeniorCitizen <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
#> $ Partner <chr> "No", "No", "No", "No", "No", "Yes", "No", "Yes", ...
#> $ Dependents <chr> "No", "No", "No", "No", "Yes", "No", "Yes", "Yes",...
#> $ tenure <dbl> 34, 2, 2, 8, 22, 28, 62, 13, 58, 49, 25, 69, 71, 1...
#> $ MultipleLines <chr> "No", "No", "No", "Yes", "Yes", "Yes", "No", "No",...
#> $ OnlineSecurity <chr> "Yes", "Yes", "No", "No", "No", "No", "Yes", "Yes"...
#> $ OnlineBackup <chr> "No", "Yes", "No", "No", "Yes", "No", "Yes", "No",...
#> $ DeviceProtection <chr> "Yes", "No", "No", "Yes", "No", "Yes", "No", "No",...
#> $ TechSupport <chr> "No", "No", "No", "No", "No", "Yes", "No", "No", "...
#> $ StreamingTV <chr> "No", "No", "No", "Yes", "Yes", "Yes", "No", "No",...
#> $ StreamingMovies <chr> "No", "No", "No", "Yes", "No", "Yes", "No", "No", ...
#> $ Contract <chr> "One year", "Month-to-month", "Month-to-month", "M...
#> $ PaperlessBilling <chr> "No", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "Ye...
#> $ PaymentMethod <chr> "Mailed check", "Mailed check", "Electronic check"...
#> $ MonthlyCharges <dbl> 56.95, 53.85, 70.70, 99.65, 89.10, 104.80, 56.15, ...
#> $ TotalCharges <dbl> 1889.50, 108.15, 151.65, 820.50, 1949.40, 3046.05,...
#> $ Churn <chr> "No", "Yes", "Yes", "Yes", "No", "Yes", "No", "No"...
Dari informasi diatas diketahui bahwa data terdiri dari 19 kolom dan 4835 observasi yang didominasi oleh kategorikal. Untuk melihat persebaran data yang bertipe kategorikal lebih jelas, kita dapat menggunakan fungsi inspect_cat()
dari package inspectdf
.
inspect_cat(churn)
#> # A tibble: 15 x 5
#> col_name cnt common common_pcnt levels
#> <chr> <int> <chr> <dbl> <named list>
#> 1 Churn 2 No 67.2 <tibble [2 x 3]>
#> 2 Contract 3 Month-to-month 61.5 <tibble [3 x 3]>
#> 3 customerID 4835 0002-ORFBO 0.0207 <tibble [4,835 x 3]>
#> 4 Dependents 2 No 73.9 <tibble [2 x 3]>
#> 5 DeviceProtection 2 No 56.2 <tibble [2 x 3]>
#> 6 gender 2 Male 50.2 <tibble [2 x 3]>
#> 7 MultipleLines 2 Yes 54.4 <tibble [2 x 3]>
#> 8 OnlineBackup 2 No 55.7 <tibble [2 x 3]>
#> 9 OnlineSecurity 2 No 64.1 <tibble [2 x 3]>
#> 10 PaperlessBilling 2 Yes 69.0 <tibble [2 x 3]>
#> 11 Partner 2 No 51.3 <tibble [2 x 3]>
#> 12 PaymentMethod 4 Electronic check 41.7 <tibble [4 x 3]>
#> 13 StreamingMovies 2 Yes 50.3 <tibble [2 x 3]>
#> 14 StreamingTV 2 Yes 50.1 <tibble [2 x 3]>
#> 15 TechSupport 2 No 63.7 <tibble [2 x 3]>
Dari 19 kolom yang ada 15 diantaranya merupakan kategorikal. semua data kategorikal akan diubah menjadi factor kecuali customerID
. customerID
merupakan ID dari pelanggan oleh yang bersifat unique sehingga tidak diubah menjadi factor. Persebaran data yang bertipe numerik dapat dilihat juga menggunakan fungsi inspect_num()
.
inspect_num(churn)
#> # A tibble: 4 x 10
#> col_name min q1 median mean q3 max sd pcnt_na hist
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <named l>
#> 1 SeniorCiti~ 0 0 0 0.204 0 1 4.03e-1 0 <tibble ~
#> 2 tenure 0 9 30 33.0 56 72 2.46e+1 0 <tibble ~
#> 3 MonthlyCha~ 42.9 69.8 82.5 81.8 95.7 119. 1.83e+1 0 <tibble ~
#> 4 TotalCharg~ 42.9 660. 2351. 2902. 4871. 8685. 2.42e+3 0.0620 <tibble ~
Kolom SeniorCitizen
memiliki sebaran yang perlu dicurigai karena nilai min dan q3 nya sama yaitu 0 dan nilai maksimumnya adalah 1, ada kemungkinan kolom ini seharusnya factor yang berisi 0 dan 1 saja.
churn %>%
count(SeniorCitizen, name = "freq")
#> # A tibble: 2 x 2
#> SeniorCitizen freq
#> <dbl> <int>
#> 1 0 3849
#> 2 1 986
Setelah dilihat lebih detail ternyata benar bahwa nilai pada kolom SeniorCitizen
hanya berisi 2 nilai yaitu 1 (Yes) dan 0 (No), oleh sebab itu kolom ini akan di ubah menjadi factor.
Missing Values
Pengecekan missing value dapat dilakukan untuk setiap kolom dengan menggabungkan fungsi is.na()
dan colsum()
.
churn %>%
is.na() %>%
colSums()
#> customerID gender SeniorCitizen Partner
#> 0 0 0 0
#> Dependents tenure MultipleLines OnlineSecurity
#> 0 0 0 0
#> OnlineBackup DeviceProtection TechSupport StreamingTV
#> 0 0 0 0
#> StreamingMovies Contract PaperlessBilling PaymentMethod
#> 0 0 0 0
#> MonthlyCharges TotalCharges Churn
#> 0 3 0
Dilihat dari hasil diatas terdapat 3 missing value pada kolom TotalCharges
. Proses imputasi akan dilakukan untuk mengisi missing value.
Class Proportion
Sebelum melakukan pemodelan penting untuk melihat seberapa seimbang proporsi kelas target. Data yang tidak seimbang akan membuat model lebih banyak belajar dari kelas mayoritas, sehingga model tidak mengenali ciri dari kelas minoritas.
churn %>%
pull(Churn) %>%
table() %>%
prop.table() %>%
round(3)
#> .
#> No Yes
#> 0.672 0.328
Kelas target didominasi oleh kelas No sebanyak 67.2%, pada kondisi seperti ini data tidak seimbang namun tidak terlalu ekstream. Teknik resampling seperti downsampling atau upsampling dapat digunakan untuk menyeimbangkan kelas.
Data Preprocessing
Data Preprocessing atau sering disebut sebagai data cleansing merupakan tahapan yang cukup penting dilakukan sebelum melakukan pemodelan. Pada tahap ini data akan dibersihkan berdasarkan informasi yang didapat pada tahap data exploration.
Cross Validation (Training and Testing)
Sebelum melakukan data cleansing, data perlu dibagi menjadi 2 bagian yaitu training data dan testing data. Training data digunakan dalam proses pemodelan, sedangkan testing data digunakan dalam proses pengujian model.
Kenapa splitting data dilakukan sebelum data cleansing? Kenapa tidak setelahnya? Jawabannya adalah karena kita ingin menjaga originalitas dari testing data yang diasumsikan sebagai data “apa adanya”, karena data yang akan datang (unseen data) tidak dapat ketahui tingkat kebersihannya.
cross validation dapat dilakukan dengan fungsi initial_split()
dari package rsample
yang merupakan bagian dari tidymodels. Fungsi tersebut memiliki 3 paramater yaitu :
* data
: Data yang digunakan
* prop
: Proporsi data training
* Strata
: Kolom target
setelah melakukan cross validation, training data dan testing data dapat dipisahkan dengan fungsi training()
dan testing()
.
set.seed(123)
churn_split <- initial_split(data = churn,prop = 0.8, strata = Churn)
churn_train <- training(churn_split)
churn_test <- testing(churn_split)
Data Preprocessing
Informasi yang didapatkan pada data exploration akan digunakan pada tahap ini, beberapa praproses yang dibutuhkan sebelum melakukan pemodelan yaitu:
- Mengubah role
customerID
menjadi Id - Mengubah tipe data chr menjadi factor
- Mengubah tipe data
SeniorCitizen
menjadi factor - Melakukan imputation terhadapap missing value
- melakukan resampling
Proses cleansing data sebenarnya bisa dilakukan dengan data wrangling sederhana menggunakan tidyverse, namun cara tersebut tidak scalable sehingga sulit untuk diterapkan pada unseen data nantinya. Permasalahan tersebut bisa diselesaikan dengan package recipes
.
Recipes
bertugas membuat sebuah blueprint praproses data. Proses pembuatan blueprint menggunakan recipes
diawali dengan function recipes()
yang didalamnya terdapat parameter formula
dan data
, setelah itu dilanjutkan dengan tahapan praproses yang diinginkan. Fungsi praproses yang dilakukan selalu diawali dengan step_
.
churn_rec <- recipe(formula = Churn~., data = churn_train) %>%
update_role(customerID, new_role = "ID") %>%
step_string2factor(all_nominal(), -customerID, skip = TRUE) %>%
step_num2factor(SeniorCitizen, transform = function(x) x +1, levels = c("No", "Yes")) %>%
step_medianimpute(TotalCharges) %>%
step_upsample(Churn,over_ratio = 4/5)
Rincian dari praproses diatas adalah sebagai berikut:
- Pada data
churn_train
kolomChurn
merupakan target dan sisanya menjadi prediktor. - Mengubah role
customerID
menjadiID
, sehingga kolom tersebut tidak digunakan dalam proses pemodelan, namun tidak dibuang dari data. - Mengubah semua data yang bertipe kategorikal (
all_nominal()
) kecualiCustomerID
yang semula bertipe data string menjadi factor. - Mengubah
SeniorCitizen
yang semula numerik menjadi factor. Bila nilainya 0 ubah menjadi No bila 1 menjadi Yes. - Mengisi missing value pada
TotalCharge
dengan nilai median. - Melakukan upsampling sehingga perbandingan kelas Yes dan No menjadi 4 : 5
- Setelah semua proses itu dilakukan simpan blueprint kedalam objek
churn_rec
.
Hasil dari proses diatas hanya berupa design metrics saja bukan sebuah dataframe yang sudah dilakukan praproses. Objek recipe yang sudah dibuat bisa digunakan langsung dengan dengan menggunakan fungsi prep()
lalu diikuti oleh fungsi juice()
untuk data train dan bake()
untuk data test.
# praproses pada data train
churn_rec %>%
prep() %>%
juice() %>%
head()
#> # A tibble: 6 x 19
#> customerID gender SeniorCitizen Partner Dependents tenure MultipleLines
#> <fct> <fct> <fct> <fct> <fct> <dbl> <fct>
#> 1 8160-HOWOX Female No No No 7 No
#> 2 0679-TDGAK Male No Yes Yes 50 No
#> 3 9620-ENEJV Female No No No 6 No
#> 4 5575-TPIZQ Male No No No 46 Yes
#> 5 7105-MXJLL Female Yes Yes No 26 No
#> 6 5343-SGUBI Female No No No 52 Yes
#> # ... with 12 more variables: OnlineSecurity <fct>, OnlineBackup <fct>,
#> # DeviceProtection <fct>, TechSupport <fct>, StreamingTV <fct>,
#> # StreamingMovies <fct>, Contract <fct>, PaperlessBilling <fct>,
#> # PaymentMethod <fct>, MonthlyCharges <dbl>, TotalCharges <dbl>, Churn <fct>
# praproses pada data test
churn_rec %>%
prep() %>%
bake(churn_test) %>%
head()
#> # A tibble: 6 x 19
#> customerID gender SeniorCitizen Partner Dependents tenure MultipleLines
#> <fct> <fct> <fct> <fct> <fct> <dbl> <fct>
#> 1 <NA> Female No No No 2 No
#> 2 <NA> Male No No Yes 22 Yes
#> 3 <NA> Female No Yes Yes 10 No
#> 4 <NA> Female No No No 46 No
#> 5 <NA> Female No No No 11 Yes
#> 6 <NA> Male No No No 2 No
#> # ... with 12 more variables: OnlineSecurity <fct>, OnlineBackup <fct>,
#> # DeviceProtection <fct>, TechSupport <fct>, StreamingTV <fct>,
#> # StreamingMovies <fct>, Contract <fct>, PaperlessBilling <fct>,
#> # PaymentMethod <fct>, MonthlyCharges <dbl>, TotalCharges <dbl>, Churn <fct>
Fungsi juice
dan bake
akan mengembalikan data yang sudah dilakukan praproses, namun pada artikel ini kedua fungsi tersebut tidak akan digunakan. Pada artikel ini object recipes yang sudah dibuat akan digunakan langsung pada saat proses pemodelan.
Modeling
Pada umumnya ketika melakukan pemodelan di R cukup dengan satu fungsi saja seperti code dibawah ini adalah pemodelan menggunakan metode Random Forest.
randomForest::randomForest(x = .., y = .., ntree = ..., mtry = ...)
atau apabila anda sudah terbiasa dengan packages caret
bisa ditambahkan K-Fold dan tunning seperti code dibawah ini
ctrl <- trainControl(method="repeatedcv", number=5, repeats=2)
fb_forest <- train(formula = y ~ ., data=..., method="rf", trControl = ctrl)
Dua cara diatas merupakan cara yang cukup mudah untuk digunakan, namun pada proses penggunaanya model tersebut tidak scalable. Tidymodels membuat proses pemodelan menjadi modular sehingga setiap proses dibuat secara terpisah.
Model Interface
Tahap awal dalam pemodelan yaitu membuat model interface menggunakan package parsnip
dari tidymodels
. Ada banyak function dengan berbagai macam parameter yang biasa digunakan untuk membuat model, parsnip
melakukan standarisasi interface dan output dari setiap function.
Terdapat 3 hal yang harus didefinisikan ketika membuat model interface yaitu:
– Metode yang digunakan
– Engine atau package yang digunakan
– Jenis pemodelan
Pada pemodelan kali ini metode yang digunakan adalah Random Forest yang berasal dari pacakges ranger
model_rf <- rand_forest(mtry = tune(),
trees = tune(),
min_n =tune()) %>%
set_engine("ranger") %>%
set_mode("classification")
model_rf
#> Random Forest Model Specification (classification)
#>
#> Main Arguments:
#> mtry = tune()
#> trees = tune()
#> min_n = tune()
#>
#> Computational engine: ranger
Fungsi rand_forest()
digunakan untuk membuat model Random Forest, didalamnya terdapat 3 parameter yaitu :
– mtry
: Jumlah prediktor yang akan dipilih secara acak pada setiap kali tree ingin membuat cabang.
– trees
: Jumlah tree yang digunakan pada Random Forest.
– min_n
: Jumlah minimum data yang dibutuhkan pada setiap node untuk membuat percabangan.
Ketiga parameter tersebut akan ditunning untuk mencari nilai yang paling optimal menggunakan grid search.
set_engine()
digunakan untuk memberitahu mesin package apa yang digunakan dalam pembuatan model. Pada kasus ini model Random Forest akan dibuat menggunakan package ranger
. Terakhir kita harus memberitahu jenis pemodelan yang akan dilakukan menggunakan fungsi set_mode()
. Terdapat 2 jenis mode yaitu “classification” dan “regression”.
Grid Search
Salah satu bagian paling “menyenangkan” dari pembuatan model machine learning adalah tunning parameter. Tujuan dari tunning parameter adalah mendapatkan nilai optimum parameter yang dapat menghasilkan model terbaik. dials
merupakan package yang digunakan untuk tunning parameter.
Terdapat beberapa jenis grid yang bisa digunakan untuk tunning parameter, pada artikel ini grid yang digunakan adalah grid_max_entropy()
. Grid tersebut akan digunakan untuk tunning 3 parameter yang ada pada model_rf
yaitu mtry
, tree
, dan min_n
.
set.seed(123)
model_grid <- grid_max_entropy(x=finalize(object = mtry(),x = churn_train[,-19]),
trees(),
min_n(),
size = 20)
model_grid %>%
summary()
#> mtry trees min_n
#> Min. : 1.00 Min. : 7.0 Min. : 2.00
#> 1st Qu.: 5.75 1st Qu.: 261.8 1st Qu.: 9.50
#> Median :11.50 Median :1104.0 Median :18.00
#> Mean : 9.90 Mean : 959.9 Mean :18.60
#> 3rd Qu.:15.25 3rd Qu.:1467.5 3rd Qu.:23.75
#> Max. :17.00 Max. :1979.0 Max. :40.00
Parameter yang ada dalam fungsi grid_max_entropy()
adalah parameter yang ingin ditunning pada model. Parameter size
menunjukkan seberapa banyak kombinasi yang ingin dibuat.
parameter trees
dan min_n
dapat digunakan secara langsung karena range nilainya tidak dipengaruhi oleh dimensi data. Parameter mtry
tidak dapat digunakan secara langsung karena jumalah mtry dipengaruhi oleh banyaknya kolom, oleh sebab itu perlu disesuaikan dengan fungsi finalize()
. Hasil dari pembuatan grid adalah sebuah dataframe yang berisi nilai yang akan digunakan pada proses tunning parameter.
Metrics Evaluation
Sebelum melakukan fitting model, penentuan metrics evaluasi perlu dilakukan, pada kasus ini metrics yang dilihat adalah specificity (recall untuk kelas Yes) dan AUC dari ROC. Metrics tersebut dapat dibuat dengan fungsi metric_set
dari package yardstick
.
model_metrics <- metric_set(roc_auc, specificity)
K-Fold Cross Validation (Training and validation)
K-Fold merupakan bentuk berkembangan dari cross validation. K-Fold akan membagi training data menjadi K bagian (fold) dimana setiap fold akan dijadikan train dan test secara bergantian. Tujuan dari proses ini adalah untuk mencari parameter yang menghasilkan model terbaik. Untuk melakukan K-Fold bisa menggunakan fungsi vfold()
.
churn_folds <- vfold_cv(data = churn_train, v = 5)
Tunning Parameters
Sejauh ini sudah ada beberapa bagian yang berhasil dibuat yaitu :
– Praproses menggunakan recipes
– Model interface menggunakan parsnip
– Grid search menggunakan dials
– Metrics evaluation menggunakan yardstick
– K-Fold menggunakan rsamples
Pada tahap ini semua bagian yang sudah dibuat akan digabungkan untuk mencari parameter terbaik, proses ini disebut sebagai tunning parameters. function yang digunakan yaitu tune_grid()
dari tune
.
Untuk mempercepat proses komputasi bisa menggunakan pemprosesan pararel dari package doParallel
.
doParallel::registerDoParallel()
set.seed(123)
rf_result <- tune_grid(object = model_rf,
preprocessor = churn_rec,
resamples = churn_folds,
grid = model_grid,
metrics = model_metrics)
untuk melihat kombinasi parameter terbaik yang dihasilkan berdasarkan metrics yang dipilih bisa menggunakan fungsi collect_metrics()
rf_result %>%
collect_metrics() %>%
group_by(.metric) %>%
slice_max(mean,n = 2)
#> # A tibble: 4 x 9
#> # Groups: .metric [2]
#> mtry trees min_n .metric .estimator mean n std_err .config
#> <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 2 1040 17 roc_auc binary 0.812 5 0.00846 Preprocessor1_Model06
#> 2 1 1836 8 roc_auc binary 0.809 5 0.00847 Preprocessor1_Model04
#> 3 1 487 4 spec binary 0.677 5 0.0257 Preprocessor1_Model02
#> 4 1 283 21 spec binary 0.676 5 0.0241 Preprocessor1_Model05
Nilai Recall tertinggi dihasilkan oleh Model19
yaitu 0.79 sedangkan nilai AUC tertinggi sebesar 0.81 yang dihasilkan oleh Model06
.
Finalization
Tahap terakhir dari pemodelan yaitu menggabungkan semua komponen menjadi satu kerangka yang utuh menggunakan package workflows
. Package workflow
menggabungkan pre-processing, modeling, dan post-processing.
churn_wf <- workflow() %>%
add_model(model_rf) %>%
add_recipe(churn_rec) %>%
finalize_workflow(rf_result %>%
show_best("roc_auc", 1)
)
churn_wf
#> == Workflow ====================================================================
#> Preprocessor: Recipe
#> Model: rand_forest()
#>
#> -- Preprocessor ----------------------------------------------------------------
#> 4 Recipe Steps
#>
#> * step_string2factor()
#> * step_num2factor()
#> * step_medianimpute()
#> * step_upsample()
#>
#> -- Model -----------------------------------------------------------------------
#> Random Forest Model Specification (classification)
#>
#> Main Arguments:
#> mtry = 2
#> trees = 1040
#> min_n = 17
#>
#> Computational engine: ranger
Fungsi add_model()
menambahkan model interface kedalam workflow, fungsi add_recipe()
menambahkan objek recipe kedalam workflow, dan fungsi finalize_workflow()
menambahkan parameters terbaik berdasarkan metrics yang dipilih dari hasil tunning parameters sebelumnya. Setelah semuanya digabungkan model terakhir dapat dibentuk menggunakan fungsi fit()
.
set.seed(123)
final_model_rf <- fit(object = churn_wf, data = churn_train)
Model Evaluation
Setelah membuat final model tahap terakhir adalah menguji performa model terhadap data test dengan melakukan prediksi. Ketika melakukan prediksi menggunakan fungsi predict()
terdapat 2 jenis type yang digunakan yaitu “prob” yang mengembalikan berupa peluang, sedangkan “class” menghasilkan kelas prediksi.
pred_prob <- predict(final_model_rf, churn_test, type = "prob")
pred_class <- predict(final_model_rf, churn_test, type = "class")
churn_results <- churn_test %>%
transmute(truth = as.factor(Churn)) %>%
bind_cols(pred_prob, pred_class)
churn_results %>%
head()
#> # A tibble: 6 x 4
#> truth .pred_No .pred_Yes .pred_class
#> <fct> <dbl> <dbl> <fct>
#> 1 Yes 0.228 0.772 Yes
#> 2 No 0.433 0.567 Yes
#> 3 Yes 0.662 0.338 No
#> 4 No 0.686 0.314 No
#> 5 Yes 0.212 0.788 Yes
#> 6 No 0.374 0.626 Yes
Confusion Matrix
Hasil prediksi dapat di evaluasi menggunakan confusion matrix untuk mengetahui seberapa banyak data yang berhasil atau gagal diprediksi.
churn_results %>%
conf_mat(truth, .pred_class) %>%
autoplot(type = "heatmap")
Specificity
Specificity yang dihasilkan sebesar 0.65 angka ini menunjukkan seberapa baik model dalam memprediksi kelas Yes
bila dibandingkan dengan data aktual.
churn_results %>%
specificity(truth, .pred_class)
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 spec binary 0.653
ROC Curve
Sedangkan untuk melihat seberapa baik model dalam membedakan kedua kelas dapat dilihat dari nilai ROC AUC nya. Nilai AUC dari ROC curve didapat sebesar 0.79.
churn_results %>%
roc_auc(truth, .pred_Yes, event_level = 'second')
#> # A tibble: 1 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.799
Closing
Tidymodels merupakan kumpulan packages yang digunakan dalam pembuatan model machine lerning. Beberapa kelebihan menggunakan tidymodels adalah:
– Fleksibilitas dalam mengubah setiap komponen karena setiap proses dibuat secara modular.
– Dapat menggabungkan praproses data hingga model menjadi satu objek workflows
.
– Model interface yang dibuat bisa digunakan pada kasus yang berbeda sehingga code yang dibuat menjadi reusable.
Full code dari artikel ini bisa ditemukan pada (repositori github saya)[]