library(dplyr)
library(data.table)
library(ggplot2)
library(depstats)

sample_size <- c(30, 50, 100, 200, 300, 400)

# -----------------------------------------------------------------------------
# Version 2 plotting choices
# -----------------------------------------------------------------------------
# Colour-blind-safe palette based on the Okabe--Ito palette.
# The three highlighted models are given distinct colours and line types.
# The remaining methods are deliberately kept light grey so that they act as
# background references rather than competing visually with the main curves.
#
# Additional fix: the grey/background curves are drawn first and the three
# highlighted coloured curves are drawn last, so the highlighted curves always
# appear above the grey ones when curves overlap.

highlight_colours <- c(
  combined = "#0072B2",          # blue
  score    = "#D55E00",          # vermillion
  image    = "#009E73",          # bluish green
  others   = "grey75",
  "Combined Model" = "#0072B2",
  "Score Model"    = "#D55E00",
  "Image Model"    = "#009E73",
  "Others"         = "grey75"
)

highlight_linetypes <- c(
  combined = "solid",
  score    = "dashed",
  image    = "dotdash",
  others   = "solid",
  "Combined Model" = "solid",
  "Score Model"    = "dashed",
  "Image Model"    = "dotdash",
  "Others"         = "solid"
)

# Helper: copy the variable used for colour to the linetype aesthetic.
# This is necessary because plotpowers() already constructs the line layers;
# merely adding scale_linetype_manual() afterwards changes the scale/legend but
# does not make the actual curves dashed or dot-dashed unless the line layers
# map a variable to linetype.
get_colour_mapping <- function(mapping) {
  if (is.null(mapping)) return(NULL)
  if (!is.null(mapping$colour)) return(mapping$colour)
  if (!is.null(mapping$color))  return(mapping$color)
  NULL
}

is_line_layer <- function(layer) {
  inherits(layer$geom, "GeomLine") ||
    inherits(layer$geom, "GeomPath") ||
    inherits(layer$geom, "GeomStep") ||
    inherits(layer$geom, "GeomSmooth")
}

detect_group_variable <- function(dat) {
  candidates <- c(
    "model", "Model", "method", "Method", "statistic", "Statistic",
    "indicator", "Indicator", "test", "Test", "name", "Name"
  )
  candidates[candidates %in% names(dat)][1]
}

# Draw background methods before highlighted methods. This matters because all
# curves can intersect; if a grey curve is drawn afterwards, it can partially
# hide one of the coloured curves.
normalise_method_label <- function(x) {
  tolower(gsub("[^[:alnum:]]+", "", as.character(x)))
}

is_highlighted_method <- function(x) {
  normalise_method_label(x) %in% c(
    "combined", "combinedmodel",
    "score", "scoremodel",
    "image", "imagemodel"
  )
}

method_draw_order <- function(x) {
  # Non-highlighted methods, including "Others", are drawn first.
  # Highlighted methods are drawn last, hence appear on top.
  ifelse(is_highlighted_method(x), 1L, 0L)
}

reorder_data_for_drawing <- function(dat, group_var = NULL) {
  if (!is.data.frame(dat)) return(dat)

  if (is.null(group_var) || !group_var %in% names(dat)) {
    group_var <- detect_group_variable(dat)
  }

  if (is.na(group_var) || length(group_var) == 0 || !group_var %in% names(dat)) {
    return(dat)
  }

  ord <- method_draw_order(dat[[group_var]])
  dat[order(ord), , drop = FALSE]
}

layer_data_or_plot_data <- function(layer, p) {
  if (is.data.frame(layer$data)) return(layer$data)
  if (is.data.frame(p$data)) return(p$data)
  NULL
}

layer_contains_highlighted_method <- function(layer, p, group_var = NULL) {
  dat <- layer_data_or_plot_data(layer, p)
  if (!is.data.frame(dat)) return(FALSE)

  if (is.null(group_var) || !group_var %in% names(dat)) {
    group_var <- detect_group_variable(dat)
  }

  if (is.na(group_var) || length(group_var) == 0 || !group_var %in% names(dat)) {
    return(FALSE)
  }

  any(is_highlighted_method(dat[[group_var]]), na.rm = TRUE)
}

