Lab 8: Tuning ML Models of Hydrological Data

ESS 330 - Quantitative Reasoning

library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.2.0 ──
✔ broom        1.0.7     ✔ recipes      1.1.0
✔ dials        1.3.0     ✔ rsample      1.2.1
✔ dplyr        1.1.4     ✔ tibble       3.2.1
✔ ggplot2      3.5.1     ✔ tidyr        1.3.1
✔ infer        1.0.7     ✔ tune         1.2.1
✔ modeldata    1.4.0     ✔ workflows    1.1.4
✔ parsnip      1.2.1     ✔ workflowsets 1.1.0
✔ purrr        1.0.2     ✔ yardstick    1.3.2
Warning: package 'scales' was built under R version 4.4.3
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ purrr::discard() masks scales::discard()
✖ dplyr::filter()  masks stats::filter()
✖ dplyr::lag()     masks stats::lag()
✖ recipes::step()  masks stats::step()
• Dig deeper into tidy modeling with R at https://www.tmwr.org
library(recipes)
library(yardstick)
library(ggthemes)
Warning: package 'ggthemes' was built under R version 4.4.3
library(ggplot2)
library(workflowsets)
library(patchwork)
library(ggfortify)
Warning: package 'ggfortify' was built under R version 4.4.3
Registered S3 method overwritten by 'ggfortify':
  method          from   
  autoplot.glmnet parsnip
library(parsnip)
library(tidyverse)
Warning: package 'lubridate' was built under R version 4.4.3
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ forcats   1.0.0     ✔ readr     2.1.5
✔ lubridate 1.9.4     ✔ stringr   1.5.1
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ readr::col_factor() masks scales::col_factor()
✖ purrr::discard()    masks scales::discard()
✖ dplyr::filter()     masks stats::filter()
✖ stringr::fixed()    masks recipes::fixed()
✖ dplyr::lag()        masks stats::lag()
✖ readr::spec()       masks yardstick::spec()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library(visdat)
Warning: package 'visdat' was built under R version 4.4.3
library(powerjoin)
Warning: package 'powerjoin' was built under R version 4.4.3
library(skimr)
Warning: package 'skimr' was built under R version 4.4.3
library(xgboost)
Warning: package 'xgboost' was built under R version 4.4.3

Attaching package: 'xgboost'

The following object is masked from 'package:dplyr':

    slice
library(dplyr)
library(purrr)
library(patchwork)
library(glue)
library(vip)
Warning: package 'vip' was built under R version 4.4.3

Attaching package: 'vip'

The following object is masked from 'package:utils':

    vi
library(baguette)
Warning: package 'baguette' was built under R version 4.4.3
# Data Import/Tidy/Transform    
root  <- 'https://gdex.ucar.edu/dataset/camels/file'
download.file('https://gdex.ucar.edu/dataset/camels/file/camels_attributes_v2.0.pdf', 
              'data/camels_attributes_v2.0.pdf')
Warning in
download.file("https://gdex.ucar.edu/dataset/camels/file/camels_attributes_v2.0.pdf",
: URL https://gdex.ucar.edu/dataset/camels/file/camels_attributes_v2.0.pdf:
cannot open destfile 'data/camels_attributes_v2.0.pdf', reason 'No such file or
directory'
Warning in
download.file("https://gdex.ucar.edu/dataset/camels/file/camels_attributes_v2.0.pdf",
: download had nonzero exit status
types <- c("clim", "geol", "soil", "topo", "vege", "hydro")

# Where the files live online ...
remote_files  <- glue('{root}/camels_{types}.txt')
# where we want to download the data ...

local_files   <- glue('../data/lab_data/camels_hydro_data/camels_{types}.txt')
walk2(remote_files, local_files, download.file, quiet = TRUE)

# Read and merge data
camels <- map(local_files, read_delim, show_col_types = FALSE) 
camels <- power_full_join(camels ,by = 'gauge_id') 

# Add log(q_mean) to df
camels <- camels %>% 
  mutate(logQmean = log(q_mean)) %>% 
  mutate(across(everything(), as.double))
Warning: There were 5 warnings in `mutate()`.
The first warning was:
ℹ In argument: `across(everything(), as.double)`.
Caused by warning:
! NAs introduced by coercion
ℹ Run `dplyr::last_dplyr_warnings()` to see the 4 remaining warnings.
skim(camels)
Data summary
Name camels
Number of rows 671
Number of columns 59
_______________________
Column type frequency:
numeric 59
________________________
Group variables None

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
gauge_id 0 1.00 6265830.84 3976867.52 1013500.00 2370650.00 6278300.00 9382765.00 14400000.00 ▇▃▅▃▃
p_mean 0 1.00 3.26 1.41 0.64 2.37 3.23 3.78 8.94 ▃▇▂▁▁
pet_mean 0 1.00 2.79 0.55 1.90 2.34 2.69 3.15 4.74 ▇▇▅▂▁
p_seasonality 0 1.00 -0.04 0.53 -1.44 -0.26 0.08 0.22 0.92 ▁▂▃▇▂
frac_snow 0 1.00 0.18 0.20 0.00 0.04 0.10 0.22 0.91 ▇▂▁▁▁
aridity 0 1.00 1.06 0.62 0.22 0.70 0.86 1.27 5.21 ▇▂▁▁▁
high_prec_freq 0 1.00 20.93 4.55 7.90 18.50 22.00 24.23 32.70 ▂▃▇▇▁
high_prec_dur 0 1.00 1.35 0.19 1.08 1.21 1.28 1.44 2.09 ▇▅▂▁▁
high_prec_timing 671 0.00 NaN NA NA NA NA NA NA
low_prec_freq 0 1.00 254.65 35.12 169.90 232.70 255.85 278.92 348.70 ▂▅▇▅▁
low_prec_dur 0 1.00 5.95 3.20 2.79 4.24 4.95 6.70 36.51 ▇▁▁▁▁
low_prec_timing 671 0.00 NaN NA NA NA NA NA NA
geol_1st_class 671 0.00 NaN NA NA NA NA NA NA
glim_1st_class_frac 0 1.00 0.79 0.20 0.30 0.61 0.83 1.00 1.00 ▁▃▃▃▇
geol_2nd_class 671 0.00 NaN NA NA NA NA NA NA
glim_2nd_class_frac 0 1.00 0.16 0.14 0.00 0.00 0.14 0.27 0.49 ▇▃▃▂▁
carbonate_rocks_frac 0 1.00 0.12 0.26 0.00 0.00 0.00 0.04 1.00 ▇▁▁▁▁
geol_porostiy 3 1.00 0.13 0.07 0.01 0.07 0.13 0.19 0.28 ▇▆▇▇▂
geol_permeability 0 1.00 -13.89 1.18 -16.50 -14.77 -13.96 -13.00 -10.90 ▂▅▇▅▂
soil_depth_pelletier 0 1.00 10.87 16.24 0.27 1.00 1.23 12.89 50.00 ▇▁▁▁▁
soil_depth_statsgo 0 1.00 1.29 0.27 0.40 1.11 1.46 1.50 1.50 ▁▁▂▂▇
soil_porosity 0 1.00 0.44 0.02 0.37 0.43 0.44 0.46 0.68 ▃▇▁▁▁
soil_conductivity 0 1.00 1.74 1.52 0.45 0.93 1.35 1.93 13.96 ▇▁▁▁▁
max_water_content 0 1.00 0.53 0.15 0.09 0.43 0.56 0.64 1.05 ▁▅▇▃▁
sand_frac 0 1.00 36.47 15.63 8.18 25.44 35.27 44.46 91.98 ▅▇▅▁▁
silt_frac 0 1.00 33.86 13.25 2.99 23.95 34.06 43.64 67.77 ▂▆▇▆▁
clay_frac 0 1.00 19.89 9.32 1.85 14.00 18.66 25.42 50.35 ▃▇▅▂▁
water_frac 0 1.00 0.10 0.94 0.00 0.00 0.00 0.00 19.35 ▇▁▁▁▁
organic_frac 0 1.00 0.59 3.84 0.00 0.00 0.00 0.00 57.86 ▇▁▁▁▁
other_frac 0 1.00 9.82 16.83 0.00 0.00 1.31 11.74 99.38 ▇▁▁▁▁
gauge_lat 0 1.00 39.24 5.21 27.05 35.70 39.25 43.21 48.82 ▂▃▇▆▅
gauge_lon 0 1.00 -95.79 16.21 -124.39 -110.41 -92.78 -81.77 -67.94 ▆▃▇▇▅
elev_mean 0 1.00 759.42 786.00 10.21 249.67 462.72 928.88 3571.18 ▇▂▁▁▁
slope_mean 0 1.00 46.20 47.12 0.82 7.43 28.80 73.17 255.69 ▇▂▂▁▁
area_gages2 0 1.00 792.62 1701.95 4.03 122.28 329.68 794.30 25791.04 ▇▁▁▁▁
area_geospa_fabric 0 1.00 808.08 1709.85 4.10 127.98 340.70 804.50 25817.78 ▇▁▁▁▁
frac_forest 0 1.00 0.64 0.37 0.00 0.28 0.81 0.97 1.00 ▃▁▁▂▇
lai_max 0 1.00 3.22 1.52 0.37 1.81 3.37 4.70 5.58 ▅▆▃▅▇
lai_diff 0 1.00 2.45 1.33 0.15 1.20 2.34 3.76 4.83 ▇▇▇▆▇
gvf_max 0 1.00 0.72 0.17 0.18 0.61 0.78 0.86 0.92 ▁▁▂▃▇
gvf_diff 0 1.00 0.32 0.15 0.03 0.19 0.32 0.46 0.65 ▃▇▅▇▁
dom_land_cover_frac 0 1.00 0.81 0.18 0.31 0.65 0.86 1.00 1.00 ▁▂▃▃▇
dom_land_cover 671 0.00 NaN NA NA NA NA NA NA
root_depth_50 24 0.96 0.18 0.03 0.12 0.17 0.18 0.19 0.25 ▃▃▇▂▂
root_depth_99 24 0.96 1.83 0.30 1.50 1.52 1.80 2.00 3.10 ▇▃▂▁▁
q_mean 1 1.00 1.49 1.54 0.00 0.63 1.13 1.75 9.69 ▇▁▁▁▁
runoff_ratio 1 1.00 0.39 0.23 0.00 0.24 0.35 0.51 1.36 ▆▇▂▁▁
slope_fdc 1 1.00 1.24 0.51 0.00 0.90 1.28 1.63 2.50 ▂▅▇▇▁
baseflow_index 0 1.00 0.49 0.16 0.01 0.40 0.50 0.60 0.98 ▁▃▇▅▁
stream_elas 1 1.00 1.83 0.78 -0.64 1.32 1.70 2.23 6.24 ▁▇▃▁▁
q5 1 1.00 0.17 0.27 0.00 0.01 0.08 0.22 2.42 ▇▁▁▁▁
q95 1 1.00 5.06 4.94 0.00 2.07 3.77 6.29 31.82 ▇▂▁▁▁
high_q_freq 1 1.00 25.74 29.07 0.00 6.41 15.10 35.79 172.80 ▇▂▁▁▁
high_q_dur 1 1.00 6.91 10.07 0.00 1.82 2.85 7.55 92.56 ▇▁▁▁▁
low_q_freq 1 1.00 107.62 82.24 0.00 37.44 96.00 162.14 356.80 ▇▆▅▂▁
low_q_dur 1 1.00 22.28 21.66 0.00 10.00 15.52 26.91 209.88 ▇▁▁▁▁
zero_q_freq 1 1.00 0.03 0.11 0.00 0.00 0.00 0.00 0.97 ▇▁▁▁▁
hfd_mean 1 1.00 182.52 33.53 112.25 160.16 173.77 204.05 287.75 ▂▇▃▂▁
logQmean 1 1.00 -0.11 1.17 -5.39 -0.46 0.12 0.56 2.27 ▁▁▂▇▂
vis_dat(camels)

# Set seed
set.seed(567)

# Generate the split
camels_split <- initial_split(camels, prop = 0.8)
camels_tr <- training(camels_split)
camels_te  <- testing(camels_split)

# Cross-validation folds
camels_10cv <- vfold_cv(camels_tr, v = 10)

# Recipe
rec <- recipe(logQmean ~ pet_mean + p_mean + aridity + runoff_ratio + baseflow_index + slope_mean + area_geospa_fabric, data = camels_tr) %>% 
  step_YeoJohnson(all_predictors()) %>% 
  step_interact(terms = ~ pet_mean:p_mean + aridity:runoff_ratio + area_geospa_fabric:slope_mean) %>% 
  step_corr(all_predictors(), threshold = 0.9) %>%   # Remove highly correlated predictors to avoid multicollinearity.
  step_normalize(all_predictors()) %>% 
  step_naomit(all_predictors(), all_outcomes())

# Define and Train Models
  ## Define rf model
  rf_model <- rand_forest() %>% 
    set_engine("ranger") %>% 
    set_mode("regression")
  
  rf_wf <- workflow() %>%
    # Add the recipe
    add_recipe(rec) %>%
    # Add the model
    add_model(rf_model) %>%
    # Fit the model
    fit(data = camels_tr)
   
  rf_predictions <- augment(rf_wf, new_data = camels_te) 

  ## Define xg model
  xg_model <- boost_tree() %>% 
    set_engine("xgboost") %>% 
    set_mode("regression")
  
  xg_wf <- workflow() %>%
    # Add the recipe
    add_recipe(rec) %>%
    # Add the model
    add_model(xg_model) %>%
    # Fit the model
    fit(data = camels_tr)
  
  xg_predictions <- augment(xg_wf, new_data = camels_te)
  
  ## Define nueral net model
  nn_model <- bag_mlp() %>% 
    set_engine("nnet") %>% 
    set_mode("regression")
  
  nn_wf <- workflow() %>%
    # Add the recipe
    add_recipe(rec) %>%
    # Add the model
    add_model(nn_model) %>%
    # Fit the model
    fit(data = camels_tr)
  
  nn_predictions <- augment(nn_wf, new_data = camels_te)
  
  ## Define linear reg model
  lm_model <- linear_reg() %>% 
    set_engine("lm") %>% 
    set_mode("regression")
  
  lm_wf <- workflow() %>%
    # Add the recipe
    add_recipe(rec) %>%
    # Add the model
    add_model(lm_model) %>%
    # Fit the model
    fit(data = camels_tr)
  
  lm_predictions <- augment(lm_wf, new_data = camels_te) 
  
  # Implement workflowset analysis
  
  ml_wf_set <- workflow_set(preproc = list(rec),
                          models = list(rf = rf_model, 
                                        xg = xg_model, 
                                        nn = nn_model, 
                                        lm = lm_model)) %>%
  workflow_map('fit_resamples', resamples = camels_10cv) 
Warning: package 'ranger' was built under R version 4.4.3
autoplot(ml_wf_set)

rank_results(ml_wf_set, rank_metric = "rsq", select_best = TRUE)
# A tibble: 8 × 9
  wflow_id  .config        .metric   mean std_err     n preprocessor model  rank
  <chr>     <chr>          <chr>    <dbl>   <dbl> <int> <chr>        <chr> <int>
1 recipe_nn Preprocessor1… rmse    0.0304 1.04e-2    10 recipe       bag_…     1
2 recipe_nn Preprocessor1… rsq     0.999  4.79e-4    10 recipe       bag_…     1
3 recipe_xg Preprocessor1… rmse    0.122  1.61e-2    10 recipe       boos…     2
4 recipe_xg Preprocessor1… rsq     0.989  2.04e-3    10 recipe       boos…     2
5 recipe_rf Preprocessor1… rmse    0.175  2.43e-2    10 recipe       rand…     3
6 recipe_rf Preprocessor1… rsq     0.981  3.14e-3    10 recipe       rand…     3
7 recipe_lm Preprocessor1… rmse    0.217  2.40e-2    10 recipe       line…     4
8 recipe_lm Preprocessor1… rsq     0.967  4.71e-3    10 recipe       line…     4
# model tuning
tuned_nn_model <- bag_mlp(
  hidden_units = tune(), 
  penalty = tune()
) %>%
  set_engine("nnet") %>%
  set_mode("regression")

wf_tune <- workflow() %>% 
  add_recipe(rec) %>% 
  add_model(tuned_nn_model)

dials <- extract_parameter_set_dials(wf_tune)

# define search space
my.grid <- grid_space_filling(dials, size = 20)

model_params <-  tune_grid(
    wf_tune,
    resamples = camels_10cv,
    grid = my.grid,
    metrics = metric_set(rmse, rsq, mae),
    control = control_grid(save_pred = TRUE)
  )
autoplot(model_params)

collect_metrics(model_params)
# A tibble: 60 × 8
   hidden_units       penalty .metric .estimator   mean     n std_err .config   
          <int>         <dbl> <chr>   <chr>       <dbl> <int>   <dbl> <chr>     
 1            1 0.000000144   mae     standard   0.104     10 0.00738 Preproces…
 2            1 0.000000144   rmse    standard   0.173     10 0.0246  Preproces…
 3            1 0.000000144   rsq     standard   0.980     10 0.00397 Preproces…
 4            1 0.0000616     mae     standard   0.109     10 0.00715 Preproces…
 5            1 0.0000616     rmse    standard   0.174     10 0.0242  Preproces…
 6            1 0.0000616     rsq     standard   0.979     10 0.00393 Preproces…
 7            2 0.0264        mae     standard   0.0399    10 0.00453 Preproces…
 8            2 0.0264        rmse    standard   0.0824    10 0.0205  Preproces…
 9            2 0.0264        rsq     standard   0.995     10 0.00191 Preproces…
10            2 0.00000000113 mae     standard   0.0339    10 0.00442 Preproces…
# ℹ 50 more rows
best_mae <- show_best(model_params, metric = "mae", n = 1)
hp_best <- select_best(model_params, metric = "mae")
#> The first row shows the mean MAE across resamples, SE of the MAE estimate, # of resamples, and mean SE. Penalty is the best hyperparameter set for this model. 

final_wf <- finalize_workflow(wf_tune, hp_best)
final_fit <- last_fit(final_wf, split = camels_split)
final_metrics <- collect_metrics(final_fit)

# The final model's rmse 0.010 and the rsq is 0.999. This means that 99.9% of the variance is explained by the model. This is an excellent number and result. The rmse is the average prediction error, and this percentage is ~1% which is quite good. This model is very good, but may be less efficient than more simple models which are less demanding computationally with worse but still acceptable values for rsq and rmse.

predictions <- collect_predictions(model_params)

ggplot(predictions, aes(x = .pred, y = logQmean)) +
  geom_smooth(method = lm, color = "blue") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
  scale_color_gradient() +
  labs(
    title = "Actual vs. Predicted Values", 
    x = "Predicted", 
    y = "Actual")
`geom_smooth()` using formula = 'y ~ x'

final_fit_full <- fit(final_wf, data = camels)
augmented_preds <- augment(final_fit_full, new_data = camels)

augmented_preds <- augmented_preds %>% 
  mutate(residual_sq = (logQmean - .pred)^2)

map_preds <- ggplot(augmented_preds, aes(x = .pred, y = logQmean)) +
  geom_point(aes(color = .pred), size = 3, alpha = 0.8) +
  scale_color_viridis_c(name = "Predicted") +
  coord_fixed() +
  labs(title = "Map of Predicted logQmean") +
  theme_minimal()

map_resid <- ggplot(augmented_preds, aes(x = .pred, y = residual_sq)) +
  geom_point() +
  scale_color_viridis_c(name = "Residual²") +
  labs(title = "Map of Squared Residuals") +
  theme_minimal()

maps_combined <- map_preds | map_resid

print(maps_combined)
Warning: Removed 1 row containing missing values or values outside the scale range
(`geom_point()`).
Removed 1 row containing missing values or values outside the scale range
(`geom_point()`).