14  Kaggle

14.1 Lernsteuerung

14.1.1 Lernziele

  • Sie wissen, wie man einen Datensatz (eine “Submission”) für einen Prognosewettbwerb bei Kaggle einreicht
  • Sie kennen einige Beispiele von Notebooks auf Kaggle (für die Sprache R)
  • Sie wissen, wie man ein Workflow-Set in Tidymodels berechnet
  • Sie wissen, dass Tidymodels im Rezept keine Transformationen im Test-Sample berücksichtigt und wie man damit umgeht

14.1.2 Hinweise

  • Machen Sie sich mit Kaggle vertraut. Als Übungs-Wettbewerb dient uns TMDB Box-office Revenue (s. Aufgaben)

14.1.3 R-Pakete

In diesem Kapitel werden folgende R-Pakete benötigt:

library(tidyverse)
library(tidymodels)
library(tictoc)  # Rechenzeit messen
library(lubridate)  # Datumsangaben
library(VIM)  # fehlende Werte
library(visdat)  # Datensatz visualisieren

14.2 Was ist Kaggle?

Kaggle, a subsidiary of Google LLC, is an online community of data scientists and machine learning practitioners. Kaggle allows users to find and publish data sets, explore and build models in a web-based data-science environment, work with other data scientists and machine learning engineers, and enter competitions to solve data science challenges.

Quelle

Kaggle as AirBnB for Data Scientists?!

14.3 Fallstudie TMDB

Wir bearbeiten hier die Fallstudie TMDB Box Office Prediction - Can you predict a movie’s worldwide box office revenue?, ein Kaggle-Prognosewettbewerb.

Ziel ist es, genaue Vorhersagen zu machen, in diesem Fall für Filme.

14.3.1 Aufgabe

Reichen Sie bei Kaggle eine Submission für die Fallstudie ein! Berichten Sie den Score!

14.3.2 Hinweise

  • Sie müssen sich bei Kaggle ein Konto anlegen (kostenlos und anonym möglich); alternativ können Sie sich mit einem Google-Konto anmelden.
  • Halten Sie das Modell so einfach wie möglich. Verwenden Sie als Algorithmus die lineare Regression ohne weitere Schnörkel.
  • Logarithmieren Sie budget und revenue.
  • Minimieren Sie die Vorverarbeitung (steps) so weit als möglich.
  • Verwenden Sie tidymodels.
  • Die Zielgröße ist revenue in Dollars; nicht in “Log-Dollars”. Sie müssen also rücktransformieren, wenn Sie revenue logarithmiert haben.

14.3.3 Daten

Die Daten können Sie von der Kaggle-Projektseite beziehen oder so:

d_train_path <- "https://raw.githubusercontent.com/sebastiansauer/Lehre/main/data/tmdb-box-office-prediction/train.csv"
d_test_path <- "https://raw.githubusercontent.com/sebastiansauer/Lehre/main/data/tmdb-box-office-prediction/test.csv"

Wir importieren die Daten von der Online-Quelle:

d_train_raw <- read_csv(d_train_path)
d_test <- read_csv(d_test_path)

Mal einen Blick werfen:

glimpse(d_train_raw)
## Rows: 3,000
## Columns: 23
## $ id                    <dbl> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…
## $ belongs_to_collection <chr> "[{'id': 313576, 'name': 'Hot Tub Time Machine C…
## $ budget                <dbl> 1.40e+07, 4.00e+07, 3.30e+06, 1.20e+06, 0.00e+00…
## $ genres                <chr> "[{'id': 35, 'name': 'Comedy'}]", "[{'id': 35, '…
## $ homepage              <chr> NA, NA, "http://sonyclassics.com/whiplash/", "ht…
## $ imdb_id               <chr> "tt2637294", "tt0368933", "tt2582802", "tt182148…
## $ original_language     <chr> "en", "en", "en", "hi", "ko", "en", "en", "en", …
## $ original_title        <chr> "Hot Tub Time Machine 2", "The Princess Diaries …
## $ overview              <chr> "When Lou, who has become the \"father of the In…
## $ popularity            <dbl> 6.575393, 8.248895, 64.299990, 3.174936, 1.14807…
## $ poster_path           <chr> "/tQtWuwvMf0hCc2QR2tkolwl7c3c.jpg", "/w9Z7A0GHEh…
## $ production_companies  <chr> "[{'name': 'Paramount Pictures', 'id': 4}, {'nam…
## $ production_countries  <chr> "[{'iso_3166_1': 'US', 'name': 'United States of…
## $ release_date          <chr> "2/20/15", "8/6/04", "10/10/14", "3/9/12", "2/5/…
## $ runtime               <dbl> 93, 113, 105, 122, 118, 83, 92, 84, 100, 91, 119…
## $ spoken_languages      <chr> "[{'iso_639_1': 'en', 'name': 'English'}]", "[{'…
## $ status                <chr> "Released", "Released", "Released", "Released", …
## $ tagline               <chr> "The Laws of Space and Time are About to be Viol…
## $ title                 <chr> "Hot Tub Time Machine 2", "The Princess Diaries …
## $ Keywords              <chr> "[{'id': 4379, 'name': 'time travel'}, {'id': 96…
## $ cast                  <chr> "[{'cast_id': 4, 'character': 'Lou', 'credit_id'…
## $ crew                  <chr> "[{'credit_id': '59ac067c92514107af02c8c8', 'dep…
## $ revenue               <dbl> 12314651, 95149435, 13092000, 16000000, 3923970,…
glimpse(d_test)
## Rows: 4,398
## Columns: 22
## $ id                    <dbl> 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, …
## $ belongs_to_collection <chr> "[{'id': 34055, 'name': 'Pokémon Collection', 'p…
## $ budget                <dbl> 0.00e+00, 8.80e+04, 0.00e+00, 6.80e+06, 2.00e+06…
## $ genres                <chr> "[{'id': 12, 'name': 'Adventure'}, {'id': 16, 'n…
## $ homepage              <chr> "http://www.pokemon.com/us/movies/movie-pokemon-…
## $ imdb_id               <chr> "tt1226251", "tt0051380", "tt0118556", "tt125595…
## $ original_language     <chr> "ja", "en", "en", "fr", "en", "en", "de", "en", …
## $ original_title        <chr> "ディアルガVSパルキアVSダークライ", "Attack of t…
## $ overview              <chr> "Ash and friends (this time accompanied by newco…
## $ popularity            <dbl> 3.851534, 3.559789, 8.085194, 8.596012, 3.217680…
## $ poster_path           <chr> "/tnftmLMemPLduW6MRyZE0ZUD19z.jpg", "/9MgBNBqlH1…
## $ production_companies  <chr> NA, "[{'name': 'Woolner Brothers Pictures Inc.',…
## $ production_countries  <chr> "[{'iso_3166_1': 'JP', 'name': 'Japan'}, {'iso_3…
## $ release_date          <chr> "7/14/07", "5/19/58", "5/23/97", "9/4/10", "2/11…
## $ runtime               <dbl> 90, 65, 100, 130, 92, 121, 119, 77, 120, 92, 88,…
## $ spoken_languages      <chr> "[{'iso_639_1': 'en', 'name': 'English'}, {'iso_…
## $ status                <chr> "Released", "Released", "Released", "Released", …
## $ tagline               <chr> "Somewhere Between Time & Space... A Legend Is B…
## $ title                 <chr> "Pokémon: The Rise of Darkrai", "Attack of the 5…
## $ Keywords              <chr> "[{'id': 11451, 'name': 'pok√©mon'}, {'id': 1155…
## $ cast                  <chr> "[{'cast_id': 3, 'character': 'Tonio', 'credit_i…
## $ crew                  <chr> "[{'credit_id': '52fe44e7c3a368484e03d683', 'dep…

14.3.4 Train-Set verschlanken

Da wir aus Gründen der Einfachheit einige Spalten nicht berücksichtigen, entfernen wir diese Spalten, was die Größe des Datensatzes massiv reduziert.

d_train <-
  d_train_raw %>% 
  select(popularity, runtime, revenue, budget, release_date) 

14.3.5 Datensatz kennenlernen

