Title: | Shed Light on Black Box Machine Learning Models |
---|---|
Description: | Shed light on black box machine learning models by the help of model performance, variable importance, global surrogate models, ICE profiles, partial dependence (Friedman J. H. (2001) <doi:10.1214/aos/1013203451>), accumulated local effects (Apley D. W. (2016) <arXiv:1612.08468>), further effects plots, interaction strength, and variable contribution breakdown (Gosiewska and Biecek (2019) <arxiv:1903.11420>). All tools are implemented to work with case weights and allow for stratified analysis. Furthermore, multiple flashlights can be combined and analyzed together. |
Authors: | Michael Mayer [aut, cre, cph] |
Maintainer: | Michael Mayer <[email protected]> |
License: | GPL (>= 2) |
Version: | 0.9.0.9000 |
Built: | 2025-02-19 04:29:53 UTC |
Source: | https://github.com/mayer79/flashlight |
Deprecated in favor of kernelshap/fastshap.
add_shap(...)
add_shap(...)
... |
Deprecated |
Error message.
Creates or updates a "flashlight" object. If a flashlight is to be created,
all arguments are optional except label
. If a flashlight is to be updated,
all arguments are optional up to x
(the flashlight to be updated).
flashlight(x, ...) ## Default S3 method: flashlight( x, model = NULL, data = NULL, y = NULL, predict_function = stats::predict, linkinv = function(z) z, w = NULL, by = NULL, metrics = list(rmse = MetricsWeighted::rmse), label = NULL, shap = NULL, ... ) ## S3 method for class 'flashlight' flashlight(x, check = TRUE, ...)
flashlight(x, ...) ## Default S3 method: flashlight( x, model = NULL, data = NULL, y = NULL, predict_function = stats::predict, linkinv = function(z) z, w = NULL, by = NULL, metrics = list(rmse = MetricsWeighted::rmse), label = NULL, shap = NULL, ... ) ## S3 method for class 'flashlight' flashlight(x, check = TRUE, ...)
x |
An object of class "flashlight". If not provided, a new flashlight is
created based on further input. Otherwise, |
... |
Arguments passed from or to other functions. |
model |
A fitted model of any type. Most models require a customized
|
data |
A |
y |
Variable name of response. |
predict_function |
A real valued function with two arguments:
A model and a data of the same structure as |
linkinv |
An inverse transformation function applied after |
w |
A variable name of case weights. |
by |
A character vector with names of grouping variables. |
metrics |
A named list of metrics. Here, a metric is a function with exactly
four arguments: actual, predicted, w (case weights) and |
label |
Name of the flashlight. Required. |
shap |
An optional shap object. Typically added by calling |
check |
When updating the flashlight: Should internal checks be performed?
Default is |
An object of class "flashlight" (and list
) containing each
input (except x
) as element.
flashlight(default)
: Used to create a flashlight object.
No x
has to be passed in this case.
flashlight(flashlight)
: Used to update an existing flashlight object.
fit <- lm(Sepal.Length ~ ., data = iris) (fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols")) (fl_updated <- flashlight(fl, linkinv = exp))
fit <- lm(Sepal.Length ~ ., data = iris) (fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols")) (fl_updated <- flashlight(fl, linkinv = exp))
Checks if an object inherits specific class relevant for the flashlight package.
is.flashlight(x) is.multiflashlight(x) is.light(x) is.light_performance(x) is.light_performance_multi(x) is.light_importance(x) is.light_importance_multi(x) is.light_breakdown(x) is.light_breakdown_multi(x) is.light_ice(x) is.light_ice_multi(x) is.light_profile(x) is.light_profile_multi(x) is.light_profile2d(x) is.light_profile2d_multi(x) is.light_effects(x) is.light_effects_multi(x) is.shap(x) is.light_scatter(x) is.light_scatter_multi(x) is.light_global_surrogate(x) is.light_global_surrogate_multi(x)
is.flashlight(x) is.multiflashlight(x) is.light(x) is.light_performance(x) is.light_performance_multi(x) is.light_importance(x) is.light_importance_multi(x) is.light_breakdown(x) is.light_breakdown_multi(x) is.light_ice(x) is.light_ice_multi(x) is.light_profile(x) is.light_profile_multi(x) is.light_profile2d(x) is.light_profile2d_multi(x) is.light_effects(x) is.light_effects_multi(x) is.shap(x) is.light_scatter(x) is.light_scatter_multi(x) is.light_global_surrogate(x) is.light_global_surrogate_multi(x)
x |
Any object. |
A logical vector of length one.
is.multiflashlight()
: Check for multiflashlight object.
is.light()
: Check for light object.
is.light_performance()
: Check for light_performance object.
is.light_performance_multi()
: Check for light_performance_multi object.
is.light_importance()
: Check for light_importance object.
is.light_importance_multi()
: Check for light_importance_multi object.
is.light_breakdown()
: Check for light_breakdown object.
is.light_breakdown_multi()
: Check for light_breakdown_multi object.
is.light_ice()
: Check for light_ice object.
is.light_ice_multi()
: Check for light_ice_multi object.
is.light_profile()
: Check for light_profile object.
is.light_profile_multi()
: Check for light_profile_multi object.
is.light_profile2d()
: Check for light_profile2d object.
is.light_profile2d_multi()
: Check for light_profile2d_multi object.
is.light_effects()
: Check for light_effects object.
is.light_effects_multi()
: Check for light_effects_multi object.
is.shap()
: Check for shap object.
is.light_scatter()
: Check for light_scatter object.
is.light_scatter_multi()
: Check for light_scatter_multi object.
is.light_global_surrogate()
: Check for light_global_surrogate object.
is.light_global_surrogate_multi()
: Check for light_global_surrogate_multi object.
a <- flashlight(label = "a") is.flashlight(a) is.flashlight("a")
a <- flashlight(label = "a") is.flashlight(a) is.flashlight("a")
Calculates sequential additive variable contributions (approximate SHAP) to the prediction of a single observation, see Gosiewska and Biecek (see reference) and the details below.
light_breakdown(x, ...) ## Default S3 method: light_breakdown(x, ...) ## S3 method for class 'flashlight' light_breakdown( x, new_obs, data = x$data, by = x$by, v = NULL, visit_strategy = c("importance", "permutation", "v"), n_max = Inf, n_perm = 20, seed = NULL, use_linkinv = FALSE, description = TRUE, digits = 2, ... ) ## S3 method for class 'multiflashlight' light_breakdown(x, ...)
light_breakdown(x, ...) ## Default S3 method: light_breakdown(x, ...) ## S3 method for class 'flashlight' light_breakdown( x, new_obs, data = x$data, by = x$by, v = NULL, visit_strategy = c("importance", "permutation", "v"), n_max = Inf, n_perm = 20, seed = NULL, use_linkinv = FALSE, description = TRUE, digits = 2, ... ) ## S3 method for class 'multiflashlight' light_breakdown(x, ...)
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed to |
new_obs |
One single new observation to calculate variable attribution for.
Needs to be a |
data |
An optional |
by |
An optional vector of column names used to filter |
v |
Vector of variable names to assess contribution for. Defaults to all except those specified by "y", "w" and "by". |
visit_strategy |
In what sequence should variables be visited?
By "importance", by |
n_max |
Maximum number of rows in |
n_perm |
Number of permutations of random visit sequences.
Only used if |
seed |
An integer random seed used to shuffle rows if |
use_linkinv |
Should retransformation function be applied? Default is |
description |
Should descriptions be added? Default is |
digits |
Passed to |
The breakdown algorithm works as follows: First, the visit order
of the variables
v
is specified.
Then, in the query data
, the column is set to the value of
of the single observation
new_obs
to be explained.
The change in the (weighted) average prediction on data
measures the
contribution of on the prediction of
new_obs
.
This procedure is iterated over all until eventually, all rows
in
data
are identical to new_obs
.
A complication with this approach is that the visit order is relevant,
at least for non-additive models. Ideally, the algorithm could be repeated
for all possible permutations of v
and its results averaged per variable.
This is basically what SHAP values do, see the reference below for an explanation.
Unfortunately, there is no efficient way to do this in a model agnostic way.
We offer two visit strategies to approximate SHAP:
"importance": Using the short-cut described in the reference below:
The variables are sorted by the size of their contribution in the same way as the
breakdown algorithm but without iteration, i.e., starting from the original query
data for each variable .
"permutation": Averages contributions from a small number of random permutations
of v
.
Note that the minimum required elements in the (multi-)flashlight are a
"predict_function", "model", and "data". The latter can also directly be passed to
light_breakdown()
. Note that by default, no retransformation function is applied.
An object of class "light_breakdown" with the following elements:
data
A tibble with results.
by
Same as input by
.
light_breakdown(default)
: Default method not implemented yet.
light_breakdown(flashlight)
: Variable attribution to single observation
for a flashlight.
light_breakdown(multiflashlight)
: Variable attribution to single observation
for a multiflashlight.
A. Gosiewska and P. Biecek (2019). IBREAKDOWN: Uncertainty of model explanations for non-additive predictive models. ArXiv.
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) fl_part <- flashlight( model = fit_part, label = "part", data = iris, y = "Sepal.Length" ) plot(light_breakdown(fl_part, new_obs = iris[1, ])) # Second model fit_full <- lm(Sepal.Length ~ ., data = iris) fl_full <- flashlight( model = fit_full, label = "full", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_part, fl_full)) plot(light_breakdown(fls, new_obs = iris[1, ]))
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) fl_part <- flashlight( model = fit_part, label = "part", data = iris, y = "Sepal.Length" ) plot(light_breakdown(fl_part, new_obs = iris[1, ])) # Second model fit_full <- lm(Sepal.Length ~ ., data = iris) fl_full <- flashlight( model = fit_full, label = "full", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_part, fl_full)) plot(light_breakdown(fls, new_obs = iris[1, ]))
Checks if an object of class "flashlight" or "multiflashlight" is consistently defined.
light_check(x, ...) ## Default S3 method: light_check(x, ...) ## S3 method for class 'flashlight' light_check(x, ...) ## S3 method for class 'multiflashlight' light_check(x, ...)
light_check(x, ...) ## Default S3 method: light_check(x, ...) ## S3 method for class 'flashlight' light_check(x, ...) ## S3 method for class 'multiflashlight' light_check(x, ...)
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed from or to other methods. |
The input x
or an error message.
light_check(default)
: Default check method not implemented yet.
light_check(flashlight)
: Checks if a flashlight object is consistently defined.
light_check(multiflashlight)
: Checks if a multiflashlight object is consistently defined.
fit <- lm(Sepal.Length ~ ., data = iris) fit_log <- lm(log(Sepal.Length) ~ ., data = iris) fl <- flashlight(fit, data = iris, y = "Sepal.Length", label = "ols") fl_log <- flashlight(fit_log, y = "Sepal.Length", label = "ols", linkinv = exp) light_check(fl) light_check(fl_log)
fit <- lm(Sepal.Length ~ ., data = iris) fit_log <- lm(log(Sepal.Length) ~ ., data = iris) fl <- flashlight(fit, data = iris, y = "Sepal.Length", label = "ols") fl_log <- flashlight(fit_log, y = "Sepal.Length", label = "ols", linkinv = exp) light_check(fl) light_check(fl_log)
Combines a list of similar objects each of class "light" by row binding
data.frame
slots and retaining the other slots from the first list element.
light_combine(x, ...) ## Default S3 method: light_combine(x, ...) ## S3 method for class 'light' light_combine(x, new_class = NULL, ...) ## S3 method for class 'list' light_combine(x, new_class = NULL, ...)
light_combine(x, ...) ## Default S3 method: light_combine(x, ...) ## S3 method for class 'light' light_combine(x, new_class = NULL, ...) ## S3 method for class 'list' light_combine(x, new_class = NULL, ...)
x |
A list of objects of the same class. |
... |
Further arguments passed from or to other methods. |
new_class |
An optional vector with additional class names to be added to the output. |
If x
is a list, an object like each element but with unioned rows
in data slots.
light_combine(default)
: Default method not implemented yet.
light_combine(light)
: Since there is nothing to combine, the input is returned
except for additional classes.
light_combine(list)
: Combine a list of similar light objects.
fit_lm <- lm(Sepal.Length ~ ., data = iris) fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = "log"), data = iris) mod_lm <- flashlight(model = fit_lm, label = "lm", data = iris, y = "Sepal.Length") mod_glm <- flashlight( model = fit_glm, label = "glm", data = iris, y = "Sepal.Length", predict_function = function(object, newdata) predict(object, newdata, type = "response") ) mods <- multiflashlight(list(mod_lm, mod_glm)) perf_lm <- light_performance(mod_lm) perf_glm <- light_performance(mod_glm) manual_comb <- light_combine( list(perf_lm, perf_glm), new_class = "light_performance_multi" ) auto_comb <- light_performance(mods) all.equal(manual_comb, auto_comb)
fit_lm <- lm(Sepal.Length ~ ., data = iris) fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = "log"), data = iris) mod_lm <- flashlight(model = fit_lm, label = "lm", data = iris, y = "Sepal.Length") mod_glm <- flashlight( model = fit_glm, label = "glm", data = iris, y = "Sepal.Length", predict_function = function(object, newdata) predict(object, newdata, type = "response") ) mods <- multiflashlight(list(mod_lm, mod_glm)) perf_lm <- light_performance(mod_lm) perf_glm <- light_performance(mod_glm) manual_comb <- light_combine( list(perf_lm, perf_glm), new_class = "light_performance_multi" ) auto_comb <- light_performance(mods) all.equal(manual_comb, auto_comb)
Calculates response- prediction-, partial dependence, and ALE profiles of a
(multi-)flashlight with respect to a covariable v
.
light_effects(x, ...) ## Default S3 method: light_effects(x, ...) ## S3 method for class 'flashlight' light_effects( x, v, data = NULL, by = x$by, stats = "mean", breaks = NULL, n_bins = 11L, cut_type = c("equal", "quantile"), use_linkinv = TRUE, counts_weighted = FALSE, v_labels = TRUE, pred = NULL, pd_indices = NULL, pd_n_max = 1000L, pd_seed = NULL, ale_two_sided = TRUE, ... ) ## S3 method for class 'multiflashlight' light_effects( x, v, data = NULL, breaks = NULL, n_bins = 11L, cut_type = c("equal", "quantile"), ... )
light_effects(x, ...) ## Default S3 method: light_effects(x, ...) ## S3 method for class 'flashlight' light_effects( x, v, data = NULL, by = x$by, stats = "mean", breaks = NULL, n_bins = 11L, cut_type = c("equal", "quantile"), use_linkinv = TRUE, counts_weighted = FALSE, v_labels = TRUE, pred = NULL, pd_indices = NULL, pd_n_max = 1000L, pd_seed = NULL, ale_two_sided = TRUE, ... ) ## S3 method for class 'multiflashlight' light_effects( x, v, data = NULL, breaks = NULL, n_bins = 11L, cut_type = c("equal", "quantile"), ... )
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed to |
v |
The variable name to be profiled. |
data |
An optional |
by |
An optional vector of column names used to additionally group the results. |
stats |
Deprecated. Will be removed in version 1.1.0. |
breaks |
Cut breaks for a numeric |
n_bins |
Approximate number of unique values to evaluate for numeric |
cut_type |
Should a numeric |
use_linkinv |
Should retransformation function be applied? Default is |
counts_weighted |
Should counts be weighted by the case weights?
If |
v_labels |
If |
pred |
Optional vector with predictions (after application of inverse link).
Can be used to avoid recalculation of predictions over and over if the functions
is to be repeatedly called for different |
pd_indices |
A vector of row numbers to consider in calculating partial dependence profiles and "ale". |
pd_n_max |
Maximum number of ICE profiles to calculate (will be randomly
picked from |
pd_seed |
Integer random seed used to select ICE profiles for partial dependence and ALE. |
ale_two_sided |
If |
Note that ALE profiles are being calibrated by (weighted) average predictions. The resulting level might be quite different from the one of the partial dependence profiles.
An object of class "light_effects" with the following elements:
response
: A tibble containing the response profiles.
Column names can be controlled by options(flashlight.column_name)
.
predicted
: A tibble containing the prediction profiles.
pd
: A tibble containing the partial dependence profiles.
ale
: A tibble containing the ALE profiles.
by
: Same as input by
.
v
: The variable(s) evaluated.
light_effects(default)
: Default method.
light_effects(flashlight)
: Profiles for a flashlight object.
light_effects(multiflashlight)
: Effect profiles for a multiflashlight object.
light_profile()
, plot.light_effects()
fit_lin <- lm(Sepal.Length ~ ., data = iris) fl_lin <- flashlight(model = fit_lin, label = "lin", data = iris, y = "Sepal.Length") # PDP, average response, average predicted by Species eff <- light_effects(fl_lin, v = "Petal.Length") plot(eff) # PDP and ALE plot(eff, use = c("pd", "ale"), recode_labels = c(ale = "ALE")) # Second model with non-linear Petal.Length effect fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris) fl_nonlin <- flashlight( model = fit_nonlin, label = "nonlin", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_lin, fl_nonlin)) # PDP and ALE plot(light_effects(fls, v = "Petal.Length"), use = c("pd", "ale"))
fit_lin <- lm(Sepal.Length ~ ., data = iris) fl_lin <- flashlight(model = fit_lin, label = "lin", data = iris, y = "Sepal.Length") # PDP, average response, average predicted by Species eff <- light_effects(fl_lin, v = "Petal.Length") plot(eff) # PDP and ALE plot(eff, use = c("pd", "ale"), recode_labels = c(ale = "ALE")) # Second model with non-linear Petal.Length effect fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris) fl_nonlin <- flashlight( model = fit_nonlin, label = "nonlin", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_lin, fl_nonlin)) # PDP and ALE plot(light_effects(fls, v = "Petal.Length"), use = c("pd", "ale"))
Model predictions are modelled by a single decision tree, serving as an easy
to interprete surrogate to the original model.
As suggested in Molnar (see reference below), the quality of the surrogate
tree can be measured by its R-squared. The size of the tree can be modified
by passing ...
arguments to rpart::rpart()
.
light_global_surrogate(x, ...) ## Default S3 method: light_global_surrogate(x, ...) ## S3 method for class 'flashlight' light_global_surrogate( x, data = x$data, by = x$by, v = NULL, use_linkinv = TRUE, n_max = Inf, seed = NULL, keep_max_levels = 4L, ... ) ## S3 method for class 'multiflashlight' light_global_surrogate(x, ...)
light_global_surrogate(x, ...) ## Default S3 method: light_global_surrogate(x, ...) ## S3 method for class 'flashlight' light_global_surrogate( x, data = x$data, by = x$by, v = NULL, use_linkinv = TRUE, n_max = Inf, seed = NULL, keep_max_levels = 4L, ... ) ## S3 method for class 'multiflashlight' light_global_surrogate(x, ...)
x |
An object of class "flashlight" or "multiflashlight". |
... |
Arguments passed to |
data |
An optional |
by |
An optional vector of column names used to additionally group the results. For each group, a separate tree is grown. |
v |
Vector of variables used in the surrogate model.
Defaults to all variables in |
use_linkinv |
Should retransformation function be applied? Default is |
n_max |
Maximum number of data rows to consider to build the tree. |
seed |
An integer random seed used to select data rows if |
keep_max_levels |
Number of levels of categorical and factor variables to keep.
Other levels are combined to a level "Other". This prevents |
An object of class "light_global_surrogate" with the following elements:
data
A tibble with results.
by
Same as input by
.
light_global_surrogate(default)
: Default method not implemented yet.
light_global_surrogate(flashlight)
: Surrogate model for a flashlight.
light_global_surrogate(multiflashlight)
: Surrogate model for a multiflashlight.
Molnar C. (2019). Interpretable Machine Learning.
fit <- lm(Sepal.Length ~ ., data = iris) x <- flashlight(model = fit, label = "lm", data = iris) sur <- light_global_surrogate(x) sur$data$r_squared plot(sur)
fit <- lm(Sepal.Length ~ ., data = iris) x <- flashlight(model = fit, label = "lm", data = iris) sur <- light_global_surrogate(x) sur$data$r_squared plot(sur)
Generates Individual Conditional Expectation (ICE) profiles. An ICE profile shows how the prediction of an observation changes if one or multiple variables are systematically changed across its ranges, holding all other values fixed (see the reference below for details). The curves can be centered in order to increase visibility of interaction effects.
light_ice(x, ...) ## Default S3 method: light_ice(x, ...) ## S3 method for class 'flashlight' light_ice( x, v = NULL, data = x$data, by = x$by, evaluate_at = NULL, breaks = NULL, grid = NULL, n_bins = 27L, cut_type = c("equal", "quantile"), indices = NULL, n_max = 20L, seed = NULL, use_linkinv = TRUE, center = c("no", "first", "middle", "last", "mean", "0"), ... ) ## S3 method for class 'multiflashlight' light_ice(x, ...)
light_ice(x, ...) ## Default S3 method: light_ice(x, ...) ## S3 method for class 'flashlight' light_ice( x, v = NULL, data = x$data, by = x$by, evaluate_at = NULL, breaks = NULL, grid = NULL, n_bins = 27L, cut_type = c("equal", "quantile"), indices = NULL, n_max = 20L, seed = NULL, use_linkinv = TRUE, center = c("no", "first", "middle", "last", "mean", "0"), ... ) ## S3 method for class 'multiflashlight' light_ice(x, ...)
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed to or from other methods. |
v |
The variable name to be profiled. |
data |
An optional |
by |
An optional vector of column names used to additionally group the results. |
evaluate_at |
Vector with values of |
breaks |
Cut breaks for a numeric |
grid |
A |
n_bins |
Approximate number of unique values to evaluate for numeric |
cut_type |
Should a numeric |
indices |
A vector of row numbers to consider. |
n_max |
If |
seed |
An integer random seed. |
use_linkinv |
Should retransformation function be applied? Default is |
center |
How should curves be centered?
|
There are two ways to specify the variable(s) to be profiled.
Pass the variable name via v
and an optional vector with evaluation points
evaluate_at
(or breaks
). This works for dependence on a single variable.
More general: Specify any grid
as a data.frame
with one or
more columns. For instance, it can be generated by a call to expand.grid()
.
The minimum required elements in the (multi-)flashlight are "predict_function", "model", "linkinv" and "data", where the latest can be passed on the fly.
Which rows in data
are profiled? This is specified by indices
.
If not given and n_max
is smaller than the number of rows in data
,
then row indices will be sampled randomly from data
.
If the same rows should be used for all flashlights in a multiflashlight,
there are two options: Either pass a seed
or a vector of indices used to select rows.
In both cases, data
should be the same for all flashlights considered.
An object of class "light_ice" with the following elements:
data
A tibble containing the results.
by
Same as input by
.
v
The variable(s) evaluated.
center
How centering was done.
light_ice(default)
: Default method not implemented yet.
light_ice(flashlight)
: ICE profiles for a flashlight object.
light_ice(multiflashlight)
: ICE profiles for a multiflashlight object.
Goldstein, A. et al. (2015). Peeking inside the black box: Visualizing statistical learning with plots of individual conditional expectation. Journal of Computational and Graphical Statistics, 24:1 <doi.org/10.1080/10618600.2014.907095>.
light_profile()
, plot.light_ice()
fit_add <- lm(Sepal.Length ~ ., data = iris) fl_add <- flashlight(model = fit_add, label = "additive", data = iris) plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200), alpha = 0.2) plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200, center = "first")) # Second model with interactions fit_nonadd <- lm(Sepal.Length ~ . + Sepal.Width:Species, data = iris) fl_nonadd <- flashlight(model = fit_nonadd, label = "nonadditive", data = iris) fls <- multiflashlight(list(fl_add, fl_nonadd)) plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200), alpha = 0.2) plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200, center = "mid"))
fit_add <- lm(Sepal.Length ~ ., data = iris) fl_add <- flashlight(model = fit_add, label = "additive", data = iris) plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200), alpha = 0.2) plot(light_ice(fl_add, v = "Sepal.Width", n_max = 200, center = "first")) # Second model with interactions fit_nonadd <- lm(Sepal.Length ~ . + Sepal.Width:Species, data = iris) fl_nonadd <- flashlight(model = fit_nonadd, label = "nonadditive", data = iris) fls <- multiflashlight(list(fl_add, fl_nonadd)) plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200), alpha = 0.2) plot(light_ice(fls, v = "Sepal.Width", by = "Species", n_max = 200, center = "mid"))
Importance of variable v
is measured as drop in performance
by permuting the values of v
, see Fisher et al. 2018 (reference below).
light_importance(x, ...) ## Default S3 method: light_importance(x, ...) ## S3 method for class 'flashlight' light_importance( x, data = x$data, by = x$by, type = c("permutation", "shap"), v = NULL, n_max = Inf, seed = NULL, m_repetitions = 1L, metric = x$metrics[1L], lower_is_better = TRUE, use_linkinv = FALSE, ... ) ## S3 method for class 'multiflashlight' light_importance(x, ...)
light_importance(x, ...) ## Default S3 method: light_importance(x, ...) ## S3 method for class 'flashlight' light_importance( x, data = x$data, by = x$by, type = c("permutation", "shap"), v = NULL, n_max = Inf, seed = NULL, m_repetitions = 1L, metric = x$metrics[1L], lower_is_better = TRUE, use_linkinv = FALSE, ... ) ## S3 method for class 'multiflashlight' light_importance(x, ...)
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed to |
data |
An optional |
by |
An optional vector of column names used to additionally group the results. |
type |
Type of importance: "permutation" (currently the only option). |
v |
Vector of variable names to assess importance for.
Defaults to all variables in |
n_max |
Maximum number of rows to consider. |
seed |
An integer random seed used to select and shuffle rows. |
m_repetitions |
Number of permutations. Defaults to 1. A value above 1 provides more stable estimates of variable importance and allows the calculation of standard errors measuring the uncertainty from permuting. |
metric |
An optional named list of length one with a metric as element.
Defaults to the first metric in the flashlight. The metric needs to be a function
with at least four arguments: actual, predicted, case weights w and |
lower_is_better |
Logical flag indicating if lower values in the metric
are better or not. If set to |
use_linkinv |
Should retransformation function be applied? Default is |
The minimum required elements in the (multi-)flashlight are "y", "predict_function", "model", "data" and "metrics".
An object of class "light_importance" with the following elements:
data
A tibble with results.
by
Same as input by
.
type
Same as input type
. For information only.
light_importance(default)
: Default method not implemented yet.
light_importance(flashlight)
: Variable importance for a flashlight.
light_importance(multiflashlight)
: Variable importance for a multiflashlight.
Fisher A., Rudin C., Dominici F. (2018). All Models are Wrong but many are Useful: Variable Importance for Black-Box, Proprietary, or Misspecified Prediction Models, using Model Class Reliance. Arxiv.
most_important()
, plot.light_importance()
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) fl_part <- flashlight( model = fit_part, label = "part", data = iris, y = "Sepal.Length" ) # No effect of some variables (incl. standard errors) plot(light_importance(fl_part, m_repetitions = 4), fill = "chartreuse4") # Second model includes all variables fit_full <- lm(Sepal.Length ~ ., data = iris) fl_full <- flashlight( model = fit_full, label = "full", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_part, fl_full)) plot(light_importance(fls), fill = "chartreuse4") plot(light_importance(fls, by = "Species"))
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) fl_part <- flashlight( model = fit_part, label = "part", data = iris, y = "Sepal.Length" ) # No effect of some variables (incl. standard errors) plot(light_importance(fl_part, m_repetitions = 4), fill = "chartreuse4") # Second model includes all variables fit_full <- lm(Sepal.Length ~ ., data = iris) fl_full <- flashlight( model = fit_full, label = "full", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_part, fl_full)) plot(light_importance(fls), fill = "chartreuse4") plot(light_importance(fls, by = "Species"))
This function provides Friedman's H statistic for overall interaction strength per covariable as well as its version for pairwise interactions, see the reference below.
light_interaction(x, ...) ## Default S3 method: light_interaction(x, ...) ## S3 method for class 'flashlight' light_interaction( x, data = x$data, by = x$by, v = NULL, pairwise = FALSE, type = c("H", "ice"), normalize = TRUE, take_sqrt = TRUE, grid_size = 200L, n_max = 1000L, seed = NULL, use_linkinv = FALSE, ... ) ## S3 method for class 'multiflashlight' light_interaction(x, ...)
light_interaction(x, ...) ## Default S3 method: light_interaction(x, ...) ## S3 method for class 'flashlight' light_interaction( x, data = x$data, by = x$by, v = NULL, pairwise = FALSE, type = c("H", "ice"), normalize = TRUE, take_sqrt = TRUE, grid_size = 200L, n_max = 1000L, seed = NULL, use_linkinv = FALSE, ... ) ## S3 method for class 'multiflashlight' light_interaction(x, ...)
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed to or from other methods. |
data |
An optional |
by |
An optional vector of column names used to additionally group the results. |
v |
Vector of variable names to be assessed. |
pairwise |
Should overall interaction strength per variable be shown or
pairwise interactions? Defaults to |
type |
Are measures based on Friedman's H statistic ("H") or on "ice" curves?
Option "ice" is available only if |
normalize |
Should the variances explained be normalized?
Default is |
take_sqrt |
In order to reproduce Friedman's H statistic,
resulting values are root transformed. Set to |
grid_size |
Grid size used to form the outer product. Will be randomly
picked from data (after limiting to |
n_max |
Maximum number of data rows to consider. Will be randomly picked
from |
seed |
An integer random seed used for subsampling. |
use_linkinv |
Should retransformation function be applied? Default is |
As a fast alternative to assess overall interaction strength, with type = "ice"
,
the function offers a method based on centered ICE curves:
The corresponding H* statistic measures how much of the variability of a c-ICE curve
is unexplained by the main effect. As for Friedman's H statistic, it can be useful
to consider unnormalized or squared values (see Details below).
Friedman's H statistic relates the interaction strength of a variable (pair)
to the total effect strength of that variable (pair) based on partial dependence
curves. Due to this normalization step, even variables with low importance can
have high values for H. The function light_interaction()
offers the option
to skip normalization in order to have a more direct comparison of the interaction
effects across variable (pairs). The values of such unnormalized H statistics are
on the scale of the response variable. Use take_sqrt = FALSE
to return
squared values of H. Note that in general, for each variable (pair), predictions
are done on a data set with grid_size * n_max
, so be cautious with
increasing the defaults too much. Still, even with larger grid_size
and n_max
, there might be considerable variation across different runs,
thus, setting a seed is recommended.
The minimum required elements in the (multi-) flashlight are a "predict_function", "model", and "data".
An object of class "light_importance" with the following elements:
data
A tibble containing the results. Can be used to build fully customized
visualizations. Column names can be controlled by
options(flashlight.column_name)
.
by
Same as input by
.
type
Same as input type
. For information only.
light_interaction(default)
: Default method not implemented yet.
light_interaction(flashlight)
: Interaction strengths for a flashlight object.
light_interaction(multiflashlight)
: for a multiflashlight object.
Friedman, J. H. and Popescu, B. E. (2008). "Predictive learning via rule ensembles." The Annals of Applied Statistics. JSTOR, 916–54.
# First model with interactions fit_nonadd <- lm( Sepal.Length ~ . + Sepal.Width:Species + Petal.Width:Species, data = iris ) fl_nonadd <- flashlight( model = fit_nonadd, label = "nonadditive", data = iris, y = "Sepal.Length" ) # Friedman's H per feature plot(light_interaction(fl_nonadd), fill = "chartreuse4") # Unnormalized H^2 measures proportion of bivariate effect explained by interaction plot( light_interaction(fl_nonadd, normalize = TRUE, take_sqrt = TRUE), fill = "chartreuse4" ) # Pairwise H plot(light_interaction(fl_nonadd, pairwise = TRUE), fill = "chartreuse4") # Second model without interactions fit_add <- lm(Sepal.Length ~ ., data = iris) fl_add <- flashlight( model = fit_add, label = "additive", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_add, fl_nonadd)) plot(light_interaction(fls), fill = "chartreuse4")
# First model with interactions fit_nonadd <- lm( Sepal.Length ~ . + Sepal.Width:Species + Petal.Width:Species, data = iris ) fl_nonadd <- flashlight( model = fit_nonadd, label = "nonadditive", data = iris, y = "Sepal.Length" ) # Friedman's H per feature plot(light_interaction(fl_nonadd), fill = "chartreuse4") # Unnormalized H^2 measures proportion of bivariate effect explained by interaction plot( light_interaction(fl_nonadd, normalize = TRUE, take_sqrt = TRUE), fill = "chartreuse4" ) # Pairwise H plot(light_interaction(fl_nonadd, pairwise = TRUE), fill = "chartreuse4") # Second model without interactions fit_add <- lm(Sepal.Length ~ ., data = iris) fl_add <- flashlight( model = fit_add, label = "additive", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_add, fl_nonadd)) plot(light_interaction(fls), fill = "chartreuse4")
Calculates performance of a flashlight with respect to one or more performance measure.
light_performance(x, ...) ## Default S3 method: light_performance(x, ...) ## S3 method for class 'flashlight' light_performance( x, data = x$data, by = x$by, metrics = x$metrics, use_linkinv = FALSE, ... ) ## S3 method for class 'multiflashlight' light_performance(x, ...)
light_performance(x, ...) ## Default S3 method: light_performance(x, ...) ## S3 method for class 'flashlight' light_performance( x, data = x$data, by = x$by, metrics = x$metrics, use_linkinv = FALSE, ... ) ## S3 method for class 'multiflashlight' light_performance(x, ...)
x |
An object of class "flashlight" or "multiflashlight". |
... |
Arguments passed from or to other functions. |
data |
An optional |
by |
An optional vector of column names used to additionally group the results.
Will overwrite |
metrics |
An optional named list with metrics. Each metric takes at least
four arguments: actual, predicted, case weights w and |
use_linkinv |
Should retransformation function be applied? Default is |
The minimal required elements in the (multi-) flashlight are "y", "predict_function",
"model", "data" and "metrics". The latter two can also directly be passed to
light_performance()
. Note that by default, no retransformation function is applied.
An object of class "light_performance" with the following elements:
data
: A tibble containing the results.
by
Same as input by
.
light_performance(default)
: Default method not implemented yet.
light_performance(flashlight)
: Model performance of flashlight object.
light_performance(multiflashlight)
: Model performance of multiflashlight object.
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) fl_part <- flashlight( model = fit_part, label = "part", data = iris, y = "Sepal.Length" ) plot(light_performance(fl_part, by = "Species"), fill = "chartreuse4") # Second model fit_full <- lm(Sepal.Length ~ ., data = iris) fl_full <- flashlight( model = fit_full, label = "full", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_part, fl_full)) plot(light_performance(fls, by = "Species")) plot(light_performance(fls, by = "Species"), swap_dim = TRUE)
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) fl_part <- flashlight( model = fit_part, label = "part", data = iris, y = "Sepal.Length" ) plot(light_performance(fl_part, by = "Species"), fill = "chartreuse4") # Second model fit_full <- lm(Sepal.Length ~ ., data = iris) fl_full <- flashlight( model = fit_full, label = "full", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_part, fl_full)) plot(light_performance(fls, by = "Species")) plot(light_performance(fls, by = "Species"), swap_dim = TRUE)
Calculates different types of profiles across covariable values. By default, partial dependence profiles are calculated (see Friedman). Other options are profiles of ALE (accumulated local effects, see Apley), response, predicted values ("M plots" or "marginal plots", see Apley), and residuals. The results are aggregated either by (weighted) means or by (weighted) quartiles.
Note that ALE profiles are calibrated by (weighted) average predictions. In contrast to the suggestions in Apley, we calculate ALE profiles of factors in the same order as the factor levels. They are not being reordered based on similiarity of other variables.
light_profile(x, ...) ## Default S3 method: light_profile(x, ...) ## S3 method for class 'flashlight' light_profile( x, v = NULL, data = NULL, by = x$by, type = c("partial dependence", "ale", "predicted", "response", "residual", "shap"), stats = "mean", breaks = NULL, n_bins = 11L, cut_type = c("equal", "quantile"), use_linkinv = TRUE, counts = TRUE, counts_weighted = FALSE, v_labels = TRUE, pred = NULL, pd_evaluate_at = NULL, pd_grid = NULL, pd_indices = NULL, pd_n_max = 1000L, pd_seed = NULL, pd_center = c("no", "first", "middle", "last", "mean", "0"), ale_two_sided = FALSE, ... ) ## S3 method for class 'multiflashlight' light_profile( x, v = NULL, data = NULL, type = c("partial dependence", "ale", "predicted", "response", "residual", "shap"), breaks = NULL, n_bins = 11L, cut_type = c("equal", "quantile"), pd_evaluate_at = NULL, pd_grid = NULL, ... )
light_profile(x, ...) ## Default S3 method: light_profile(x, ...) ## S3 method for class 'flashlight' light_profile( x, v = NULL, data = NULL, by = x$by, type = c("partial dependence", "ale", "predicted", "response", "residual", "shap"), stats = "mean", breaks = NULL, n_bins = 11L, cut_type = c("equal", "quantile"), use_linkinv = TRUE, counts = TRUE, counts_weighted = FALSE, v_labels = TRUE, pred = NULL, pd_evaluate_at = NULL, pd_grid = NULL, pd_indices = NULL, pd_n_max = 1000L, pd_seed = NULL, pd_center = c("no", "first", "middle", "last", "mean", "0"), ale_two_sided = FALSE, ... ) ## S3 method for class 'multiflashlight' light_profile( x, v = NULL, data = NULL, type = c("partial dependence", "ale", "predicted", "response", "residual", "shap"), breaks = NULL, n_bins = 11L, cut_type = c("equal", "quantile"), pd_evaluate_at = NULL, pd_grid = NULL, ... )
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed to |
v |
The variable name to be profiled. |
data |
An optional |
by |
An optional vector of column names used to additionally group the results. |
type |
Type of the profile: Either "partial dependence", "ale", "predicted", "response", or "residual". |
stats |
Deprecated. Will be removed in version 1.1.0. |
breaks |
Cut breaks for a numeric |
n_bins |
Approximate number of unique values to evaluate for numeric |
cut_type |
Should a numeric |
use_linkinv |
Should retransformation function be applied? Default is |
counts |
Should observation counts be added? |
counts_weighted |
If |
v_labels |
If |
pred |
Optional vector with predictions (after application of inverse link).
Can be used to avoid recalculation of predictions over and over if the functions
is to be repeatedly called for different |
pd_evaluate_at |
Vector with values of |
pd_grid |
A |
pd_indices |
A vector of row numbers to consider in calculating partial dependence profiles and "ale". |
pd_n_max |
Maximum number of ICE profiles to calculate (will be randomly
picked from |
pd_seed |
Integer random seed used to select ICE profiles for partial dependence and ALE. |
pd_center |
How should ICE curves be centered?
|
ale_two_sided |
If |
Numeric covariables v
with more than n_bins
disjoint values
are binned into n_bins
bins. Alternatively, breaks
can be provided
to specify the binning. For partial dependence profiles
(and partly also ALE profiles), this behaviour can be overwritten either
by providing a vector of evaluation points (pd_evaluate_at
) or an
evaluation pd_grid
. By the latter we mean a data frame with column name(s)
with a (multi-)variate evaluation grid.
For partial dependence, ALE, and prediction profiles, "model", "predict_function", "linkinv" and "data" are required. For response profiles its "y", "linkinv" and "data". "data" can also be passed on the fly.
An object of class "light_profile" with the following elements:
data
A tibble containing results.
by
Names of group by variable.
v
The variable(s) evaluated.
type
Same as input type
. For information only.
light_profile(default)
: Default method not implemented yet.
light_profile(flashlight)
: Profiles for flashlight.
light_profile(multiflashlight)
: Profiles for multiflashlight.
Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine. The Annals of Statistics, 29:1189–1232.
Apley D. W. (2016). Visualizing the effects of predictor variables in black box supervised learning models.
light_effects()
, plot.light_profile()
fit_lin <- lm(Sepal.Length ~ ., data = iris) fl_lin <- flashlight(model = fit_lin, label = "lin", data = iris, y = "Sepal.Length") # PDP by Species plot(light_profile(fl_lin, v = "Petal.Length", by = "Species")) # Average predicted plot(light_profile(fl_lin, v = "Petal.Length", type = "pred")) # Second model with non-linear Petal.Length effect fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris) fl_nonlin <- flashlight( model = fit_nonlin, label = "nonlin", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_lin, fl_nonlin)) # PDP by Species plot(light_profile(fls, v = "Petal.Length", by = "Species")) plot(light_profile(fls, v = "Petal.Length", by = "Species"), swap_dim = TRUE) # Average residuals (calibration) plot(light_profile(fls, v = "Petal.Length", type = "residual"))
fit_lin <- lm(Sepal.Length ~ ., data = iris) fl_lin <- flashlight(model = fit_lin, label = "lin", data = iris, y = "Sepal.Length") # PDP by Species plot(light_profile(fl_lin, v = "Petal.Length", by = "Species")) # Average predicted plot(light_profile(fl_lin, v = "Petal.Length", type = "pred")) # Second model with non-linear Petal.Length effect fit_nonlin <- lm(Sepal.Length ~ . + I(Petal.Length^2), data = iris) fl_nonlin <- flashlight( model = fit_nonlin, label = "nonlin", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_lin, fl_nonlin)) # PDP by Species plot(light_profile(fls, v = "Petal.Length", by = "Species")) plot(light_profile(fls, v = "Petal.Length", by = "Species"), swap_dim = TRUE) # Average residuals (calibration) plot(light_profile(fls, v = "Petal.Length", type = "residual"))
Calculates different types of 2D-profiles across two variables. By default, partial dependence profiles are calculated (see Friedman). Other options are response, predicted values, and residuals. The results are aggregated by (weighted) means.
light_profile2d(x, ...) ## Default S3 method: light_profile2d(x, ...) ## S3 method for class 'flashlight' light_profile2d( x, v = NULL, data = NULL, by = x$by, type = c("partial dependence", "predicted", "response", "residual", "shap"), breaks = NULL, n_bins = 11L, cut_type = "equal", use_linkinv = TRUE, counts = TRUE, counts_weighted = FALSE, pd_evaluate_at = NULL, pd_grid = NULL, pd_indices = NULL, pd_n_max = 1000L, pd_seed = NULL, ... ) ## S3 method for class 'multiflashlight' light_profile2d( x, v = NULL, data = NULL, type = c("partial dependence", "predicted", "response", "residual", "shap"), breaks = NULL, n_bins = 11L, cut_type = "equal", pd_evaluate_at = NULL, pd_grid = NULL, ... )
light_profile2d(x, ...) ## Default S3 method: light_profile2d(x, ...) ## S3 method for class 'flashlight' light_profile2d( x, v = NULL, data = NULL, by = x$by, type = c("partial dependence", "predicted", "response", "residual", "shap"), breaks = NULL, n_bins = 11L, cut_type = "equal", use_linkinv = TRUE, counts = TRUE, counts_weighted = FALSE, pd_evaluate_at = NULL, pd_grid = NULL, pd_indices = NULL, pd_n_max = 1000L, pd_seed = NULL, ... ) ## S3 method for class 'multiflashlight' light_profile2d( x, v = NULL, data = NULL, type = c("partial dependence", "predicted", "response", "residual", "shap"), breaks = NULL, n_bins = 11L, cut_type = "equal", pd_evaluate_at = NULL, pd_grid = NULL, ... )
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed to |
v |
A vector of exactly two variable names to be profiled. |
data |
An optional |
by |
An optional vector of column names used to additionally group the results. |
type |
Type of the profile: Either "partial dependence", "predicted", "response", or "residual". |
breaks |
Named list of cut breaks specifying how to bin one or more numeric
variables. Used to overwrite automatic binning via |
n_bins |
Approximate number of unique values to evaluate for numeric |
cut_type |
Should numeric |
use_linkinv |
Should retransformation function be applied? Default is |
counts |
Should observation counts be added? |
counts_weighted |
If |
pd_evaluate_at |
An named list of evaluation points for one or more variables. Only relevant for type = "partial dependence". |
pd_grid |
An evaluation |
pd_indices |
A vector of row numbers to consider in calculating partial dependence profiles. Only used for type = "partial dependence". |
pd_n_max |
Maximum number of ICE profiles to calculate
(will be randomly picked from |
pd_seed |
Integer random seed used to select ICE profiles. Only used for type = "partial dependence". |
Different binning options are available, see arguments below.
For high resolution partial dependence plots, it might be necessary to specify
breaks
, pd_evaluate_at
or pd_grid
in order to avoid empty parts
in the plot. A high value of n_bins
might not have the desired effect as it
internally capped at the number of distinct values of a variable.
For partial dependence and prediction profiles, "model", "predict_function", "linkinv" and "data" are required. For response profiles it is "y", "linkinv" and "data". "data" can also be passed on the fly.
An object of class "light_profile2d" with the following elements:
data
A tibble containing results.
by
Names of group by variables.
v
The two variable names evaluated.
type
Same as input type
. For information only.
light_profile2d(default)
: Default method not implemented yet.
light_profile2d(flashlight)
: 2D profiles for flashlight.
light_profile2d(multiflashlight)
: 2D profiles for multiflashlight.
Friedman J. H. (2001). Greedy function approximation: A gradient boosting machine. The Annals of Statistics, 29:1189–1232.
light_profile()
, plot.light_profile2d()
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) fl_part <- flashlight( model = fit_part, label = "part", data = iris, y = "Sepal.Length" ) # No effect of Petal.Width plot(light_profile2d(fl_part, v = c("Petal.Length", "Petal.Width"))) # Second model includes Petal.Width fit_full <- lm(Sepal.Length ~ ., data = iris) fl_full <- flashlight( model = fit_full, label = "full", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_part, fl_full)) plot(light_profile2d(fls, v = c("Petal.Length", "Petal.Width")))
fit_part <- lm(Sepal.Length ~ Species + Petal.Length, data = iris) fl_part <- flashlight( model = fit_part, label = "part", data = iris, y = "Sepal.Length" ) # No effect of Petal.Width plot(light_profile2d(fl_part, v = c("Petal.Length", "Petal.Width"))) # Second model includes Petal.Width fit_full <- lm(Sepal.Length ~ ., data = iris) fl_full <- flashlight( model = fit_full, label = "full", data = iris, y = "Sepal.Length" ) fls <- multiflashlight(list(fl_part, fl_full)) plot(light_profile2d(fls, v = c("Petal.Length", "Petal.Width")))
DEPRECATED
Recode Factor Columns - DEPRECATED
light_recode(...) light_recode(...)
light_recode(...) light_recode(...)
... |
Deprecated. |
Error message.
Deprecated.
This function prepares values for drawing a scatter plot of predicted values, responses, or residuals against a selected variable.
light_scatter(x, ...) ## Default S3 method: light_scatter(x, ...) ## S3 method for class 'flashlight' light_scatter( x, v, data = x$data, by = x$by, type = c("predicted", "response", "residual", "shap"), use_linkinv = TRUE, n_max = 400, seed = NULL, ... ) ## S3 method for class 'multiflashlight' light_scatter(x, ...)
light_scatter(x, ...) ## Default S3 method: light_scatter(x, ...) ## S3 method for class 'flashlight' light_scatter( x, v, data = x$data, by = x$by, type = c("predicted", "response", "residual", "shap"), use_linkinv = TRUE, n_max = 400, seed = NULL, ... ) ## S3 method for class 'multiflashlight' light_scatter(x, ...)
x |
An object of class "flashlight" or "multiflashlight". |
... |
Further arguments passed from or to other methods. |
v |
The variable name to be shown on the x-axis. |
data |
An optional |
by |
An optional vector of column names used to additionally group the results. |
type |
Type of the profile: Either "predicted", "response", or "residual". |
use_linkinv |
Should retransformation function be applied? Default is |
n_max |
Maximum number of data rows to select. Will be randomly picked. |
seed |
An integer random seed used for subsampling. |
An object of class "light_scatter" with the following elements:
data
: A tibble with results.
by
: Same as input by
.
v
: The variable name evaluated.
type
: Same as input type
. For information only.
light_scatter(default)
: Default method not implemented yet.
light_scatter(flashlight)
: Variable profile for a flashlight.
light_scatter(multiflashlight)
: light_scatter for a multiflashlight.
fit_a <- lm(Sepal.Length ~ . -Petal.Length, data = iris) fit_b <- lm(Sepal.Length ~ ., data = iris) fl_a <- flashlight(model = fit_a, label = "no Petal.Length") fl_b <- flashlight(model = fit_b, label = "all") fls <- multiflashlight(list(fl_a, fl_b), data = iris, y = "Sepal.Length") plot(light_scatter(fls, v = "Petal.Width"), color = "darkred") sc <- light_scatter(fls, "Petal.Length", by = "Species", type = "residual") plot(sc)
fit_a <- lm(Sepal.Length ~ . -Petal.Length, data = iris) fit_b <- lm(Sepal.Length ~ ., data = iris) fl_a <- flashlight(model = fit_a, label = "no Petal.Length") fl_b <- flashlight(model = fit_b, label = "all") fls <- multiflashlight(list(fl_a, fl_b), data = iris, y = "Sepal.Length") plot(light_scatter(fls, v = "Petal.Width"), color = "darkred") sc <- light_scatter(fls, "Petal.Length", by = "Species", type = "residual") plot(sc)
Returns the most important variable names sorted descendingly.
most_important(x, top_m = Inf)
most_important(x, top_m = Inf)
x |
An object of class "light_importance". |
top_m |
Maximum number of important variables to be returned. |
A character vector of variable names sorted in descending importance.
fit <- lm(Sepal.Length ~ ., data = iris) fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") imp <- light_importance(fl) most_important(imp) most_important(imp, top_m = 2)
fit <- lm(Sepal.Length ~ ., data = iris) fl <- flashlight(model = fit, label = "lm", data = iris, y = "Sepal.Length") imp <- light_importance(fl) most_important(imp) most_important(imp, top_m = 2)
Combines a list of flashlights to an object of class "multiflashlight" and/or updates a multiflashlight.
multiflashlight(x, ...) ## Default S3 method: multiflashlight(x, ...) ## S3 method for class 'flashlight' multiflashlight(x, ...) ## S3 method for class 'list' multiflashlight(x, ...) ## S3 method for class 'multiflashlight' multiflashlight(x, ...)
multiflashlight(x, ...) ## Default S3 method: multiflashlight(x, ...) ## S3 method for class 'flashlight' multiflashlight(x, ...) ## S3 method for class 'list' multiflashlight(x, ...) ## S3 method for class 'multiflashlight' multiflashlight(x, ...)
x |
An object of class "multiflashlight", "flashlight" or a list of flashlights. |
... |
Optional arguments in the flashlights to update, see examples. |
An object of class "multiflashlight" (a named list of flashlight objects).
multiflashlight(default)
: Used to create a flashlight object.
No x
has to be passed in this case.
multiflashlight(flashlight)
: Updates an existing flashlight object and turns
into a multiflashlight.
multiflashlight(list)
: Creates (and updates) a multiflashlight from a list
of flashlights.
multiflashlight(multiflashlight)
: Updates an object of class "multiflashlight".
fit_lm <- lm(Sepal.Length ~ ., data = iris) fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = log), data = iris) mod_lm <- flashlight(model = fit_lm, label = "lm") mod_glm <- flashlight(model = fit_glm, label = "glm") (mods <- multiflashlight(list(mod_lm, mod_glm)))
fit_lm <- lm(Sepal.Length ~ ., data = iris) fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = log), data = iris) mod_lm <- flashlight(model = fit_lm, label = "lm") mod_glm <- flashlight(model = fit_glm, label = "glm") (mods <- multiflashlight(list(mod_lm, mod_glm)))
DEPRECATED
plot_counts(...)
plot_counts(...)
... |
Any input. |
Error message.
Minimal visualization of an object of class "light_breakdown" as waterfall plot. The object returned is of class "ggplot" and can be further customized.
## S3 method for class 'light_breakdown' plot(x, facet_scales = "free", facet_ncol = 1, rotate_x = FALSE, ...)
## S3 method for class 'light_breakdown' plot(x, facet_scales = "free", facet_ncol = 1, rotate_x = FALSE, ...)
x |
An object of class "light_breakdown". |
facet_scales |
Scales argument passed to |
facet_ncol |
|
rotate_x |
Should x axis labels be rotated by 45 degrees? |
... |
Further arguments passed to |
The waterfall plot is to be read from top to bottom. The first line describes the (weighted) average prediction in the query data used to start with. Then, each additional line shows how the prediction changes due to the impact of the corresponding variable. The last line finally shows the original prediction of the selected observation. Multiple flashlights are shown in different facets. Positive and negative impacts are visualized with different colors.
An object of class "ggplot".
Visualizes response-, prediction-, partial dependence, and/or ALE profiles
of a (multi-)flashlight with respect to a covariable v
.
Different flashlights or a single flashlight with one "by" variable are separated
by a facet wrap.
## S3 method for class 'light_effects' plot( x, use = c("response", "predicted", "pd"), zero_counts = TRUE, size_factor = 1, facet_scales = "free_x", facet_nrow = 1L, rotate_x = TRUE, show_points = TRUE, recode_labels = NULL, ... )
## S3 method for class 'light_effects' plot( x, use = c("response", "predicted", "pd"), zero_counts = TRUE, size_factor = 1, facet_scales = "free_x", facet_nrow = 1L, rotate_x = TRUE, show_points = TRUE, recode_labels = NULL, ... )
x |
An object of class "light_effects". |
use |
A vector of elements to show. Any subset of ("response", "predicted", "pd", "ale") or "all". Defaults to all except "ale" |
zero_counts |
Logical flag if 0 count levels should be shown on the x axis. |
size_factor |
Factor used to enlarge default |
facet_scales |
Scales argument passed to |
facet_nrow |
Number of rows in |
rotate_x |
Should x axis labels be rotated by 45 degrees? |
show_points |
Should points be added to the line (default is |
recode_labels |
Named vector of curve labels. The names refer to the usual labels, while the values are the desired labels, e.g., 'c("partial dependence" = PDP", "ale" = "ALE"). |
... |
Further arguments passed to geoms. |
An object of class "ggplot".
light_effects()
, plot_counts()
Use rpart.plot::rpart.plot()
to visualize trees fitted by
light_global_surrogate()
.
## S3 method for class 'light_global_surrogate' plot(x, type = 5, auto_main = TRUE, mfrow = NULL, ...)
## S3 method for class 'light_global_surrogate' plot(x, type = 5, auto_main = TRUE, mfrow = NULL, ...)
x |
An object of class "light_global_surrogate". |
type |
Plot type, see help of |
auto_main |
Automatic plot titles (only if multiple trees are shown). |
mfrow |
If multiple trees are shown in the same figure:
what value of |
... |
Further arguments passed to |
An object of class "ggplot".
Minimal visualization of an object of class "light_ice" as ggplot2::geom_line()
.
The object returned is of class "ggplot" and can be further customized.
## S3 method for class 'light_ice' plot(x, facet_scales = "fixed", rotate_x = FALSE, ...)
## S3 method for class 'light_ice' plot(x, facet_scales = "fixed", rotate_x = FALSE, ...)
x |
An object of class "light_ice". |
facet_scales |
Scales argument passed to |
rotate_x |
Should x axis labels be rotated by 45 degrees? |
... |
Further arguments passed to |
Each observation is visualized by a line. The first "by" variable is represented by the color, a second "by" variable or a multiflashlight by facets.
An object of class "ggplot".
Visualization of an object of class "light_importance" via ggplot2::geom_bar()
.
If available, standard errors are added by ggplot2::geom_errorbar()
.
The object returned is of class "ggplot" and can be further customized.
## S3 method for class 'light_importance' plot( x, top_m = Inf, swap_dim = FALSE, facet_scales = "fixed", rotate_x = FALSE, error_bars = TRUE, ... )
## S3 method for class 'light_importance' plot( x, top_m = Inf, swap_dim = FALSE, facet_scales = "fixed", rotate_x = FALSE, error_bars = TRUE, ... )
x |
An object of class "light_importance". |
top_m |
Maximum number of important variables to be returned. |
swap_dim |
If multiflashlight and one "by" variable or single flashlight with two "by" variables, swap the role of dodge/fill variable and facet variable. If multiflashlight or one "by" variable, use facets instead of colors. |
facet_scales |
Scales argument passed to |
rotate_x |
Should x axis labels be rotated by 45 degrees? |
error_bars |
Should error bars be added? Defaults to |
... |
Further arguments passed to |
The plot is organized as a bar plot with variable names as x-aesthetic.
Up to two additional dimensions (multiflashlight and one "by" variable or single
flashlight with two "by" variables) can be visualized by facetting and dodge/fill.
Set swap_dim = FALSE
to revert the role of these two dimensions.
One single additional dimension is visualized by a facet wrap,
or - if swap_dim = FALSE
- by dodge/fill.
An object of class "ggplot".
Minimal visualization of an object of class "light_performance" as
ggplot2::geom_bar()
. The object returned has class "ggplot",
and can be further customized.
## S3 method for class 'light_performance' plot( x, swap_dim = FALSE, geom = c("bar", "point"), facet_scales = "free_y", rotate_x = FALSE, ... )
## S3 method for class 'light_performance' plot( x, swap_dim = FALSE, geom = c("bar", "point"), facet_scales = "free_y", rotate_x = FALSE, ... )
x |
An object of class "light_performance". |
swap_dim |
Should representation of dimensions
(either two "by" variables or one "by" variable and multiflashlight)
of x aesthetic and dodge fill aesthetic be swapped? Default is |
geom |
Geometry of plot (either "bar" or "point") |
facet_scales |
Scales argument passed to |
rotate_x |
Should x axis labels be rotated by 45 degrees? |
... |
Further arguments passed to |
The plot is organized as a bar plot as follows: For flashlights without "by" variable specified, a single bar is drawn. Otherwise, the "by" variable (or the flashlight label if there is no "by" variable) is represented by the "x" aesthetic.
The flashlight label (in case of one "by" variable) is represented by dodged bars. This strategy makes sure that performance of different flashlights can be compared easiest. Set "swap_dim = TRUE" to revert the role of dodging and x aesthetic. Different metrics are always represented by facets.
An object of class "ggplot".
Minimal visualization of an object of class "light_profile". The object returned is of class "ggplot" and can be further customized.
## S3 method for class 'light_profile' plot( x, swap_dim = FALSE, facet_scales = "free_x", rotate_x = x$type != "partial dependence", show_points = TRUE, ... )
## S3 method for class 'light_profile' plot( x, swap_dim = FALSE, facet_scales = "free_x", rotate_x = x$type != "partial dependence", show_points = TRUE, ... )
x |
An object of class "light_profile". |
swap_dim |
If multiflashlight and one "by" variable or single flashlight with two "by" variables, swap the role of dodge/fill variable and facet variable. If multiflashlight or one "by" variable, use facets instead of colors. |
facet_scales |
Scales argument passed to |
rotate_x |
Should x axis labels be rotated by 45 degrees? |
show_points |
Should points be added to the line (default is |
... |
Further arguments passed to |
Either lines and points are plotted (if stats = "mean") or quartile boxes.
If there is a "by" variable or a multiflashlight, this first dimension
is represented by color (or if swap_dim = TRUE
by facets).
If there are two "by" variables or a multiflashlight with one "by" variable,
the first "by" variable is visualized as color, while the second one
or the multiflashlight is shown via facet (change with swap_dim
).
An object of class "ggplot".
light_profile()
, plot.light_effects()
Minimal visualization of an object of class "light_profile2d". The object returned is of class "ggplot" and can be further customized.
## S3 method for class 'light_profile2d' plot(x, swap_dim = FALSE, rotate_x = TRUE, numeric_as_factor = FALSE, ...)
## S3 method for class 'light_profile2d' plot(x, swap_dim = FALSE, rotate_x = TRUE, numeric_as_factor = FALSE, ...)
x |
An object of class "light_profile2d". |
swap_dim |
Swap the |
rotate_x |
Should the x axis labels be rotated by 45 degrees? Default is |
numeric_as_factor |
Should numeric x and y values be converted to factors first?
Default is |
... |
Further arguments passed to |
The main geometry is ggplot2::geom_tile()
. Additional dimensions
("by" variable(s) and/or multiflashlight) are represented by facet_wrap/grid
.
For all types of profiles except "partial dependence", it is natural to see
empty parts in the plot. These are combinations of the v
variables that
do not appear in the data. Even for type "partial dependence", such gaps can occur,
e.g. for cut_type = "quantile"
or if n_bins
are larger than the number
of distinct values of a v
variable.
Such gaps can be suppressed by setting numeric_as_factor = TRUE
or by using the arguments breaks
, pd_evaluate_at
or pd_grid
in
light_profile2d()
.
An object of class "ggplot".
Values are plotted against a variable. The object returned is of class "ggplot"
and can be further customized. To avoid overplotting, try alpha = 0.2
or
position = "jitter"
.
## S3 method for class 'light_scatter' plot(x, swap_dim = FALSE, facet_scales = "free_x", rotate_x = FALSE, ...)
## S3 method for class 'light_scatter' plot(x, swap_dim = FALSE, facet_scales = "free_x", rotate_x = FALSE, ...)
x |
An object of class "light_scatter". |
swap_dim |
If multiflashlight and one "by" variable, or single flashlight with two "by" variables, swap the role of color variable and facet variable. If multiflashlight or one "by" variable, use colors instead of facets. |
facet_scales |
Scales argument passed to |
rotate_x |
Should x axis labels be rotated by 45 degrees? |
... |
Further arguments passed to |
An object of class "ggplot".
Predict method for an object of class "flashlight".
Pass additional elements to update the flashlight, typically data
.
## S3 method for class 'flashlight' predict(object, ...)
## S3 method for class 'flashlight' predict(object, ...)
object |
An object of class "flashlight". |
... |
Arguments used to update the flashlight. |
A vector with predictions.
fit <- lm(Sepal.Length ~ ., data = iris) fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols") predict(fl)[1:5] predict(fl, data = iris[1:5, ])
fit <- lm(Sepal.Length ~ ., data = iris) fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols") predict(fl)[1:5] predict(fl, data = iris[1:5, ])
Predict method for an object of class "multiflashlight".
Pass additional elements to update the flashlight, typically data
.
## S3 method for class 'multiflashlight' predict(object, ...)
## S3 method for class 'multiflashlight' predict(object, ...)
object |
An object of class "multiflashlight". |
... |
Arguments used to update the multiflashlight. |
A named list of prediction vectors.
fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris) fit_full <- lm(Sepal.Length ~ ., data = iris) mod_full <- flashlight(model = fit_full, label = "full") mod_part <- flashlight(model = fit_part, label = "part") mods <- multiflashlight(list(mod_full, mod_part), data = iris, y = "Sepal.Length") predict(mods, data = iris[1:5, ])
fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris) fit_full <- lm(Sepal.Length ~ ., data = iris) mod_full <- flashlight(model = fit_full, label = "full") mod_part <- flashlight(model = fit_part, label = "part") mods <- multiflashlight(list(mod_full, mod_part), data = iris, y = "Sepal.Length") predict(mods, data = iris[1:5, ])
Print method for an object of class "flashlight".
## S3 method for class 'flashlight' print(x, ...)
## S3 method for class 'flashlight' print(x, ...)
x |
A on object of class "flashlight". |
... |
Further arguments passed from other methods. |
Invisibly, the input is returned.
fit <- lm(Sepal.Length ~ ., data = iris) x <- flashlight(model = fit, label = "lm", y = "Sepal.Length", data = iris) x
fit <- lm(Sepal.Length ~ ., data = iris) x <- flashlight(model = fit, label = "lm", y = "Sepal.Length", data = iris) x
Print method for an object of class "light".
## S3 method for class 'light' print(x, ...)
## S3 method for class 'light' print(x, ...)
x |
A on object of class "light". |
... |
Further arguments passed from other methods. |
Invisibly, the input is returned.
fit <- lm(Sepal.Length ~ ., data = iris) fl <- flashlight(model = fit, label = "lm", y = "Sepal.Length", data = iris) light_performance(fl, v = "Species")
fit <- lm(Sepal.Length ~ ., data = iris) fl <- flashlight(model = fit, label = "lm", y = "Sepal.Length", data = iris) light_performance(fl, v = "Species")
Print method for an object of class "multiflashlight".
## S3 method for class 'multiflashlight' print(x, ...)
## S3 method for class 'multiflashlight' print(x, ...)
x |
An object of class "multiflashlight". |
... |
Further arguments passed to |
Invisibly, the input is returned.
fit_lm <- lm(Sepal.Length ~ ., data = iris) fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = log), data = iris) fl_lm <- flashlight(model = fit_lm, label = "lm") fl_glm <- flashlight(model = fit_glm, label = "glm") multiflashlight(list(fl_lm, fl_glm), data = iris)
fit_lm <- lm(Sepal.Length ~ ., data = iris) fit_glm <- glm(Sepal.Length ~ ., family = Gamma(link = log), data = iris) fl_lm <- flashlight(model = fit_lm, label = "lm") fl_glm <- flashlight(model = fit_glm, label = "glm") multiflashlight(list(fl_lm, fl_glm), data = iris)
Residuals method for an object of class "flashlight". Pass additional elements to update the flashlight before calculation of residuals.
## S3 method for class 'flashlight' residuals(object, ...)
## S3 method for class 'flashlight' residuals(object, ...)
object |
An object of class "flashlight". |
... |
Arguments used to update the flashlight before calculating the residuals. |
A numeric vector with residuals.
fit <- lm(Sepal.Length ~ ., data = iris) x <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols") residuals(x)[1:5]
fit <- lm(Sepal.Length ~ ., data = iris) x <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols") residuals(x)[1:5]
Residuals method for an object of class "multiflashlight". Pass additional elements to update the multiflashlight before calculation of residuals.
## S3 method for class 'multiflashlight' residuals(object, ...)
## S3 method for class 'multiflashlight' residuals(object, ...)
object |
An object of class "multiflashlight". |
... |
Arguments used to update the multiflashlight before calculating the residuals. |
A named list with residuals per flashlight.
fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris) fit_full <- lm(Sepal.Length ~ ., data = iris) mod_full <- flashlight(model = fit_full, label = "full") mod_part <- flashlight(model = fit_part, label = "part") mods <- multiflashlight(list(mod_full, mod_part), data = iris, y = "Sepal.Length") residuals(mods, data = head(iris))
fit_part <- lm(Sepal.Length ~ Petal.Length, data = iris) fit_full <- lm(Sepal.Length ~ ., data = iris) mod_full <- flashlight(model = fit_full, label = "full") mod_part <- flashlight(model = fit_part, label = "part") mods <- multiflashlight(list(mod_full, mod_part), data = iris, y = "Sepal.Length") residuals(mods, data = head(iris))
Extracts response from object of class "flashlight".
response(object, ...) ## Default S3 method: response(object, ...) ## S3 method for class 'flashlight' response(object, ...) ## S3 method for class 'multiflashlight' response(object, ...)
response(object, ...) ## Default S3 method: response(object, ...) ## S3 method for class 'flashlight' response(object, ...) ## S3 method for class 'multiflashlight' response(object, ...)
object |
An object of class "flashlight". |
... |
Arguments used to update the flashlight before extracting the response. |
A numeric vector of responses.
response(default)
: Default method not implemented yet.
response(flashlight)
: Extract response from flashlight object.
response(multiflashlight)
: Extract responses from multiflashlight object.
fit <- lm(Sepal.Length ~ ., data = iris) (fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols")) response(fl)[1:5] response(fl, data = iris[1:5, ]) response(fl, data = iris[1:5, ], linkinv = exp)
fit <- lm(Sepal.Length ~ ., data = iris) (fl <- flashlight(model = fit, data = iris, y = "Sepal.Length", label = "ols")) response(fl)[1:5] response(fl, data = iris[1:5, ]) response(fl, data = iris[1:5, ], linkinv = exp)