# combined_cnn_mlp.R ----------------------------------------------------

# 1) Load libraries -----------------------------------------------------
library(keras)
library(reticulate)
library(fst)
library(dplyr)
library(pROC)
library(ggplot2)

# 2) Path & timestamp ---------------------------------------------------
ts <- format(Sys.time(), "%Y%m%d_%H%M%S")
finetuned_model_path <- sprintf("cnn_mlp_trained_%s.h5", ts)
history_path         <- sprintf("cnn_mlp_history_%s.rds", ts)
data_info_path       <- sprintf("cnn_mlp_datainfo_%s.rds", ts)

# 3) Directories & sample sizes ----------------------------------------
sample_sizes <- c(50, 100, 200, 400)
ind_dir      <- "../0_Training/independent_samples"
dep_dir      <- "../0_Training/dependent_samples"

# 4) Helper to load a set of .fst files into one data.frame ------------
load_block_df <- function(base_dir, sizes, filename) {
  dfs <- lapply(sizes, function(n) {
    path <- file.path(base_dir, as.character(n), filename)
    read_fst(path) %>% as.data.frame()
  })
  bind_rows(dfs)
}

# 5) Load image & score tables and labels ------------------------------
#   assumes you have files named ind_image.fst / dep_image.fst,
#   and ind_scores.fst / dep_scores.fst in the same folders.

img_ind    <- load_block_df(ind_dir, sample_sizes, "ind_image.fst")
img_dep    <- load_block_df(dep_dir, sample_sizes, "dep_image.fst")
scores_ind <- load_block_df(ind_dir, sample_sizes, "ind.fst")
scores_dep <- load_block_df(dep_dir, sample_sizes, "dep.fst")

# 6) Combine, shuffle, and extract matrices + labels -------------------
df_img    <- bind_rows(img_ind,    img_dep)
df_scores <- bind_rows(scores_ind, scores_dep)
y_all     <- c(rep(0L, nrow(img_ind)), rep(1L, nrow(img_dep)))

# Shuffle
set.seed(123)
perm      <- sample.int(nrow(df_img))
df_img    <- df_img[perm, , drop=FALSE]
df_scores <- df_scores[perm, , drop=FALSE]
y_all     <- y_all[perm]

# Convert to arrays
N         <- nrow(df_img)
x_img     <- array_reshape(as.matrix(df_img),    c(N, 25L, 25L, 1L))
x_scores  <- as.matrix(df_scores)  # should have exactly 20 columns

# 7) To NumPy arrays ----------------------------------------------------
use_virtualenv("r-reticulate", required=TRUE)
np        <- import("numpy")
x_img_py  <- np$array(x_img, dtype="float32")
x_scr_py  <- np$array(x_scores, dtype="float32")
y_py      <- np$array(as.numeric(y_all), dtype="float32")

# 8) Metrics helper -----------------------------------------------------
compute_metrics <- function(y_true, y_proba, y_pred) {
  roc_obj <- roc(y_true, y_proba)
  auc_v    <- auc(roc_obj)
  tp   <- sum(y_true==1 & y_pred==1)
  tn   <- sum(y_true==0 & y_pred==0)
  fp   <- sum(y_true==0 & y_pred==1)
  fn   <- sum(y_true==1 & y_pred==0)
  acc  <- (tp + tn)/length(y_true)
  prec <- ifelse(tp+fp>0, tp/(tp+fp), NA)
  rec  <- ifelse(tp+fn>0, tp/(tp+fn), NA)
  f1   <- ifelse(!is.na(prec)&!is.na(rec)&(prec+rec)>0,
                 2*prec*rec/(prec+rec), NA)
  list(AUC=auc_v, Accuracy=acc, Precision=prec, Recall=rec, F1=f1)
}

# 9) Build the two-branch model from scratch ----------------------------
image_input <- layer_input(
  shape = c(25,25,1),
  dtype = "float32",
  name  = "image_input"
)
score_input <- layer_input(
  shape = ncol(x_scores),
  dtype = "float32",
  name  = "score_input"
)

