diff --git a/R/methods_marginaleffects.R b/R/methods_marginaleffects.R index 3361b5f5d..fc76cffdd 100644 --- a/R/methods_marginaleffects.R +++ b/R/methods_marginaleffects.R @@ -12,7 +12,7 @@ model_parameters.marginaleffects <- function(model, insight::check_if_installed("marginaleffects") # Bayesian models have posterior draws as attribute - is_bayesian <- !is.null(suppressWarnings(marginaleffects::get_draws(model, "PxD"))) + is_bayesian <- suppressWarnings(!is.null(marginaleffects::get_draws(model, "PxD"))) if (is_bayesian) { # Bayesian @@ -48,9 +48,7 @@ model_parameters.marginaleffects <- function(model, # do not print or report these columns out <- out[, !colnames(out) %in% c("predicted_lo", "predicted_hi"), drop = FALSE] - if (inherits(model, "marginalmeans")) { - attr(out, "coefficient_name") <- "Marginal Means" - } else if (inherits(model, "comparisons")) { + if (inherits(model, "comparisons")) { attr(out, "coefficient_name") <- "Estimate" attr(out, "title") <- "Contrasts between Adjusted Predictions" if ("Type" %in% colnames(out)) { @@ -137,10 +135,11 @@ model_parameters.predictions <- function(model, out$rowid <- out$Type <- out$rowid_dedup <- NULL # find at-variables - at_variables <- attributes(model)$newdata_at - if (is.null(at_variables)) { - at_variables <- attributes(model)$by - } + at_variables <- c( + marginaleffects::components(model, "variable_names_datagrid"), + marginaleffects::components(model, "variable_names_by"), + marginaleffects::components(model, "variable_names_by_hypothesis") + ) # find cofficient name - differs for Bayesian models coef_name <- intersect(c("Predicted", "Coefficient"), colnames(out))[1] @@ -153,7 +152,7 @@ model_parameters.predictions <- function(model, } # extract response, remove from data frame - reg_model <- attributes(model)$model + reg_model <- marginaleffects::components(model, "model") if (!is.null(reg_model) && insight::is_model(reg_model)) { resp <- insight::find_response(reg_model) # check if response could be extracted diff --git a/tests/testthat/test-marginaleffects.R b/tests/testthat/test-marginaleffects.R index 45c6f38b5..79b69d5fb 100644 --- a/tests/testthat/test-marginaleffects.R +++ b/tests/testthat/test-marginaleffects.R @@ -180,13 +180,10 @@ test_that("predictions, using bayestestR #1063", { skip_if(is.null(m)) d <- insight::get_datagrid(m, by = "Days", include_random = TRUE) - x <- marginaleffects::avg_predictions(m, newdata = d, by = "Days") + x <- marginaleffects::predictions(m, newdata = d, allow_new_levels = TRUE) out <- model_parameters(x) - expect_named( - out, - c( - "Median", "CI", "CI_low", "CI_high", "pd", "ROPE_CI", "ROPE_low", - "ROPE_high", "ROPE_Percentage", "Days", "subgrp", "grp", "Subject" - ) - ) + cols <- c( + "Median", "CI", "CI_low", "CI_high", "pd", "ROPE_CI", "ROPE_low", + "ROPE_high", "ROPE_Percentage", "Days", "subgrp", "grp", "Subject") + expect_named(out, cols) })