force_highlighted_curves_on_top <- function(p, dat = NULL) {
  group_var <- if (!is.null(dat)) detect_group_variable(dat) else NULL

  # First reorder the plot-level data. This handles the common case where all
  # curves are drawn by one geom_line() layer.
  if (is.data.frame(p$data)) {
    p$data <- reorder_data_for_drawing(p$data, group_var)
  }

  # Also reorder explicit layer data, in case plotpowers() stores data directly
  # in the line layers.
  for (i in seq_along(p$layers)) {
    if (is_line_layer(p$layers[[i]]) && is.data.frame(p$layers[[i]]$data)) {
      p$layers[[i]]$data <- reorder_data_for_drawing(p$layers[[i]]$data, group_var)
    }
  }

  # Finally, if plotpowers() created separate line layers, put layers containing
  # highlighted methods after background-only line layers.
  if (length(p$layers) > 1L) {
    line_layer <- vapply(p$layers, is_line_layer, logical(1))
    highlighted_layer <- vapply(
      p$layers,
      layer_contains_highlighted_method,
      logical(1),
      p = p,
      group_var = group_var
    )

    # 1: background line layers, 2: non-line layers, 3: highlighted line layers.
    layer_priority <- ifelse(line_layer & !highlighted_layer, 1L,
                             ifelse(line_layer & highlighted_layer, 3L, 2L))
    p$layers <- p$layers[order(layer_priority)]
  }

  p
}

force_linetype_mapping <- function(p, dat = NULL) {
  colour_mapping <- get_colour_mapping(p$mapping)

  # If the colour mapping is only defined inside the line layers, use it.
  if (is.null(colour_mapping)) {
    for (i in seq_along(p$layers)) {
      if (is_line_layer(p$layers[[i]])) {
        colour_mapping <- get_colour_mapping(p$layers[[i]]$mapping)
        if (!is.null(colour_mapping)) break
      }
    }
  }

  group_var <- if (!is.null(dat)) detect_group_variable(dat) else NULL
  fallback_mapping <- NULL
  if (!is.null(group_var)) {
    fallback_mapping <- aes(linetype = .data[[group_var]])$linetype
  }

  # Add a global linetype mapping whenever possible.
  if (!is.null(colour_mapping)) {
    p$mapping$linetype <- colour_mapping
  } else if (!is.null(fallback_mapping)) {
    p$mapping$linetype <- fallback_mapping
  }

  # Force each line-like layer to use the same mapping for linetype as it uses
  # for colour. Also remove any fixed linetype = "solid" that would override
  # the mapping.
  for (i in seq_along(p$layers)) {
    if (is_line_layer(p$layers[[i]])) {
      p$layers[[i]]$aes_params$linetype <- NULL

      layer_colour_mapping <- get_colour_mapping(p$layers[[i]]$mapping)
      if (!is.null(layer_colour_mapping)) {
        p$layers[[i]]$mapping$linetype <- layer_colour_mapping
      } else if (!is.null(colour_mapping)) {
        p$layers[[i]]$mapping$linetype <- colour_mapping
      } else if (!is.null(fallback_mapping)) {
        p$layers[[i]]$mapping$linetype <- fallback_mapping
      }
    }
  }

  p
}

# Robust fix for visibility: redraw highlighted curves as the final data layer.
# This does not depend on how plotpowers() orders groups internally. Even if a
# grey curve is drawn late by plotpowers(), the highlighted curves are drawn
# again afterwards and therefore remain visible.
merge_mappings <- function(plot_mapping, layer_mapping, inherit = TRUE) {
  if (isTRUE(inherit)) {
    out <- plot_mapping
  } else {
    out <- ggplot2::aes()
  }

  if (!is.null(layer_mapping) && length(layer_mapping) > 0L) {
    for (nm in names(layer_mapping)) out[[nm]] <- layer_mapping[[nm]]
  }
  out
}

