diff --git a/R/02_03_phylo_preparation.R b/R/02_03_phylo_preparation.R
index f00499354ffea8349ca3f0c1e38550490daefaf9..3963892f71b712031b89d47792836947f6f2abb5 100644
--- a/R/02_03_phylo_preparation.R
+++ b/R/02_03_phylo_preparation.R
@@ -3,7 +3,7 @@ library(phytools)
 library(ape)
 library(Symobio)
 
-forest = read.nexus("data/phylogenies/Small_phylogeny.nex")
+forest = read.nexus("data/phylogenies/Complete_phylogeny.nex")
 load("data/r_objects/range_maps_names_matched.RData")
 
 # Get taxonomic information for target species
@@ -22,7 +22,9 @@ phylo_names_matched  = lapply(phylo_names_unique, function(name){
   }, error = function(e){
     return(NULL)
   })
-}) %>% bind_rows()
+}) %>% 
+  bind_rows() %>% 
+  drop_na()
 
 save(phylo_names_matched, file = "data/r_objects/phylo_names_matched.RData")
 
@@ -32,14 +34,14 @@ phylo_names_target = dplyr::filter(phylo_names_matched, name_matched %in% range_
 forest_pruned = lapply(forest, function(x) {
   labels_drop = setdiff(x$tip.label, phylo_names_target$name_orig)
   x_pruned = drop.tip(x, labels_drop)
-  labels_new = phylo_names_matched$name_matched[match(phylo_names_target$name_orig, x_pruned$tip.label)]
+  labels_new = phylo_names_target$name_matched[match(phylo_names_target$name_orig, x_pruned$tip.label)]
   x_pruned$tip.label = labels_new
   return(x_pruned)
 })
 
 class(forest_pruned) <- "multiPhylo"
 
-# Calculate pairwise phylogenetic distances across random sample of forests
+# Calculate pairwise phylogenetic distances across a random sample 50 of forests
 indices = sample(length(forest_pruned), 50)
 dists = lapply(indices, function(i){
   cophenetic.phylo(forest_pruned[[i]])
@@ -47,4 +49,4 @@ dists = lapply(indices, function(i){
 
 # Save result
 phylo_dist = Reduce("+", dists) / length(dists)
-save(phylo_dist, file = "data/r_objects/phylo_dist.RData")
\ No newline at end of file
+save(phylo_dist, file = "data/r_objects/phylo_dist.RData")
diff --git a/R/04_04_msdm_multiclass.R b/R/04_02_msdm_multiclass.R
similarity index 100%
rename from R/04_04_msdm_multiclass.R
rename to R/04_02_msdm_multiclass.R
diff --git a/R/04_02_msdm_embed_raw.R b/R/04_03_msdm_embed_raw.R
similarity index 100%
rename from R/04_02_msdm_embed_raw.R
rename to R/04_03_msdm_embed_raw.R
diff --git a/R/04_03_msdm_embed_traits.R b/R/04_04_msdm_embed_traits.R
similarity index 97%
rename from R/04_03_msdm_embed_traits.R
rename to R/04_04_msdm_embed_traits.R
index 774121e2cc3692fc5d0c10a7af9297a5beb5e1af..9dd0b34ac126ae121c90aba426caac94db58f756 100644
--- a/R/04_03_msdm_embed_traits.R
+++ b/R/04_04_msdm_embed_traits.R
@@ -5,12 +5,12 @@ library(cito)
 source("R/utils.R")
 
 load("data/r_objects/model_data.RData")
-load("data/r_objects/func_dist.RData")
+load("data/r_objects/phylo_dist.RData")
 
 # ----------------------------------------------------------------------#
 # Prepare data                                                       ####
 # ----------------------------------------------------------------------#
-model_species = intersect(model_data$species, names(func_dist)) 
+model_species = intersect(model_data$species, colnames(phylo_dist)) 
 
 model_data_final = model_data %>%
   dplyr::filter(species %in% !!model_species) %>% 
diff --git a/R/04_05_msdm_embed_phylo.R b/R/04_05_msdm_embed_phylo.R
new file mode 100644
index 0000000000000000000000000000000000000000..df8b57ab681a51cedfd4f641123de24ee7f33838
--- /dev/null
+++ b/R/04_05_msdm_embed_phylo.R
@@ -0,0 +1,139 @@
+library(dplyr)
+library(tidyr)
+library(cito)
+
+source("R/utils.R")
+
+load("data/r_objects/model_data.RData")
+load("data/r_objects/phylo_dist.RData")
+
+# ----------------------------------------------------------------------#
+# Prepare data                                                       ####
+# ----------------------------------------------------------------------#
+model_species = intersect(model_data$species, colnames(phylo_dist)) 
+
+model_data_final = model_data %>%
+  dplyr::filter(species %in% !!model_species) %>% 
+  # dplyr::mutate_at(vars(starts_with("layer")), scale) %>%  # Scaling seems to make things worse often
+  dplyr::mutate(species_int = as.integer(as.factor(species)))
+
+train_data = dplyr::filter(model_data_final, train == 1)
+test_data = dplyr::filter(model_data_final, train == 0)
+
+sp_ind = match(model_species, colnames(phylo_dist))
+phylo_dist = phylo_dist[sp_ind, sp_ind]
+
+embeddings = eigen(phylo_dist)$vectors[,1:20]
+predictors = paste0("layer_", 1:19)
+
+# ----------------------------------------------------------------------#
+# Without training the embedding                                     ####
+# ----------------------------------------------------------------------#
+# 1. Train
+formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, weights = embeddings, lambda = 0.00001, train = F)"))
+msdm_fit_embedding_phylo_static = dnn(
+  formula,
+  data = train_data,
+  hidden = c(500L, 500L, 500L),
+  loss = "binomial",
+  activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+  epochs = 15000L, 
+  lr = 0.01,   
+  baseloss = 1,
+  batchsize = nrow(train_data),
+  dropout = 0.1,
+  burnin = 100,
+  optimizer = config_optimizer("adam", weight_decay = 0.001),
+  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+  early_stopping = 250,
+  validation = 0.3,
+  device = "cuda",
+)
+save(msdm_fit_embedding_phylo_static, file = "data/r_objects/msdm_fit_embedding_phylo_static.RData")
+
+# 2. Evaluate
+# Per species
+load("data/r_objects/msdm_fit_embedding_phylo_static.RData")
+data_split = test_data %>% 
+  group_by(species_int) %>% 
+  group_split()
+
+msdm_results_embedding_phylo_static = lapply(data_split, function(data_spec){
+  target_species =  data_spec$species[1]
+  data_spec = dplyr::select(data_spec, -species)
+  
+  msdm_performance = tryCatch({
+    evaluate_model(msdm_fit_embedding_phylo_static, data_spec)
+  }, error = function(e){
+    list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
+  })
+  
+  performance_summary = tibble(
+    species = !!target_species,
+    obs = length(which(model_data$species == target_species)),
+    model = "MSDM_embed_informed_phylo_static",
+    auc = msdm_performance$AUC,
+    accuracy = msdm_performance$Accuracy,
+    kappa = msdm_performance$Kappa,
+    precision = msdm_performance$Precision,
+    recall = msdm_performance$Recall,
+    f1 = msdm_performance$F1
+  )
+}) %>% bind_rows()
+
+save(msdm_results_embedding_phylo_static, file = "data/r_objects/msdm_results_embedding_phylo_static.RData")
+
+# -------------------------------------------------------------------#
+# With training the embedding                                     ####
+# ------------------------------------------------------------ ------#
+formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, weights = embeddings, lambda = 0.00001, train = T)"))
+msdm_fit_embedding_phylo_trained = dnn(
+  formula,
+  data = train_data,
+  hidden = c(500L, 500L, 500L),
+  loss = "binomial",
+  activation = c("sigmoid", "leaky_relu", "leaky_relu"),
+  epochs = 15000L, 
+  lr = 0.01,   
+  baseloss = 1,
+  batchsize = nrow(train_data),
+  dropout = 0.1,
+  burnin = 100,
+  optimizer = config_optimizer("adam", weight_decay = 0.001),
+  lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
+  early_stopping = 250,
+  validation = 0.3,
+  device = "cuda",
+)
+save(msdm_fit_embedding_phylo_trained, file = "data/r_objects/msdm_fit_embedding_phylo_trained.RData")
+
+# 2. Evaluate
+load("data/r_objects/msdm_fit_embedding_phylo_trained.RData")
+data_split = test_data %>% 
+  group_by(species_int) %>% 
+  group_split()
+
+msdm_results_embedding_phylo_trained = lapply(data_split, function(data_spec){
+  target_species =  data_spec$species[1]
+  data_spec = dplyr::select(data_spec, -species)
+  
+  msdm_performance = tryCatch({
+    evaluate_model(msdm_fit_embedding_phylo_trained, data_spec)
+  }, error = function(e){
+    list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
+  })
+  
+  performance_summary = tibble(
+    species = !!target_species,
+    obs = length(which(model_data$species == target_species)),
+    model = "MSDM_embed_informed_phylo_trained",
+    auc = msdm_performance$AUC,
+    accuracy = msdm_performance$Accuracy,
+    kappa = msdm_performance$Kappa,
+    precision = msdm_performance$Precision,
+    recall = msdm_performance$Recall,
+    f1 = msdm_performance$F1
+  )
+}) %>% bind_rows()
+
+save(msdm_results_embedding_phylo_trained, file = "data/r_objects/msdm_results_embedding_phylo_trained.RData")