## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

## ----setup--------------------------------------------------------------------
library(ggplot2)
library(dplyr)
library(tibble)
library(purrr)
library(patchwork)
library(masc)

## -----------------------------------------------------------------------------
# masc_time_pressure_final.R - Creates Figure 10 reproduction from
# Gluth, S., Deakin, J., & Rieskamp, J. (2026). A theory of multiattribute search and choice. Psychological Review.

# Custom function to generate non-dominated attributes
generate_valid_attributes <- function(n_options, n_attributes, lambda = 1) {
  # Keep generating until we get a valid matrix with no dominating options
  while(TRUE) {
    # Generate random values
    x <- matrix(rnorm(n_options * n_attributes), nrow = n_options)

    # Check if any option dominates (is better on all attributes)
    dominates <- logical(n_options)
    for(i in 1:n_options) {
      better_count <- 0
      for(j in 1:n_attributes) {
        if(all(x[i,j] >= x[-i,j])) better_count <- better_count + 1
      }
      dominates[i] <- (better_count == n_attributes)
    }

    # If no option dominates, return this matrix
    if(!any(dominates)) return(x)
  }
}

# Function to simulate MASC behavior under time pressure
simulate_time_pressure_effects <- function(
    n_trials = 50,                     # Number of trials per condition
    n_subjects = 25,                   # Number of simulated participants
    threshold_levels = seq(0.001, 0.2, length.out = 10), # Threshold parameters
    sensitivity_levels = seq(0, 3, length.out = 5),      # Search sensitivity levels
    n_options = 2,                     # Number of options
    n_attributes = 3,                  # Number of attributes
    max_fixations = 100                # Maximum number of fixations
) {
  # Pre-allocate results
  results <- tibble()

  # Create common attribute values and weights for all conditions
  set.seed(2025)

  # Generate attribute weights for each subject
  weights_list <- lapply(1:n_subjects, function(s) {
    # Sample from beta distribution as in the MASC paper
    w <- rbeta(n_attributes, 0.75, 0.75)
    # Normalize to sum to 1
    w / sum(w)
  })

  # Generate attribute values for all trials and subjects
  attr_values <- lapply(1:n_subjects, function(s) {
    # Create dataset for this subject
    lapply(1:n_trials, function(t) {
      # Generate valid attribute values (no option dominates)
      valid_matrix <- generate_valid_attributes(n_options, n_attributes)
      return(valid_matrix)
    })
  })

  # Loop through sensitivity levels
  for(sensitivity_idx in seq_along(sensitivity_levels)) {
    search_sensitivity <- sensitivity_levels[sensitivity_idx]
    cat(sprintf("Processing search sensitivity %.2f (%d/%d)\n",
                search_sensitivity, sensitivity_idx, length(sensitivity_levels)))

    # Loop through threshold levels (time pressure)
    for(thresh_idx in seq_along(threshold_levels)) {
      threshold <- threshold_levels[thresh_idx]
      cat(sprintf("  Processing threshold %.3f (%d/%d)\n",
                  threshold, thresh_idx, length(threshold_levels)))

      # Subject-level metrics
      subject_metrics <- tibble(
        subject = 1:n_subjects,
        payne_index = numeric(n_subjects),
        attr_variance = numeric(n_subjects),
        opt_variance = numeric(n_subjects),
        most_important_prop = numeric(n_subjects),
        choice_consistency = numeric(n_subjects),
        avg_fixations = numeric(n_subjects),
        reward_rate = numeric(n_subjects)
      )

      # Process each subject
      for(s in 1:n_subjects) {
        # Set up parameters for this subject
        weights <- weights_list[[s]]
        most_important <- which.max(weights)

        # Convert attribute values to dataframe format for rMASC
        subject_data <- do.call(rbind, lapply(1:n_trials, function(t) {
          trial_matrix <- attr_values[[s]][[t]]
          as.data.frame(matrix(trial_matrix, nrow = 1))
        }))

        # Rename columns appropriately for rMASC
        colnames(subject_data) <- c(
          paste0("opt", rep(1:n_options, each = n_attributes),
                 "_att", rep(1:n_attributes, n_options))
        )

        # Run rMASC for this subject with current parameters
        result <- rMASC(
          data = subject_data,
          w = weights,
          sigma = 1,  # Fixed sampling noise
          alpha = search_sensitivity,
          delta = 0.01,  # Fixed threshold increase
          theta = threshold,
          max_steps = max_fixations
        )

        # Extract metrics

        # 1. Use the correctness directly from rMASC
        subject_metrics$choice_consistency[s] <- mean(result$results$correct)

        # 2. Average number of fixations
        subject_metrics$avg_fixations[s] <- mean(result$results$rt)

        # 3. Calculate reward rate
        subject_metrics$reward_rate[s] <- subject_metrics$choice_consistency[s] /
          subject_metrics$avg_fixations[s]

        # Extract fixation metrics
        prop_fix_attr <- matrix(0, nrow = n_trials, ncol = n_attributes)
        prop_fix_opt <- matrix(0, nrow = n_trials, ncol = n_options)
        payne_indices <- numeric(n_trials)

        for(t in 1:n_trials) {
          fix_seq <- result$raw[[t]]$fix_sequence

          # Calculate fixation proportions
          for(a in 1:n_attributes) {
            attr_indices <- which(ceiling(fix_seq / n_options) == a)
            prop_fix_attr[t, a] <- length(attr_indices) / length(fix_seq)
          }

          for(o in 1:n_options) {
            opt_indices <- which(((fix_seq - 1) %% n_options) + 1 == o)
            prop_fix_opt[t, o] <- length(opt_indices) / length(fix_seq)
          }

          # Calculate Payne Index for this trial
          within_option <- 0
          within_attribute <- 0

          if(length(fix_seq) > 1) {
            for(i in 1:(length(fix_seq) - 1)) {
              curr_oap <- fix_seq[i]
              next_oap <- fix_seq[i+1]

              curr_opt <- ((curr_oap - 1) %% n_options) + 1
              next_opt <- ((next_oap - 1) %% n_options) + 1

              curr_attr <- ceiling(curr_oap / n_options)
              next_attr <- ceiling(next_oap / n_options)

              if(curr_opt == next_opt) {
                within_option <- within_option + 1
              } else if(curr_attr == next_attr) {
                within_attribute <- within_attribute + 1
              }
            }

            total_transitions <- within_option + within_attribute
            if(total_transitions > 0) {
              payne_indices[t] <- (within_option - within_attribute) / total_transitions
            }
          }
        }

        # Aggregate metrics across trials
        subject_metrics$payne_index[s] <- mean(payne_indices, na.rm = TRUE)
        subject_metrics$attr_variance[s] <- mean(apply(prop_fix_attr, 1, var), na.rm = TRUE)
        subject_metrics$opt_variance[s] <- mean(apply(prop_fix_opt, 1, var), na.rm = TRUE)
        subject_metrics$most_important_prop[s] <- mean(prop_fix_attr[, most_important], na.rm = TRUE)
      }

      # Add aggregated results to main dataframe
      results <- bind_rows(results, tibble(
        search_sensitivity = search_sensitivity,
        threshold = threshold,
        payne_index = mean(subject_metrics$payne_index, na.rm = TRUE),
        attr_variance = mean(subject_metrics$attr_variance, na.rm = TRUE),
        opt_variance = mean(subject_metrics$opt_variance, na.rm = TRUE),
        most_important_prop = mean(subject_metrics$most_important_prop, na.rm = TRUE),
        choice_consistency = mean(subject_metrics$choice_consistency, na.rm = TRUE),
        fixations = mean(subject_metrics$avg_fixations, na.rm = TRUE),
        reward_rate = mean(subject_metrics$reward_rate, na.rm = TRUE)
      ))
    }
  }

  return(results)
}