evaluate_mapped_values <- function(mapping, dat, aesthetics) {
  if (is.null(mapping) || !is.data.frame(dat)) return(NULL)

  for (aesthetic in aesthetics) {
    if (!is.null(mapping[[aesthetic]])) {
      values <- tryCatch(
        rlang::eval_tidy(mapping[[aesthetic]], data = dat),
        error = function(e) NULL
      )
      if (!is.null(values)) return(values)
    }
  }

  NULL
}

add_highlight_overlay <- function(p, dat = NULL) {
  if (length(p$layers) == 0L) return(p)

  # Add one final overlay layer for each line/path layer created by plotpowers().
  # The overlay contains only the highlighted methods.
  for (i in seq_along(p$layers)) {
    layer <- p$layers[[i]]
    if (!is_line_layer(layer)) next

    layer_dat <- layer_data_or_plot_data(layer, p)
    if (!is.data.frame(layer_dat) && is.data.frame(dat)) layer_dat <- dat
    if (!is.data.frame(layer_dat)) next

    mapping <- merge_mappings(p$mapping, layer$mapping, layer$inherit.aes)

    method_values <- evaluate_mapped_values(
      mapping,
      layer_dat,
      c("colour", "color", "linetype", "group")
    )

    if (is.null(method_values)) {
      group_var <- detect_group_variable(layer_dat)
      if (!is.na(group_var) && length(group_var) > 0L && group_var %in% names(layer_dat)) {
        method_values <- layer_dat[[group_var]]
      }
    }

    if (is.null(method_values)) next

    keep <- is_highlighted_method(method_values)
    if (!any(keep, na.rm = TRUE)) next

    highlight_dat <- layer_dat[keep, , drop = FALSE]

    # Ensure the overlaid highlighted curves use the same variable for colour
    # and linetype. This is deliberately a mapped aesthetic, not a fixed value.
    colour_mapping <- get_colour_mapping(mapping)
    if (!is.null(colour_mapping) && is.null(mapping$linetype)) {
      mapping$linetype <- colour_mapping
    }

    overlay_width <- layer$aes_params$linewidth
    if (is.null(overlay_width)) overlay_width <- layer$aes_params$size
    if (is.null(overlay_width)) overlay_width <- 1.05
    overlay_alpha <- layer$aes_params$alpha
    if (is.null(overlay_alpha)) overlay_alpha <- 1

    # White halo under the highlighted curves. This prevents grey background
    # curves from visually cutting through the dashed/dot-dashed highlighted
    # curves at crossings.
    halo_mapping <- mapping
    if (is.null(halo_mapping$group) && !is.null(colour_mapping)) {
      halo_mapping$group <- colour_mapping
    }
    halo_mapping$colour <- NULL
    halo_mapping$color <- NULL
    halo_mapping$linetype <- NULL

    p <- p + ggplot2::geom_line(
      data = highlight_dat,
      mapping = halo_mapping,
      inherit.aes = FALSE,
      colour = "white",
      linetype = "solid",
      linewidth = 1.8 * overlay_width,
      alpha = 0.95,
      lineend = "round",
      na.rm = TRUE,
      show.legend = FALSE
    )

    p <- p + ggplot2::geom_line(
      data = highlight_dat,
      mapping = mapping,
      inherit.aes = FALSE,
      linewidth = overlay_width,
      alpha = overlay_alpha,
      lineend = "round",
      na.rm = TRUE,
      show.legend = FALSE
    )
  }

  p
}