# CNN trunk
x_cnn <- image_input %>%
  layer_conv_2d(filters=32, kernel_size=3, activation="relu",
                padding="same", name="conv2d_2") %>%
  layer_conv_2d(32,3, activation="relu", padding="same",
                name="conv2d_3") %>%
  layer_max_pooling_2d(pool_size=3, name="max_pooling2d") %>%
  layer_dropout(rate=0.2, name="dropout") %>%
  layer_flatten(name="flatten")

# Score MLP
x_mlp <- score_input %>%
  layer_dense(units=32, activation="relu", name="score_dense_32")

# Concatenate & head
x <- layer_concatenate(list(x_cnn, x_mlp), name="concat") %>%
  layer_dense(units=256, activation="relu", name="dense_256") %>%
  layer_dense(units=128, activation="relu", name="dense_128") %>%
  layer_dense(units=32,  activation="relu", name="dense_32") %>%
  layer_dense(units=1,   activation="sigmoid", name="output")

model_image <- keras_model(
  inputs  = list(image_input, score_input),
  outputs = x,
  name    = "cnn_plus_score"
)

# 10) Compile ------------------------------------------------------------
model_image$compile(
  optimizer   = optimizer_adam(
                  learning_rate = 1e-3,
                  beta_1        = 0.9,
                  beta_2        = 0.999,
                  epsilon       = 1e-7,
                  amsgrad       = FALSE
                ),
  loss        = "binary_crossentropy",
  metrics     = list("accuracy"),
  run_eagerly = TRUE
  )

model_image$summary()

# 11) Callbacks ----------------------------------------------------------
callbacks_module <- import("tensorflow.keras.callbacks")
py_cbs <- list(
  callbacks_module$EarlyStopping(
    monitor              = "val_loss",
    patience             = 3L,
    restore_best_weights = TRUE
  ),
  callbacks_module$ReduceLROnPlateau(
    monitor  = "val_loss",
    factor   = 0.5,
    patience = 2L,
    min_lr   = 1e-6
  )
)

# 12) Fit ---------------------------------------------------------------
history <- model_image$fit(
  x                = list(x_img_py, x_scr_py),
  y                = y_py,
  batch_size       = 128L,
  epochs           = 50L,
  validation_split = 0.2,
  callbacks        = py_cbs
)

# 13) Plot training history ---------------------------------------------
hist_list <- py_to_r(history$history)
df_hist   <- data.frame(
  epoch       = seq_along(hist_list$loss),
  loss        = hist_list$loss,
  val_loss    = hist_list$val_loss,
  accuracy    = hist_list$accuracy,
  val_accuracy= hist_list$val_accuracy
)

p1 <- ggplot(df_hist, aes(epoch)) +
  geom_line(aes(y = loss)) + geom_line(aes(y = val_loss), linetype="dashed") +
  labs(title="Loss", y="loss")
p2 <- ggplot(df_hist, aes(epoch)) +
  geom_line(aes(y = accuracy)) + geom_line(aes(y = val_accuracy), linetype="dashed") +
  labs(title="Accuracy", y="accuracy")
print(p1); print(p2)

# 14) Save model, history, data info ------------------------------------
save_model_hdf5(model_image, finetuned_model_path)
saveRDS(history, history_path)
saveRDS(list(
  images  = df_img,
  scores  = df_scores,
  y       = y_all,
  sample_sizes = sample_sizes
), data_info_path)

# 15) Final evaluation --------------------------------------------------
preds <- as.numeric(model_image$predict(list(x_img_py, x_scr_py)))
yhat  <- ifelse(preds > 0.5, 1, 0)
met   <- compute_metrics(y_all, preds, yhat)
cat("\nFinal metrics on full data:\n")
print(unlist(met))

# eof -------------------------------------------------------------------
