diff --git a/R/02_02_functional_traits_preparation.R b/R/02_02_functional_traits_preparation.R
index 80f5887a25403fc8363e21eb443429ce1f504210..9d66867b7616d8ca25895d2442a8ef3ff1676c3b 100644
--- a/R/02_02_functional_traits_preparation.R
+++ b/R/02_02_functional_traits_preparation.R
@@ -36,7 +36,41 @@ library("vegan")
 load("data/r_objects/range_maps.RData")
 load("data/r_objects/traits_matched.RData")
 target_species = unique(range_maps$name_matched[!is.na(range_maps$name_matched)])
-traits_target = dplyr::filter(traits_matched, species %in% target_species)
+
+traits_target = traits_matched %>% 
+  dplyr::filter(species %in% !!target_species) %>% 
+  mutate(match_genus = stringr::str_detect(species, Genus),
+         match_epithet = stringr::str_detect(species, Species)) %>% 
+  dplyr::group_by(species) %>% 
+  dplyr::group_modify(~ {
+    if (nrow(.x) == 1) { 
+      return (.x)
+    }
+    
+    x_mod = .x
+    print(x_mod)
+    while (nrow(x_mod) > 1){
+      if(any(isFALSE(x_mod$match_genus))){
+        x_mod = dplyr::filter(x_mod, isTRUE(match_genus))
+        next
+      }
+      
+      if(any(isFALSE(x_mod$match_species))){
+        x_mod = dplyr::filter(x_mod, isTRUE(match_species))
+        next
+      }
+      
+      x_mod = x_mod[1,]
+    }
+    
+    if(nrow(x_mod) == 1){
+      return(x_mod)
+    } else {
+      return(.x[1,])
+    }
+    
+  })
+
 
 # some pre-processing
 traits_target$`ForStrat.Value.Proc` = as.numeric(factor(traits_target$`ForStrat.Value`, levels = c("G", "S", "Ar", "A")))
@@ -54,6 +88,6 @@ activity_dist = vegan::vegdist(traits_target[,activity_cols], "bray")
 bodymass_dist = dist(traits_target[,bodymass_cols])
 
 func_dist = (diet_dist + foraging_dist/max(foraging_dist) + activity_dist + bodymass_dist/max(bodymass_dist)) / 4
-names(func_dist) = stringr::str_replace_all(traits_target$species, pattern = " ", "_")
+names(func_dist) = traits_target$species
 
 save(func_dist, file = "data/r_objects/func_dist.RData")
