Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
05_02_publication_analysis.R 19.84 KiB
# General packages
library(tidyverse)
library(patchwork)

# Geo packages
library(terra)
library(sf)
library(geos)

# Stats packages
library(Rtsne)
library(cito)

source("R/utils.R")

load("data/r_objects/model_data_random_abs_extended.RData")
load("data/r_objects/ssdm_results.RData")
load("data/r_objects/msdm_embed_performance.RData")
load("data/r_objects/msdm_onehot_performance.RData")
load("data/r_objects/functional_groups.RData")

load("data/r_objects/sa_polygon.RData")
load("data/r_objects/range_maps.RData")

sf::sf_use_s2(use_s2 = FALSE)

model_data = model_data %>% 
  dplyr::filter(!is.na(fold_eval)) %>% 
  dplyr::mutate(species = as.factor(species))

# ------------------------------------------------------------------ #
# 1. Collect performance results                                  ####
# ------------------------------------------------------------------ #
performance = ssdm_results %>% 
  bind_rows(msdm_embed_performance) %>% 
  bind_rows(msdm_onehot_performance) %>% 
  dplyr::group_by(species, model, metric) %>% 
  dplyr::summarise(value = mean(value, na.rm = F)) %>% 
  dplyr::mutate(
    metric = factor(tolower(metric), levels = c("auc", "f1", "kappa", "accuracy", "precision", "recall")),
    model = factor(model, levels = c("GAM", "GBM", "RF", "NN", "MSDM_onehot", "MSDM_embed")),
    value = case_when(
      ((is.na(value) | is.nan(value)) & metric %in% c("auc", "f1", "accuracy", "precision", "recall")) ~ 0.5,
      ((is.na(value) | is.nan(value)) & metric %in% c("kappa")) ~ 0,
      .default = value
    )
  )

# ------------------------------------------------------------------ #
# 2. Model comparison                                             ####
# ------------------------------------------------------------------ #
## Overall model Comparison ####
df_plot = performance %>% 
  dplyr::filter(metric %in% c("auc", "f1", "kappa", "accuracy", "precision", "recall"))

ggplot(df_plot, aes(x = model, y = value, color = model)) +
  geom_boxplot(alpha = 0.5, outlier.shape = 16) + 
  facet_wrap(~ metric, scales = "free_y") +
  labs(
    title = "Model Performance Across Metrics",
    x = "Model",
    y = "Value",
    color = "Model",
  ) +
  theme_minimal(base_size = 14) +
  theme(
    strip.text = element_text(face = "bold"),
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "right"
  )

ggsave("plots/publication/model_performance_across_metrics.pdf", 
       device = "pdf", 
       scale = 2,
       width = 18, 
       height = 14,
       units = "cm")

## Performance vs number of records ####
obs_count = model_data %>% 
  sf::st_drop_geometry() %>% 
  dplyr::filter(present == 1, !is.na(fold_eval)) %>% 
  dplyr::group_by(species) %>% 
  dplyr::summarise(obs = n())

df_plot = performance %>% 
  dplyr::filter(metric %in% c("auc", "f1", "kappa", "accuracy", "precision", "recall")) %>% 
  dplyr::left_join(obs_count, by = "species") 

ggplot(df_plot, aes(x = obs, y = value, color = model, fill = model)) +
  geom_point(alpha = 0.1) +
  geom_smooth() + 
  facet_wrap(~ metric, scales = "free_y") +
  scale_x_continuous(trans = "log10") +
  labs(
    title = "Model Performance vs Number of Records",
    x = "Number of Records (log scale)",
    y = "Value",
    color = "Model",
    fill = "Model"
  ) +
  theme_minimal(base_size = 14) +
  theme(
    strip.text = element_text(face = "bold"),
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "right"
  )

ggsave("plots/publication/model_performance_vs_number_of_records.pdf", 
       device = "pdf", 
       scale = 2,
       width = 18, 
       height = 14,
       units = "cm")

## Performance vs functional group ####
load("data/r_objects/functional_groups.RData")

df_plot = performance %>% 
  dplyr::filter(metric %in% c("auc", "f1", "kappa")) %>% 
  dplyr::left_join(functional_groups, by = c("species" = "name_matched"))