14.3.6 Fehlende Werte prüfen

Welche Spalten haben viele fehlende Werte?

vis_miss(d_train)

Mit VIM kann man einen Datensatz gut auf fehlende Werte hin untersuchen:

aggr(d_train)

14.4 Rezept

14.4.1 Rezept definieren

rec1 <-
  recipe(revenue ~ ., data = d_train) %>% 
  #update_role(all_predictors(), new_role = "id") %>% 
  #update_role(popularity, runtime, revenue, budget, original_language) %>% 
  #update_role(revenue, new_role = "outcome") %>% 
  step_mutate(budget = if_else(budget < 10, 10, budget)) %>% 
  step_log(budget) %>% 
  step_mutate(release_date = mdy(release_date)) %>% 
  step_date(release_date, features = c("year", "month"), 
keep_original_cols = FALSE) %>% 
  step_impute_knn(all_predictors()) %>% 
  step_dummy(all_nominal())

rec1
tidy(rec1)

14.4.2 Check das Rezept

prep(rec1, verbose = TRUE)
## oper 1 step mutate [training] 
## oper 2 step log [training] 
## oper 3 step mutate [training] 
## oper 4 step date [training] 
## oper 5 step impute knn [training] 
## oper 6 step dummy [training] 
## The retained training set is ~ 0.37 Mb  in memory.
prep(rec1) %>% 
  bake(new_data = NULL) 

Wir definieren eine Helper-Funktion:

sum_isna <- function(x) {sum(is.na(x))}

Und wenden diese auf jede Spalte an:

prep(rec1) %>% 
  bake(new_data = NULL) %>%  
  map_df(sum_isna)

Keine fehlenden Werte mehr in den Prädiktoren.

Nach fehlenden Werten könnte man z.B. auch so suchen:

datawizard::describe_distribution(d_train)

So bekommt man gleich noch ein paar Infos über die Verteilung der Variablen. Praktische Sache.

14.4.3 Check Test-Sample

Das Test-Sample backen wir auch mal.

Wichtig: Wir preppen den Datensatz mit dem Train-Sample.

bake(prep(rec1), new_data = d_test) %>% 
  head()

14.5 Kreuzvalidierung

cv_scheme <- vfold_cv(d_train,
  v = 5, 
  repeats = 3)

14.6 Modelle

14.6.1 Baum

mod_tree <-
  decision_tree(cost_complexity = tune(),
tree_depth = tune(),
mode = "regression")

14.6.2 Random Forest

doParallel::registerDoParallel()
mod_rf <-
  rand_forest(mtry = tune(),
  min_n = tune(),
  trees = 1000,
  mode = "regression") %>% 
  set_engine("ranger", num.threads = 4)

14.6.3 XGBoost

mod_boost <- boost_tree(mtry = tune(),
min_n = tune(),
trees = tune()) %>% 
  set_engine("xgboost", nthreads = parallel::detectCores()) %>% 
  set_mode("regression")

14.6.4 LM

mod_lm <-
  linear_reg()

14.7 Workflows

preproc <- list(rec1 = rec1)
models <- list(tree1 = mod_tree, rf1 = mod_rf, boost1 = mod_boost, lm1 = mod_lm)
 
 
all_workflows <- workflow_set(preproc, models)

14.8 Fitten und tunen

if (file.exists("objects/tmdb_model_set.rds")) {
  tmdb_model_set <- read_rds("objects/tmdb_model_set.rds")
} else {
  tic()
  tmdb_model_set <-
all_workflows %>% 
workflow_map(
  resamples = cv_scheme,
  grid = 10,
#  metrics = metric_set(rmse),
  seed = 42,  # reproducibility
  verbose = TRUE)
  toc()
}

Man könnte sich das Ergebnisobjekt abspeichern, um künftig Rechenzeit zu sparen:

write_rds(tmdb_model_set, "objects/tmdb_model_set.rds")

Aber Achtung: Wenn Sie vergessen, das Objekt auf der Festplatte zu aktualisieren, haben Sie eine zusätzliche Fehlerquelle. Gefahr im Verzug. Professioneller ist der Ansatz mit dem R-Paket target.