diff --git a/R/04_01_ssdm_modeling.R b/R/04_01_ssdm_modeling.R
index bdc7c9f0df3f6acd185a28f872cf5ce97d106cb9..959f77c38589b6ad8594c0a3d465aaa6f9907f7a 100644
--- a/R/04_01_ssdm_modeling.R
+++ b/R/04_01_ssdm_modeling.R
@@ -114,7 +114,7 @@ ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(see
     nn_fit = dnn(
       X = train_data[, predictors],
       Y = train_data$present,
-      hidden = c(500L, 200L, 50L),
+      hidden = c(200L, 200L, 200L),
       loss = "binomial",
       activation = c("sigmoid", "leaky_relu", "leaky_relu"),
       epochs = 500L, 
@@ -139,7 +139,7 @@ ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(see
   performance_summary = tibble(
     species = data_spec$species[1],
     obs = nrow(data_spec),
-    model = c("RF", "GBM", "GLM", "NN"),
+    model = c("SSDM_RF", "SSDM_GBM", "SSDM_GLM", "SSDM_NN"),
     auc = c(rf_performance$AUC, gbm_performance$AUC, glm_performance$AUC, nn_performance$AUC),
     accuracy = c(rf_performance$Accuracy, gbm_performance$Accuracy, glm_performance$Accuracy, nn_performance$Accuracy),
     kappa = c(rf_performance$Kappa, gbm_performance$Kappa, glm_performance$Kappa, nn_performance$Kappa),
@@ -153,4 +153,4 @@ ssdm_results = furrr::future_map(data_split, .options = furrr::furrr_options(see
 
 ssdm_results = bind_rows(ssdm_results)
 
-save(ssdm_results, file = "data/r_objects/ssdm_results.RData")
\ No newline at end of file
+save(ssdm_results, file = "data/r_objects/ssdm_results.RData")
diff --git a/R/04_02_msdm_modeling.R b/R/04_02_msdm_embed.R
similarity index 84%
rename from R/04_02_msdm_modeling.R
rename to R/04_02_msdm_embed.R
index a9e5eea27027b37d37845d18320ac91ae19990ec..6c825fdab324129f47e36c035d3fef7db9c50a0f 100644
--- a/R/04_02_msdm_modeling.R
+++ b/R/04_02_msdm_embed.R
@@ -41,18 +41,8 @@ save(msdm_fit, file = "data/r_objects/msdm_fit2.RData")
 # ----------------------------------------------------------------------#
 # Evaluate model                                                     ####
 # ----------------------------------------------------------------------#
-load("data/r_objects/msdm_fit2.RData")
+load("data/r_objects/msdm_fit.RData")
 
-# Overall 
-preds_train = predict(msdm_fit, newdata = as.matrix(train_data), type = "response")
-preds_test = predict(msdm_fit, newdata = as.matrix(test_data), type = "response")
-
-hist(preds_train)
-hist(preds_test)
-
-eval_overall = evaluate_model(msdm_fit,  test_data)
-
-# Per species
 data_split = split(model_data, model_data$species)
 
 msdm_results = lapply(data_split, function(data_spec){
@@ -68,7 +58,7 @@ msdm_results = lapply(data_split, function(data_spec){
   performance_summary = tibble(
     species = data_spec$species[1],
     obs = nrow(data_spec),
-    model = "NN_MSDM",
+    model = "MSDM_embed",
     auc = msdm_performance$AUC,
     accuracy = msdm_performance$Accuracy,
     kappa = msdm_performance$Kappa,
@@ -78,4 +68,4 @@ msdm_results = lapply(data_split, function(data_spec){
   )
 }) %>% bind_rows()
 
-save(msdm_results, file = "data/r_objects/msdm_results2.RData")
+save(msdm_results, file = "data/r_objects/msdm_results.RData")
diff --git a/R/04_03_msdm_embed_informed.R b/R/04_03_msdm_embed_informed.R
new file mode 100644
index 0000000000000000000000000000000000000000..718a8e96fbaeedb4c77b35085d60fd18eccad0bb
--- /dev/null
+++ b/R/04_03_msdm_embed_informed.R
@@ -0,0 +1,135 @@
+library(dplyr)
+library(tidyr)
+library(cito)
+
+source("R/utils.R")
+
+load("data/r_objects/model_data.RData")
+load("data/r_objects/func_dist.RData")
+
+# ----------------------------------------------------------------------#
+# Prepare data                                                       ####
+# ----------------------------------------------------------------------#
+model_species = intersect(model_data$species, names(func_dist)) 
+
+model_data_final = model_data %>%
+  dplyr::filter(species %in% !!model_species) %>% 
+  # dplyr::mutate_at(vars(starts_with("layer")), scale) %>%  # Scaling seems to make things worse often
+  dplyr::mutate(species_int = as.integer(as.factor(species)))
+
+train_data = dplyr::filter(model_data_final, train == 1)
+test_data = dplyr::filter(model_data_final, train == 0)
+
+sp_ind = match(model_species, names(func_dist))
+func_dist = as.matrix(func_dist)[sp_ind, sp_ind]
+
+embeddings = eigen(func_dist)$vectors[,1:20]
+predictors = paste0("layer_", 1:19)
+
+# ----------------------------------------------------------------------#
+# Without training the embedding                                     ####
+# ----------------------------------------------------------------------#
+# 1. Train
+formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, weights = embeddings, lambda = 0.00001, train = F)"))
+msdm_fit_embedding_untrained = dnn(
+  formula,
+  data = train_data,
+  hidden = c(200L, 200L, 200L),
+  loss = "binomial",
+  activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+  epochs = 12000L, 
+  lr = 0.01,
+  batchsize = nrow(train_data)/5,
+  dropout = 0.1,
+  burnin = 100,
+  optimizer = config_optimizer("adam", weight_decay = 0.001),
+  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+  early_stopping = 250,
+  validation = 0.3,
+  device = "cuda"
+)
+save(msdm_fit_embedding_untrained, file = "data/r_objects/msdm_fit_embedding_untrained.RData")
+
+# 2. Evaluate
+# Per species
+data_split = test_data %>% 
+  group_by(species_int) %>% 
+  group_split()
+
+msdm_results_embedding_untrained = lapply(data_split, function(data_spec){
+  target_species =  data_spec$species[1]
+  data_spec = dplyr::select(data_spec, -species)
+  
+  msdm_performance = tryCatch({
+    evaluate_model(msdm_fit_embedding_untrained, data_spec)
+  }, error = function(e){
+    list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
+  })
+  
+  performance_summary = tibble(
+    species = !!target_species,
+    obs = length(which(model_data$species == target_species)),
+    model = "MSDM_embed_informed_untrained",
+    auc = msdm_performance$AUC,
+    accuracy = msdm_performance$Accuracy,
+    kappa = msdm_performance$Kappa,
+    precision = msdm_performance$Precision,
+    recall = msdm_performance$Recall,
+    f1 = msdm_performance$F1
+  )
+}) %>% bind_rows()
+
+save(msdm_results_embedding_untrained, file = "data/r_objects/msdm_results_embedding_untrained.RData")
+
+# -------------------------------------------------------------------#
+# With training the embedding                                     ####
+# ------------------------------------------------------------ ------#
+formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, weights = embeddings, lambda = 0.00001, train = T)"))
+msdm_fit_embedding_trained = dnn(
+  formula,
+  data = train_data,
+  hidden = c(200L, 200L, 200L),
+  loss = "binomial",
+  activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+  epochs = 12000L, 
+  lr = 0.01,
+  batchsize = nrow(train_data)/5,
+  dropout = 0.1,
+  burnin = 100,
+  optimizer = config_optimizer("adam", weight_decay = 0.001),
+  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+  early_stopping = 250,
+  validation = 0.3,
+  device = "cuda"
+)
+save(msdm_fit_embedding_trained, file = "data/r_objects/msdm_fit_embedding_trained.RData")
+
+# 2. Evaluate
+data_split = test_data %>% 
+  group_by(species_int) %>% 
+  group_split()
+
+msdm_results_embedding_trained = lapply(data_split, function(data_spec){
+  target_species =  data_spec$species[1]
+  data_spec = dplyr::select(data_spec, -species)
+  
+  msdm_performance = tryCatch({
+    evaluate_model(msdm_fit_embedding_trained, data_spec)
+  }, error = function(e){
+    list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
+  })
+  
+  performance_summary = tibble(
+    species = !!target_species,
+    obs = length(which(model_data$species == target_species)),
+    model = "MSDM_embed_informed_trained",
+    auc = msdm_performance$AUC,
+    accuracy = msdm_performance$Accuracy,
+    kappa = msdm_performance$Kappa,
+    precision = msdm_performance$Precision,
+    recall = msdm_performance$Recall,
+    f1 = msdm_performance$F1
+  )
+}) %>% bind_rows()
+
+save(msdm_results_embedding_trained, file = "data/r_objects/msdm_results_embedding_trained.RData")
diff --git a/R/04_03_msdm_modeling_with_traits.R b/R/04_04_msdm_multiclass.R
similarity index 53%
rename from R/04_03_msdm_modeling_with_traits.R
rename to R/04_04_msdm_multiclass.R
index 34956297492e6be88e05f9213dd515c517c20736..b2af99b100ec74d44ee1d24d69d47b8254bab1e0 100644
--- a/R/04_03_msdm_modeling_with_traits.R
+++ b/R/04_04_msdm_multiclass.R
@@ -5,33 +5,29 @@ library(cito)
 source("R/utils.R")
 
 load("data/r_objects/model_data.RData")
-load("data/r_objects/func_dist.RData")
 
 # ----------------------------------------------------------------------#
-# Prepare embeddings                                                 ####
+# Prepare data                                                       ####
 # ----------------------------------------------------------------------#
-model_data$species_int = as.integer(as.factor(model_data$species))
+train_data = dplyr::filter(model_data, present == 1, train == 1) # Use only presences for training
+test_data = dplyr::filter(model_data, train == 0)                 # Evaluate on presences + pseudo-absences for comparability with binary models
 
-train_data = dplyr::filter(model_data, train == 1)
-test_data = dplyr::filter(model_data, train == 0)
+predictors = paste0("layer_", 1:19)
 
 # ----------------------------------------------------------------------#
-# Train model                                                       ####
+# Train model                                                        ####
 # ----------------------------------------------------------------------#
-predictors = paste0("layer_", 1:19)
-formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, dim = 20, lambda = 0.000001)"))
-
-msdm_fit = dnn(
+formula = as.formula(paste0("species ~ ", paste(predictors, collapse = '+')))
+msdm_fit_multiclass = dnn(
   formula,
   data = train_data,
-  hidden = c(500L, 1000L, 1000L),
-  loss = "binomial",
-  activation = c("sigmoid", "leaky_relu", "leaky_relu"),
-  epochs = 10000L, 
-  lr = 0.01,   
-  baseloss = 1,
+  hidden = c(200L, 200L, 200L),
+  loss = "softmax",
+  activation = c("leaky_relu", "leaky_relu", "leaky_relu"),
+  epochs = 5000L, 
+  lr = 0.01,
   batchsize = nrow(train_data)/5,
-  dropout = 0.25,
+  dropout = 0.1,
   burnin = 100,
   optimizer = config_optimizer("adam", weight_decay = 0.001),
   lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
@@ -40,38 +36,26 @@ msdm_fit = dnn(
   device = "cuda"
 )
 
-save(msdm_fit, file = "data/r_objects/msdm_fit.RData")
+save(msdm_fit_multiclass, file = "data/r_objects/msdm_fit_multiclass.RData")
 
 # ----------------------------------------------------------------------#
 # Evaluate model                                                     ####
 # ----------------------------------------------------------------------#
-load("data/r_objects/msdm_fit.RData")
-
-# Overall 
-preds_train = predict(msdm_fit, newdata = as.matrix(train_data), type = "response")
-preds_test = predict(msdm_fit, newdata = as.matrix(test_data), type = "response")
+data_split = test_data %>% 
+  group_by(species) %>% 
+  group_split()
 
-hist(preds_train)
-hist(preds_test)
-
-eval_overall = evaluate_model(msdm_fit,  test_data)
-
-# Per species
-data_split = split(model_data, model_data$species)
-
-msdm_results = lapply(data_split, function(data_spec){
-  test_data = dplyr::filter(data_spec, train == 0)
-  
+msdm_results_multiclass = lapply(data_split, function(data_spec){
   msdm_performance = tryCatch({
-    evaluate_model(msdm_fit, test_data)
+    evaluate_multiclass_model(msdm_fit_multiclass, data_spec, k = 10) # Top-k accuracy
   }, error = function(e){
     list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
   })
   
   performance_summary = tibble(
     species = data_spec$species[1],
-    obs = nrow(data_spec),
-    model = "NN_MSDM",
+    obs = nrow(dplyr::filter(model_data, species == data_spec$species[1])),
+    model = "MSDM_multiclass",
     auc = msdm_performance$AUC,
     accuracy = msdm_performance$Accuracy,
     kappa = msdm_performance$Kappa,
@@ -81,4 +65,4 @@ msdm_results = lapply(data_split, function(data_spec){
   )
 }) %>% bind_rows()
 
-save(msdm_results, file = "data/r_objects/msdm_results.RData")
+save(msdm_results_multiclass, file = "data/r_objects/msdm_results_multiclass.RData")
diff --git a/R/05_performance_analysis.qmd b/R/05_performance_analysis.qmd
index 7eb33f11dc84d6568cbb81df38cabb2cd9bd4a9a..ca81a81d29b4a727482d197ca3f40858384b167c 100644
--- a/R/05_performance_analysis.qmd
+++ b/R/05_performance_analysis.qmd
@@ -7,15 +7,16 @@ engine: knitr
 
 ```{r init, echo = FALSE, include = FALSE}
 library(tidyverse)
-library(Symobio)
 library(sf)
 library(plotly)
 library(DT)
 
 
 load("../data/r_objects/ssdm_results.RData")
-load("../data/r_objects/msdm_fit2.RData")
-load("../data/r_objects/msdm_results2.RData")
+load("../data/r_objects/msdm_fit.RData")
+load("../data/r_objects/msdm_results.RData")
+load("../data/r_objects/msdm_results_embedding_trained.RData.")
+load("../data/r_objects/msdm_results_multiclass.RData.")
 load("../data/r_objects/range_maps.RData")
 load("../data/r_objects/range_maps_gridded.RData")
 load("../data/r_objects/occs_final.RData")
@@ -25,7 +26,7 @@ sf::sf_use_s2(use_s2 = FALSE)
 
 ```{r globals, echo = FALSE, include = FALSE}
 # Select metrics
-focal_metrics = c("auc", "f1", "accuracy")  # There's a weird bug in plotly that scrambles up lines when using more then three groups
+focal_metrics = c("auc", "f1", "accuracy")  # There's a weird bug in plotly that scrambles up lines when using more than three groups
 
 # Dropdown options
 plotly_buttons = list()
@@ -52,72 +53,92 @@ lin_regression = function(x, y){
   )
 }
 
+# Performance table
+performance = bind_rows(ssdm_results, msdm_results, msdm_results_embedding_trained, msdm_results_multiclass) %>% 
+  pivot_longer(c(auc, accuracy, kappa, precision, recall, f1), names_to = "metric") %>% 
+  dplyr::filter(!is.na(value)) %>% 
+  dplyr::mutate(
+    metric = factor(metric, levels = c("auc", "kappa", "f1", "accuracy", "precision", "recall")),
+    value = round(pmax(value, 0, na.rm = T), 3) # Fix one weird instance of f1 < 0
+  ) %>% 
+  dplyr::filter(metric %in% focal_metrics)
 ```
 
 ## Summary
 
-This document summarizes the performance of four SDM algorithms (Random Forest, Gradient Boosting Machine, Generalized Linear Model, Deep Neural Network) for `r I(length(unique(ssdm_results$species)))` South American mammal species. We use `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) to evaluate model performance and look at how performance varies with five factors (number of records, range size, range coverage, range coverage bias, and functional group).
+This document summarizes the performance of different sSDM and mMSDM algorithms for `r I(length(unique(performance$species)))` South American mammal species. Model performance is evaluated on `r I(xfun::numbers_to_words(length(focal_metrics)))` metrics (`r I(paste(focal_metrics, collapse = ', '))`) and analyzed along five potential influence factors (number of records, range size, range coverage, range coverage bias, and functional group). The comparison of sSDM vs mSDM approaches is of particular interest.
 
-### Modeling decisions:
+Code can be found on [GitLab](https://git.idiv.de/ye87zine/symobio-modeling).
 
-#### Data
+### Modeling overview:
+
+#### General decisions
 
 -   Randomly sampled pseudo-absences from expanded area of extent of occurrence records (×1.25)
 -   Balanced presences and absences for each species
 -   Predictors: all 19 CHELSA bioclim variables
--   70/30 Split of training vs. test data
+-   70/30 Split of training vs. test data (except for NN models)
 
-#### Algorithms
+#### sSDM Algorithms
 
-Random Forest
+Random Forest (**SSDM_RF**)
 
--   Spatial block cross-validation during training
 -   Hyperparameter tuning of `mtry`
+-   Spatial block cross-validation during training
 
-Generalized boosted machine
+Generalized boosted machine (**SSDM_GBM**)
 
--   Spatial block cross-validation during training
 -   Hyperparameter tuning across `n.trees` , `interaction.depth` , `shrinkage`, `n.minobsinnode`
+-   Spatial block cross-validation during training
 
-Generalized Linear Model
+Generalized Linear Model (**SSDM_GLM**)
 
--   Spatial block cross-validation during training
 -   Logistic model with binomial link function
+-   Spatial block cross-validation during training
 
-Neural Netwok (single-species)
+Neural Netwok (**SSDM_NN**)
 
--   Three hidden layers (sigmoid - 500, leaky_relu - 200, leaky_relu - 50)
--   Binomial loss, ADAM optimizer
--   Poor convergence
+-   Three hidden layers, leaky ReLu activations, binomial loss
+-   no spatial block cross-validation during training
 
-Neural Network (multi-species)
+#### mSDM Algorithms
 
--   Three hidden layers (sigmoid - 500, leaky_relu - 1000, leaky_relu - 1000)
--   Binomial loss, ADAM optimizer
--   very slow convergence (`r I(length(which(!is.na(msdm_fit$losses$train_l))))` epochs)
+Binary Neural Network with species embedding (**MSDM_embed**)
+
+-   definition: presence \~ environment + embedding(species)
+-   prediction: probability of occurrence given a set of (environmental) inputs and species identity
+-   embedding initialized at random
+-   three hidden layers, sigmoid + leaky ReLu activations, binomial loss
+
+Binary Neural Network with trait-informed species embedding (**MSDM_embed_traits**)
+
+-   definition: presence \~ environment + embedding(species)
+-   prediction: probability of occurrence given a set of (environmental) inputs and species identity
+-   embedding initialized using eigenvectors of functional distance matrix
+-   three hidden layers, sigmoid + leaky ReLu activations, binomial loss
+
+Multi-Class Neural Network (**MSDM_multiclass**)
+
+-   definition: species identity \~ environment
+-   prediction: probability distribution across all observed species given a set of (environmental) inputs
+-   presence-only data in training
+-   three hidden layers, leaky ReLu activations, softmax loss
+-   Top-k based evaluation (k=10, P/A \~ target species in / not among top 10 predictions)
 
 ### Key findings:
 
--   Performance: RF \> GBM \> GLM \> NN
--   Convergence problems with Neural Notwork Models
+-   sSDM algorithms (RF, GBM) outperformed mSDMs in most cases
+-   mSDMs showed indications of better performance for rare species (\< 10-20 occurrences)
 -   More occurrence records and larger range sizes tended to improve model performance
--   Higher range coverage correlated with better performance.
+-   Higher range coverage correlated with better performance
 -   Range coverage bias and functional group showed some impact but were less consistent
+-   Convergence problems hampered NN sSDM performance
 
 ## Analysis
 
 The table below shows the analysed modeling results.
 
 ```{r performance, echo = FALSE, message=FALSE, warnings=FALSE}
-performance = bind_rows(ssdm_results, msdm_results) %>% 
-  pivot_longer(c(auc, accuracy, kappa, precision, recall, f1), names_to = "metric") %>% 
-  dplyr::filter(!is.na(value)) %>% 
-  dplyr::mutate(
-    metric = factor(metric, levels = c("auc", "kappa", "f1", "accuracy", "precision", "recall")),
-    value = round(pmax(value, 0, na.rm = T), 3) # Fix one weird instance of f1 < 0
-  ) %>% 
-  dplyr::filter(metric %in% focal_metrics)
-
 DT::datatable(performance)
 ```
 
@@ -508,7 +529,7 @@ bslib::card(plot, full_screen = T)
 Functional groups were assigned based on taxonomic order. The following groupings were used:
 
 | Functional group      | Taxomic orders                                                        |
-|------------------------|-----------------------------------------------|
+|-----------------------|-----------------------------------------------------------------------|
 | large ground-dwelling | Carnivora, Artiodactyla, Cingulata, Perissodactyla                    |
 | small ground-dwelling | Rodentia, Didelphimorphia, Soricomorpha, Paucituberculata, Lagomorpha |
 | arboreal              | Primates, Pilosa                                                      |
diff --git a/R/utils.R b/R/utils.R
index edde20d758b1d33e81cf00ef13caa997496b6fa3..b1da26abc38515507c2381fa178c92bab4e50296 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -67,3 +67,56 @@ evaluate_model <- function(model, test_data) {
     )
   )
 }
+
+evaluate_multiclass_model <- function(model, test_data, k) {
+  # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
+  # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
+  
+  # Precision: The proportion of true positives out of all instances predicted as positive.
+  # Formula: Precision = TP / (TP + FP)
+  
+  # Recall (Sensitivity): The proportion of true positives out of all actual positive instances.
+  # Formula: Recall = TP / (TP + FN)
+  
+  # F1 Score: The harmonic mean of Precision and Recall, balancing the two metrics.
+  # Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
+  target_species = unique(test_data$species)
+  checkmate::assert_character(target_species, len = 1, any.missing = F)
+  
+  new_data = dplyr::select(test_data, -species)
+  
+  # Predict probabilities
+  if(class(model) %in% c("citodnn", "citodnnBootstrap")){
+    preds_overall = predict(model, as.matrix(new_data), type = "response")
+    probs <- as.vector(preds_overall[,target_species])
+    
+    rank = apply(preds_overall, 1, function(x){         # Top-K approach
+      x_sort = sort(x, decreasing = T)
+      return(which(names(x_sort) == target_species))
+    })
+    top_k = as.character(as.numeric(rank <= k))
+    preds <- factor(top_k, levels = c("0", "1"), labels = c("A", "P"))
+  } else {
+    stop("Unsupported model type: ", class(model))
+  }
+  
+  actual <- factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
+  
+  # Calculate AUC
+  auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
+  
+  # Calculate confusion matrix
+  cm <- caret::confusionMatrix(preds, actual, positive = "P")
+  
+  # Return metrics
+  return(
+    list(
+      AUC = as.numeric(auc),
+      Accuracy = cm$overall["Accuracy"],
+      Kappa = cm$overall["Kappa"],
+      Precision = cm$byClass["Precision"],
+      Recall = cm$byClass["Recall"],
+      F1 = cm$byClass["F1"]
+    )
+  )
+}