ggplot(df_plot, aes(x = model, y = value, color = model)) +
  geom_boxplot(alpha = 0.5, outlier.shape = 16) + 
  facet_grid(rows = vars(metric), cols = vars(functional_group), scales = "free_y", switch = "both") +
  labs(
    title = "Model Performance Across Functional Groups",
    x = "Model",
    y = "Value",
    color = "Model",
  ) +
  theme_minimal(base_size = 14) +
  theme(
    strip.text = element_text(face = "bold"),
    panel.border = element_rect(color = "gray", fill = NA),
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "right"
  )

ggsave("plots/publication/model_performance_vs_functional_groups.pdf", 
       device = "pdf", 
       scale = 2,
       width = 18, 
       height = 14,
       units = "cm")

# ------------------------------------------------------------------ #
# 3. Range predictions                                            ####
# ------------------------------------------------------------------ #
library(caret)
library(gam)
library(gbm)
library(cito)
library(ranger)

# Define plotting function
plot_predictions = function(spec, model_data, raster_data, algorithms){
  # Species data
  load("data/r_objects/range_maps.RData")
  
  pa_spec = model_data %>% 
    dplyr::filter(species == !!spec)
  range_spec = range_maps %>% 
    dplyr::filter(name_matched == !!spec) %>% 
    sf::st_transform(sf::st_crs(pa_spec))
  
  # Extract raster values into df
  bbox_spec = sf::st_bbox(range_spec) %>% 
    expand_bbox(expansion = 0.25)
  
  raster_crop = terra::crop(raster_data, bbox_spec)
  pa_crop = pa_spec[st_intersects(pa_spec, st_as_sfc(bbox_spec), sparse = FALSE),]
  
  new_data = raster_crop %>% 
    terra::values(dataframe = T, na.rm = T) %>% 
    dplyr::mutate(species = unique(pa_spec$species)) 
  
  plots = list()
  # Make predictions
  for(algorithm in algorithms){
    message(algorithm)
    # Load model
    if(algorithm == "msdm_onehot"){
      load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
      probabilities = predict(msdm_onehot_fit, new_data, type = "response")[,1]
      predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
    } else if(algorithm == "msdm_embed"){
      load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
      probabilities = predict(msdm_embed_fit, new_data, type = "response")[,1]
      predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
    } else if(algorithm == "msdm_rf"){
      load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
      predictions = predict(rf_fit, new_data, type = "raw", num.threads = 48)
    } else if(algorithm == "msdm_rf_random_abs"){
      load("data/r_objects/msdm_rf/msdm_rf_fit_random_abs_full.RData")
      predictions = predict(rf_fit, new_data, type = "raw", num.threads = 48)
    } else if(algorithm == "msdm_rf_fit_no_species"){
      load("data/r_objects/msdm_rf/msdm_rf_fit_no_species_full.RData")
      predictions = predict(rf_fit, new_data, type = "raw", num.threads = 48)
    } else if(algorithm == "msdm_rf_fit_no_species_random_abs"){
      load("data/r_objects/msdm_rf/msdm_rf_fit_no_species_random_abs_full.RData")
      predictions = predict(rf_fit, new_data, type = "raw", num.threads = 48)
    } else if(algorithm == "nn") {
      load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_nn_fit.RData"))
      new_data_tmp = dplyr::select(new_data, -species)
      probabilities = predict(nn_fit, new_data_tmp, type = "response")[,1]
      predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
    } else {
      load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_", algorithm, "_fit.RData"))
      new_data_tmp = dplyr::select(new_data, -species)
      predictions = predict(get(paste0(algorithm, "_fit")), new_data_tmp, type = "raw")
    }
    
    raster_pred = terra::rast(raster_crop, nlyrs = 1)
    values(raster_pred)[as.integer(rownames(new_data))] <- as.integer(predictions) - 1
    
    # Plot
    p = ggplot() +
      tidyterra::geom_spatraster(data = as.factor(raster_pred), maxcell = 5e7) +
      scale_fill_manual(values = c("0" = "black", "1" = "yellow"), name = "Prediction", na.translate = FALSE) +
      geom_sf(data = pa_crop, aes(shape = as.factor(present)), color = "#FFFFFF99") +
      geom_sf(data = range_spec, col = "red", fill = NA) +
      scale_shape_manual(values = c("0" = 1, "1" = 4), name = "Observation") +
      theme_minimal() +
      coord_sf() +
      labs(title = paste0("Predictions vs Observations (", algorithm, "): ", spec))  +
      guides(shape = guide_legend(
        override.aes = list(color = "black")  # This makes shapes black in the legend only
      ))
    
    plots[[algorithm]] = p
  }
  
  return(plots)
}

# Load raster
raster_filepaths = list.files("~/symobio-modeling/data/geospatial/raster/", full.names = T) %>% 
  stringr::str_sort(numeric = T) 

raster_data = terra::rast(raster_filepaths) %>% 
  terra::crop(sf::st_bbox(sa_polygon)) %>% 
  terra::project(sf::st_crs(model_data)$input)

specs = sort(sample(levels(model_data$species), 4))
for(spec in specs){
  pdf(file = paste0("plots/range_predictions/", spec, " (msdm_rf).pdf"), width = 12)
  plots = plot_predictions(spec, model_data, raster_data, algorithms = c("msdm_rf", "msdm_rf_random_abs", "msdm_rf_fit_no_species", "msdm_rf_fit_no_species_random_abs"))
  lapply(plots, print)
  dev.off()
}

# ------------------------------------------------------------------ #
# 4. Compare predictions across species                           ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
load("data/r_objects/msdm_rf/msdm_rf_fit_random_abs_full.RData")

compare_species_predictions = function(model, sample_size){
  specs = sample(unique(model_data$species), 2, replace = F)
  df1 = spatSample(raster_data, sample_size, replace = F, as.df = T) %>% 
    drop_na() %>% 
    dplyr::mutate(species = specs[1])
  df2= df1 %>% 
    dplyr::mutate(species = specs[2])
  
  p1 = predict_new(model, df1)
  p2 = predict_new(model, df2)
  
  plot(x = p1, y = p2, 
       xlab = df1$species[1], 
       ylab = df2$species[1], xlim=c(0,1), ylim = c(0,1), pch = 16, 
       main = deparse(substitute(model)))
  abline(a = 0, b = 1)
  text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
  mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8)
  abline(h=0.5)
  abline(v=0.5)
}

compare_species_predictions(msdm_embed_fit, 500) 
compare_species_predictions(msdm_onehot_fit, 500) 
compare_species_predictions(rf_fit, 500) 

## --> Predictions for different species are weakly/moderately correlated in NN models (makes sense)
## --> Predictions for different species are always highly correlated in RF model (seems problematic)

# ------------------------------------------------------------------ #
# 5. Compare predictions across models                            ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")

compare_model_predictions = function(model1, model2, sample_size){
  spec = sample(unique(model_data$species), 1)
  new_data = spatSample(raster_data, sample_size, replace = F, as.df = T) %>% 
    drop_na() %>% 
    dplyr::mutate(species = spec)
  
  p1 = predict_new(model1, new_data)
  p2 = predict_new(model2, new_data)
  
  # Compare onehot vs embed
  plot(x = p1, y = p2, 
       xlab = deparse(substitute(model1)), 
       ylab = deparse(substitute(model2)), 
       xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec)
  abline(a = 0, b = 1)
  text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
  mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8)
  abline(h=0.5)
  abline(v=0.5)
}

compare_model_predictions(msdm_embed_fit, msdm_onehot_fit, 500) 
compare_model_predictions(msdm_embed_fit, rf_fit, 500)
compare_model_predictions(msdm_onehot_fit, rf_fit, 500) 

## --> Predictions for the same species from NN_embed and NN_onehot are moderately/strongly correlated (makes sense)
## --> Predictions for the same species from NN and RF are weakly/not correlated (seems problematic)

# ------------------------------------------------------------------ #
# 6. Compare species embeddings across folds                      ####
# ------------------------------------------------------------------ #
obs_count = model_data %>% 
  sf::st_drop_geometry() %>% 
  dplyr::filter(present == 1, !is.na(fold_eval)) %>% 
  dplyr::group_by(species) %>% 
  dplyr::summarise(obs = n())

