# Load required libraries
library(keras)
library(reticulate)
library(fst)
library(dplyr)
library(pROC)    # for AUC
library(ggplot2)

# Timestamp for file naming
ts <- format(Sys.time(), "%Y%m%d_%H%M%S")

# Paths
base_model_path      <- "score.h5"
finetuned_model_path <- sprintf("score_finetuned_%s.h5", ts)
history_path         <- sprintf("training_history_score_%s.rds", ts)
data_info_path       <- sprintf("training_data_info_score_%s.rds", ts)

# 0) Define sample-size levels and data directories
sample_sizes <- c(50, 100, 200, 400)
ind_dir  <- "../0_Training/independent_samples"
dep_dir  <- "../0_Training/dependent_samples"

# 1) Helper to load one class of samples
load_block <- function(base_dir, sizes, filename, label) {
  dfs <- lapply(sizes, function(n) {
    path <- file.path(base_dir, as.character(n), filename)
    mat  <- read_fst(path)               
    df   <- as.data.frame(mat)          
    df$sample_size <- n                 
    df$y           <- label             
    df
  })
  bind_rows(dfs)
}

# 2) Load independent (y=0) and dependent (y=1) samples
df_ind <- load_block(ind_dir, sample_sizes, "ind.fst", 0L)
df_dep <- load_block(dep_dir, sample_sizes, "dep.fst", 1L)

# 3) Combine, shuffle, and split into features/labels
df_all <- bind_rows(df_ind, df_dep) %>% sample_frac(1.0)
x_all  <- as.matrix(df_all %>% select(-y))
y_all  <- df_all$y

input_dim <- ncol(x_all)  # should be 20

compute_metrics <- function(y_true, y_proba, y_pred) {
  auc  <- auc(roc(y_true, y_proba))
  tp   <- sum(y_true==1 & y_pred==1)
  tn   <- sum(y_true==0 & y_pred==0)
  fp   <- sum(y_true==0 & y_pred==1)
  fn   <- sum(y_true==1 & y_pred==0)
  acc  <- (tp+tn)/length(y_true)
  prec <- ifelse(tp+fp>0, tp/(tp+fp), NA)
  rec  <- ifelse(tp+fn>0, tp/(tp+fn), NA)
  f1   <- ifelse(!is.na(prec)&!is.na(rec)&(prec+rec)>0,
                 2*prec*rec/(prec+rec), NA)
  list(AUC=auc, Accuracy=acc, Precision=prec, Recall=rec, F1=f1)
}


