expand_bbox <- function(bbox, min_span = 1, expansion = 0.25) {
  # Get current bbox dimensions
  x_range <- bbox["xmax"] - bbox["xmin"]
  y_range <- bbox["ymax"] - bbox["ymin"]
  x_expand = expansion
  y_expand = expansion
  
  # Make sure bbox is at least `min_diameter` wide
  if(x_range*x_expand < min_span){
    x_range = min_span/2
    x_expand = 1
  }
  
  if(y_range*y_expand < min_span){
    y_range = min_span/2
    y_expand = 1
  }
  
  # Expand the limits, adjusting both directions correctly
  bbox["xmin"] <- bbox["xmin"] - (x_expand * x_range)
  bbox["xmax"] <- bbox["xmax"] + (x_expand * x_range)
  bbox["ymin"] <- bbox["ymin"] - (y_expand * y_range)
  bbox["ymax"] <- bbox["ymax"] + (y_expand * y_range)
  
  return(bbox)
}

predict_new = function(model, data, type = "prob"){
  stopifnot(type %in% c("prob", "class"))
  
  if(class(model) %in% c("citodnn", "citodnnBootstrap")){
    probs = predict(model, data, type = "response")[,1]
    if(type == "prob"){
      return(probs)
    } else {
      preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
      return(preds)
    }
  } else {
    probs = predict(model, data, type = "prob")$P
    if(type == "prob"){
      return(probs)
    } else {
      preds = predict(model, data, type = "raw")
      return(preds)
    }
  }
}

evaluate_model <- function(model, data) {
  # Accuracy: The proportion of correctly predicted instances (both true positives and true negatives) out of the total instances.
  # Formula: Accuracy = (TP + TN) / (TP + TN + FP + FN)
  
  # Precision: The proportion of true positives out of all instances predicted as positive.
  # Formula: Precision = TP / (TP + FP)
  
  # Recall (Sensitivity): The proportion of true positives out of all actual positive instances.
  # Formula: Recall = TP / (TP + FN)
  
  # F1 Score: The harmonic mean of Precision and Recall, balancing the two metrics.
  # Formula: F1 = 2 * (Precision * Recall) / (Precision + Recall)
  
  # Predict probabilities
  if(class(model) %in% c("citodnn", "citodnnBootstrap")){
    probs = predict(model, data, type = "response")[,1]
    preds = factor(round(probs), levels = c("0", "1"), labels = c("A", "P"))
  } else {
    probs = predict(model, data, type = "prob")$P
    preds = predict(model, data, type = "raw")
  }
  
  actual = factor(data$present, levels = c("0", "1"), labels = c("A", "P"))
  
  # Calculate AUC
  auc = pROC::roc(actual, probs, levels = c("P", "A"), direction = ">")$auc
  
  # Calculate confusion matrix
  cm <- caret::confusionMatrix(preds, actual, positive = "P")
  
  # Return metrics
  return(
    list(
      auc = as.numeric(auc),
      accuracy = cm$overall["Accuracy"],
      kappa = cm$overall["Kappa"],
      precision = cm$byClass["Precision"],
      recall = cm$byClass["Recall"],
      f1 = cm$byClass["F1"],
      tp = cm$table["P", "P"],
      fp = cm$table["P", "A"],
      tn = cm$table["A", "A"],
      fn = cm$table["A", "P"]
    )
  )
}