all_embeddings = lapply(1:5, function(fold){
  load(paste0("data/r_objects/msdm_embed_results/msdm_embed_fit_fold", fold, ".RData"))
  coef(msdm_embed_fit)[[1]][[1]]
})

pairs = combn(1:5, 2, simplify = F)

## Correlations
### Embedding matrices across folds ####
pairwise_rv = lapply(pairs, function(pair){
  FactoMineR::coeffRV(all_embeddings[[pair[1]]], all_embeddings[[pair[2]]])  
})

rv_matrix = matrix(NA, 5, 5)
for(i in seq_along(pairs)){
  rv_matrix[pairs[[i]][2], pairs[[i]][1]] = pairwise_rv[[i]][["rv"]]
}

p_matrix = matrix(NA, 5, 5)
for(i in seq_along(pairs)){
  p_matrix[pairs[[i]][2], pairs[[i]][1]] = pairwise_rv[[i]][["p.value"]]
}

# --> Embeddings from different folds are highly correlated
# --> Models seem to learn consistent underlying structure in the data

## Embedding vectors across folds ####
r_embed = sapply(unique(model_data$species), function(spec){
  sapply(pairs, function(pair){
    cor(
      all_embeddings[[pair[1]]][as.integer(spec),], 
      all_embeddings[[pair[2]]][as.integer(spec),]
    )
  })
}, simplify = F)

r_embed_df = data.frame(
  species = unique(model_data$species),
  r_mean = sapply(r_embed, mean),
  r_var = sapply(r_embed, var)
) %>% 
  tidyr::pivot_longer(cols = c(r_mean, r_var), names_to = "stat", values_to = "value")

df_plot = obs_count %>% 
  dplyr::left_join(r_embed_df, by = "species")

ggplot(data = df_plot, aes(x = obs, y = value)) +
  geom_point() +
  geom_smooth() +
  labs(title = "Mean and variance of correlation coefficient of species embeddings across folds") +
  scale_x_continuous(transform = "log10") + 
  facet_wrap("stat") +
  theme_minimal()

# --> Individual species embeddings are not correlated across folds
# --> Embeddings rotated?

## Pairwise distances of embedding vectors across folds ####
all_embeddings_dist = lapply(all_embeddings, function(e){
  as.matrix(dist(e, method = "euclidean"))
})

r_dist = sapply(unique(model_data$species), function(spec){
  sapply(pairs, function(pair){
    cor(
      all_embeddings_dist[[pair[1]]][as.integer(spec),], 
      all_embeddings_dist[[pair[2]]][as.integer(spec),]
    )
  })
}, simplify = F)

r_dist_df = data.frame(
  species = unique(model_data$species),
  r_mean = sapply(r_dist, mean),
  r_var = sapply(r_dist, var)
) %>% 
  tidyr::pivot_longer(cols = c(r_mean, r_var), names_to = "stat", values_to = "value")

df_plot = obs_count %>% 
  dplyr::left_join(r_dist_df, by = "species")

ggplot(data = df_plot, aes(x = obs, y = value)) +
  geom_point() +
  geom_smooth() +
  labs(title = "Mean and variance of species' pairwise distances in embedding space across folds") +
  scale_x_continuous(transform = "log10") + 
  facet_wrap("stat") +
  theme_minimal()

## --> Pairwise distances are highly coirrelated across folds
## --> Distance vectors for abundant species tend to be more similar across folds
## --> Some potential for informed embeddings with respect to rare species?

# ------------------------------------------------------------------ #
# 7. Analyse embedding of full model                              ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
species_lookup = data.frame(species = levels(model_data$species)) %>% 
  dplyr::mutate(
    genus = stringr::str_extract(species, "([a-zA-Z]+) (.*)", group = 1)
  )

embeddings = coef(msdm_embed_fit)[[1]][[1]]
rownames(embeddings) = levels(species_lookup$species)

##  Dimensionality reduction        ####
### PCA                             ####
pca_result = prcomp(t(embeddings), scale. = TRUE)
var_explained = pca_result$sdev^2 / sum(pca_result$sdev^2)
plot(var_explained) # --> First two dimensions explain ~40% of variance

coords = pca_result$rotation[,1:2] %>%
  magrittr::set_colnames(c("X", "Y"))

df_plot = species_lookup %>% 
  left_join(functional_groups, by = c("species" = "name_matched")) %>% 
  bind_cols(coords)

ggplot(df_plot, aes(x = X, y = Y, col = functional_group, label=genus)) +
  geom_point() +
  geom_text(hjust=0, vjust=0) +
  theme_minimal()

### T-SNE                            ####
tsne_result <- Rtsne(embeddings, verbose = TRUE)
coords = tsne_result$Y %>%
  magrittr::set_colnames(c("X", "Y"))

df_plot = species_lookup %>% 
  left_join(functional_groups, by = c("species" = "name_matched")) %>% 
  bind_cols(coords)

ggplot(df_plot, aes(x = X, y = Y, col = functional_group, label=genus)) +
  geom_point() +
  geom_text(hjust=0, vjust=0) +
  labs(
    title = "T-SNE Ordination of species embeddings",
    color = "Functional group",
  ) +
  theme_minimal(base_size = 14) +
  theme(
    strip.text = element_text(face = "bold"),
    axis.text.x = element_text(angle = 45, hjust = 1),
    legend.position = "right"
  )

ggsave("plots/publication/tsne_ordination.pdf", 
       device = "pdf", 
       scale = 2,
       width = 20, 
       height = 18,
       units = "cm")

### KNN                             ####
rownames(embeddings) = species_lookup$species
embeddings_dist = as.matrix(dist(embeddings))
k = 9
knn_results = sapply(species_lookup$species, FUN = function(spec){
  return(sort(embeddings_dist[spec,])[2:(k+1)])
}, USE.NAMES = T, simplify = F)

plot_knn_ranges = function(spec){
  spec_df = model_data %>% 
    dplyr::filter(species == !!spec)
  
  p_spec_range = ggplot() +
    geom_sf(data = st_as_sf(sa_polygon)) +
    geom_sf(data = spec_df) + 
    ggtitle(paste0(spec)) +
    theme_minimal()
  
  knn_df = model_data %>% 
    dplyr::filter(species %in% c(names(knn_results[[spec]])))
  
  p_knn_ranges = ggplot() +
    geom_sf(data = st_as_sf(sa_polygon)) +
    geom_sf(data = knn_df, aes(color = species)) + 
    facet_wrap(facets = "species", ncol = 3) +
    theme_minimal() +
    guides(color="none")
  
  p_spec_range + p_knn_ranges
}

plot_knn_ranges(sample(species_lookup$species, 1)) # Repeat this line to plot random species

## Clustering             ####
rownames(embeddings) = species_lookup$species
embeddings_dist = dist(embeddings, method = "euclidean")
hclust_complete = hclust(embeddings_dist, method = "complete")
plot(hclust_complete, main = "Complete Linkage", xlab = "", sub = "", cex = 0.5)

## Correlations           ####
### Phylo dist            ####
load("data/r_objects/phylo_dist.RData")

species_intersect = intersect(colnames(phylo_dist), species_lookup$species)

phylo_dist_subset = phylo_dist[species_intersect, species_intersect]

e_indices = match(species_intersect, species_lookup$species)
embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
FactoMineR::coeffRV(embeddings_dist_subset, phylo_dist_subset)

### Functional dist       ####
load("data/r_objects/func_dist.RData")

species_intersect = intersect(colnames(func_dist), species_lookup$species)

func_dist_subset = func_dist[species_intersect, species_intersect]

e_indices = match(species_intersect, species_lookup$species)
embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
FactoMineR::coeffRV(embeddings_dist_subset, func_dist_subset)

### Range dist            ####
load("data/r_objects/range_dist.RData")

species_intersect = intersect(colnames(range_dist), species_lookup$species)

range_dist_subset = range_dist[species_intersect, species_intersect]

e_indices = match(species_intersect, species_lookup$species)
embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
FactoMineR::coeffRV(embeddings_dist_subset, range_dist_subset)