#Functions to define correlations between Beals, SYNCSA, ENV
# W - Sp. composition
# X - fuzzy weighted
# Y - Beals
# B - traits
# E - Environment


require(tidyverse)
require(SYNCSA)
require(vegan)
require(abind)
require(ade4)
require(parallel)
require(doParallel)

#### Function 1 - CorXY ####
get.corXY <- function(comm, traits, trait.sel="all", stat=c("mantel", "RV", "procrustes")){
  if(identical(trait.sel, "all")) {trait.sel <- 1:ncol(traits)}
  ii <- trait.sel
  lab.tmp <- paste(ii, collapse="_")
  ### delete potential missing species
  if(any(colSums(comm)==0)){
    empty <- which(colSums(comm)==0)
    traits <- traits[-empty,]
    comm <- comm[,-empty]
  }
  syn.out.tmp <- matrix.x(comm=comm, traits=traits[,ii,drop=F], scale=T)$matrix.X
  W.beals <- as.data.frame(beals(comm, include=T, type=2))
  corXY <- NULL
  if("mantel" %in% stat){
    W.beals.d <- dist(W.beals)
    mantel.tmp <- mantel(W.beals.d, dist(syn.out.tmp[]))
    corXY <- rbind(corXY, 
                   data.frame(Trait.comb=lab.tmp,  Test="Mantel", Coef=mantel.tmp$statistic, pvalue=mantel.tmp$signif))
  } 
  if("RV" %in% stat){
    RV.tmp <- RV.rtest(W.beals, as.data.frame(syn.out.tmp))
    corXY <- rbind(corXY,
                   data.frame(Trait.comb=lab.tmp,  Test="RV", Coef=RV.tmp$obs, pvalue=RV.tmp$pvalue))
  } 
  if("procrustes" %in% stat){
    prot.tmp <- protest(W.beals, syn.out.tmp)
    corXY <- rbind(corXY,
                   data.frame(Trait.comb=lab.tmp,  Test="Procrustes", Coef=prot.tmp$t0, pvalue=prot.tmp$signif))
  }
  return(corXY)
}



#### Function 2 - CorTE ####
get.corTE <- function(comm, traits, envir,  trait.sel="all", env.sel="all", stat=c("mantel", "RV")){
  if(identical(trait.sel, "all")) {trait.sel <- 1:ncol(traits)}
  if(identical(env.sel, "all" )) {env.sel <- 1:ncol(envir)}
  ii <- trait.sel
  lab.tmp <- paste(ii, collapse="_")
  ## delete potential missing species
  if(any(colSums(comm)==0)){
    empty <- which(colSums(comm)==0)
    traits <- traits[-empty,]
    comm <- comm[,-empty]
  }
  syn.out.tmp <- matrix.t(comm=comm, traits=traits, scale=T)$matrix.T
  ee <- env.sel
  lab.env <- paste(ee, collapse="_")
  corTE <- NULL
  if("mantel" %in% stat){
    mantel.tmp <- mantel(dist(envir[,ee, drop=F]), dist(syn.out.tmp[,ii,drop=F]))
    corTE <- rbind(corTE, 
                   data.frame(Trait.comb=lab.tmp, Env.comb=lab.env, Test="Mantel", Coef=mantel.tmp$statistic, pvalue=mantel.tmp$signif))
  } 
  if("RV" %in% stat){
    RV.tmp <- RV.rtest(as.data.frame(envir[,ee, drop=F]), as.data.frame(syn.out.tmp[,ii,drop=F]))
    corTE <- rbind(corTE,
                   data.frame(Trait.comb=lab.tmp, Env.comb=lab.env, Test="RV", Coef=RV.tmp$obs, pvalue=RV.tmp$pvalue))
  } 
  if("procrustes" %in% stat){
    prot.tmp <- protest(envir[,ee, drop=F], syn.out.tmp[,ii,drop=F])
    corTE <- rbind(corTE,
                   data.frame(Trait.comb=lab.tmp, Env.comb=lab.env, Test="Procrustes", Coef=prot.tmp$t0, pvalue=prot.tmp$signif))
  }
  return(corTE)
}


#### Function 3 - CorXE ####
get.corXE <- function(comm, traits, envir, trait.sel="all", env.sel="all", stat=c("mantel", "RV", "procrustes")){
  if(identical(trait.sel, "all")) {trait.sel <- 1:ncol(traits)}
  if(identical(env.sel, "all" )) {env.sel <- 1:ncol(envir)}
  ii <- trait.sel
  lab.tmp <- paste(ii, collapse="_")
  ### delete potential missing species
  if(any(colSums(comm)==0)){
    empty <- which(colSums(comm)==0)
    traits <- traits[-empty,]
    comm <- comm[,-empty]
  }
  syn.out.tmp <- matrix.x(comm=comm, traits=traits[,ii,drop=F], scale=T)$matrix.X
  ee <- env.sel
  lab.env <- paste(ee, collapse="_")
  corXE <- NULL
  if("mantel" %in% stat){
    mantel.tmp <- mantel(dist(envir[,ee, drop=F]), dist(syn.out.tmp[]))
    corXE <- rbind(corXE, 
                   data.frame(Trait.comb=lab.tmp, Env.comb=lab.env, Test="Mantel", Coef=mantel.tmp$statistic, pvalue=mantel.tmp$signif))
  } 
  if("RV" %in% stat){
    RV.tmp <- RV.rtest(as.data.frame(envir[,ee, drop=F]), as.data.frame(syn.out.tmp))
    corXE <- rbind(corXE,
                   data.frame(Trait.comb=lab.tmp, Env.comb=lab.env, Test="RV", Coef=RV.tmp$obs, pvalue=RV.tmp$pvalue))
  } 
  if("procrustes" %in% stat){
    prot.tmp <- protest(envir[,ee, drop=F], syn.out.tmp)
    corXE <- rbind(corXE,
                   data.frame(Trait.comb=lab.tmp, Env.comb=lab.env, Test="Procrustes", Coef=prot.tmp$t0, pvalue=prot.tmp$signif))
  }
  return(corXE)
}



