cjerzak commited on
Commit
2c61538
·
verified ·
1 Parent(s): b22d438

Update app.R

Browse files
Files changed (1) hide show
  1. app.R +244 -43
app.R CHANGED
@@ -1,58 +1,259 @@
 
 
1
  library(shiny)
2
- library(bslib)
3
- library(dplyr)
4
  library(ggplot2)
 
 
5
 
6
- df <- readr::read_csv("penguins.csv")
7
- # Find subset of columns that are suitable for scatter plot
8
- df_num <- df |> select(where(is.numeric), -Year)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- ui <- page_sidebar(
11
- theme = bs_theme(bootswatch = "minty"),
12
- title = "Penguins explorer",
13
- sidebar = sidebar(
14
- varSelectInput("xvar", "X variable", df_num, selected = "Bill Length (mm)"),
15
- varSelectInput("yvar", "Y variable", df_num, selected = "Bill Depth (mm)"),
16
- checkboxGroupInput("species", "Filter by species",
17
- choices = unique(df$Species), selected = unique(df$Species)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ),
19
- hr(), # Add a horizontal rule
20
- checkboxInput("by_species", "Show species", TRUE),
21
- checkboxInput("show_margins", "Show marginal plots", TRUE),
22
- checkboxInput("smooth", "Add smoother"),
23
- ),
24
- plotOutput("scatter")
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
 
 
27
  server <- function(input, output, session) {
28
- subsetted <- reactive({
29
- req(input$species)
30
- df |> filter(Species %in% input$species)
 
 
 
 
 
 
 
 
 
31
  })
32
-
33
- output$scatter <- renderPlot(
34
- {
35
- p <- ggplot(subsetted(), aes(!!input$xvar, !!input$yvar)) +
36
- theme_light() +
37
- list(
38
- theme(legend.position = "bottom"),
39
- if (input$by_species) aes(color = Species),
40
- geom_point(),
41
- if (input$smooth) geom_smooth()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
-
44
- if (input$show_margins) {
45
- margin_type <- if (input$by_species) "density" else "histogram"
46
- p <- p |> ggExtra::ggMarginal(
47
- type = margin_type, margins = "both",
48
- size = 8, groupColour = input$by_species, groupFill = input$by_species
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  )
 
 
 
 
 
 
 
 
 
 
50
  }
51
-
52
- p
53
- },
54
- res = 100
55
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
 
 
58
  shinyApp(ui, server)
 
1
+ # setwd("~/Dropbox/OptimizingSI/Analysis/ono")
2
+
3
  library(shiny)
 
 
4
  library(ggplot2)
5
+ library(strategize)
6
+ library(dplyr)
7
 
8
+ # Custom plotting function for optimal strategy distributions
9
+ plot_factor <- function(pi_star_list, pi_star_se_list, factor_name, zStar = 1.96) {
10
+ probs <- lapply(pi_star_list, function(x) x[[factor_name]])
11
+ ses <- lapply(pi_star_se_list, function(x) x[[factor_name]])
12
+ levels <- names(probs[[1]])
13
+ n_strategies <- length(probs)
14
+
15
+ # Create data frame for plotting
16
+ df <- do.call(rbind, lapply(1:n_strategies, function(i) {
17
+ data.frame(
18
+ Strategy = if (n_strategies == 1) "Optimal" else c("Democrat", "Republican")[i],
19
+ Level = levels,
20
+ Probability = probs[[i]],
21
+ SE = ses[[i]]
22
+ )
23
+ }))
24
+
25
+ # Plot with ggplot2
26
+ p <- ggplot(df, aes(x = Level, y = Probability, fill = Strategy)) +
27
+ geom_bar(stat = "identity", position = position_dodge(width = 0.9), width = 0.8) +
28
+ geom_errorbar(aes(ymin = Probability - zStar * SE, ymax = Probability + zStar * SE),
29
+ position = position_dodge(width = 0.9), width = 0.25) +
30
+ labs(title = paste("Optimal Distribution for", factor_name),
31
+ x = "Level", y = "Probability") +
32
+ theme_minimal() +
33
+ theme(axis.text.x = element_text(angle = 45, hjust = 1),
34
+ legend.position = "top") +
35
+ scale_fill_manual(values = c("Democrat" = "#89cff0", "Republican" = "red", "Optimal" = "black"))
36
+
37
+ return(p)
38
+ }
39
 
40
+ # UI Definition
41
+ ui <- fluidPage(
42
+ titlePanel("Exploring strategize with the candidate choice conjoint data"),
43
+
44
+ sidebarLayout(
45
+ sidebarPanel(
46
+ h4("Analysis Options"),
47
+ radioButtons("case_type", "Case Type:",
48
+ choices = c("Average", "Adversarial"),
49
+ selected = "Average"),
50
+ conditionalPanel(
51
+ condition = "input.case_type == 'Average'",
52
+ selectInput("respondent_group", "Respondent Group:",
53
+ choices = c("All", "Democrat", "Independent", "Republican"),
54
+ selected = "All")
55
+ ),
56
+ # Add a single numeric input for lambda
57
+ numericInput("lambda_input", "Lambda (regularization):",
58
+ value = 0.01, min = 1e-6, max = 10, step = 0.01),
59
+ actionButton("compute", "Compute Results", class = "btn-primary"),
60
+ hr(),
61
+ h4("Visualization"),
62
+ selectInput("factor", "Select Factor to Display:",
63
+ choices = NULL),
64
+ hr(),
65
+ h5("Instructions:"),
66
+ p("1. Select a case type and, for Average case, a respondent group."),
67
+ p("2. Specify the single lambda to be used by strategize."),
68
+ p("3. Click 'Compute Results' to generate optimal strategies."),
69
+ p("4. Choose a factor to view its distribution.")
70
  ),
71
+
72
+ mainPanel(
73
+ tabsetPanel(
74
+ tabPanel("Optimal Strategy Plot",
75
+ plotOutput("strategy_plot", height = "600px")),
76
+ tabPanel("Q Value",
77
+ verbatimTextOutput("q_value"),
78
+ p("Q represents the estimated outcome (e.g., selection probability) under the optimal strategy, with 95% confidence interval.")),
79
+ tabPanel("About",
80
+ h3("About This App"),
81
+ p("This Shiny app explores the `strategize` package using Ono experimental data. It computes optimal strategies for Average (optimizing for a respondent group) and Adversarial (optimizing for both parties in competition) cases on the fly."),
82
+ p("**Average Case**: Optimizes candidate characteristics for a selected respondent group."),
83
+ p("**Adversarial Case**: Finds equilibrium strategies for Democrats and Republicans, identified by 'Pro-life' stance.")
84
+ )
85
+ )
86
+ )
87
+ )
88
  )
89
 
90
+ # Server Definition
91
  server <- function(input, output, session) {
92
+ # Load data
93
+ load("Processed_OnoData.RData")
94
+ Primary2016 <- read.csv("PrimaryCandidates2016 - Sheet1.csv")
95
+
96
+ # Update factor choices dynamically
97
+ observe({
98
+ if (input$case_type == "Average") {
99
+ factors <- colnames(FACTOR_MAT_FULL)[!colnames(FACTOR_MAT_FULL) %in% c("Office")]
100
+ } else {
101
+ factors <- colnames(FACTOR_MAT_FULL)[!colnames(FACTOR_MAT_FULL) %in% c("Office", "Party.affiliation", "Party.competition")]
102
+ }
103
+ updateSelectInput(session, "factor", choices = factors, selected = factors[1])
104
  })
105
+
106
+ # Reactive computation triggered by button
107
+ result <- eventReactive(input$compute, {
108
+ withProgress(message = "Computing optimal strategies...", value = 0, {
109
+ # Increment progress
110
+ incProgress(0.2, detail = "Preparing data...")
111
+
112
+ # Common hyperparameters (mirroring QRun_Apps.R)
113
+ params <- list(
114
+ nSGD = 1000L,
115
+ batch_size = 50L,
116
+ penalty_type = "KL",
117
+ nFolds = 3L,
118
+ use_optax = TRUE,
119
+ compute_se = FALSE, # Set to FALSE for quicker results
120
+ conf_level = 0.95,
121
+ conda_env = "strategize",
122
+ conda_env_required = TRUE
123
+ )
124
+
125
+ # Grab the single user-chosen lambda
126
+ my_lambda <- input$lambda_input
127
+
128
+ if (input$case_type == "Average") {
129
+ # Subset data for Average case
130
+ if (input$respondent_group == "All") {
131
+ indices <- 1:nrow(my_data_FULL)
132
+ } else {
133
+ indices <- which(my_data_FULL$R_Partisanship == input$respondent_group)
134
+ }
135
+
136
+ FACTOR_MAT <- FACTOR_MAT_FULL[indices, !colnames(FACTOR_MAT_FULL) %in% "Office"]
137
+ Yobs <- Yobs_FULL[indices]
138
+ X <- X_FULL[indices, ]
139
+ log_pr_w <- log_pr_w_FULL[indices]
140
+ assignmentProbList <- assignmentProbList_FULL[!names(assignmentProbList_FULL) %in% "Office"]
141
+
142
+ incProgress(0.4, detail = "Running strategize...")
143
+
144
+ # Compute with strategize using a single lambda
145
+ Qoptimized <- strategize(
146
+ Y = Yobs,
147
+ W = FACTOR_MAT,
148
+ X = X,
149
+ p_list = assignmentProbList,
150
+ lambda = my_lambda,
151
+ adversarial = FALSE,
152
+ K = 1L, # Base analysis
153
+ nSGD = params$nSGD,
154
+ penalty_type = params$penalty_type,
155
+ folds = params$nFolds,
156
+ use_optax = params$use_optax,
157
+ compute_se = params$compute_se,
158
+ conf_level = params$conf_level,
159
+ conda_env = params$conda_env,
160
+ conda_env_required = params$conda_env_required
161
  )
162
+ } else { # Adversarial case
163
+ # Use full data, drop specific factors
164
+ DROP_FACTORS <- c("Office", "Party.affiliation", "Party.competition")
165
+ FACTOR_MAT <- FACTOR_MAT_FULL[, !colnames(FACTOR_MAT_FULL) %in% DROP_FACTORS]
166
+ Yobs <- Yobs_FULL
167
+ X <- X_FULL
168
+ log_pr_w <- log_pr_w_FULL
169
+ assignmentProbList <- assignmentProbList_FULL[!names(assignmentProbList_FULL) %in% DROP_FACTORS]
170
+
171
+ # Prepare slate_list (simplified from QRun_Apps.R)
172
+ incProgress(0.3, detail = "Preparing slate data...")
173
+ FactorOptions <- apply(FACTOR_MAT, 2, table)
174
+ prior_alpha <- 10
175
+ Primary_D <- Primary2016[Primary2016$Party == "Democratic", colnames(FACTOR_MAT)]
176
+ Primary_R <- Primary2016[Primary2016$Party == "Republican", colnames(FACTOR_MAT)]
177
+
178
+ Primary_D_slate <- lapply(colnames(Primary_D), function(col) {
179
+ posterior_alpha <- FactorOptions[[col]]; posterior_alpha[] <- prior_alpha
180
+ Empirical_ <- table(Primary_D[[col]])
181
+ Empirical_ <- Empirical_[names(Empirical_) != "Unclear"]
182
+ posterior_alpha[names(Empirical_)] <- posterior_alpha[names(Empirical_)] + Empirical_
183
+ prop.table(posterior_alpha)
184
+ })
185
+ names(Primary_D_slate) <- colnames(Primary_D)
186
+
187
+ Primary_R_slate <- lapply(colnames(Primary_R), function(col) {
188
+ posterior_alpha <- FactorOptions[[col]]; posterior_alpha[] <- prior_alpha
189
+ Empirical_ <- table(Primary_R[[col]])
190
+ Empirical_ <- Empirical_[names(Empirical_) != "Unclear"]
191
+ posterior_alpha[names(Empirical_)] <- posterior_alpha[names(Empirical_)] + Empirical_
192
+ prop.table(posterior_alpha)
193
+ })
194
+ names(Primary_R_slate) <- colnames(Primary_R)
195
+
196
+ slate_list <- list("Democratic" = Primary_D_slate, "Republican" = Primary_R_slate)
197
+
198
+ incProgress(0.4, detail = "Running strategize...")
199
+
200
+ # Compute with strategize using a single lambda
201
+ Qoptimized <- strategize(
202
+ Y = Yobs,
203
+ W = FACTOR_MAT,
204
+ X = X,
205
+ p_list = assignmentProbList,
206
+ slate_list = slate_list,
207
+ competing_group_variable_respondent = my_data_FULL$R_Partisanship,
208
+ competing_group_variable_candidate = ifelse(my_data_FULL$Party.affiliation == "Republican Party", "Republican",
209
+ ifelse(my_data_FULL$Party.affiliation == "Democratic Party", "Democrat", "Independent")),
210
+ lambda = my_lambda,
211
+ adversarial = TRUE,
212
+ K = 1L,
213
+ nMonte_adversarial = 100L,
214
+ nSGD = params$nSGD,
215
+ penalty_type = params$penalty_type,
216
+ folds = params$nFolds,
217
+ use_optax = params$use_optax,
218
+ compute_se = params$compute_se,
219
+ conf_level = params$conf_level,
220
+ conda_env = params$conda_env,
221
+ conda_env_required = params$conda_env_required
222
  )
223
+
224
+ # Identify Democrat vs Republican based on "Pro-life" stance
225
+ prolife_probs <- c(Qoptimized$pi_star_point$k1$Position.on.abortion["Pro-life"],
226
+ Qoptimized$pi_star_point$k2$Position.on.abortion["Pro-life"])
227
+ which_repub <- which.max(prolife_probs)
228
+ if (which_repub == 1) {
229
+ # Swap
230
+ Qoptimized$pi_star_point <- list(k1 = Qoptimized$pi_star_point$k2, k2 = Qoptimized$pi_star_point$k1)
231
+ Qoptimized$pi_star_se <- list(k1 = Qoptimized$pi_star_se$k2, k2 = Qoptimized$pi_star_se$k1)
232
+ }
233
  }
234
+
235
+ incProgress(0.8, detail = "Finalizing results...")
236
+ return(Qoptimized)
237
+ })
238
+ })
239
+
240
+ # Render strategy plot
241
+ output$strategy_plot <- renderPlot({
242
+ req(result())
243
+ factor_name <- input$factor
244
+ pi_star_list <- result()$pi_star_point
245
+ pi_star_se_list <- result()$pi_star_se
246
+ plot_factor(pi_star_list, pi_star_se_list, factor_name)
247
+ })
248
+
249
+ # Render Q value
250
+ output$q_value <- renderText({
251
+ req(result())
252
+ q_point <- result()$Q_point_mEst
253
+ q_se <- result()$Q_se_mEst
254
+ paste("Estimated Q Value: ", sprintf("%.3f ± %.3f", q_point, 1.96 * q_se))
255
+ })
256
  }
257
 
258
+ # Run the app
259
  shinyApp(ui, server)