# This script creates a private Python virtualenv, installs TensorFlow & h5py,
# loads the pre-trained Keras model in image.h5, displays its architecture,
# and plots a graphical representation of the network.

# 1. Install and load necessary R packages (Just one time!)

# 1) In a fresh R session, before loading keras/tensorflow:
library(reticulate)

# Tell reticulate to use your existing virtualenv
use_virtualenv("r-reticulate", required = TRUE)

# 2) Install TensorFlow (and h5py just to be safe) into that venv
#    This runs pip inside ~/.virtualenvs/r-reticulate
reticulate::virtualenv_install(
  envname  = "r-reticulate",
  packages = c("tensorflow", "h5py")
)

# 3) Confirm both modules are now visible
py_module_available("pydot")      # should be TRUE
py_module_available("tensorflow") # should now be TRUE

# 4) Now load Keras and plot:
library(keras)
# 3. Load the pre-trained Keras model
model_image <- keras::load_model_hdf5("image.h5")

# 4. Print its summary (architecture and parameter counts)
model_image$summary()


# Extract all layers safely
layer_info <- lapply(model_image$layers, function(layer) {
  config <- layer$get_config()
  output_shape <- tryCatch({
    shape <- py_to_r(layer$output$shape)
    paste0("(", paste(shape, collapse = ","), ")")
  }, error = function(e) {
    NA_character_
  })
  
  list(
    name         = config$name,
    class        = class(layer)[1],
    output_shape = output_shape,
    params       = layer$count_params(),
    details      = if ("units" %in% names(config)) {
      paste0("units = ", config$units)
    } else if ("filters" %in% names(config)) {
      paste0("filters = ", config$filters, 
             ", kernel = ", paste(config$kernel_size, collapse = "x"))
    } else {
      NA_character_
    }
  )
})

# Format into tibble
library(dplyr)
library(tibble)

layer_df <- bind_rows(lapply(layer_info, as_tibble)) %>%
  select(name, class, output_shape, details, params)

print(layer_df)

model_image$loss   # the name of the loss function

# 1) Print all layer activations:
configs <- lapply(model_image$layers, function(layer) layer$get_config())
# The activation is in the “activation” field of each config:
sapply(configs, `[[`, "activation")



# 0. Install/load dependencies
if (!requireNamespace("DiagrammeR", quietly=TRUE)) install.packages("DiagrammeR")
if (!requireNamespace("glue",       quietly=TRUE)) install.packages("glue")
library(DiagrammeR)
library(glue)

# 1. Define the layers
layers <- c(
  Input      = 1,
  Conv1      = 3,
  Conv2      = 3,
  Pool       = 3,
  Flatten    = 1,
  Dense1     = 5,
  Dense2     = 3,
  Output     = 1
)

# 2. Generate node IDs
node_ids <- lapply(layers, function(n) seq_len(n))
names(node_ids) <- names(layers)

# 3. Create DOT graph
dot <- "digraph CNN {\n  graph [rankdir=LR];\n  node [shape=circle, fixedsize=true, width=0.4];\n"

# 4. Define clusters and colors
cluster_colors <- c(
  Input   = "lightgrey",
  Conv1   = "lightblue",
  Conv2   = "lightblue",
  Pool    = "khaki",
  Flatten = "plum",
  Dense1  = "lightblue",
  Dense2  = "lightblue",
  Output  = "lightgreen"
)

for (ln in names(layers)) {
  ids   <- node_ids[[ln]]
  color <- cluster_colors[ln]
  lbl   <- switch(ln,
    Input   = "Input Image\n25x25",
    Conv1   = "Conv2D + ReLU\n32 filters",
    Conv2   = "Conv2D + ReLU\n32 filters",
    Pool    = "MaxPooling\n8x8x32",
    Flatten = "Flatten",
    Dense1  = "Dense + ReLU\n256",
    Dense2  = "Dense + ReLU\n128",
    Output  = "Sigmoid\n1"
  )
  dot <- paste0(dot, glue("
  subgraph cluster_{tolower(ln)} {{
    style=filled; color={color}; 
    node [style=filled, fillcolor=white];
    {paste0(ln, ids, collapse='; ')}; 
    label=\"{lbl}\";
  }}\n"))
}

# 5. Node declarations
for (ln in names(layers)) {
  for (i in node_ids[[ln]]) {
    dot <- paste0(dot, glue("  {ln}{i} [label=\" \"];\n"))
  }
}

# 6. Connections (1 to all next)
layer_names <- names(layers)
for (k in 1:(length(layer_names) - 1)) {
  from <- layer_names[k]
  to   <- layer_names[k + 1]
  for (i in node_ids[[from]]) {
    for (j in node_ids[[to]]) {
      dot <- paste0(dot, glue("  {from}{i} -> {to}{j};\n"))
    }
  }
}

# 7. Finalize and render
dot <- paste0(dot, "}\n")
graph <- grViz(dot)
print(graph)
