We’re over the moon to announce the release of orbital 0.4.0. orbital lets you predict in databases using tidymodels workflows.

You can install it from CRAN with:

install.packages("orbital")

This blog post will cover the highlights, which are post processing support and the new show_query() method.

You can see a full list of changes in the release notes .

Post processing support#

The biggest improvement in this version is that orbital() now works for supported tailor methods. See vignette for a list of all supported post-processors.

Let’s start by fitting a classification model on the penguins data set, using {xgboost} as the engine. We will be showcasing using an adjustment that only works on binary classification and will thus recode species to have levels "Adelie" and "not_Adelie".

penguins$species <- forcats::fct_recode(
 penguins$species,
 not_Adelie = "Chinstrap", not_Adelie = "Gentoo"
)

After we have modified the data, we set up a simple workflow, with a preprocessor using recipes and the model specification using parsnip.

We also set up a post processor using the tailor package. A single adjustment will be done by adding adjust_equivocal_zone(). This will apply an equivocal zone to our binary classification model. Stopping predictions that are too close to the thresholds by labeling them as "[EQ]". Setting the argument value = 0.2 means that any predictions with a predicted probability of between 0.3 and 0.7 will be predicted as "[EQ]" instead.

rec_spec <- recipe(species ~ ., data = penguins) |>
  step_unknown(all_nominal_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  step_impute_mean(all_numeric_predictors()) |>
  step_zv(all_predictors())

lr_spec <- boost_tree(tree_depth = 1, trees = 5) |>
  set_mode("classification") |>
  set_engine("xgboost")

tlr_spec <- tailor() |>
  adjust_equivocal_zone(value = 0.2)

wf_spec <- workflow(rec_spec, lr_spec, tlr_spec)
wf_fit <- fit(wf_spec, data = penguins)

With this fitted workflow object, we can call orbital() on it to create an orbital object. Notice that for adjust_equivocal_zone() to work, we need to set type = c("class", "prob") as both are required for the adjust_equivocal_zone() transformation.

orbital_obj <- orbital(wf_fit, type = c("class", "prob"))
orbital_obj
#> 
#> ── orbital Object ───────────────────────────────────────────────────────
#> • bill_length_mm = dplyr::if_else(is.na(bill_length_mm), 43.92193, ...
#> • flipper_length_mm = dplyr::if_else(is.na(flipper_length_mm), 201 ...
#> • .pred_class = dplyr::case_when(1 - 1/(1 + exp(dplyr::case_when(b ...
#> • .pred_Adelie = 1 - 1/(1 + exp(dplyr::case_when(bill_length_mm < ...
#> • .pred_not_Adelie = 1 - (1 - 1/(1 + exp(dplyr::case_when(bill_len ...
#> • .pred_class = dplyr::case_when( .pred_Adelie > 0.5 + 0.2 ~ 'Adel ...
#> ─────────────────────────────────────────────────────────────────────────
#> 6 equations in total.

This object contains all the information that is needed to produce predictions. Which we can produce with predict() .

preds <- predict(orbital_obj, penguins)
preds
#> # A tibble: 344 × 3
#>    .pred_class .pred_Adelie .pred_not_Adelie
#>    <chr>              <dbl>            <dbl>
#>  1 Adelie             0.845            0.155
#>  2 Adelie             0.845            0.155
#>  3 Adelie             0.845            0.155
#>  4 not_Adelie         0.291            0.709
#>  5 Adelie             0.845            0.155
#>  6 Adelie             0.845            0.155
#>  7 Adelie             0.845            0.155
#>  8 Adelie             0.845            0.155
#>  9 Adelie             0.845            0.155
#> 10 Adelie             0.845            0.155
#> # ℹ 334 more rows

The predictions are working; however, we don’t see any evidence that adjust_equivocal_zone() is working. A call to count() reveals that a couple of observation lands in the equivocal zone.

count(preds, .pred_class)
#> # A tibble: 3 × 2
#>   .pred_class     n
#>   <chr>       <int>
#> 1 Adelie        144
#> 2 [EQ]           15
#> 3 not_Adelie    185

And we can further verify that they are correct.

filter(preds, .pred_class == '[EQ]')
#> # A tibble: 15 × 3
#>    .pred_class .pred_Adelie .pred_not_Adelie
#>    <chr>              <dbl>            <dbl>
#>  1 [EQ]               0.483            0.517
#>  2 [EQ]               0.483            0.517
#>  3 [EQ]               0.483            0.517
#>  4 [EQ]               0.483            0.517
#>  5 [EQ]               0.483            0.517
#>  6 [EQ]               0.483            0.517
#>  7 [EQ]               0.483            0.517
#>  8 [EQ]               0.348            0.652
#>  9 [EQ]               0.348            0.652
#> 10 [EQ]               0.348            0.652
#> 11 [EQ]               0.348            0.652
#> 12 [EQ]               0.348            0.652
#> 13 [EQ]               0.483            0.517
#> 14 [EQ]               0.483            0.517
#> 15 [EQ]               0.483            0.517

New show_query method#

One of the main purposes of orbital is to allow for predictions in databases.

library(DBI)
library(RSQLite)

con_sqlite <- dbConnect(SQLite(), path = ":memory:")
penguins_sqlite <- copy_to(con_sqlite, penguins, name = "penguins_table")

Having set up a database we could have used orbital_sql() to show what the SQL query would have looked like. For quick testing, the output isn’t immediately ready to be pasted into its own file due to the <SQL> fragments within the output.

The show_query() method has been implemented to see exactly what the generated SQL looks like.

show_query(orbital_obj, con_sqlite)
#> CASE WHEN ((`bill_length_mm` IS NULL)) THEN 43.9219298245614 WHEN NOT ((`bill_length_mm` IS NULL)) THEN `bill_length_mm` END AS bill_length_mm
#> CASE WHEN ((`flipper_length_mm` IS NULL)) THEN 201.0 WHEN NOT ((`flipper_length_mm` IS NULL)) THEN `flipper_length_mm` END AS flipper_length_mm
#> CASE
#> WHEN ((1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047))))) > 0.5) THEN 'Adelie'
#> ELSE 'not_Adelie'
#> END AS .pred_class
#> 1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047)))) AS .pred_Adelie
#> 1.0 - (1.0 - 1.0 / (1.0 + EXP(((((CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.627138138
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.449751347)
#> END + CASE
#> WHEN (`bill_length_mm` < 43.2999992) THEN 0.425288886
#> WHEN ((`bill_length_mm` >= 43.2999992 OR (`bill_length_mm` IS NULL))) THEN (-0.398178101)
#> END) + CASE
#> WHEN (`bill_length_mm` < 42.4000015) THEN 0.380251437
#> WHEN ((`bill_length_mm` >= 42.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.306771189)
#> END) + CASE
#> WHEN (`bill_length_mm` < 44.4000015) THEN 0.286071777
#> WHEN ((`bill_length_mm` >= 44.4000015 OR (`bill_length_mm` IS NULL))) THEN (-0.330096036)
#> END) + CASE
#> WHEN (`flipper_length_mm` < 203.0) THEN 0.209298179
#> WHEN ((`flipper_length_mm` >= 203.0 OR (`flipper_length_mm` IS NULL))) THEN (-0.348002464)
#> END) + LOG(0.44186047 / (1.0 - 0.44186047))))) AS .pred_not_Adelie
#> CASE
#> WHEN (`.pred_Adelie` > (0.5 + 0.2)) THEN 'Adelie'
#> WHEN (`.pred_Adelie` < (0.5 - 0.2)) THEN 'not_Adelie'
#> ELSE '[EQ]'
#> END AS .pred_class

Acknowledgements#

A big thank you to all the people who have contributed to orbital since the release of v0.4.0:

@EmilHvitfeldt , @frankiethull , @jeroenjanssens , and @topepo .