14.9 Finalisieren

14.9.1 Welcher Algorithmus schneidet am besten ab?

Genauer geagt, welches Modell, denn es ist ja nicht nur ein Algorithmus, sondern ein Algorithmus plus ein Rezept plus die Parameterinstatiierung plus ein spezifischer Datensatz.

tune::autoplot(tmdb_model_set) +
  theme(legend.position = "bottom")

R-Quadrat ist nicht entscheidend; rmse ist wichtiger.

Die Ergebnislage ist nicht ganz klar, aber einiges spricht für das Boosting-Modell, rec1_boost1.

tmdb_model_set %>% 
  collect_metrics() %>% 
  arrange(-mean) %>% 
  head(10)
best_model_params <-
extract_workflow_set_result(tmdb_model_set, "rec1_boost1") %>% 
  select_best()

best_model_params
best_wf <- 
all_workflows %>% 
  extract_workflow("rec1_boost1")

#best_wf
best_wf_finalized <- 
  best_wf %>% 
  finalize_workflow(best_model_params)

best_wf_finalized
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: boost_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## • step_mutate()
## • step_log()
## • step_mutate()
## • step_date()
## • step_impute_knn()
## • step_dummy()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Boosted Tree Model Specification (regression)
## 
## Main Arguments:
##   mtry = 6
##   trees = 100
##   min_n = 4
## 
## Engine-Specific Arguments:
##   nthreads = parallel::detectCores()
## 
## Computational engine: xgboost

14.9.2 Final Fit

fit_final <-
  best_wf_finalized %>% 
  fit(d_train)
## [21:41:04] WARNING: src/learner.cc:767: 
## Parameters: { "nthreads" } are not used.

fit_final
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: boost_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 6 Recipe Steps
## 
## • step_mutate()
## • step_log()
## • step_mutate()
## • step_date()
## • step_impute_knn()
## • step_dummy()
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## ##### xgb.Booster
## raw: 257.9 Kb 
## call:
##   xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 0, 
##     colsample_bytree = 1, colsample_bynode = 0.4, min_child_weight = 4L, 
##     subsample = 1), data = x$data, nrounds = 100L, watchlist = x$watchlist, 
##     verbose = 0, nthreads = 8L, nthread = 1, objective = "reg:squarederror")
## params (as set within xgb.train):
##   eta = "0.3", max_depth = "6", gamma = "0", colsample_bytree = "1", colsample_bynode = "0.4", min_child_weight = "4", subsample = "1", nthreads = "8", nthread = "1", objective = "reg:squarederror", validate_parameters = "TRUE"
## xgb.attributes:
##   niter
## callbacks:
##   cb.evaluation.log()
## # of features: 15 
## niter: 100
## nfeatures : 15 
## evaluation_log:
##     iter training_rmse
##        1     122355586
##        2     100873316
## ---                   
##       99      28074964
##      100      27979563
d_test$revenue <- NA

final_preds <- 
  fit_final %>% 
  predict(new_data = d_test) %>% 
  bind_cols(d_test)

14.10 Submission

14.10.1 Submission vorbereiten

submission_df <-
  final_preds %>% 
  select(id, revenue = .pred)

Abspeichern und einreichen:

write_csv(submission_df, file = "objects/submission.csv")

Diese CSV-Datei reichen wir dann bei Kagglei ein.

14.10.2 Kaggle Score

Diese Submission erzielte einen Score von 4.79227 (RMSLE).

14.11 Miniprojekt

Reichen Sie Ihre Vorhersagen für die TMDB-Competition bei Kaggle ein!

Stellen Sie auch (im Rahmen dieses Wettbewerbs) Ihre Syntax offen.

Bereiten Sie sich vor, Ihre Analyse zu präsentieren.

14.12 Aufgaben

Schauen Sie sich mal die Kategorie tmdb auf Datenwerk an.

Alternativ bietet die Kategorie tidymodels eine Sammlung von Aufgaben rund um das R-Paket Tidymodels; dort können Sie sich Aufgaben anpassen.

14.13 Kaggle-Fallstudien

14.14 Vertiefung