From c49fa18f5f7dc8c77429162cde8444c380944251 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?K=C3=B6nig?= <ye87zine@usr.idiv.de>
Date: Wed, 11 Dec 2024 22:50:59 +0100
Subject: [PATCH] implement multi embedding msdms. Enhance reports

---
 R/02_02_functional_traits_preparation.R |   1 +
 R/02_03_phylo_preparation.R             |   1 +
 R/04_07_msdm_embed_multi_nolonlat.R     | 107 +++++++
 R/04_08_msdm_embed_multi_lonlat.R       | 107 +++++++
 R/05_01_performance_analysis.qmd        |  24 +-
 R/05_02_MSDM_comparison.qmd             | 387 ++++++++++++++++++++++--
 6 files changed, 573 insertions(+), 54 deletions(-)
 create mode 100644 R/04_07_msdm_embed_multi_nolonlat.R
 create mode 100644 R/04_08_msdm_embed_multi_lonlat.R

diff --git a/R/02_02_functional_traits_preparation.R b/R/02_02_functional_traits_preparation.R
index 0268d53..dfec715 100644
--- a/R/02_02_functional_traits_preparation.R
+++ b/R/02_02_functional_traits_preparation.R
@@ -89,5 +89,6 @@ bodymass_dist = dist(traits_target[,bodymass_cols])
 
 func_dist = (diet_dist + foraging_dist/max(foraging_dist) + activity_dist + bodymass_dist/max(bodymass_dist)) / 4
 names(func_dist) = traits_target$species
+func_dist = as.matrix(func_dist)
 
 save(func_dist, file = "data/r_objects/func_dist.RData")
diff --git a/R/02_03_phylo_preparation.R b/R/02_03_phylo_preparation.R
index 3963892..a72c833 100644
--- a/R/02_03_phylo_preparation.R
+++ b/R/02_03_phylo_preparation.R
@@ -49,4 +49,5 @@ dists = lapply(indices, function(i){
 
 # Save result
 phylo_dist = Reduce("+", dists) / length(dists)
+phylo_dist = phylo_dist / max(phylo_dist)
 save(phylo_dist, file = "data/r_objects/phylo_dist.RData")
diff --git a/R/04_07_msdm_embed_multi_nolonlat.R b/R/04_07_msdm_embed_multi_nolonlat.R
new file mode 100644
index 0000000..2044e03
--- /dev/null
+++ b/R/04_07_msdm_embed_multi_nolonlat.R
@@ -0,0 +1,107 @@
+library(dplyr)
+library(tidyr)
+library(cito)
+
+source("R/utils.R")
+
+load("data/r_objects/model_data.RData")
+load("data/r_objects/func_dist.RData")
+load("data/r_objects/phylo_dist.RData")
+load("data/r_objects/range_dist.RData")
+
+# ----------------------------------------------------------------------#
+# Prepare data                                                       ####
+# ----------------------------------------------------------------------#
+model_species = Reduce(
+  intersect, 
+  list(unique(model_data$species), colnames(range_dist), colnames(phylo_dist), colnames(func_dist))
+) 
+
+model_data_final = model_data %>%
+  dplyr::filter(species %in% !!model_species) %>% 
+  dplyr::mutate(species_int = as.integer(as.factor(species)))
+
+train_data = dplyr::filter(model_data_final, train == 1)
+test_data = dplyr::filter(model_data_final, train == 0)
+
+# Create embeddings
+func_ind = match(model_species, colnames(func_dist))
+func_dist = func_dist[func_ind, func_ind]
+func_embeddings = eigen(func_dist)$vectors[,1:20]
+
+phylo_ind = match(model_species, colnames(phylo_dist))
+phylo_dist = phylo_dist[phylo_ind, phylo_ind]
+phylo_embeddings = eigen(phylo_dist)$vectors[,1:20]
+
+range_ind = match(model_species, colnames(range_dist))
+range_dist = range_dist[range_ind, range_ind]
+range_embeddings = eigen(range_dist)$vectors[,1:20]
+
+
+# ----------------------------------------------------------------------#
+# Train model                                                        ####
+# ----------------------------------------------------------------------#
+predictors = paste0("layer_", 1:19)
+
+formula = as.formula(
+  paste0("present ~ ", 
+         paste(predictors, collapse = '+'),  
+         " + e(species_int, weights = func_embeddings, lambda = 0.00001, train = F)",
+         " + e(species_int, weights = phylo_embeddings, lambda = 0.00001, train = F)",
+         " + e(species_int, weights = range_embeddings, lambda = 0.00001, train = F)"
+  )
+)
+
+plot(1, type="n", xlab="", ylab="", xlim=c(0, 25000), ylim=c(0, 0.7)) # empty plot with better limits, draw points in there
+msdm_fit_embedding_multi_nolonlat = dnn(
+  formula,
+  data = train_data,
+  hidden = c(500L, 500L, 500L),
+  loss = "binomial",
+  activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+  epochs = 30000L, 
+  lr = 0.01,   
+  baseloss = 1,
+  batchsize = nrow(train_data),
+  dropout = 0.1,
+  burnin = 100,
+  optimizer = config_optimizer("adam", weight_decay = 0.001),
+  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+  early_stopping = 250,
+  validation = 0.3,
+  device = "cuda",
+)
+save(msdm_fit_embedding_multi_nolonlat, file = "data/r_objects/msdm_fit_embedding_multi_nolonlat.RData")
+
+# ----------------------------------------------------------------------#
+# Evaluate results                                                   ####
+# ----------------------------------------------------------------------#
+load("data/r_objects/msdm_fit_embedding_multi_nolonlat.RData")
+data_split = test_data %>% 
+  group_by(species_int) %>% 
+  group_split()
+
+msdm_results_embedding_multi_nolonlat = lapply(data_split, function(data_spec){
+  target_species =  data_spec$species[1]
+  data_spec = dplyr::select(data_spec, -species)
+  
+  msdm_performance = tryCatch({
+    evaluate_model(msdm_fit_embedding_multi_nolonlat, data_spec)
+  }, error = function(e){
+    list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
+  })
+  
+  performance_summary = tibble(
+    species = !!target_species,
+    obs = length(which(model_data$species == target_species)),
+    model = "MSDM_embed_informed_multi_nolonlat",
+    auc = msdm_performance$AUC,
+    accuracy = msdm_performance$Accuracy,
+    kappa = msdm_performance$Kappa,
+    precision = msdm_performance$Precision,
+    recall = msdm_performance$Recall,
+    f1 = msdm_performance$F1
+  )
+}) %>% bind_rows()
+
+save(msdm_results_embedding_multi_nolonlat, file = "data/r_objects/msdm_results_embedding_multi_nolonlat.RData")
\ No newline at end of file
diff --git a/R/04_08_msdm_embed_multi_lonlat.R b/R/04_08_msdm_embed_multi_lonlat.R
new file mode 100644
index 0000000..2f2e03f
--- /dev/null
+++ b/R/04_08_msdm_embed_multi_lonlat.R
@@ -0,0 +1,107 @@
+library(dplyr)
+library(tidyr)
+library(cito)
+
+source("R/utils.R")
+
+load("data/r_objects/model_data.RData")
+load("data/r_objects/func_dist.RData")
+load("data/r_objects/phylo_dist.RData")
+load("data/r_objects/range_dist.RData")
+
+# ----------------------------------------------------------------------#
+# Prepare data                                                       ####
+# ----------------------------------------------------------------------#
+model_species = Reduce(
+  intersect, 
+  list(unique(model_data$species), colnames(range_dist), colnames(phylo_dist), colnames(func_dist))
+) 
+
+model_data_final = model_data %>%
+  dplyr::filter(species %in% !!model_species) %>% 
+  dplyr::mutate(species_int = as.integer(as.factor(species)))
+
+train_data = dplyr::filter(model_data_final, train == 1)
+test_data = dplyr::filter(model_data_final, train == 0)
+
+# Create embeddings
+func_ind = match(model_species, colnames(func_dist))
+func_dist = func_dist[func_ind, func_ind]
+func_embeddings = eigen(func_dist)$vectors[,1:20]
+
+phylo_ind = match(model_species, colnames(phylo_dist))
+phylo_dist = phylo_dist[phylo_ind, phylo_ind]
+phylo_embeddings = eigen(phylo_dist)$vectors[,1:20]
+
+range_ind = match(model_species, colnames(range_dist))
+range_dist = range_dist[range_ind, range_ind]
+range_embeddings = eigen(range_dist)$vectors[,1:20]
+
+
+# ----------------------------------------------------------------------#
+# Train model                                                        ####
+# ----------------------------------------------------------------------#
+predictors = c("longitude", "latitude", paste0("layer_", 1:19))
+
+formula = as.formula(
+  paste0("present ~ ", 
+         paste(predictors, collapse = '+'),  
+         " + e(species_int, weights = func_embeddings, lambda = 0.00001, train = F)",
+         " + e(species_int, weights = phylo_embeddings, lambda = 0.00001, train = F)",
+         " + e(species_int, weights = range_embeddings, lambda = 0.00001, train = F)"
+  )
+)
+
+plot(1, type="n", xlab="", ylab="", xlim=c(0, 25000), ylim=c(0, 0.7)) # empty plot with better limits, draw points in there
+msdm_fit_embedding_multi_lonlat = dnn(
+  formula,
+  data = train_data,
+  hidden = c(500L, 500L, 500L),
+  loss = "binomial",
+  activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+  epochs = 30000L, 
+  lr = 0.01,   
+  baseloss = 1,
+  batchsize = nrow(train_data),
+  dropout = 0.1,
+  burnin = 100,
+  optimizer = config_optimizer("adam", weight_decay = 0.001),
+  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+  early_stopping = 250,
+  validation = 0.3,
+  device = "cuda",
+)
+save(msdm_fit_embedding_multi_lonlat, file = "data/r_objects/msdm_fit_embedding_multi_lonlat.RData")
+
+# ----------------------------------------------------------------------#
+# Evaluate results                                                   ####
+# ----------------------------------------------------------------------#
+load("data/r_objects/msdm_fit_embedding_multi_lonlat.RData")
+data_split = test_data %>% 
+  group_by(species_int) %>% 
+  group_split()
+
+msdm_results_embedding_multi_lonlat = lapply(data_split, function(data_spec){
+  target_species =  data_spec$species[1]
+  data_spec = dplyr::select(data_spec, -species)
+  
+  msdm_performance = tryCatch({
+    evaluate_model(msdm_fit_embedding_multi_lonlat, data_spec)
+  }, error = function(e){
+    list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
+  })
+  
+  performance_summary = tibble(
+    species = !!target_species,
+    obs = length(which(model_data$species == target_species)),
+    model = "MSDM_embed_informed_multi_lonlat",
+    auc = msdm_performance$AUC,
+    accuracy = msdm_performance$Accuracy,
+    kappa = msdm_performance$Kappa,
+    precision = msdm_performance$Precision,
+    recall = msdm_performance$Recall,
+    f1 = msdm_performance$F1
+  )
+}) %>% bind_rows()
+
+save(msdm_results_embedding_multi_lonlat, file = "data/r_objects/msdm_results_embedding_multi_lonlat.RData")
\ No newline at end of file
diff --git a/R/05_01_performance_analysis.qmd b/R/05_01_performance_analysis.qmd
index 7a99fe1..e2c2f33 100644
--- a/R/05_01_performance_analysis.qmd
+++ b/R/05_01_performance_analysis.qmd
@@ -529,7 +529,7 @@ bslib::card(plot, full_screen = T)
 Functional groups were assigned based on taxonomic order. The following groupings were used:
 
 | Functional group      | Taxomic orders                                                        |
-|--------------------|----------------------------------------------------|
+|------------------|-----------------------------------------------------|
 | large ground-dwelling | Carnivora, Artiodactyla, Cingulata, Perissodactyla                    |
 | small ground-dwelling | Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha |
 | arboreal              | Primates, Pilosa                                                      |
@@ -583,25 +583,3 @@ plot <- plot %>%
 bslib::card(plot, full_screen = T)
 ```
 
-### Relative performance
-
-The table below summarizes the relative performance of the models across different observation frequency ranges. The `rank` column indicates the model's performance rank compared to all other models for a given combination of model and metric. The subsequent columns, `(1,10]`, `(10,25]`, ..., `(5000, Inf]`, represent bins of observation frequency. The values in these columns show how many times the model's performance was ranked at the specified `rank` within the respective frequency range.
-
-```{r msdm_vs_ssdm, echo = FALSE, message=FALSE, warnings=FALSE}
-freq_thresholds = c(1, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, Inf)
-
-df_print = performance %>%
-  mutate(freq_class = cut(obs, freq_thresholds, dig.lab = 5)) %>%
-  group_by(species, metric, freq_class) %>%
-  dplyr::mutate(rank = order(value, decreasing = T)) %>% 
-  group_by(model, metric, rank, freq_class) %>% 
-  tally() %>% 
-  pivot_wider(names_from = freq_class, values_from = n) %>% 
-  dplyr::select("model", "metric", "rank", "(1,10]", "(10,25]", 
-                "(25,50]", "(50,100]", "(100,250]", "(250,500]", 
-                "(500,1000]", "(1000,2500]", "(2500,5000]", "(5000,Inf]") %>% 
-  dplyr::arrange(metric, rank, model) %>% 
-  replace(is.na(.), 0)
-
-DT::datatable(df_print)
-```
diff --git a/R/05_02_MSDM_comparison.qmd b/R/05_02_MSDM_comparison.qmd
index 6de0327..0adece2 100644
--- a/R/05_02_MSDM_comparison.qmd
+++ b/R/05_02_MSDM_comparison.qmd
@@ -23,7 +23,6 @@ load("../data/r_objects/msdm_results_embedding_range_trained.RData")
 sf::sf_use_s2(use_s2 = FALSE)
 ```
 
-
 ```{r globals, echo = FALSE, include = FALSE}
 # Select metrics
 focal_metrics = c("auc", "f1", "accuracy")  # There's a weird bug in plotly that scrambles up lines when using more than three groups
@@ -103,7 +102,6 @@ lin_regression = function(x, y){
 }
 ```
 
-
 ```{r create_summaries, echo = FALSE}
 auc_overview = results_final %>% 
   pivot_wider(names_from = model, values_from = auc, id_cols = c(species, obs)) %>% 
@@ -120,8 +118,21 @@ f1_overview = results_final %>%
 
 ## *Model overview*
 
-::: panel-tabset
+### *Performance summary*
+
+```{r echo = FALSE, message=FALSE, warnings=FALSE}
+p_summary = results_final_long %>% 
+  group_by(model, metric) %>% 
+  dplyr::summarize(value = round(mean(value, na.rm = T), 3)) %>% 
+  pivot_wider(names_from = model, values_from = value) %>% 
+  select(metric, M_STD, M_TS, M_PS, M_RS, M_TT, M_PT, M_RT)
+
+datatable(p_summary)
+```
+
+Exact results for different performance metrics across all models are shown below.
 
+::: panel-tabset
 ### *AUC*
 
 ```{r echo = FALSE}
@@ -177,8 +188,7 @@ datatable(
 ```
 :::
 
-## *Model comparison*
-### *Number of records*
+## *Number of records*
 
 ```{r performance_vs_occurrences, echo = FALSE, message=FALSE, warnings=FALSE}
 # Dropdown options
@@ -264,6 +274,54 @@ for(model_name in unique(df_plot$model)){
 bslib::card(plot, full_screen = T)
 ```
 
+
+## *Relative Performance*
+
+### *Ranking*
+
+
+```{r delta, echo = FALSE, message=FALSE, warnings=FALSE}
+results_ranked = results_final_long %>% 
+  group_by(species, metric) %>% 
+  mutate(rank = rev(rank(value))) %>% 
+  group_by(model, metric) %>% 
+  summarize(mean_rank = mean(rank)) %>% 
+  group_by(metric) %>% 
+  mutate(position = rank(mean_rank))
+
+results_ranked_obs = results_final_long %>% 
+  group_by(species,  metric) %>% 
+  mutate(rank = rev(rank(value)))
+
+ggplot(data = results_ranked_obs, aes(x = obs, y = rank, color = model)) +
+  geom_point(alpha = 0.1) +
+  scale_y_continuous(name = "rank (lower is better)") +
+  scale_x_log10() +
+  geom_smooth() +
+  theme_minimal()
+
+# The table below summarizes the relative performance of the models across different observation frequency ranges. The `rank` column indicates the model's performance rank compared to all other models for a given combination of model and metric. The subsequent columns, `(1,10]`, `(10,25]`, ..., `(5000, Inf]`, represent bins of observation frequency. The values in these columns show how many times the model's performance was ranked at the specified `rank` within the respective frequency range.
+
+# freq_thresholds = c(1, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, Inf)
+# 
+# df_print = performance %>%
+#   mutate(freq_class = cut(obs, freq_thresholds, dig.lab = 5)) %>%
+#   group_by(species, metric, freq_class) %>%
+#   dplyr::mutate(rank = order(value, decreasing = T)) %>% 
+#   group_by(model, metric, rank, freq_class) %>% 
+#   tally() %>% 
+#   pivot_wider(names_from = freq_class, values_from = n) %>% 
+#   dplyr::select("model", "metric", "rank", "(1,10]", "(10,25]", 
+#                 "(25,50]", "(50,100]", "(100,250]", "(250,500]", 
+#                 "(500,1000]", "(1000,2500]", "(2500,5000]", "(5000,Inf]") %>% 
+#   dplyr::arrange(metric, rank, model) %>% 
+#   replace(is.na(.), 0)
+# 
+# DT::datatable(df_print)
+```
+
+### *Trait space*
+
 ```{r trait_pca, echo = FALSE, message=FALSE, warnings=FALSE}
 load("../data/r_objects/traits_proc.RData")
 
@@ -287,25 +345,65 @@ traits_pca = prcomp(traits_num)
 species_vcts = traits_pca$x %>% 
   as.data.frame() %>% 
   rownames_to_column(var = "species")
+
+df_performance_vs_traits = dplyr::inner_join(delta_performance, species_vcts)
 ```
 
+Functional traits for `r I(nrow(traits_proc))` species were used to construct a multidimensional trait space. Before ordination, categorical variables were converted into dummy variables, and all variables were scaled and centered. A Principal Component Analysis (PCA) was performed on the preprocessed dataset, and the first three axes are used to visualize species' positions within the trait space. For reference, the rotation vectors of the original traits are also plotted.
 
-```{r performance_vs_traits, echo = FALSE, message=FALSE, warnings=FALSE}
+::: panel-tabset
+#### *AUC*
+
+```{r auc_trait_space, echo = FALSE, message=FALSE, warnings=FALSE}
 # Create plot df
-df_plot = dplyr::inner_join(delta_performance, species_vcts)
+df_plot = dplyr::filter(df_performance_vs_traits, metric == "auc", !is.na(delta)) %>% 
+  group_by(species) %>% 
+  mutate(
+    PC1 = PC1 + runif(1, -0.15, 0.15), # add jitter
+    PC2 = PC2 + runif(1, -0.15, 0.15)  # add jitter
+  ) 
 
 # Buttons
 plotly_buttons = list()
-for(metric in unique(df_plot$model)){
-  plotly_buttons[[length(plotly_buttons) + 1]] = list(method = "restyle", args = list("transforms[0].value", metric), label = metric)
+for(model in c("M_PS", "M_PT", "M_TS", "M_TT", "M_RS", "M_RT")){
+  plotly_buttons[[length(plotly_buttons) + 1]] = list(method = "restyle", args = list("transforms[0].value", model), label = model)
 }
 
-# Create base plot
-plot <- plot_ly() %>% 
+# Create plot
+plot <- plot_ly(
+  marker = list(           
+    color = ~delta,
+    colors = colorRamp(c("blue", "lightgrey", "red")),
+    cauto = F,
+    cmin = -1,
+    cmid = 0, 
+    cmax = 1,
+    opacity = 0.6,
+    colorbar = list(
+      title = "\u0394 AUC",
+      titleside = "right",  
+      x = 1.1,  # Adjust x position of colorbar
+      y = 0.5   # Adjust y position of colorbar
+    )
+  )) %>%
+  add_markers(
+    data = df_plot,
+    x = ~PC1,
+    y = ~PC2,
+    z = ~PC3,
+    hoverinfo = 'text',
+    text = ~paste("Species:", species, "<br>\u0394:", round(delta, 3)),
+    transforms = list(
+      list(
+        type = 'filter',
+        target = ~model,
+        operation = '=',
+        value = "M_TS"
+      )
+    )
+  ) %>% 
   layout(
-    title = "Delta Performance Traits",
-    xaxis = list(title = "PC1"),
-    yaxis = list(title = "PC2"),
+    title = "Relative model performance (trait space)",
     legend = list(x = 1.1, y = 0.5),  # Move legend to the right of the plot
     margin = list(r = 150),  # Add right margin to accommodate legend
     hovermode = 'closest',
@@ -316,28 +414,255 @@ plot <- plot_ly() %>%
         buttons = plotly_buttons
       )
     )
-  )
+  ) 
+
+trait_loadings = traits_pca$rotation * traits_pca$sdev
+for (i in 1:nrow(trait_loadings)) {
+  plot <- plot %>%
+    add_trace(
+      x = c(0, trait_loadings[i, 1]),
+      y = c(0, trait_loadings[i, 2]),
+      z = c(0, trait_loadings[i, 3]),
+      type = "scatter3d",
+      mode = "lines+markers+text",
+      line = list(
+        color = "black",
+        width = 2,
+        arrow = list(
+          end = 1,
+          type = "open",
+          length = 0.1
+        )
+      ),
+      marker = list(
+        color = "black",
+        size = 1,
+        symbol="arrow"
+      ),
+      text = rownames(trait_loadings)[i],
+      textposition = "top right",
+      opacity = 1,
+      hoverinfo = "text",
+      text = rownames(trait_loadings)[i],
+      showlegend = FALSE
+    )
+}
 
-# Points
-for (metric in unique(df_plot$metric)) {
-  plot = plot %>%
-    add_markers(
-      data = filter(df_plot, metric == !!metric),
-      x = ~PC1,
-      y = ~PC2,
-      color = ~delta,
-      opacity = 0.6,
-      hoverinfo = 'text',
-      transforms = list(
-        list(
-          type = 'filter',
-          target = ~metric,
-          operation = '=',
-          value = focal_metrics[1]
+
+bslib::card(plot, full_screen = T)
+```
+
+#### *Accuracy*
+
+```{r accuracy_trait_space, echo = FALSE, message=FALSE, warnings=FALSE}
+# Create plot df
+df_plot = dplyr::filter(df_performance_vs_traits, metric == "accuracy", !is.na(delta)) %>% 
+  group_by(species) %>% 
+  mutate(     # add jitter
+    PC1 = PC1 + runif(1, -0.15, 0.15),
+    PC2 = PC2 + runif(1, -0.15, 0.15),
+    PC3 = PC3 + runif(1, -0.15, 0.15)
+  ) 
+
+# Buttons
+plotly_buttons = list()
+for(model in c("M_PS", "M_PT", "M_TS", "M_TT", "M_RS", "M_RT")){
+  plotly_buttons[[length(plotly_buttons) + 1]] = list(method = "restyle", args = list("transforms[0].value", model), label = model)
+}
+
+# Create plot
+plot <- plot_ly(
+  marker = list(           
+    color = ~delta,
+    colors = colorRamp(c("blue", "lightgrey", "red")),
+    cauto = F,
+    cmin = -1,
+    cmid = 0, 
+    cmax = 1,
+    opacity = 0.6,
+    colorbar = list(
+      title = "\u0394 Accuracy",
+      titleside = "right",  
+      x = 1.1,  # Adjust x position of colorbar
+      y = 0.5   # Adjust y position of colorbar
+    )
+  )) %>%
+  add_markers(
+    data = df_plot,
+    x = ~PC1,
+    y = ~PC2,
+    z = ~PC3,
+    hoverinfo = 'text',
+    text = ~paste("Species:", species, "<br>\u0394:", round(delta, 3)),
+    transforms = list(
+      list(
+        type = 'filter',
+        target = ~model,
+        operation = '=',
+        value = "M_TS"
+      )
+    )
+  ) %>% 
+  layout(
+    title = "Relative model performance (trait space)",
+    legend = list(x = 1.1, y = 0.5),  # Move legend to the right of the plot
+    margin = list(r = 150),  # Add right margin to accommodate legend
+    hovermode = 'closest',
+    updatemenus = list(
+      list(
+        type = "dropdown",
+        active = 0,
+        buttons = plotly_buttons
+      )
+    )
+  ) 
+
+trait_loadings = traits_pca$rotation * traits_pca$sdev
+for (i in 1:nrow(trait_loadings)) {
+  plot <- plot %>%
+    add_trace(
+      x = c(0, trait_loadings[i, 1]),
+      y = c(0, trait_loadings[i, 2]),
+      z = c(0, trait_loadings[i, 3]),
+      type = "scatter3d",
+      mode = "lines+markers+text",
+      line = list(
+        color = "black",
+        width = 2,
+        arrow = list(
+          end = 1,
+          type = "open",
+          length = 0.1
         )
+      ),
+      marker = list(
+        color = "black",
+        size = 1,
+        symbol="arrow"
+      ),
+      text = rownames(trait_loadings)[i],
+      textposition = "top right",
+      opacity = 1,
+      hoverinfo = "text",
+      text = rownames(trait_loadings)[i],
+      showlegend = FALSE
+    )
+}
+
+
+bslib::card(plot, full_screen = T)
+```
+
+#### *F1*
+
+```{r f1_trait_space, echo = FALSE, message=FALSE, warnings=FALSE}
+# Create plot df
+df_plot = dplyr::filter(df_performance_vs_traits, metric == "f1", !is.na(delta)) %>% 
+  group_by(species) %>% 
+  mutate(
+    PC1 = PC1 + runif(1, -0.15, 0.15),    # add jitter
+    PC2 = PC2 + runif(1, -0.15, 0.15)     # add jitter 
+  ) 
+
+# Buttons
+plotly_buttons = list()
+for(model in c("M_PS", "M_PT", "M_TS", "M_TT", "M_RS", "M_RT")){
+  plotly_buttons[[length(plotly_buttons) + 1]] = list(method = "restyle", args = list("transforms[0].value", model), label = model)
+}
+
+# Create plot
+plot <- plot_ly(
+  marker = list(           
+    color = ~delta,
+    colors = colorRamp(c("blue", "lightgrey", "red")),
+    cauto = F,
+    cmin = -1,
+    cmid = 0, 
+    cmax = 1,
+    opacity = 0.6,
+    colorbar = list(
+      title = "\u0394 F1",
+      titleside = "right",  
+      x = 1.1,  # Adjust x position of colorbar
+      y = 0.5   # Adjust y position of colorbar
+    )
+  )) %>%
+  add_markers(
+    data = df_plot,
+    x = ~PC1,
+    y = ~PC2,
+    z = ~PC3,
+    hoverinfo = 'text',
+    text = ~paste("Species:", species, "<br>\u0394:", round(delta, 3)),
+    transforms = list(
+      list(
+        type = 'filter',
+        target = ~model,
+        operation = '=',
+        value = "M_TS"
       )
     )
+  ) %>% 
+  layout(
+    title = "Relative model performance (trait space)",
+    legend = list(x = 1.1, y = 0.5),  # Move legend to the right of the plot
+    margin = list(r = 150),  # Add right margin to accommodate legend
+    hovermode = 'closest',
+    updatemenus = list(
+      list(
+        type = "dropdown",
+        active = 0,
+        buttons = plotly_buttons
+      )
+    )
+  ) 
+
+trait_loadings = traits_pca$rotation * traits_pca$sdev
+for (i in 1:nrow(trait_loadings)) {
+  plot <- plot %>%
+    add_trace(
+      x = c(0, trait_loadings[i, 1]),
+      y = c(0, trait_loadings[i, 2]),
+      z = c(0, trait_loadings[i, 3]),
+      type = "scatter3d",
+      mode = "lines+markers+text",
+      line = list(
+        color = "black",
+        width = 2,
+        arrow = list(
+          end = 1,
+          type = "open",
+          length = 0.1
+        )
+      ),
+      marker = list(
+        color = "black",
+        size = 1,
+        symbol="arrow"
+      ),
+      text = rownames(trait_loadings)[i],
+      textposition = "top right",
+      opacity = 1,
+      hoverinfo = "text",
+      text = rownames(trait_loadings)[i],
+      showlegend = FALSE
+    )
 }
 
 bslib::card(plot, full_screen = T)
 ```
+:::
+
+### *Taxonomy*
+
+```{r taxonomy, echo = FALSE, message=FALSE, warnings=FALSE}
+# load("../data/r_objects/functional_groups.RData")
+# 
+# df_plot = results_final %>% 
+#   dplyr::left_join(functional_groups, by = c("species" = "name_matched"))
+# 
+# plot = 
+#   
+# bslib::card(plot, full_screen = T)
+
+```
-- 
GitLab