# 1) Load libraries -----------------------------------------------------
library(keras)
library(tensorflow)
library(fst)
library(data.table)
library(progress)

# 2) Sample sizes -------------------------------------------------------
sample_sizes <- c(30, 50, 100, 200, 300, 400)

# 3) Load models (no metrics warnings) ----------------------------------
image_model    <- load_model_hdf5("../1_Models/image.h5",        compile = FALSE)
score_model    <- load_model_hdf5("../1_Models/score.h5",        compile = FALSE)
combined_model <- load_model_hdf5("../1_Models/score_image.h5",  compile = FALSE)

# 4) Prep result tables & progress bar ---------------------------------
image_thresholds    <- data.table(sample_size = sample_sizes)
score_thresholds    <- copy(image_thresholds)
combined_thresholds <- copy(image_thresholds)

pb <- progress_bar$new(
  total  = length(sample_sizes) * 3,
  stream = stdout(),
  format = "  Computing thresholds [:bar] :percent eta: :eta"
)

# 5) Fixed column indices (from your data) -----------------------------
thr_cols    <- 19        # number of cols in thres.fst
pix_count   <- prod(c(25,25,1))  # = 625
start_img   <- thr_cols + 2     # = 21
end_img     <- start_img + pix_count - 1  # = 645
score_dim   <- as.integer(score_model$input_shape[[2]])  # = 20

# 6) Main loop ----------------------------------------------------------
for (n in sample_sizes) {
  thr       <- fst::read.fst(sprintf("threshold/%s/thres.fst",       n))
  thr_image <- fst::read.fst(sprintf("threshold/%s/thres_image.fst", n))
  d <- cbind(thr, n = n, thr_image)

  # 1) Image-only branch: columns 21–645
  x_img <- array_reshape(
    as.matrix(d[, start_img:end_img]),
    c(nrow(d), 25L, 25L, 1L)
  )
  image_thresholds[sample_size == n,
    thres := quantile(as.numeric(image_model$predict(x_img)), 0.95)
  ]
  pb$tick()

  # 2) Score-only branch: columns 1–20 (thr + n)
  x_score <- as.matrix(d[, 1:score_dim])
  score_thresholds[sample_size == n,
    thres := quantile(as.numeric(score_model$predict(x_score)), 0.95)
  ]
  pb$tick()

  # 3) Combined branch:
  #    – use the SAME x_img,
  #    – but only the 19 thr columns for the MLP
  x_comb_score <- as.matrix(d[, 1:thr_cols])
  preds_comb   <- combined_model$predict(list(x_img, x_comb_score))
  combined_thresholds[sample_size == n,
    thres := quantile(as.numeric(preds_comb), 0.95)
  ]
  pb$tick()
}

# 7) Save outputs -------------------------------------------------------
saveRDS(image_thresholds,    "threshold/image_thresholds.RDS")
saveRDS(score_thresholds,    "threshold/score_thresholds.RDS")
saveRDS(combined_thresholds, "threshold/combined_thresholds.RDS")