### Get SES (both parametric and non parametric)
#obs.df = output from one of the get functions above
#perm.df = df traitCombination X permutations with the chosen statistic of correlation on permuteda data

get.SES <- function(obs.df, perm.df, stat="RV") {
  nperm <- dim(perm.df)[2]
  SES.out <- as.data.frame.table(perm.df) %>% 
    rename(Trait.comb=Var1, perm=Var2, coef=Freq ) %>% 
    group_by(Trait.comb) %>% 
    left_join(obs.df %>% 
                filter(Test==stat) %>% 
                dplyr::select(-pvalue, -Test) %>% 
                mutate(Trait.comb=paste0("t", Trait.comb)) %>% 
                rename(obs=Coef), 
              by=c("Trait.comb")) %>% 
    summarize(q15 = quantile(coef, probs = 0.15865), 
              q50 = quantile(coef, .5),
              q84 = quantile(coef, 0.84135), 
              mean.perm=mean(coef),
              sd.perm=sd(coef), 
              greater.than.obs=sum( (abs(coef)>=abs(obs))/n())) %>% 
    left_join(obs.df %>% 
                filter(Test==stat) %>% 
                dplyr::select(-pvalue, -Test) %>% 
                mutate(Trait.comb=paste0("t", Trait.comb)) %>% 
                rename(obs=Coef), 
              by=c("Trait.comb")) %>% 
    mutate(SES.np= ifelse(obs>=q50, (obs-q50)/(q84-q50), (obs-q50)/(q50-q15))) %>%
    mutate(SES=(obs-mean.perm)/sd.perm) %>% 
    arrange(desc(SES.np)) %>% 
    mutate(conf.m=obs-1.65*(sd.perm/sqrt(nperm))) %>% 
    mutate(conf.p=obs+1.65*(sd.perm/sqrt(nperm)))
  return(SES.out)
}




Mesobromion <- function(species.path, traits.path, output, myfunction="get.corXY", max.inter.t, chunkn, chunk.i, nperm=99){

  myfunction <- get(myfunction)
  ## calculate corr between species composition matrix and traits
  species <- read_delim(species.path, delim="\t") %>%
		as.data.frame()
  traits <- read_delim(traits.path, delim="\t") %>%
		as.data.frame()
  
  
  traits <- traits %>% 
    column_to_rownames("species0") %>% 
    rename_all(.funs=~gsub(pattern=".mean$", replacement="", x=.))  %>% 
#temporary    ### Use only a subset of traits
    dplyr::select(LeafArea:Disp.unit.leng)
  
  ## create list of indices for each combination of traits up to a max number of interactions
  n.traits <- ncol(traits)
  allcomb.t <- NULL
  for(n.inter in 1:max.inter.t){
    allcomb.t <- c(allcomb.t, combn(1:n.traits, n.inter, simplify=F))
  }
  nall <- length(allcomb.t)
  
  ## Divide in chunks if requested
  if(chunkn>1 & !is.na(chunk.i)){
      print(paste("divide in", chunkn, "chunks and run on chunk n.", chunk.i))
      indices <- 1:length(allcomb.t)
      chunks <- split(indices, sort(indices%%chunkn))
      allcomb.t <- allcomb.t[chunks[[chunk.i]]]
    } 
  print(paste("Running on", length(allcomb.t),"out of", nall, "combinations"))
  names.t <- unlist(lapply(allcomb.t, paste, collapse="_"))
  
  print("Start main loop")
  cor.obs <- NA
  cor.bootstr <- matrix(NA, nrow=length(allcomb.t), ncol=nperm, 
                        dimnames = list(paste("t", names.t, sep=""), paste("p",1:nperm, sep="")))
  cor.perm <- matrix(NA, nrow=length(allcomb.t), ncol=nperm, 
                     dimnames = list(paste("t", names.t, sep=""), paste("p",1:nperm, sep="")))
                                                                             
  speciesb <- species
  traitsb <- traits
  for(i in 1:length(allcomb.t)) {
    tt <- unlist(allcomb.t[i])
    #bootstrap species matrix 
#    set.seed(1984)
    for(b in 1:nperm){
      if(b>1){
        speciesb <- species[sample(1:nrow(species), replace=T),] #resample plots
        speciesb <- speciesb[,-which(colSums(speciesb)==0)] # delete empty species
        traitsb <- traits[which(rownames(traits) %in%  colnames(speciesb)),] #subset traits
      }
      traitsb.perm <- traitsb[sample(1:nrow(traitsb)),] #permute traits
      tmp <- myfunction(comm=speciesb, trait=traitsb, trait.sel=tt, stat="RV")
      if(b==1) {cor.obs <- rbind(cor.obs, tmp)}
      cor.bootstr[i,b] <- tmp$Coef
      cor.perm[i,b] <- myfunction(comm=speciesb, trait=traitsb.perm, trait.sel=tt, stat="RV")$Coef
      print(paste("trait", tt, "perm", b))
    }
  print(i)
  }
  
  
  #save(corXY, file="_data/Mesobromion/corXY/corXY.RData")
  save(cor.obs, cor.bootstr, cor.perm, file = paste(output, "_", chunk.i, ".RData", sep=""))
}