diff --git a/Scripts/logit/chr_vol_treat.R b/Scripts/logit/chr_vol_treat.R index 67ddbf8e9ad8ff22c5b59db968ddd700ec0fc74a..399ec5456d29a9dbfbf3254420db5923e7ad3ccf 100644 --- a/Scripts/logit/chr_vol_treat.R +++ b/Scripts/logit/chr_vol_treat.R @@ -94,6 +94,7 @@ data <- data %>% # Split the data into labeled and unlabeled sets labeled_data <- filter(data, Choice_Treat==1| Choice_Treat==0) unlabeled_data <- filter(data, is.na(Choice_Treat)) +labeled_data_id<-labeled_data labeled_data<-select(labeled_data,-id) # Assuming the group information is in the column called 'Group' labeled_data$Choice_Treat<- as.factor(labeled_data$Choice_Treat) @@ -140,6 +141,10 @@ labeled_data$PredictedGroup <- labeled_predictions table(labeled_data$Choice_Treat, labeled_data$PredictedGroup) unlabeled_predictions <- predict(model3, newdata = unlabeled_data) +labeled_data_id$PredictedGroup <- labeled_predictions +data_prediction_labeled<-select(labeled_data_id, c("id", "PredictedGroup")) +saveRDS(data_prediction_labeled, "Data/predictions_labeled.RDS") + unlabeled_data$PredictedGroup <- unlabeled_predictions data_prediction<-select(unlabeled_data, c("id", "PredictedGroup")) saveRDS(data_prediction, "Data/predictions.RDS")