# Load required libraries
library(keras)
library(reticulate)
library(fst)
library(dplyr)
library(pROC)    # for AUC
library(ggplot2)

# Timestamp for file naming
ts <- format(Sys.time(), "%Y%m%d_%H%M%S")

# Paths
base_model_path      <- "image.h5"
finetuned_model_path <- sprintf("image_finetuned_%s.h5", ts)
history_path         <- sprintf("training_history_image_%s.rds", ts)
data_info_path       <- sprintf("training_data_info_image_%s.rds", ts)

# 0) Define sample-size levels and data directories
sample_sizes <- c(50, 100, 200, 400)
ind_dir  <- "../0_Training/independent_samples"
dep_dir  <- "../0_Training/dependent_samples"

# 1) Helper to load one class of samples
load_block <- function(base_dir, sizes, filename, label) {
  dfs <- lapply(sizes, function(n) {
    path <- file.path(base_dir, as.character(n), filename)
    mat  <- read_fst(path)               
    df   <- as.data.frame(mat)          
    df$y           <- label             
    df
  })
  bind_rows(dfs)
}

# 2) Load independent (y=0) and dependent (y=1) samples
df_ind <- load_block(ind_dir, sample_sizes, "ind_image.fst", 0L)
df_dep <- load_block(dep_dir, sample_sizes, "dep_image.fst", 1L)

# 3) Combine, shuffle, and split into features/labels
df_all <- bind_rows(df_ind, df_dep) %>% sample_frac(1.0)
x_all  <- as.matrix(df_all %>% select(-y))
y_all  <- df_all$y

# 3) Reshape into (N,25,25,1) for the CNN
n    <- nrow(x_all)
x_img <- array_reshape(x_all, c(n, 25, 25, 1))

use_virtualenv("r-reticulate", required = TRUE)
np   <- import("numpy")
x_py <- np$array(x_img, dtype="float32")
y_py <- np$array(as.numeric(y_all), dtype="float32")

compute_metrics <- function(y_true, y_proba, y_pred) {
  auc  <- auc(roc(y_true, y_proba))
  tp   <- sum(y_true==1 & y_pred==1)
  tn   <- sum(y_true==0 & y_pred==0)
  fp   <- sum(y_true==0 & y_pred==1)
  fn   <- sum(y_true==1 & y_pred==0)
  acc  <- (tp+tn)/length(y_true)
  prec <- ifelse(tp+fp>0, tp/(tp+fp), NA)
  rec  <- ifelse(tp+fn>0, tp/(tp+fn), NA)
  f1   <- ifelse(!is.na(prec)&!is.na(rec)&(prec+rec)>0,
                 2*prec*rec/(prec+rec), NA)
  list(AUC=auc, Accuracy=acc, Precision=prec, Recall=rec, F1=f1)
}