style_power_plot <- function(p, dat = NULL) {
  p <- force_linetype_mapping(p, dat)
  p <- force_highlighted_curves_on_top(p, dat)
  p <- add_highlight_overlay(p, dat)

  p +
    scale_colour_manual(values = highlight_colours, na.value = "grey75") +
    scale_linetype_manual(values = highlight_linetypes, na.value = "solid") +
    scale_x_continuous(breaks = sample_size) +
    coord_cartesian(ylim = c(0, 1)) +
    labs(x = "Sample size", y = "Power") +
    guides(colour = "none", linetype = "none") +
    theme_minimal(base_size = 12) +
    theme(
      plot.title = element_text(size = 13, face = "bold", hjust = 0.5),
      axis.title = element_text(size = 12),
      axis.text  = element_text(size = 10),
      panel.grid.minor = element_blank(),
      legend.position = "none"
    )
}

make_power_plot <- function(dat, title) {
  dat <- reorder_data_for_drawing(dat)
  style_power_plot(plotpowers(dat, title), dat)
}

# Compact horizontal legend used below each panel.
legend_data <- data.table(
  model = factor(c("combined", "score", "image", "others"),
                 levels = c("combined", "score", "image", "others")),
  label = c("Combined Model", "Score Model", "Image Model", "Others"),
  x0 = c(0.0, 2.45, 4.55, 6.55),
  x1 = c(0.65, 3.10, 5.20, 7.20),
  xt = c(0.78, 3.23, 5.33, 7.33),
  y  = 1
)

legend_plot <- ggplot(legend_data) +
  geom_segment(
    aes(x = x0, xend = x1, y = y, yend = y, colour = model, linetype = model),
    linewidth = 1.05,
    lineend = "round"
  ) +
  geom_text(aes(x = xt, y = y, label = label), hjust = 0, size = 3.0) +
  scale_colour_manual(values = highlight_colours) +
  scale_linetype_manual(values = highlight_linetypes) +
  xlim(0, 8.65) +
  ylim(0.75, 1.25) +
  theme_void() +
  theme(legend.position = "none")

print_power_pdf <- function(filename, powers, titles) {
  pdf(filename, width = 7, height = 5)
  on.exit(dev.off(), add = TRUE)

  for (i in seq_along(titles)) {
    p_main <- make_power_plot(powers[scenario == i], titles[i])
    combo <- gridExtra::grid.arrange(
      p_main,
      legend_plot,
      ncol = 1,
      heights = c(4.6, 0.4)
    )
    print(combo)
  }
}

# -----------------------------------------------------------------------------
# Additive scenarios
# -----------------------------------------------------------------------------

powers <- readRDS("add_powers.RDS") %>% as.data.table()

titles <- as.vector(t(outer(
  c("Correlated Laplace", "Ishigami style", "Tree ring",
    "Escalating variance", "Infinity", "Pi"),
  c("A", "B"),
  paste
)))

print_power_pdf("power_curves_add.pdf", powers, titles)

# -----------------------------------------------------------------------------
# Dependence scenarios
# -----------------------------------------------------------------------------

powers <- readRDS("dep_powers.RDS") %>% as.data.table()

titles <- c(
  "Linear", "Diamond", "Triangle", "Crescent", "Points", "Exponential",
  "Circles", "Cross", "Wedge", "Cubic", "W-shape", "Parabola",
  "Two-parabola", "Sine", "Doppler", "Heavy-sine", "Heart", "Spiral",
  "Taegeuk", "Samtaegeuk"
)

print_power_pdf("power_curves_dep.pdf", powers, titles)

# -----------------------------------------------------------------------------
# Increasing-noise scenarios
# -----------------------------------------------------------------------------

powers <- readRDS("increasingnoise_powers.RDS") %>% as.data.table()

# Circles
titles <- paste0("Circles: noise level ", 1:4)
print_power_pdf("power_curves_circles.pdf", powers, titles)

# Cross
titles <- paste0("Cross: noise level ", 1:4)
print_power_pdf("power_curves_cross.pdf", powers, titles)

# Sine
titles <- paste0("Sine: noise level ", 1:4)
print_power_pdf("power_curves_sine.pdf", powers, titles)

# Spiral
titles <- paste0("Spiral: noise level ", 1:4)
print_power_pdf("power_curves_spiral.pdf", powers, titles)
