Skip to content
Snippets Groups Projects
Select Git revision
  • 60a58262fc8179f340e82642471dcd7d934e4446
  • main default protected
2 results

04_04_msdm_embed_traits.R

  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    04_04_msdm_embed_traits.R 5.27 KiB
    library(dplyr)
    library(tidyr)
    library(cito)
    
    source("R/utils.R")
    
    load("data/r_objects/model_data.RData")
    load("data/r_objects/func_dist.RData")
    
    # ----------------------------------------------------------------------#
    # Prepare data                                                       ####
    # ----------------------------------------------------------------------#
    model_species = intersect(model_data$species, names(func_dist)) 
    
    model_data_final = model_data %>%
      dplyr::filter(species %in% !!model_species) %>% 
      # dplyr::mutate_at(vars(starts_with("layer")), ~as.vector(scale(.))) %>%  # Scaling seems to make things worse often
      dplyr::mutate(species_int = as.integer(as.factor(species)))
    
    train_data = dplyr::filter(model_data_final, train == 1)
    test_data = dplyr::filter(model_data_final, train == 0)
    
    sp_ind = match(model_species, names(func_dist))
    func_dist = as.matrix(func_dist)[sp_ind, sp_ind]
    
    embeddings = eigen(func_dist)$vectors[,1:20]
    predictors = paste0("layer_", 1:19)
    
    # ----------------------------------------------------------------------#
    # Without training the embedding                                     ####
    # ----------------------------------------------------------------------#
    # 1. Train
    formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, weights = embeddings, lambda = 0.00001, train = F)"))
    
    plot(1, type="n", xlab="", ylab="", xlim=c(0, 15000), ylim=c(0, 0.7)) # empty plot with better limits, draw points in there
    msdm_fit_embedding_traits_static = dnn(
      formula,
      data = train_data,
      hidden = c(500L, 500L, 500L),
      loss = "binomial",
      activation = c("sigmoid", "leaky_relu", "leaky_relu"),
      epochs = 15000L, 
      lr = 0.01,   
      baseloss = 1,
      batchsize = nrow(train_data),
      dropout = 0.1,
      burnin = 100,
      optimizer = config_optimizer("adam", weight_decay = 0.001),
      lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
      early_stopping = 250,
      validation = 0.3,
      device = "cuda",
    )
    save(msdm_fit_embedding_traits_static, file = "data/r_objects/msdm_fit_embedding_traits_static.RData")
    
    # 2. Evaluate
    # Per species
    load("data/r_objects/msdm_fit_embedding_traits_static.RData")
    data_split = test_data %>% 
      group_by(species_int) %>% 
      group_split()
    
    msdm_results_embedding_traits_static = lapply(data_split, function(data_spec){
      target_species =  data_spec$species[1]
      data_spec = dplyr::select(data_spec, -species)
      
      msdm_performance = tryCatch({
        evaluate_model(msdm_fit_embedding_traits_static, data_spec)
      }, error = function(e){
        list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
      })
      
      performance_summary = tibble(
        species = !!target_species,
        obs = length(which(model_data$species == target_species)),
        model = "MSDM_embed_informed_traits_static",
        auc = msdm_performance$AUC,
        accuracy = msdm_performance$Accuracy,
        kappa = msdm_performance$Kappa,
        precision = msdm_performance$Precision,
        recall = msdm_performance$Recall,
        f1 = msdm_performance$F1
      )
    }) %>% bind_rows()
    
    save(msdm_results_embedding_traits_static, file = "data/r_objects/msdm_results_embedding_traits_static.RData")
    
    # -------------------------------------------------------------------#
    # With training the embedding                                     ####
    # ------------------------------------------------------------ ------#
    formula = as.formula(paste0("present ~ ", paste(predictors, collapse = '+'), " + ", "e(species_int, weights = embeddings, lambda = 0.00001, train = T)"))
    
    plot(1, type="n", xlab="", ylab="", xlim=c(0, 15000), ylim=c(0, 0.7)) # empty plot with better limits, draw points in there
    msdm_fit_embedding_traits_trained = dnn(
      formula,
      data = train_data,
      hidden = c(500L, 500L, 500L),
      loss = "binomial",
      activation = c("sigmoid", "leaky_relu", "leaky_relu"),
      epochs = 15000L, 
      lr = 0.01,   
      baseloss = 1,
      batchsize = nrow(train_data),
      dropout = 0.1,
      burnin = 100,
      optimizer = config_optimizer("adam", weight_decay = 0.001),
      lr_scheduler = config_lr_scheduler("reduce_on_plateau", patience = 100, factor = 0.7),
      early_stopping = 250,
      validation = 0.3,
      device = "cuda",
    )
    save(msdm_fit_embedding_traits_trained, file = "data/r_objects/msdm_fit_embedding_traits_trained.RData")
    
    # 2. Evaluate
    load("data/r_objects/msdm_fit_embedding_traits_trained.RData")
    data_split = test_data %>% 
      group_by(species_int) %>% 
      group_split()
    
    msdm_results_embedding_traits_trained = lapply(data_split, function(data_spec){
      target_species =  data_spec$species[1]
      data_spec = dplyr::select(data_spec, -species)
      
      msdm_performance = tryCatch({
        evaluate_model(msdm_fit_embedding_traits_trained, data_spec)
      }, error = function(e){
        list(AUC = NA, Accuracy = NA, Kappa = NA, Precision = NA, Recall = NA, F1 = NA)
      })
      
      performance_summary = tibble(
        species = !!target_species,
        obs = length(which(model_data$species == target_species)),
        model = "MSDM_embed_informed_traits_trained",
        auc = msdm_performance$AUC,
        accuracy = msdm_performance$Accuracy,
        kappa = msdm_performance$Kappa,
        precision = msdm_performance$Precision,
        recall = msdm_performance$Recall,
        f1 = msdm_performance$F1
      )
    }) %>% bind_rows()
    
    save(msdm_results_embedding_traits_trained, file = "data/r_objects/msdm_results_embedding_traits_trained.RData")