Code owners
Assign users and groups as approvers for specific file changes. Learn more.
05_02_publication_analysis.R 19.84 KiB
# General packages
library(tidyverse)
library(patchwork)
# Geo packages
library(terra)
library(sf)
library(geos)
# Stats packages
library(Rtsne)
library(cito)
source("R/utils.R")
load("data/r_objects/model_data_random_abs_extended.RData")
load("data/r_objects/ssdm_results.RData")
load("data/r_objects/msdm_embed_performance.RData")
load("data/r_objects/msdm_onehot_performance.RData")
load("data/r_objects/functional_groups.RData")
load("data/r_objects/sa_polygon.RData")
load("data/r_objects/range_maps.RData")
sf::sf_use_s2(use_s2 = FALSE)
model_data = model_data %>%
dplyr::filter(!is.na(fold_eval)) %>%
dplyr::mutate(species = as.factor(species))
# ------------------------------------------------------------------ #
# 1. Collect performance results ####
# ------------------------------------------------------------------ #
performance = ssdm_results %>%
bind_rows(msdm_embed_performance) %>%
bind_rows(msdm_onehot_performance) %>%
dplyr::group_by(species, model, metric) %>%
dplyr::summarise(value = mean(value, na.rm = F)) %>%
dplyr::mutate(
metric = factor(tolower(metric), levels = c("auc", "f1", "kappa", "accuracy", "precision", "recall")),
model = factor(model, levels = c("GAM", "GBM", "RF", "NN", "MSDM_onehot", "MSDM_embed")),
value = case_when(
((is.na(value) | is.nan(value)) & metric %in% c("auc", "f1", "accuracy", "precision", "recall")) ~ 0.5,
((is.na(value) | is.nan(value)) & metric %in% c("kappa")) ~ 0,
.default = value
)
)
# ------------------------------------------------------------------ #
# 2. Model comparison ####
# ------------------------------------------------------------------ #
## Overall model Comparison ####
df_plot = performance %>%
dplyr::filter(metric %in% c("auc", "f1", "kappa", "accuracy", "precision", "recall"))
ggplot(df_plot, aes(x = model, y = value, color = model)) +
geom_boxplot(alpha = 0.5, outlier.shape = 16) +
facet_wrap(~ metric, scales = "free_y") +
labs(
title = "Model Performance Across Metrics",
x = "Model",
y = "Value",
color = "Model",
) +
theme_minimal(base_size = 14) +
theme(
strip.text = element_text(face = "bold"),
axis.text.x = element_text(angle = 45, hjust = 1),
legend.position = "right"
)
ggsave("plots/publication/model_performance_across_metrics.pdf",
device = "pdf",
scale = 2,
width = 18,
height = 14,
units = "cm")
## Performance vs number of records ####
obs_count = model_data %>%
sf::st_drop_geometry() %>%
dplyr::filter(present == 1, !is.na(fold_eval)) %>%
dplyr::group_by(species) %>%
dplyr::summarise(obs = n())
df_plot = performance %>%
dplyr::filter(metric %in% c("auc", "f1", "kappa", "accuracy", "precision", "recall")) %>%
dplyr::left_join(obs_count, by = "species")
ggplot(df_plot, aes(x = obs, y = value, color = model, fill = model)) +
geom_point(alpha = 0.1) +
geom_smooth() +
facet_wrap(~ metric, scales = "free_y") +
scale_x_continuous(trans = "log10") +
labs(
title = "Model Performance vs Number of Records",
x = "Number of Records (log scale)",
y = "Value",
color = "Model",
fill = "Model"
) +
theme_minimal(base_size = 14) +
theme(
strip.text = element_text(face = "bold"),
axis.text.x = element_text(angle = 45, hjust = 1),
legend.position = "right"
)
ggsave("plots/publication/model_performance_vs_number_of_records.pdf",
device = "pdf",
scale = 2,
width = 18,
height = 14,
units = "cm")
## Performance vs functional group ####
load("data/r_objects/functional_groups.RData")
df_plot = performance %>%
dplyr::filter(metric %in% c("auc", "f1", "kappa")) %>%
dplyr::left_join(functional_groups, by = c("species" = "name_matched"))
ggplot(df_plot, aes(x = model, y = value, color = model)) +
geom_boxplot(alpha = 0.5, outlier.shape = 16) +
facet_grid(rows = vars(metric), cols = vars(functional_group), scales = "free_y", switch = "both") +
labs(
title = "Model Performance Across Functional Groups",
x = "Model",
y = "Value",
color = "Model",
) +
theme_minimal(base_size = 14) +
theme(
strip.text = element_text(face = "bold"),
panel.border = element_rect(color = "gray", fill = NA),
axis.text.x = element_text(angle = 45, hjust = 1),
legend.position = "right"
)
ggsave("plots/publication/model_performance_vs_functional_groups.pdf",
device = "pdf",
scale = 2,
width = 18,
height = 14,
units = "cm")
# ------------------------------------------------------------------ #
# 3. Range predictions ####
# ------------------------------------------------------------------ #
library(caret)
library(gam)
library(gbm)
library(cito)
library(ranger)
# Define plotting function
plot_predictions = function(spec, model_data, raster_data, algorithms){
# Species data
load("data/r_objects/range_maps.RData")
pa_spec = model_data %>%
dplyr::filter(species == !!spec)
range_spec = range_maps %>%
dplyr::filter(name_matched == !!spec) %>%
sf::st_transform(sf::st_crs(pa_spec))
# Extract raster values into df
bbox_spec = sf::st_bbox(range_spec) %>%
expand_bbox(expansion = 0.25)
raster_crop = terra::crop(raster_data, bbox_spec)
pa_crop = pa_spec[st_intersects(pa_spec, st_as_sfc(bbox_spec), sparse = FALSE),]
new_data = raster_crop %>%
terra::values(dataframe = T, na.rm = T) %>%
dplyr::mutate(species = unique(pa_spec$species))
plots = list()
# Make predictions
for(algorithm in algorithms){
message(algorithm)
# Load model
if(algorithm == "msdm_onehot"){
load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
probabilities = predict(msdm_onehot_fit, new_data, type = "response")[,1]
predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
} else if(algorithm == "msdm_embed"){
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
probabilities = predict(msdm_embed_fit, new_data, type = "response")[,1]
predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
} else if(algorithm == "msdm_rf"){
load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
predictions = predict(rf_fit, new_data, type = "raw", num.threads = 48)
} else if(algorithm == "msdm_rf_random_abs"){
load("data/r_objects/msdm_rf/msdm_rf_fit_random_abs_full.RData")
predictions = predict(rf_fit, new_data, type = "raw", num.threads = 48)
} else if(algorithm == "msdm_rf_fit_no_species"){
load("data/r_objects/msdm_rf/msdm_rf_fit_no_species_full.RData")
predictions = predict(rf_fit, new_data, type = "raw", num.threads = 48)
} else if(algorithm == "msdm_rf_fit_no_species_random_abs"){
load("data/r_objects/msdm_rf/msdm_rf_fit_no_species_random_abs_full.RData")
predictions = predict(rf_fit, new_data, type = "raw", num.threads = 48)
} else if(algorithm == "nn") {
load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_nn_fit.RData"))
new_data_tmp = dplyr::select(new_data, -species)
probabilities = predict(nn_fit, new_data_tmp, type = "response")[,1]
predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
} else {
load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_", algorithm, "_fit.RData"))
new_data_tmp = dplyr::select(new_data, -species)
predictions = predict(get(paste0(algorithm, "_fit")), new_data_tmp, type = "raw")
}
raster_pred = terra::rast(raster_crop, nlyrs = 1)
values(raster_pred)[as.integer(rownames(new_data))] <- as.integer(predictions) - 1
# Plot
p = ggplot() +
tidyterra::geom_spatraster(data = as.factor(raster_pred), maxcell = 5e7) +
scale_fill_manual(values = c("0" = "black", "1" = "yellow"), name = "Prediction", na.translate = FALSE) +
geom_sf(data = pa_crop, aes(shape = as.factor(present)), color = "#FFFFFF99") +
geom_sf(data = range_spec, col = "red", fill = NA) +
scale_shape_manual(values = c("0" = 1, "1" = 4), name = "Observation") +
theme_minimal() +
coord_sf() +
labs(title = paste0("Predictions vs Observations (", algorithm, "): ", spec)) +
guides(shape = guide_legend(
override.aes = list(color = "black") # This makes shapes black in the legend only
))
plots[[algorithm]] = p
}
return(plots)
}
# Load raster
raster_filepaths = list.files("~/symobio-modeling/data/geospatial/raster/", full.names = T) %>%
stringr::str_sort(numeric = T)
raster_data = terra::rast(raster_filepaths) %>%
terra::crop(sf::st_bbox(sa_polygon)) %>%
terra::project(sf::st_crs(model_data)$input)
specs = sort(sample(levels(model_data$species), 4))
for(spec in specs){
pdf(file = paste0("plots/range_predictions/", spec, " (msdm_rf).pdf"), width = 12)
plots = plot_predictions(spec, model_data, raster_data, algorithms = c("msdm_rf", "msdm_rf_random_abs", "msdm_rf_fit_no_species", "msdm_rf_fit_no_species_random_abs"))
lapply(plots, print)
dev.off()
}
# ------------------------------------------------------------------ #
# 4. Compare predictions across species ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
load("data/r_objects/msdm_rf/msdm_rf_fit_random_abs_full.RData")
compare_species_predictions = function(model, sample_size){
specs = sample(unique(model_data$species), 2, replace = F)
df1 = spatSample(raster_data, sample_size, replace = F, as.df = T) %>%
drop_na() %>%
dplyr::mutate(species = specs[1])
df2= df1 %>%
dplyr::mutate(species = specs[2])
p1 = predict_new(model, df1)
p2 = predict_new(model, df2)
plot(x = p1, y = p2,
xlab = df1$species[1],
ylab = df2$species[1], xlim=c(0,1), ylim = c(0,1), pch = 16,
main = deparse(substitute(model)))
abline(a = 0, b = 1)
text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8)
abline(h=0.5)
abline(v=0.5)
}
compare_species_predictions(msdm_embed_fit, 500)
compare_species_predictions(msdm_onehot_fit, 500)
compare_species_predictions(rf_fit, 500)
## --> Predictions for different species are weakly/moderately correlated in NN models (makes sense)
## --> Predictions for different species are always highly correlated in RF model (seems problematic)
# ------------------------------------------------------------------ #
# 5. Compare predictions across models ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
compare_model_predictions = function(model1, model2, sample_size){
spec = sample(unique(model_data$species), 1)
new_data = spatSample(raster_data, sample_size, replace = F, as.df = T) %>%
drop_na() %>%
dplyr::mutate(species = spec)
p1 = predict_new(model1, new_data)
p2 = predict_new(model2, new_data)
# Compare onehot vs embed
plot(x = p1, y = p2,
xlab = deparse(substitute(model1)),
ylab = deparse(substitute(model2)),
xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec)
abline(a = 0, b = 1)
text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8)
abline(h=0.5)
abline(v=0.5)
}
compare_model_predictions(msdm_embed_fit, msdm_onehot_fit, 500)
compare_model_predictions(msdm_embed_fit, rf_fit, 500)
compare_model_predictions(msdm_onehot_fit, rf_fit, 500)
## --> Predictions for the same species from NN_embed and NN_onehot are moderately/strongly correlated (makes sense)
## --> Predictions for the same species from NN and RF are weakly/not correlated (seems problematic)
# ------------------------------------------------------------------ #
# 6. Compare species embeddings across folds ####
# ------------------------------------------------------------------ #
obs_count = model_data %>%
sf::st_drop_geometry() %>%
dplyr::filter(present == 1, !is.na(fold_eval)) %>%
dplyr::group_by(species) %>%
dplyr::summarise(obs = n())
all_embeddings = lapply(1:5, function(fold){
load(paste0("data/r_objects/msdm_embed_results/msdm_embed_fit_fold", fold, ".RData"))
coef(msdm_embed_fit)[[1]][[1]]
})
pairs = combn(1:5, 2, simplify = F)
## Correlations
### Embedding matrices across folds ####
pairwise_rv = lapply(pairs, function(pair){
FactoMineR::coeffRV(all_embeddings[[pair[1]]], all_embeddings[[pair[2]]])
})
rv_matrix = matrix(NA, 5, 5)
for(i in seq_along(pairs)){
rv_matrix[pairs[[i]][2], pairs[[i]][1]] = pairwise_rv[[i]][["rv"]]
}
p_matrix = matrix(NA, 5, 5)
for(i in seq_along(pairs)){
p_matrix[pairs[[i]][2], pairs[[i]][1]] = pairwise_rv[[i]][["p.value"]]
}
# --> Embeddings from different folds are highly correlated
# --> Models seem to learn consistent underlying structure in the data
## Embedding vectors across folds ####
r_embed = sapply(unique(model_data$species), function(spec){
sapply(pairs, function(pair){
cor(
all_embeddings[[pair[1]]][as.integer(spec),],
all_embeddings[[pair[2]]][as.integer(spec),]
)
})
}, simplify = F)
r_embed_df = data.frame(
species = unique(model_data$species),
r_mean = sapply(r_embed, mean),
r_var = sapply(r_embed, var)
) %>%
tidyr::pivot_longer(cols = c(r_mean, r_var), names_to = "stat", values_to = "value")
df_plot = obs_count %>%
dplyr::left_join(r_embed_df, by = "species")
ggplot(data = df_plot, aes(x = obs, y = value)) +
geom_point() +
geom_smooth() +
labs(title = "Mean and variance of correlation coefficient of species embeddings across folds") +
scale_x_continuous(transform = "log10") +
facet_wrap("stat") +
theme_minimal()
# --> Individual species embeddings are not correlated across folds
# --> Embeddings rotated?
## Pairwise distances of embedding vectors across folds ####
all_embeddings_dist = lapply(all_embeddings, function(e){
as.matrix(dist(e, method = "euclidean"))
})
r_dist = sapply(unique(model_data$species), function(spec){
sapply(pairs, function(pair){
cor(
all_embeddings_dist[[pair[1]]][as.integer(spec),],
all_embeddings_dist[[pair[2]]][as.integer(spec),]
)
})
}, simplify = F)
r_dist_df = data.frame(
species = unique(model_data$species),
r_mean = sapply(r_dist, mean),
r_var = sapply(r_dist, var)
) %>%
tidyr::pivot_longer(cols = c(r_mean, r_var), names_to = "stat", values_to = "value")
df_plot = obs_count %>%
dplyr::left_join(r_dist_df, by = "species")
ggplot(data = df_plot, aes(x = obs, y = value)) +
geom_point() +
geom_smooth() +
labs(title = "Mean and variance of species' pairwise distances in embedding space across folds") +
scale_x_continuous(transform = "log10") +
facet_wrap("stat") +
theme_minimal()
## --> Pairwise distances are highly coirrelated across folds
## --> Distance vectors for abundant species tend to be more similar across folds
## --> Some potential for informed embeddings with respect to rare species?
# ------------------------------------------------------------------ #
# 7. Analyse embedding of full model ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
species_lookup = data.frame(species = levels(model_data$species)) %>%
dplyr::mutate(
genus = stringr::str_extract(species, "([a-zA-Z]+) (.*)", group = 1)
)
embeddings = coef(msdm_embed_fit)[[1]][[1]]
rownames(embeddings) = levels(species_lookup$species)
## Dimensionality reduction ####
### PCA ####
pca_result = prcomp(t(embeddings), scale. = TRUE)
var_explained = pca_result$sdev^2 / sum(pca_result$sdev^2)
plot(var_explained) # --> First two dimensions explain ~40% of variance
coords = pca_result$rotation[,1:2] %>%
magrittr::set_colnames(c("X", "Y"))
df_plot = species_lookup %>%
left_join(functional_groups, by = c("species" = "name_matched")) %>%
bind_cols(coords)
ggplot(df_plot, aes(x = X, y = Y, col = functional_group, label=genus)) +
geom_point() +
geom_text(hjust=0, vjust=0) +
theme_minimal()
### T-SNE ####
tsne_result <- Rtsne(embeddings, verbose = TRUE)
coords = tsne_result$Y %>%
magrittr::set_colnames(c("X", "Y"))
df_plot = species_lookup %>%
left_join(functional_groups, by = c("species" = "name_matched")) %>%
bind_cols(coords)
ggplot(df_plot, aes(x = X, y = Y, col = functional_group, label=genus)) +
geom_point() +
geom_text(hjust=0, vjust=0) +
labs(
title = "T-SNE Ordination of species embeddings",
color = "Functional group",
) +
theme_minimal(base_size = 14) +
theme(
strip.text = element_text(face = "bold"),
axis.text.x = element_text(angle = 45, hjust = 1),
legend.position = "right"
)
ggsave("plots/publication/tsne_ordination.pdf",
device = "pdf",
scale = 2,
width = 20,
height = 18,
units = "cm")
### KNN ####
rownames(embeddings) = species_lookup$species
embeddings_dist = as.matrix(dist(embeddings))
k = 9
knn_results = sapply(species_lookup$species, FUN = function(spec){
return(sort(embeddings_dist[spec,])[2:(k+1)])
}, USE.NAMES = T, simplify = F)
plot_knn_ranges = function(spec){
spec_df = model_data %>%
dplyr::filter(species == !!spec)
p_spec_range = ggplot() +
geom_sf(data = st_as_sf(sa_polygon)) +
geom_sf(data = spec_df) +
ggtitle(paste0(spec)) +
theme_minimal()
knn_df = model_data %>%
dplyr::filter(species %in% c(names(knn_results[[spec]])))
p_knn_ranges = ggplot() +
geom_sf(data = st_as_sf(sa_polygon)) +
geom_sf(data = knn_df, aes(color = species)) +
facet_wrap(facets = "species", ncol = 3) +
theme_minimal() +
guides(color="none")
p_spec_range + p_knn_ranges
}
plot_knn_ranges(sample(species_lookup$species, 1)) # Repeat this line to plot random species
## Clustering ####
rownames(embeddings) = species_lookup$species
embeddings_dist = dist(embeddings, method = "euclidean")
hclust_complete = hclust(embeddings_dist, method = "complete")
plot(hclust_complete, main = "Complete Linkage", xlab = "", sub = "", cex = 0.5)
## Correlations ####
### Phylo dist ####
load("data/r_objects/phylo_dist.RData")
species_intersect = intersect(colnames(phylo_dist), species_lookup$species)
phylo_dist_subset = phylo_dist[species_intersect, species_intersect]
e_indices = match(species_intersect, species_lookup$species)
embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
FactoMineR::coeffRV(embeddings_dist_subset, phylo_dist_subset)
### Functional dist ####
load("data/r_objects/func_dist.RData")
species_intersect = intersect(colnames(func_dist), species_lookup$species)
func_dist_subset = func_dist[species_intersect, species_intersect]
e_indices = match(species_intersect, species_lookup$species)
embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
FactoMineR::coeffRV(embeddings_dist_subset, func_dist_subset)
### Range dist ####
load("data/r_objects/range_dist.RData")
species_intersect = intersect(colnames(range_dist), species_lookup$species)
range_dist_subset = range_dist[species_intersect, species_intersect]
e_indices = match(species_intersect, species_lookup$species)
embeddings_dist_subset = as.matrix(embeddings_dist)[e_indices, e_indices]
FactoMineR::coeffRV(embeddings_dist_subset, range_dist_subset)