# 4) Build or load the model
if (file.exists(base_model_path)) {

    cat("Found pretrained model. Loading and fine-tuning...\n")

    model_image <- load_model_hdf5(base_model_path, compile = FALSE)

# 2) Import the Python modules you need
    optim_module     <- import("tensorflow.keras.optimizers")
    callbacks_module <- import("tensorflow.keras.callbacks")

# 3) Create the Python optimizer object
    py_opt <- optim_module$Adam(
                               learning_rate = 0.001,
                               beta_1        = 0.9,
                               beta_2        = 0.999,
                               epsilon       = 1e-7,
                               amsgrad       = FALSE
                           )

# 4) Create the Python-side callbacks
    py_cbs <- list(
        callbacks_module$EarlyStopping(
                             monitor              = "val_loss",
                             patience             = 3L,
                             restore_best_weights = TRUE
                         ),
        callbacks_module$ReduceLROnPlateau(
                             monitor  = "val_loss",
                             factor   = 0.5,
                             patience = 2L,
                             min_lr   = 1e-6
                         )
    )

# 5) Compile the model via its Python method
    model_image$compile(
                    optimizer = py_opt,
                    loss      = "binary_crossentropy",
                    metrics   = list("accuracy")
                )


# 2) Now call the Python fit() on your loaded & compiled model
    history <- model_image$fit(
                               x                = x_py,
                               y                = y_py,
                               batch_size       = as.integer(128),
                               epochs           = as.integer(50),
                               validation_split = 0.2,
                               callbacks        = py_cbs
                           )

# Convert the Python History object to R list
    hist_list <- reticulate::py_to_r(history$history)
    tail(hist_list$loss); tail(hist_list$val_loss)
    tail(hist_list$accuracy); tail(hist_list$val_accuracy)
    
    df <- data.frame(
        epoch       = seq_along(hist_list$loss),
        loss        = hist_list$loss,
        val_loss    = hist_list$val_loss,
        accuracy    = hist_list$accuracy,
        val_accuracy= hist_list$val_accuracy
    )

    ggplot(df, aes(x = epoch)) +
        geom_line(aes(y = loss),        linetype = "solid") +
        geom_line(aes(y = val_loss),    linetype = "dashed") +
        ggtitle("Training vs. Validation Loss") +
        ylab("Loss")

    ggplot(df, aes(x = epoch)) +
        geom_line(aes(y = accuracy),        linetype = "solid") +
        geom_line(aes(y = val_accuracy),    linetype = "dashed") +
        ggtitle("Training vs. Validation Accuracy") +
        ylab("Accuracy")
    
    # Save fine-tuned model & history
    save_model_hdf5(model_image, finetuned_model_path)
    saveRDS(history, history_path)
    saveRDS(list(x = x_all, y = y_all, sample_sizes = sample_sizes), data_info_path)
    
    cat("Fine-tuning complete. Model and history saved.\n")
  


# --- Compare base vs fine-tuned on full data ---

# 1) Load & compile the base model (no need to re-compile for eval, but OK)
    base_model <- load_model_hdf5(base_model_path, compile = FALSE)
    base_model$compile(
                   optimizer = import("tensorflow.keras.optimizers")$Adam(),
                   loss      = "binary_crossentropy",
                   metrics   = list("accuracy")
               )

# 2) Get probability predictions via the Python method
    preds_base <- as.numeric(base_model$predict(x_py))
    preds_ft   <- as.numeric(model_image$predict(x_py))

# 3) Apply 0.5 threshold
    yhat_base <- ifelse(preds_base > 0.5, 1, 0)
    yhat_ft   <- ifelse(preds_ft   > 0.5, 1, 0)

# 5) Compute and display
    metrics_base <- compute_metrics(y_all, preds_base, yhat_base)
    metrics_ft   <- compute_metrics(y_all, preds_ft,   yhat_ft)

    cat("\nComparison of Base vs Fine-Tuned Models (higher is better for all):\n")
    print(rbind(Base      = unlist(metrics_base),
                FineTuned = unlist(metrics_ft)))
    
} else {

    cat("No pretrained model found. Training from scratch...\n")
  
  
  # 1) Build the network with the functional API
# 1) Build the CNN with the Functional API
    inputs <- layer_input(
        shape = c(25, 25, 1),
        dtype = "float32",
        name  = "input_image"
    )

    x <- inputs %>%
        layer_conv_2d(
            filters            = 32,
            kernel_size        = c(3, 3),
            activation         = "relu",
            padding            = "same",
            kernel_initializer = initializer_glorot_uniform(),
            bias_initializer   = initializer_zeros(),
            name               = "conv2d_2"
        ) %>%
        layer_conv_2d(
            filters            = 32,
            kernel_size        = c(3, 3),
            activation         = "relu",
            padding            = "same",
            kernel_initializer = initializer_glorot_uniform(),
            bias_initializer   = initializer_zeros(),
            name               = "conv2d_3"
        ) %>%
        layer_max_pooling_2d(
            pool_size = c(3, 3),
            name      = "max_pooling2d"
        ) %>%
        layer_dropout(
            rate = 0.2,
            name = "dropout"
        ) %>%
        layer_flatten(name = "flatten") %>%
        layer_dense(
            units              = 256,
            activation         = "relu",
            kernel_initializer = initializer_glorot_uniform(),
            bias_initializer   = initializer_zeros(),
            name               = "dense"
        ) %>%
        layer_dropout(
            rate = 0.2,
            name = "dropout_1"
        ) %>%
        layer_dense(
            units              = 128,
            activation         = "relu",
            kernel_initializer = initializer_glorot_uniform(),
            bias_initializer   = initializer_zeros(),
            name               = "dense_1"
        ) %>%
        layer_dropout(
            rate = 0.2,
            name = "dropout_2"
        ) %>%
        layer_dense(
            units      = 1,
            activation = "sigmoid",
            name       = "dense_2"
        )

    model_image <- keras_model(
        inputs = inputs,
        outputs = x,
        name    = "cnn_image_model"
    )

  # 2) Compile using the same Adam settings
    model_image$compile(
                    optimizer = optimizer_adam(
                        learning_rate = 0.001,
                        beta_1        = 0.9,
                        beta_2        = 0.999,
                        epsilon       = 1e-7,
                        amsgrad       = FALSE
                    ),
                    loss    = "binary_crossentropy",
                    metrics = list("accuracy")
                )
  
  # 3) Inspect the architecture
    model_image$summary()
  
  
  # 5) Create Python callbacks
    callbacks_module <- import("tensorflow.keras.callbacks")
    py_cbs <- list(
        callbacks_module$EarlyStopping(
                             monitor              = "val_loss",
                             patience             = 3L,
                             restore_best_weights = TRUE
                         ),
        callbacks_module$ReduceLROnPlateau(
                             monitor  = "val_loss",
                             factor   = 0.5,
                             patience = 2L,
                             min_lr   = 1e-6
                         )
    )
  
  # 6) Train the model via the Python fit() method
    history <- model_image$fit(
                               x                = x_py,
                               y                = y_py,
                               batch_size       = as.integer(128),
                               epochs           = as.integer(50),
                               validation_split = 0.2,
                               callbacks        = py_cbs
                           )
  
  # 7) Convert history to R and plot metrics
    hist_list <- reticulate::py_to_r(history$history)
    df <- data.frame(
        epoch        = seq_along(hist_list$loss),
        loss         = hist_list$loss,
        val_loss     = hist_list$val_loss,
        accuracy     = hist_list$accuracy,
        val_accuracy = hist_list$val_accuracy
    )
  
    ggplot(df, aes(x = epoch)) +
        geom_line(aes(y = loss),        size = 1) +
        geom_line(aes(y = val_loss),    size = 1, linetype = "dashed") +
        labs(title = "Training vs. Validation Loss", y = "Loss") +
        theme_minimal()
    
    ggplot(df, aes(x = epoch)) +
        geom_line(aes(y = accuracy),        size = 1) +
        geom_line(aes(y = val_accuracy),    size = 1, linetype = "dashed") +
        labs(title = "Training vs. Validation Accuracy", y = "Accuracy") +
        theme_minimal()
  
  # 8) Save the newly trained model & history
    save_model_hdf5(model_image, finetuned_model_path)
    saveRDS(history, history_path)
    saveRDS(list(x = x_all, y = y_all, sample_sizes = sample_sizes),
            data_info_path)
    cat("Training from scratch complete. Model and history saved.\n")

# 2) Get probability predictions via the Python method
    preds_ft   <- as.numeric(model_image$predict(x_py))

# 3) Apply 0.5 threshold
    yhat_ft   <- ifelse(preds_ft   > 0.5, 1, 0)


# 5) Compute and display
    metrics_ft   <- compute_metrics(y_all, preds_ft,   yhat_ft)
    
    cat("\nMetrics:\n")
    print(unlist(metrics_ft))
    
}
    


# Final model summary
model_image$summary()