# 4) Build or load the model
if (file.exists(base_model_path)) {
  cat("Found pretrained model. Loading and fine-tuning...\n")
  model_score <- load_model_hdf5(base_model_path, compile = FALSE)

# 2) Import the Python modules you need
optim_module     <- import("tensorflow.keras.optimizers")
callbacks_module <- import("tensorflow.keras.callbacks")

# 3) Create the Python optimizer object
py_opt <- optim_module$Adam(
  learning_rate = 0.001,
  beta_1        = 0.9,
  beta_2        = 0.999,
  epsilon       = 1e-7,
  amsgrad       = FALSE
)

# 4) Create the Python-side callbacks
py_cbs <- list(
  callbacks_module$EarlyStopping(
    monitor              = "val_loss",
    patience             = 3L,
    restore_best_weights = TRUE
  ),
  callbacks_module$ReduceLROnPlateau(
    monitor  = "val_loss",
    factor   = 0.5,
    patience = 2L,
    min_lr   = 1e-6
  )
)

# 5) Compile the model via its Python method
model_score$compile(
  optimizer = py_opt,
  loss      = "binary_crossentropy",
  metrics   = list("accuracy")
)

# 1) Convert R matrices/vectors into Python numpy arrays:
np <- import("numpy")

# Suppose x_all is an R data.frame or matrix of shape (N, d)
x_py <- np$array(as.matrix(x_all), dtype = "float32")

# Suppose y_all is an R vector of 0/1’s of length N
y_py <- np$array(as.numeric(y_all), dtype = "float32")

# 2) Now call the Python fit() on your loaded & compiled model
history <- model_score$fit(
  x                = x_py,
  y                = y_py,
  batch_size       = as.integer(128),
  epochs           = as.integer(50),
  validation_split = 0.2,
  callbacks        = py_cbs
)

# Convert the Python History object to R list
hist_list <- reticulate::py_to_r(history$history)
tail(hist_list$loss); tail(hist_list$val_loss)
tail(hist_list$accuracy); tail(hist_list$val_accuracy)

df <- data.frame(
  epoch       = seq_along(hist_list$loss),
  loss        = hist_list$loss,
  val_loss    = hist_list$val_loss,
  accuracy    = hist_list$accuracy,
  val_accuracy= hist_list$val_accuracy
)

ggplot(df, aes(x = epoch)) +
  geom_line(aes(y = loss),        linetype = "solid") +
  geom_line(aes(y = val_loss),    linetype = "dashed") +
  ggtitle("Training vs. Validation Loss") +
  ylab("Loss")

ggplot(df, aes(x = epoch)) +
  geom_line(aes(y = accuracy),        linetype = "solid") +
  geom_line(aes(y = val_accuracy),    linetype = "dashed") +
  ggtitle("Training vs. Validation Accuracy") +
  ylab("Accuracy")
  
  # Save fine-tuned model & history
  save_model_hdf5(model_score, finetuned_model_path)
  saveRDS(history, history_path)
  saveRDS(list(x = x_all, y = y_all, sample_sizes = sample_sizes), data_info_path)
  
  cat("Fine-tuning complete. Model and history saved.\n")
  

# 0) Prepare a NumPy array of your full data
np   <- import("numpy")
x_py <- np$array(as.matrix(x_all), dtype = "float32")

# --- Compare base vs fine-tuned on full data ---

# 1) Load & compile the base model (no need to re‑compile for eval, but OK)
base_model <- load_model_hdf5(base_model_path, compile = FALSE)
base_model$compile(
  optimizer = import("tensorflow.keras.optimizers")$Adam(),
  loss      = "binary_crossentropy",
  metrics   = list("accuracy")
)

# 2) Get probability predictions via the Python method
preds_base <- as.numeric(base_model$predict(x_py))
preds_ft   <- as.numeric(model_score$predict(x_py))

# 3) Apply 0.5 threshold
yhat_base <- ifelse(preds_base > 0.5, 1, 0)
yhat_ft   <- ifelse(preds_ft   > 0.5, 1, 0)


# 5) Compute and display
metrics_base <- compute_metrics(y_all, preds_base, yhat_base)
metrics_ft   <- compute_metrics(y_all, preds_ft,   yhat_ft)

cat("\nComparison of Base vs Fine-Tuned Models (higher is better for all):\n")
print(rbind(Base      = unlist(metrics_base),
            FineTuned = unlist(metrics_ft)))
  
} else {
  cat("No pretrained model found. Training from scratch...\n")
  
  
  # 1) Build the network with the functional API
  inputs <- layer_input(
    shape = c(ncol(x_all)),
    dtype = "float32",
    name  = "input"
  )
  
  x <- inputs %>%
    layer_dense(
      units              = 32,
      activation         = "relu",
      use_bias           = TRUE,
      kernel_initializer = initializer_glorot_uniform(),
      bias_initializer   = initializer_zeros(),
      name               = "dense"
    ) %>%
    layer_dropout(rate = 0.2, name = "dropout") %>%
    layer_dense(
      units              = 32,
      activation         = "relu",
      use_bias           = TRUE,
      kernel_initializer = initializer_glorot_uniform(),
      bias_initializer   = initializer_zeros(),
      name               = "dense_1"
    ) %>%
    layer_dropout(rate = 0.2, name = "dropout_1") %>%
    layer_dense(
      units              = 1,
      activation         = "sigmoid",
      use_bias           = TRUE,
      kernel_initializer = initializer_glorot_uniform(),
      bias_initializer   = initializer_zeros(),
      name               = "dense_2"
    )
  
  model_score <- keras_model(inputs = inputs, outputs = x, name = "sequential")
  
  # 2) Compile using the same Adam settings
  model_score$compile(
    optimizer = optimizer_adam(
      learning_rate = 0.001,
      beta_1        = 0.9,
      beta_2        = 0.999,
      epsilon       = 1e-7,
      amsgrad       = FALSE
    ),
    loss    = "binary_crossentropy",
    metrics = list("accuracy")
  )
  
  # 3) Inspect the architecture
  model_score$summary()
  
  # 4) Prepare Python-side training data
  np   <- import("numpy")
  x_py <- np$array(as.matrix(x_all), dtype = "float32")
  y_py <- np$array(as.numeric(y_all), dtype = "float32")
  
  # 5) Create Python callbacks
  callbacks_module <- import("tensorflow.keras.callbacks")
  py_cbs <- list(
    callbacks_module$EarlyStopping(
      monitor              = "val_loss",
      patience             = 3L,
      restore_best_weights = TRUE
    ),
    callbacks_module$ReduceLROnPlateau(
      monitor  = "val_loss",
      factor   = 0.5,
      patience = 2L,
      min_lr   = 1e-6
    )
  )
  
  # 6) Train the model via the Python fit() method
  history <- model_score$fit(
    x                = x_py,
    y                = y_py,
    batch_size       = as.integer(128),
    epochs           = as.integer(50),
    validation_split = 0.2,
    callbacks        = py_cbs
  )
  
  # 7) Convert history to R and plot metrics
  hist_list <- reticulate::py_to_r(history$history)
  df <- data.frame(
    epoch        = seq_along(hist_list$loss),
    loss         = hist_list$loss,
    val_loss     = hist_list$val_loss,
    accuracy     = hist_list$accuracy,
    val_accuracy = hist_list$val_accuracy
  )
  
  ggplot(df, aes(x = epoch)) +
    geom_line(aes(y = loss),        size = 1) +
    geom_line(aes(y = val_loss),    size = 1, linetype = "dashed") +
    labs(title = "Training vs. Validation Loss", y = "Loss") +
    theme_minimal()
  
  ggplot(df, aes(x = epoch)) +
    geom_line(aes(y = accuracy),        size = 1) +
    geom_line(aes(y = val_accuracy),    size = 1, linetype = "dashed") +
    labs(title = "Training vs. Validation Accuracy", y = "Accuracy") +
    theme_minimal()
  
  # 8) Save the newly trained model & history
  save_model_hdf5(model_score, finetuned_model_path)
  saveRDS(history, history_path)
  saveRDS(list(x = x_all, y = y_all, sample_sizes = sample_sizes),
          data_info_path)
  cat("Training from scratch complete. Model and history saved.\n")

# 2) Get probability predictions via the Python method
preds_ft   <- as.numeric(model_score$predict(x_py))

# 3) Apply 0.5 threshold
yhat_ft   <- ifelse(preds_ft   > 0.5, 1, 0)


# 5) Compute and display
metrics_ft   <- compute_metrics(y_all, preds_ft,   yhat_ft)

cat("\nMetrics:\n")
print(unlist(metrics_ft))

  
}
    


# Final model summary
model_score$summary()