# Plot results function
plot_time_pressure_results <- function(results) {
  # Create the green gradient color palette
  create_color_palette <- function(n) {
    start_color <- c(194, 218, 184) / 255
    end_color <- c(1, 50, 32) / 255

    # Generate color gradient
    colors <- tibble(
      level = 1:n,
      r = seq(start_color[1], end_color[1], length.out = n),
      g = seq(start_color[2], end_color[2], length.out = n),
      b = seq(start_color[3], end_color[3], length.out = n)
    )

    # Convert to hex colors
    colors <- colors %>%
      mutate(hex = rgb(r, g, b))

    # Return as named vector
    colors$hex
  }

  # Get number of sensitivity levels
  n_sens <- length(unique(results$search_sensitivity))
  colors <- create_color_palette(n_sens)

  # Prepare for plotting by converting sensitivity to factor
  plot_data <- results %>%
    mutate(search_sensitivity = factor(search_sensitivity))

  # Create common theme
  theme_tp <- theme_classic() +
    theme(
      plot.title = element_text(size = 10, face = "bold", hjust = 0),
      axis.title = element_text(size = 9),
      legend.position = "bottom",
      legend.title = element_text(size = 9),
      legend.text = element_text(size = 8),
      panel.grid.minor = element_blank()
    )

  # Panel A: Payne Index
  p1 <- plot_data %>%
    ggplot(aes(x = threshold, y = payne_index, color = search_sensitivity)) +
    geom_line() +
    geom_point() +
    scale_color_manual(values = colors, name = "Search Sensitivity") +
    labs(
      title = "A",
      x = "Threshold θ",
      y = "Payne Index"
    ) +
    theme_tp +
    xlim(-0.05, 0.25)

  # Panel B: Attribute Variance
  p2 <- plot_data %>%
    ggplot(aes(x = threshold, y = attr_variance, color = search_sensitivity)) +
    geom_line() +
    geom_point() +
    scale_color_manual(values = colors, name = "Search Sensitivity") +
    labs(
      title = "B",
      x = "Threshold θ",
      y = "Attribute Variance"
    ) +
    theme_tp +
    xlim(-0.05, 0.25)

  # Panel C: Option Variance
  p3 <- plot_data %>%
    ggplot(aes(x = threshold, y = opt_variance, color = search_sensitivity)) +
    geom_line() +
    geom_point() +
    scale_color_manual(values = colors, name = "Search Sensitivity") +
    labs(
      title = "C",
      x = "Threshold θ",
      y = "Option Variance"
    ) +
    theme_tp +
    xlim(-0.05, 0.25)

  # Panel D: Most Important Attribute
  p4 <- plot_data %>%
    ggplot(aes(x = threshold, y = most_important_prop, color = search_sensitivity)) +
    geom_line() +
    geom_point() +
    scale_color_manual(values = colors, name = "Search Sensitivity") +
    labs(
      title = "D",
      x = "Threshold θ",
      y = "p(Fix = Most Important)"
    ) +
    theme_tp +
    xlim(-0.05, 0.25)

  # Additional Panels (optional)
  # Panel E: Choice Consistency
  p5 <- plot_data %>%
    ggplot(aes(x = threshold, y = choice_consistency, color = search_sensitivity)) +
    geom_line() +
    geom_point() +
    scale_color_manual(values = colors, name = "Search Sensitivity") +
    labs(
      title = "E",
      x = "Threshold θ",
      y = "Choice Consistency"
    ) +
    theme_tp +
    xlim(-0.05, 0.25)

  # Panel F: Number of Fixations
  p6 <- plot_data %>%
    ggplot(aes(x = threshold, y = fixations, color = search_sensitivity)) +
    geom_line() +
    geom_point() +
    scale_color_manual(values = colors, name = "Search Sensitivity") +
    labs(
      title = "F",
      x = "Threshold θ",
      y = "Number of Fixations"
    ) +
    theme_tp +
    xlim(-0.05, 0.25)

  # Panel G: Reward Rate
  p7 <- plot_data %>%
    ggplot(aes(x = threshold, y = reward_rate, color = search_sensitivity)) +
    geom_line() +
    geom_point() +
    scale_color_manual(values = colors, name = "Search Sensitivity") +
    labs(
      title = "G",
      x = "Threshold θ",
      y = "Reward Rate"
    ) +
    theme_tp +
    xlim(-0.05, 0.25)

  # Combine main four panels
  main_plot <- (p1 + p2) / (p3 + p4) +
    plot_layout(guides = "collect") &
    theme(legend.position = "bottom")

  # Combine all seven panels
  full_plot <- (p1 + p2 + p3 + p4) / (p5 + p6 + p7 + plot_spacer()) +
    plot_layout(guides = "collect") &
    theme(legend.position = "bottom")

  list(
    main_plot = main_plot +
      plot_annotation(
        title = "Simulating Time Pressure with Varying Levels of Search Sensitivity",
        theme = theme(plot.title = element_text(size = 14, face = "bold"))
      ),
    full_plot = full_plot +
      plot_annotation(
        title = "Simulating Time Pressure with Varying Levels of Search Sensitivity",
        theme = theme(plot.title = element_text(size = 14, face = "bold"))
      )
  )
}

## -----------------------------------------------------------------------------
# Run simulation (note: this will take some time to run)
set.seed(2025)
results <- simulate_time_pressure_effects(
  n_trials = 15,
  n_subjects = 6,
  threshold_levels = seq(0.001, 0.2, length.out = 4),
  sensitivity_levels = seq(0, 3, length.out = 3)
)

## ----fig.width=12, fig.height=8, out.width="100%"-----------------------------
# Plot results
plots <- plot_time_pressure_results(results)

# Display main plot (4 panels as in Figure 10)
print(plots$main_plot)

# Display full plot (all 7 panels)
print(plots$full_plot)

