From ec98b6e95aa96cd2a50fd561483bd442eb33e9e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20K=C3=B6nig?= <ye87zine@usr.idiv.de> Date: Thu, 5 Dec 2024 09:56:40 +0100 Subject: [PATCH] work on performance comparison --- .gitignore | 4 +- R/02_02_functional_traits_preparation.R | 44 ++-- R/05_01_performance_analysis.qmd | 2 +- R/05_02_MSDM_comparison.qmd | 283 +++++++++++++++++++++++- 4 files changed, 305 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index 4c95a7b..ac6cec0 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,5 @@ renv/cache/ # Data files data/ -R/05_performance_analysis_files -R/05_performance_analysis.html \ No newline at end of file +R/*/ +R/*.html \ No newline at end of file diff --git a/R/02_02_functional_traits_preparation.R b/R/02_02_functional_traits_preparation.R index 9d66867..0268d53 100644 --- a/R/02_02_functional_traits_preparation.R +++ b/R/02_02_functional_traits_preparation.R @@ -30,15 +30,13 @@ traits_matched = trait_names_matched %>% save(traits_matched, file = "data/r_objects/traits_matched.RData") -# Calculate Distances -library("vegan") - -load("data/r_objects/range_maps.RData") -load("data/r_objects/traits_matched.RData") -target_species = unique(range_maps$name_matched[!is.na(range_maps$name_matched)]) +# Clean trait data +diet_cols = c("Diet.Inv", "Diet.Vend", "Diet.Vect", "Diet.Vfish", "Diet.Vunk", "Diet.Scav", "Diet.Fruit", "Diet.Nect", "Diet.Seed", "Diet.PlantO") +strata_cols = c("ForStrat.Value") +activity_cols = c("Activity.Nocturnal", "Activity.Crepuscular", "Activity.Diurnal") +bodymass_cols = c("BodyMass.Value") -traits_target = traits_matched %>% - dplyr::filter(species %in% !!target_species) %>% +traits_proc = traits_matched %>% mutate(match_genus = stringr::str_detect(species, Genus), match_epithet = stringr::str_detect(species, Species)) %>% dplyr::group_by(species) %>% @@ -46,17 +44,15 @@ traits_target = traits_matched %>% if (nrow(.x) == 1) { return (.x) } - x_mod = .x - print(x_mod) while (nrow(x_mod) > 1){ if(any(isFALSE(x_mod$match_genus))){ x_mod = dplyr::filter(x_mod, isTRUE(match_genus)) next } - if(any(isFALSE(x_mod$match_species))){ - x_mod = dplyr::filter(x_mod, isTRUE(match_species)) + if(any(isFALSE(x_mod$match_epithet))){ + x_mod = dplyr::filter(x_mod, isTRUE(match_epithet)) next } @@ -68,19 +64,23 @@ traits_target = traits_matched %>% } else { return(.x[1,]) } - - }) + }) %>% + dplyr::select(c("species", diet_cols, strata_cols, activity_cols, bodymass_cols)) +save(traits_proc, file = "data/r_objects/traits_proc.RData") -# some pre-processing -traits_target$`ForStrat.Value.Proc` = as.numeric(factor(traits_target$`ForStrat.Value`, levels = c("G", "S", "Ar", "A"))) -traits_target$`BodyMass.Value.Proc` = scale(log(traits_target$`BodyMass.Value`)) +# Calculate Distances +library("vegan") -# Define columns -diet_cols = c("Diet.Inv", "Diet.Vend", "Diet.Vect", "Diet.Vfish", "Diet.Vunk", "Diet.Scav", "Diet.Fruit", "Diet.Nect", "Diet.Seed", "Diet.PlantO") -strata_cols = c("ForStrat.Value.Proc") -activity_cols = c("Activity.Nocturnal", "Activity.Crepuscular", "Activity.Diurnal") -bodymass_cols = c("BodyMass.Value.Proc") +load("data/r_objects/range_maps.RData") +load("data/r_objects/traits_proc.RData") + +target_species = unique(range_maps$name_matched[!is.na(range_maps$name_matched)]) +traits_target = dplyr::filter(traits_proc, species %in% target_species) + +# some pre-processing +traits_target$`ForStrat.Value` = as.numeric(factor(traits_target$`ForStrat.Value`, levels = c("G", "S", "Ar", "A"))) +traits_target$`BodyMass.Value` = scale(log(traits_target$`BodyMass.Value`)) diet_dist = vegan::vegdist(traits_target[,diet_cols], "bray") foraging_dist = dist(traits_target[,strata_cols]) diff --git a/R/05_01_performance_analysis.qmd b/R/05_01_performance_analysis.qmd index 2fea07f..7a99fe1 100644 --- a/R/05_01_performance_analysis.qmd +++ b/R/05_01_performance_analysis.qmd @@ -529,7 +529,7 @@ bslib::card(plot, full_screen = T) Functional groups were assigned based on taxonomic order. The following groupings were used: | Functional group | Taxomic orders | -|-------------------|-----------------------------------------------------| +|--------------------|----------------------------------------------------| | large ground-dwelling | Carnivora, Artiodactyla, Cingulata, Perissodactyla | | small ground-dwelling | Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha | | arboreal | Primates, Pilosa | diff --git a/R/05_02_MSDM_comparison.qmd b/R/05_02_MSDM_comparison.qmd index 667edb6..6de0327 100644 --- a/R/05_02_MSDM_comparison.qmd +++ b/R/05_02_MSDM_comparison.qmd @@ -10,6 +10,7 @@ library(tidyverse) library(sf) library(plotly) library(DT) +library(shiny) load("../data/r_objects/msdm_results_embedding_raw.RData") load("../data/r_objects/msdm_results_embedding_traits_static.RData") @@ -22,7 +23,12 @@ load("../data/r_objects/msdm_results_embedding_range_trained.RData") sf::sf_use_s2(use_s2 = FALSE) ``` + ```{r globals, echo = FALSE, include = FALSE} +# Select metrics +focal_metrics = c("auc", "f1", "accuracy") # There's a weird bug in plotly that scrambles up lines when using more than three groups + +# Prepare final dataframes results_embedding_raw = msdm_results_embedding_raw %>% dplyr::mutate( embedding = "raw", @@ -48,10 +54,57 @@ results_embedding_informed_merged = lapply(results_embedding_informed, function( results_final = results_embedding_raw %>% bind_rows(results_embedding_informed_merged) %>% - drop_na() + drop_na() %>% + mutate( + model = recode( + model, + "MSDM_embed" = "M_STD", + "MSDM_embed_informed_phylo_static" = "M_PS", + "MSDM_embed_informed_phylo_trained" = "M_PT", + "MSDM_embed_informed_traits_static" = "M_TS", + "MSDM_embed_informed_traits_trained" = "M_TT", + "MSDM_embed_informed_range_static" = "M_RS", + "MSDM_embed_informed_range_trained" = "M_RT" + ), + across(all_of(focal_metrics), round, 3) + ) + +results_final_long = results_final %>% + tidyr::pivot_longer(focal_metrics, names_to = "metric", values_to = "value") + +delta_performance = results_final %>% + dplyr::select(all_of(c("species", "model", focal_metrics))) %>% + tidyr::pivot_wider(id_cols = species, names_from = model, values_from = c(auc, accuracy, f1)) %>% + dplyr::mutate( + across(starts_with("auc_M_"), ~ . - auc_M_STD), + across(starts_with("accuracy_M_"), ~ . - accuracy_M_STD), + across(starts_with("f1_M_"), ~ . - f1_M_STD) + ) %>% + dplyr::select(-auc_M_STD, -accuracy_M_STD, -f1_M_STD) %>% + pivot_longer(-species, names_to = c("metric", "model"), names_pattern = "(^[a-zA-Z1-9]+)_(.+$)", values_to = "delta") + +# Regression functions +asym_regression = function(x, y){ + nls_fit = nls(y ~ 1 - (1-b) * exp(-c * log(x)), start = list(b = 0.1, c = 0.1)) + new_x = exp(seq(log(min(x)), log(max(x)), length.out = 100)) + data.frame( + x = new_x, + fit = predict(nls_fit, newdata = data.frame(x = new_x)) + ) +} + +lin_regression = function(x, y){ + glm_fit = suppressWarnings(glm(y~x, family = "binomial")) + new_x = seq(min(x), max(x), length.out = 100) + data.frame( + x = new_x, + fit = predict(glm_fit, newdata = data.frame(x = new_x), type = "response") + ) +} ``` -```{r globals, echo = FALSE} + +```{r create_summaries, echo = FALSE} auc_overview = results_final %>% pivot_wider(names_from = model, values_from = auc, id_cols = c(species, obs)) %>% dplyr::arrange(obs) @@ -63,4 +116,228 @@ accuracy_overview = results_final %>% f1_overview = results_final %>% pivot_wider(names_from = model, values_from = f1, id_cols = c(species, obs)) %>% dplyr::arrange(obs) -``` \ No newline at end of file +``` + +## *Model overview* + +::: panel-tabset + +### *AUC* + +```{r echo = FALSE} +datatable( + auc_overview, + options = list( + pageLength = 10, + initComplete = htmlwidgets::JS( + "function(settings, json) { + $(this.api().table().container()).css({'font-size': '10pt'}); + }" + ), + autoWidth = TRUE, + columnDefs = list(list(width = "250px", targets = 1)) + ) +) +``` + +### *Accuracy* + +```{r echo = FALSE} +datatable( + accuracy_overview, + options = list( + pageLength = 10, + initComplete = htmlwidgets::JS( + "function(settings, json) { + $(this.api().table().container()).css({'font-size': '10pt'}); + }" + ), + autoWidth = TRUE, + columnDefs = list(list(width = "250px", targets = 1)) + ) +) +``` + +### *F1 score* + +```{r echo = FALSE} +datatable( + f1_overview, + options = list( + pageLength = 10, + initComplete = htmlwidgets::JS( + "function(settings, json) { + $(this.api().table().container()).css({'font-size': '10pt'}); + }" + ), + autoWidth = TRUE, + columnDefs = list(list(width = "250px", targets = 1)) + ) +) +``` +::: + +## *Model comparison* +### *Number of records* + +```{r performance_vs_occurrences, echo = FALSE, message=FALSE, warnings=FALSE} +# Dropdown options +plotly_buttons = list() +for(metric in focal_metrics){ + plotly_buttons[[length(plotly_buttons) + 1]] = list(method = "restyle", args = list("transforms[0].value", metric), label = metric) +} + +df_plot = results_final_long + +# Calculate regression lines for each model and metric combination +suppressWarnings({ + regression_lines = df_plot %>% + group_by(model, metric) %>% + group_modify(~asym_regression(.x$obs, .x$value)) +}) + +# Create base plot +plot <- plot_ly() %>% + layout( + title = "Model Performance vs. Number of observations", + xaxis = list(title = "Number of observations", type = "log"), + yaxis = list(title = "Value"), + legend = list(x = 1.1, y = 0.5), # Move legend to the right of the plot + margin = list(r = 150), # Add right margin to accommodate legend + hovermode = 'closest', + updatemenus = list( + list( + type = "dropdown", + active = 0, + buttons = plotly_buttons + ) + ) + ) + +# Points +for (model_name in unique(df_plot$model)) { + plot = plot %>% + add_markers( + data = filter(df_plot, model == model_name), + x = ~obs, + y = ~value, + color = model_name, # Set color to match legendgroup + legendgroup = model_name, + opacity = 0.6, + name = ~model, + hoverinfo = 'text', + text = ~paste("Species:", species, "<br>Observations:", obs, "<br>Value:", round(value, 3)), + transforms = list( + list( + type = 'filter', + target = ~metric, + operation = '=', + value = focal_metrics[1] + ) + ) + ) +} + +# Add regression lines +for(model_name in unique(df_plot$model)){ + reg_data = dplyr::filter(regression_lines, model == model_name) + plot = plot %>% + add_lines( + data = reg_data, + x = ~x, + y = ~fit, + color = model_name, # Set color to match legendgroup + legendgroup = model_name, + name = paste(model_name, '(fit)'), + showlegend = FALSE, + transforms = list( + list( + type = 'filter', + target = ~metric, + operation = '=', + value = focal_metrics[1] + ) + ) + ) +} + +bslib::card(plot, full_screen = T) +``` + +```{r trait_pca, echo = FALSE, message=FALSE, warnings=FALSE} +load("../data/r_objects/traits_proc.RData") + +# Preprocess traits for PCA +traits_num = traits_proc %>% + ungroup() %>% + dplyr::mutate(temp_id = row_number()) %>% + pivot_wider( + names_from = ForStrat.Value, + values_from = temp_id, + values_fn = function(x) 1, + values_fill = 0, + names_prefix = "ForStrat." + ) %>% + drop_na(species) %>% + column_to_rownames(var = "species") %>% + scale() + +# Run PCA +traits_pca = prcomp(traits_num) +species_vcts = traits_pca$x %>% + as.data.frame() %>% + rownames_to_column(var = "species") +``` + + +```{r performance_vs_traits, echo = FALSE, message=FALSE, warnings=FALSE} +# Create plot df +df_plot = dplyr::inner_join(delta_performance, species_vcts) + +# Buttons +plotly_buttons = list() +for(metric in unique(df_plot$model)){ + plotly_buttons[[length(plotly_buttons) + 1]] = list(method = "restyle", args = list("transforms[0].value", metric), label = metric) +} + +# Create base plot +plot <- plot_ly() %>% + layout( + title = "Delta Performance Traits", + xaxis = list(title = "PC1"), + yaxis = list(title = "PC2"), + legend = list(x = 1.1, y = 0.5), # Move legend to the right of the plot + margin = list(r = 150), # Add right margin to accommodate legend + hovermode = 'closest', + updatemenus = list( + list( + type = "dropdown", + active = 0, + buttons = plotly_buttons + ) + ) + ) + +# Points +for (metric in unique(df_plot$metric)) { + plot = plot %>% + add_markers( + data = filter(df_plot, metric == !!metric), + x = ~PC1, + y = ~PC2, + color = ~delta, + opacity = 0.6, + hoverinfo = 'text', + transforms = list( + list( + type = 'filter', + target = ~metric, + operation = '=', + value = focal_metrics[1] + ) + ) + ) +} + +bslib::card(plot, full_screen = T) +``` -- GitLab