expand_bbox <- function(bbox, min_span = 1, expansion = 0.25) {
  # Get current bbox dimensions
  x_range <- bbox["xmax"] - bbox["xmin"]
  y_range <- bbox["ymax"] - bbox["ymin"]
  x_expand = expansion
  y_expand = expansion
  
  # Make sure bbox is at least `min_diameter` wide
  if(x_range*x_expand < min_span){
    x_range = min_span/2
    x_expand = 1
  }
  
  if(y_range*y_expand < min_span){
    y_range = min_span/2
    y_expand = 1
  }
  
  # Expand the limits, adjusting both directions correctly
  bbox["xmin"] <- max(bbox["xmin"] - (x_expand * x_range), -180)
  bbox["xmax"] <- min(bbox["xmax"] + (x_expand * x_range), 180)
  bbox["ymin"] <- max(bbox["ymin"] - (y_expand * y_range), -90)
  bbox["ymax"] <- min(bbox["ymax"] + (y_expand * y_range), 90)
  
  return(bbox)
}

evaluate_model <- function(model, test_data) {
  # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
  # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
  
  # Precision: The proportion of true positives out of all instances predicted as positive.
  # Formula: Precision = TP / (TP + FP)
  
  # Recall (Sensitivity): The proportion of true positives out of all actual positive instances.
  # Formula: Recall = TP / (TP + FN)
  
  # F1 Score: The harmonic mean of Precision and Recall, balancing the two metrics.
  # Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
  
  # Predict probabilities
  if(class(model) %in% c("citodnn", "citodnnBootstrap")){
    probs <- predict(model, as.matrix(test_data), type = "response")[,1]
    preds <- factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
  } else {
    probs <- predict(model, test_data, type = "prob")$P
    preds <- predict(model, test_data, type = "raw")
  }
  
  actual <- factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
  
  # Calculate AUC
  auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
  
  # Calculate confusion matrix
  cm <- caret::confusionMatrix(preds, actual, positive = "P")
  
  # Return metrics
  return(
    list(
      AUC = as.numeric(auc),
      Accuracy = cm$overall["Accuracy"],
      Kappa = cm$overall["Kappa"],
      Precision = cm$byClass["Precision"],
      Recall = cm$byClass["Recall"],
      F1 = cm$byClass["F1"]
    )
  )
}

evaluate_multiclass_model <- function(model, test_data, k) {
  # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
  # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
  
  # Precision: The proportion of true positives out of all instances predicted as positive.
  # Formula: Precision = TP / (TP + FP)
  
  # Recall (Sensitivity): The proportion of true positives out of all actual positive instances.
  # Formula: Recall = TP / (TP + FN)
  
  # F1 Score: The harmonic mean of Precision and Recall, balancing the two metrics.
  # Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
  target_species = unique(test_data$species)
  checkmate::assert_character(target_species, len = 1, any.missing = F)
  
  new_data = dplyr::select(test_data, -species)
  
  # Predict probabilities
  if(class(model) %in% c("citodnn", "citodnnBootstrap")){
    preds_overall = predict(model, as.matrix(new_data), type = "response")
    probs <- as.vector(preds_overall[,target_species])
    
    rank = apply(preds_overall, 1, function(x){         # Top-K approach
      x_sort = sort(x, decreasing = T)
      return(which(names(x_sort) == target_species))
    })
    top_k = as.character(as.numeric(rank <= k))
    preds <- factor(top_k, levels = c("0", "1"), labels = c("A", "P"))
  } else {
    stop("Unsupported model type: ", class(model))
  }
  
  actual <- factor(test_data$present, levels = c("0", "1"), labels = c("A", "P"))
  
  # Calculate AUC
  auc <- pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
  
  # Calculate confusion matrix
  cm <- caret::confusionMatrix(preds, actual, positive = "P")
  
  # Return metrics
  return(
    list(
      AUC = as.numeric(auc),
      Accuracy = cm$overall["Accuracy"],
      Kappa = cm$overall["Kappa"],
      Precision = cm$byClass["Precision"],
      Recall = cm$byClass["Recall"],
      F1 = cm$byClass["F1"]
    )
  )
}