Skip to content
Snippets Groups Projects
Commit 643e472f authored by ye87zine's avatar ye87zine
Browse files

additional model evaluation and analysis

parent e0cabef3
Branches
No related tags found
No related merge requests found
......@@ -87,16 +87,16 @@ for(pa_spec in data_split){
})
## Gradient Boosted Machine ####
xgb_performance = tryCatch({
xgb_fit = train(
gbm_performance = tryCatch({
gbm_fit = train(
x = data_train[, predictors],
y = data_train$present_fct,
method = "xgbTree",
method = "gbm",
trControl = train_ctrl,
tuneLength = 8,
verbose = F
)
evaluate_model(xgb_fit, data_test)
evaluate_model(gbm_fit, data_test)
}, error = function(e){
na_performance
})
......
......@@ -89,7 +89,7 @@ save(rf_fit, file = "data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
# Evaluate model ####
# ----------------------------------------------------------------------#
msdm_rf_performance = lapply(1:5, function(fold){
load(paste0("data/r_objects/msdm_rf/rf_fit_fold", fold, ".RData"))
load(paste0("data/r_objects/msdm_rf/msdm_rf_fit_fold", fold, ".RData"))
test_data =dplyr::filter(model_data, fold_eval == fold) %>%
sf::st_drop_geometry()
......
......@@ -154,7 +154,7 @@ library(cito)
library(ranger)
# Define plotting function
plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "nn", "msdm_embed", "msdm_onehot")){
plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam", "gbm", "rf", "nn", "msdm_embed", "msdm_onehot", "msdm_rf")){
# Species data
load("data/r_objects/range_maps.RData")
......@@ -166,7 +166,7 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam",
# Extract raster values into df
bbox_spec = sf::st_bbox(pa_spec) %>%
expand_bbox(expansion = 0.75)
expand_bbox(expansion = 0.5)
raster_crop = terra::crop(raster_data, bbox_spec)
new_data = raster_crop %>%
......@@ -185,6 +185,9 @@ plot_predictions = function(spec, model_data, raster_data, algorithms = c("gam",
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
probabilities = predict(msdm_embed_fit, new_data, type = "response")[,1]
predictions = factor(round(probabilities), levels = c("0", "1"), labels = c("A", "P"))
} else if(algorithm == "msdm_rf"){
load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
predictions = predict(rf_fit, new_data, type = "raw")
} else if(algorithm == "nn") {
load(paste0("data/r_objects/ssdm_results/full_models/", spec, "_nn_fit.RData"))
new_data_tmp = dplyr::select(new_data, -species)
......@@ -218,46 +221,89 @@ raster_data = terra::rast(raster_filepaths) %>%
terra::crop(sf::st_bbox(sa_polygon)) %>%
terra::project(sf::st_crs(model_data)$input)
specs = sort(sample(levels(model_data$species), 20))
specs = sort(sample(levels(model_data$species), 30))
for(spec in specs){
pdf(file = paste0("plots/range_predictions/", spec, ".pdf"))
pdf(file = paste0("plots/range_predictions/", spec, " (msdm).pdf"))
tryCatch({
plot_predictions(spec, model_data, raster_data, algorithms = c("gam", "xgb", "rf", "msdm_embed", "msdm_onehot"))
plot_predictions(spec, model_data, raster_data, algorithms = c("msdm_embed", "msdm_onehot", "msdm_rf"))
}, finally = {
dev.off()
})
}
# ------------------------------------------------------------------ #
# 4. Compare msdm predictions (embed vs. onehot) ####
# 4. Compare predictions across species ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
plot_embed_vs_onehot = function(){
compare_species_predictions = function(model, sample_size){
specs = sample(unique(model_data$species), 2, replace = F)
df1 = spatSample(raster_data, sample_size, replace = F, as.df = T) %>%
drop_na() %>%
dplyr::mutate(species = specs[1])
df2= df1 %>%
dplyr::mutate(species = specs[2])
p1 = predict_new(model, df1)
p2 = predict_new(model, df2)
plot(x = p1, y = p2,
xlab = df1$species[1],
ylab = df2$species[1], xlim=c(0,1), ylim = c(0,1), pch = 16,
main = deparse(substitute(model)))
abline(a = 0, b = 1)
text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8)
abline(h=0.5)
abline(v=0.5)
}
compare_species_predictions(msdm_embed_fit, 500)
compare_species_predictions(msdm_onehot_fit, 500)
compare_species_predictions(rf_fit, 500)
## --> Predictions for different species are weakly/moderately correlated in NN models (makes sense)
## --> Predictions for different species are always highly correlated in RF model (seems problematioc)
# ------------------------------------------------------------------ #
# 5. Compare predictions across models ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
load("data/r_objects/msdm_onehot_results/msdm_onehot_fit_full.RData")
load("data/r_objects/msdm_rf/msdm_rf_fit_full.RData")
compare_model_predictions = function(model1, model2, sample_size){
spec = sample(unique(model_data$species), 1)
new_data = spatSample(raster_data, 100, replace = F, as.df = T) %>%
new_data = spatSample(raster_data, sample_size, replace = F, as.df = T) %>%
drop_na() %>%
dplyr::mutate(species = spec)
p1_em = predict(msdm_embed_fit, new_data, type = "response")
p1_oh = predict(msdm_onehot_fit, new_data, type = "response")
p1 = predict_new(model1, new_data)
p2 = predict_new(model2, new_data)
# Compare onehot vs embed
plot(x = p1_em, y = p1_oh, xlab = "embed", ylab = "onehot", xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec)
plot(x = p1, y = p2,
xlab = deparse(substitute(model1)),
ylab = deparse(substitute(model2)),
xlim=c(0,1), ylim = c(0,1), pch = 16, main = spec)
abline(a = 0, b = 1)
text(x = c(0.1,0.9,0.1,0.9), y = c(0.1,0.1,0.9,0.9), c("AA", "PA", "AP", "PP"), col = "red")
mtext(paste0("(R² = ", round(cor(p1,p2) ^ 2, 3), ")"), side = 3, cex = 0.8)
abline(h=0.5)
abline(v=0.5)
}
plot_embed_vs_onehot()
compare_model_predictions(msdm_embed_fit, msdm_onehot_fit, 500)
compare_model_predictions(msdm_embed_fit, rf_fit, 500)
compare_model_predictions(msdm_onehot_fit, rf_fit, 500)
## --> Predictions are similar between msdms
## --> Onehot model seems to predict slightly more presences than embedded model (???)
## --> Predictions for the same species from NN_embed and NN_onehot are moderately/strongly correlated (makes sense)
## --> Predictions for the same species from NN and RF are weakly/not correlated (seems problematic)
# ------------------------------------------------------------------ #
# 5. Compare species embeddings across folds ####
# 6. Compare species embeddings across folds ####
# ------------------------------------------------------------------ #
obs_count = model_data %>%
sf::st_drop_geometry() %>%
......@@ -359,7 +405,7 @@ ggplot(data = df_plot, aes(x = obs, y = value)) +
## --> Some potential for informed embeddings with respect to rare species?
# ------------------------------------------------------------------ #
# 6. Analyse embedding of full model ####
# 7. Analyse embedding of full model ####
# ------------------------------------------------------------------ #
load("data/r_objects/msdm_embed_results/msdm_embed_fit_full.RData")
species_lookup = data.frame(species = levels(model_data$species)) %>%
......
......@@ -25,6 +25,35 @@ expand_bbox <- function(bbox, min_span = 1, expansion = 0.25) {
return(bbox)
}
predict_new = function(model, data, type = "prob"){
stopifnot(type %in% c("prob", "class"))
if(class(model) %in% c("citodnn", "citodnnBootstrap")){
probs = predict(model, data, type = "response")[,1]
if(type == "prob"){
return(probs)
} else {
preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
return(preds)
}
} else {
probs = predict(model, data, type = "prob")$P
if(type == "prob"){
return(probs)
} else {
preds = predict(model, data, type = "raw")
}
}
}
if(class(model) %in% c("citodnn", "citodnnBootstrap")){
p1 = predict(model, df1, type = "response")[,1]
p2 = predict(model, df2, type = "response")[,1]
} else {
p1 = predict(model, df1, type = "prob")$P
p2 = predict(model, df2, type = "prob")$P
}
evaluate_model <- function(model, data) {
# Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
# Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
......@@ -40,7 +69,6 @@ evaluate_model <- function(model, data) {
# Predict probabilities
if(class(model) %in% c("citodnn", "citodnnBootstrap")){
data = dplyr::select(data, any_of(all.vars(model$old_formula)))
probs = predict(model, data, type = "response")[,1]
preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
} else {
......@@ -72,56 +100,3 @@ evaluate_model <- function(model, data) {
)
)
}
evaluate_multiclass_model <- function(model, test_data, k) {
# Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
# Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
# Precision: The proportion of true positives out of all instances predicted as positive.
# Formula: Precision = TP / (TP + FP)
# Recall (Sensitivity): The proportion of true positives out of all actual positive instances.
# Formula: Recall = TP / (TP + FN)
# F1 Score: The harmonic mean of Precision and Recall, balancing the two metrics.
# Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
target_species = unique(test_data$species)
checkmate::assert_character(target_species, len = 1, any.missing = F)
new_data = dplyr::select(test_data, -species)
# Predict probabilities
if(class(model) %in% c("citodnn", "citodnnBootstrap")){
preds_overall = predict(model, as.matrix(new_data), type = "response")
probs <- as.vector(preds_overall[,target_species])
rank = apply(preds_overall, 1, function(x){ # Top-K approach
x_sort = sort(x, decreasing = T)
return(which(names(x_sort) == target_species))
})
top_k = as.character(as.numeric(rank <= k))
preds <- factor(top_k, levels = c("0", "1"), labels = c("A", "P"))
} else {
stop("Unsupported model type: ", class(model))
}
actual <- factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
# Calculate AUC
auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
# Calculate confusion matrix
cm <- caret::confusionMatrix(preds, actual, positive = "P")
# Return metrics
return(
list(
AUC = as.numeric(auc),
Accuracy = cm$overall["Accuracy"],
Kappa = cm$overall["Kappa"],
Precision = cm$byClass["Precision"],
Recall = cm$byClass["Recall"],
F1 = cm$byClass["F1"]
)
)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment