Milad Alshomary commited on
Commit
3d73c8d
Β·
1 Parent(s): b96061f
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["python", "app.py"]
add_hf_env_to_hf_space.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ import os
3
+ repo_id = "miladalsh/explaining_authorship_attribution_models"
4
+ api = HfApi()
5
+
6
+ api.add_space_variable(repo_id=repo_id, key="OPENAI_API_KEY", value=os.environ["OPENAI_API_KEY"])
app.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+
4
+
5
+ import os
6
+ os.environ["GRADIO_TEMP_DIR"] = "./datasets/temp" # Set a custom temp directory for Gradio
7
+ os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
8
+
9
+ import yaml
10
+ import argparse
11
+ import os
12
+ import urllib.request
13
+ from tqdm import tqdm
14
+
15
+ from dotenv import load_dotenv
16
+ from openai import OpenAI
17
+ from utils.file_download import download_file_override
18
+
19
+
20
+ def load_config(path="config/config.yaml"):
21
+ with open(path, "r") as f:
22
+ return yaml.safe_load(f)
23
+
24
+ cfg = load_config()
25
+
26
+
27
+ download_file_override(cfg.get('interp_space_url'), cfg.get('interp_space_path'))
28
+ download_file_override(cfg.get('instances_to_explain_url'), cfg.get('instances_to_explain_path'))
29
+ download_file_override(cfg.get('gram2vec_feats_url'), cfg.get('gram2vec_feats_path'))
30
+
31
+ from utils.visualizations import *
32
+ from utils.llm_feat_utils import *
33
+ from utils.gram2vec_feat_utils import *
34
+ from utils.interp_space_utils import *
35
+ from utils.ui import *
36
+
37
+ load_dotenv()
38
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
39
+
40
+
41
+ # ── load once at startup ────────────────────────────────────────
42
+ GRAM2VEC_SHORTHAND = load_code_map()
43
+
44
+ def validate_ground_truth(gt1, gt2, gt3):
45
+ selected = [gt1, gt2, gt3]
46
+ selected_count = sum(selected)
47
+
48
+ if selected_count > 1:
49
+ return None, "Please select only one ground truth author."
50
+ elif selected_count == 0:
51
+ return None, "No ground truth author selected."
52
+
53
+ index = selected.index(True)
54
+ return index, f"Candidate {index+1} is marked as the ground truth author."
55
+
56
+
57
+ def app(share=False, use_cluster_feats=False):
58
+ instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
59
+
60
+ interp = load_interp_space(cfg)
61
+ clustered_authors_df = interp['clustered_authors_df'][:1000]
62
+ clustered_authors_df['fullText'] = clustered_authors_df['fullText'].map(lambda l: l[:3]) # Take at most 3 texts per author
63
+
64
+ with gr.Blocks(title="Author Attribution Explainability Tool") as demo:
65
+ # ── Big Centered Title ──────────────────────────────────────────
66
+ gr.HTML(styled_block("""
67
+ <h1 style="
68
+ text-align:center;
69
+ font-size:3em; /* About 48px */
70
+ margin-bottom:0.3em;
71
+ font-weight:700;
72
+ ">
73
+ Author Attribution (AA) Explainability Tool
74
+ </h1>
75
+ """))
76
+
77
+ gr.HTML(styled_block("""
78
+ <div style="
79
+ text-align:center;
80
+ margin: 1em auto 2em auto;
81
+ max-width:900px;
82
+ ">
83
+ <p style="font-size:1.3em; line-height:1.4;">
84
+ This demo helps you <strong>see inside</strong> a deep AA model’s latent style space.
85
+ </p>
86
+ <p style="font-size:0.9em; line-height:1.4;">
87
+ Currently you are inspecting <a href="https://huggingface.co/rrivera1849/LUAR-MUD">LUAR</a> with pre-defined AA tasks from the <a href="https://www.iarpa.gov/images/research-programs/HIATUS/IARPA_HIATUS_Phase_1_HRS_Data.to_DAC_20240610.pdf">HRS dataset </a>
88
+ </p>
89
+ <div style="
90
+ display:flex;
91
+ justify-content:center;
92
+ gap:3em;
93
+ margin-top:1em;
94
+ ">
95
+ <!-- Visualize -->
96
+ <div style="max-width:200px;">
97
+ <div style="font-size:2em;">πŸ”</div>
98
+ <h4 style="margin:0.2em 0;">Visualize</h4>
99
+ <p style="margin:0; font-size:1em; line-height:1.3;">
100
+ Place your AA task with respect to other background authors.
101
+ </p>
102
+ </div>
103
+ <!-- GENERATE -->
104
+ <div style="max-width:200px;">
105
+ <div style="font-size:2em;">✏️</div>
106
+ <h4 style="margin:0.2em 0;">Generate</h4>
107
+ <p style="margin:0; font-size:1em; line-height:1.3;">
108
+ Describe your investigated authors' writing style via human-readable LLM-generated style features.
109
+ </p>
110
+ </div>
111
+ <!-- COMPARE -->
112
+ <div style="max-width:200px;">
113
+ <div style="font-size:2em;">βš–οΈ</div>
114
+ <h4 style="margin:0.2em 0;">Compare</h4>
115
+ <p style="margin:0; font-size:1em; line-height:1.3;">
116
+ Contrast with <a href=""https://github.com/eric-sclafani/gram2vec>Gram2Vec</a> stylometric features.
117
+ </p>
118
+ </div>
119
+ </div>
120
+ </div>
121
+ """))
122
+
123
+
124
+ # ── Step-by-Step Guided Panel ──
125
+ with gr.Accordion("πŸ“ How to Use", open=True):
126
+ gr.Markdown("""
127
+ 1. **Select** a model and a task source (pre-defined or custom)
128
+ 2. Click **Load Task & Generate Embeddings** to load the task and generate embeddings
129
+ 3. **Run Visualization** to see the mystery author and candidates in the AA model's latent space
130
+ 4. **Zoom** into the visualization to select a cluster of background authors
131
+ 5. Pick an **LLM feature** to highlight in yellow
132
+ 6. Pick a **Gram2Vec feature** to highlight in blue
133
+ 7. Click **Show Combined Spans** to compare side-by-side
134
+ """
135
+ )
136
+
137
+ # ── Model Selection ─────────────────────────────────
138
+ model_radio = gr.Radio(
139
+ choices=[
140
+ 'gabrielloiseau/LUAR-MUD-sentence-transformers',
141
+ 'gabrielloiseau/LUAR-CRUD-sentence-transformers',
142
+ 'miladalsh/light-luar',
143
+ 'AnnaWegmann/Style-Embedding',
144
+ 'Other'
145
+ ],
146
+ value='gabrielloiseau/LUAR-MUD-sentence-transformers',
147
+ label='Choose a Model to inspect'
148
+ )
149
+ print(f"Model choices: {model_radio.choices}")
150
+ print(f"Model default: {model_radio.value}")
151
+ custom_model = gr.Textbox(
152
+ label='Custom Model ID',
153
+ placeholder='Enter your Hugging Face Model ID here',
154
+ visible=False,
155
+ interactive=True
156
+ )
157
+ # Show the textbox when 'Other' is selected
158
+ model_radio.change(
159
+ fn=toggle_custom_model,
160
+ inputs=[model_radio],
161
+ outputs=[custom_model]
162
+ )
163
+
164
+ # ── Task Source Selection ─────────────────────────────────
165
+ task_mode = gr.Radio(
166
+ choices=["Predefined HRS Task", "Upload Your Own Task"],
167
+ value="Predefined HRS Task",
168
+ label="Select Task Source"
169
+ )
170
+
171
+ ground_truth_author = gr.State() # To store the index of the ground truth author
172
+
173
+ with gr.Column():
174
+ with gr.Column(visible=True) as predefined_container:
175
+ gr.HTML("""
176
+ <div style="
177
+ font-size: 1.3em;
178
+ font-weight: 600;
179
+ margin-bottom: 0.5em;
180
+ ">
181
+ Pick a pre-defined task to investigate (a mystery text and its three candidate authors)
182
+ </div>
183
+ """)
184
+ task_dropdown = gr.Dropdown(
185
+ choices=[f"Task {i}" for i in instance_ids],
186
+ value=f"Task {instance_ids[0]}",
187
+ label="Choose which mystery document to explain",
188
+ )
189
+ with gr.Column(visible=False) as custom_container:
190
+ gr.HTML("""
191
+ <div style="
192
+ font-size: 1.3em;
193
+ font-weight: 600;
194
+ margin-bottom: 0.5em;
195
+ ">
196
+ Upload your own task
197
+ </div>
198
+ """)
199
+ mystery_input = gr.File(label="Mystery (.txt)", file_types=['.txt'])
200
+ with gr.Row():
201
+ candidate1 = gr.File(label="Candidate 1 (.txt)", file_types=['.txt'])
202
+ gt1_checkbox = gr.Checkbox(label="Ground Truth?", value=False)
203
+
204
+ with gr.Row():
205
+ candidate2 = gr.File(label="Candidate 2 (.txt)", file_types=['.txt'])
206
+ gt2_checkbox = gr.Checkbox(label="Ground Truth?", value=False)
207
+
208
+ with gr.Row():
209
+ candidate3 = gr.File(label="Candidate 3 (.txt)", file_types=['.txt'])
210
+ gt3_checkbox = gr.Checkbox(label="Ground Truth?", value=False)
211
+
212
+ validation_msg = gr.Textbox(label="Validation Result", interactive=False)
213
+
214
+ for checkbox in [gt1_checkbox, gt2_checkbox, gt3_checkbox]:
215
+ checkbox.change(
216
+ fn=validate_ground_truth,
217
+ inputs=[gt1_checkbox, gt2_checkbox, gt3_checkbox],
218
+ outputs=[ground_truth_author, validation_msg]
219
+ )
220
+
221
+
222
+ # ── Load Task Button ─────────────────────────────────────
223
+ gr.HTML(instruction_callout("Click the button below to load the tasks and generate embeddings using selected model."))
224
+ load_button = gr.Button("Load Task & Generate Embeddings")
225
+
226
+ # ── HTML outputs for author texts ───────────────────────────
227
+ default_outputs = load_instance(0, instances)
228
+ #dont need defaults since they are loaded only on click of the load button
229
+ header = gr.HTML()
230
+ mystery = gr.HTML()
231
+ mystery_state = gr.State() # Store unformatted mystery text for later use
232
+ with gr.Row():
233
+ c0 = gr.HTML()
234
+ c1 = gr.HTML()
235
+ c2 = gr.HTML()
236
+ c0_state = gr.State() # Store unformatted candidate 1 text for later use
237
+ c1_state = gr.State() # Store unformatted candidate 2 text for later use
238
+ c2_state = gr.State() # Store unformatted candidate 3 text for later use
239
+ # ── State to hold embeddings DataFrame ─────────────────────
240
+ task_authors_embeddings_df = gr.State() # Store embeddings of task authors
241
+ background_authors_embeddings_df = gr.State() # Store background authors DataFrame
242
+ task_mode.change(
243
+ fn=toggle_task,
244
+ inputs=[task_mode],
245
+ outputs=[predefined_container, custom_container]
246
+ )
247
+ # ── Wire call to load task and generate embeddings once load button is clicked ───────────────────
248
+ predicted_author = gr.State() # Store predicted author from the embeddings
249
+ load_button.click(
250
+ fn=lambda: gr.update(value="⏳ Loading... Please wait", interactive=False),
251
+ inputs=[],
252
+ outputs=[load_button]
253
+ ).then(
254
+ fn=lambda mode, dropdown, mystery, c1, c2, c3, ground_truth_author, model_radio, custom_model_input:
255
+ update_task_display(
256
+ mode,
257
+ dropdown,
258
+ instances, # closed over
259
+ clustered_authors_df,
260
+ mystery,
261
+ c1,
262
+ c2,
263
+ c3,
264
+ ground_truth_author, # true_author placeholder
265
+ model_radio,
266
+ custom_model_input
267
+ ),
268
+ inputs=[task_mode, task_dropdown, mystery_input, candidate1, candidate2, candidate3, ground_truth_author, model_radio, custom_model],
269
+ outputs=[header, mystery, c0, c1, c2, mystery_state, c0_state, c1_state, c2_state, task_authors_embeddings_df, background_authors_embeddings_df, predicted_author, ground_truth_author] # embeddings_df is a placeholder for now
270
+ ).then(
271
+ fn=lambda: gr.update(value="Load Task & Generate Embeddings", interactive=True),
272
+ inputs=[],
273
+ outputs=[load_button]
274
+ )
275
+
276
+ # ── Visualization for features ─────────────────────────────
277
+ gr.HTML(instruction_callout("Run visualization to see which author is similar to the mystery document."))
278
+ run_btn = gr.Button("Run visualization")
279
+ bg_proj_state = gr.State()
280
+ bg_lbls_state = gr.State()
281
+ bg_authors_df = gr.State() # Holds the background authors DataFrame
282
+ with gr.Row():
283
+ with gr.Column(scale=3):
284
+ axis_ranges = gr.Textbox(visible=False, elem_id="axis-ranges")
285
+ plot = gr.Plot(
286
+ label="Visualization",
287
+ elem_id="feature-plot",
288
+ )
289
+ plot.change(
290
+ fn=None,
291
+ inputs=[plot],
292
+ outputs=[axis_ranges],
293
+ js="""
294
+ function(){
295
+ console.log("------------>[JS] plot.change() triggered<------------");
296
+
297
+ let attempts = 0;
298
+ const maxAttempts = 50;
299
+
300
+ const tryAttach = () => {
301
+ const gd = document.querySelector('#feature-plot .js-plotly-plot');
302
+ if (!gd) {
303
+ if (++attempts < maxAttempts) {
304
+ requestAnimationFrame(tryAttach);
305
+ } else {
306
+ console.error(" ------------>Could not find .js-plotly-plot after multiple attempts.<------------");
307
+ }
308
+ return;
309
+ }
310
+
311
+ if (gd.__zoomListenerAttached) {
312
+ console.log("------------>Zoom listener already attached.<------------");
313
+ return;
314
+ }
315
+
316
+ gd.__zoomListenerAttached = true;
317
+ console.log("------------>Zoom listener attached!<------------");
318
+
319
+ gd.on('plotly_relayout', (ev) => {
320
+ if (
321
+ ev['xaxis.range[0]'] === undefined ||
322
+ ev['xaxis.range[1]'] === undefined ||
323
+ ev['yaxis.range[0]'] === undefined ||
324
+ ev['yaxis.range[1]'] === undefined
325
+ ) return;
326
+
327
+ const payload = {
328
+ xaxis: [ev['xaxis.range[0]'], ev['xaxis.range[1]']],
329
+ yaxis: [ev['yaxis.range[0]'], ev['yaxis.range[1]']]
330
+ };
331
+
332
+ const txtbox = document.querySelector('#axis-ranges textarea');
333
+ if (txtbox) {
334
+ txtbox.value = JSON.stringify(payload);
335
+ txtbox.dispatchEvent(new Event('input', { bubbles: true }));
336
+ console.log("------------> Zoom payload dispatched:<------------", payload);
337
+ } else {
338
+ console.warn("------------> No hidden textbox found to write zoom payload.<------------");
339
+ }
340
+ });
341
+ };
342
+
343
+ requestAnimationFrame(tryAttach);
344
+ return '';
345
+ }
346
+ """
347
+ )
348
+
349
+
350
+ with gr.Column(scale=1):
351
+ expl_html = """
352
+ <h4>What am I looking at?</h4>
353
+ <p>
354
+ This plot shows the mystery author (β˜…) and three candidate authors (β—†)
355
+ in the AA model’s latent space.<br>
356
+ The grey ● symbols represent the background corpusβ€”real authors with diverse writing styles.
357
+ You can zoom in on any region of the plot. The system will analyze the visible authors
358
+ in that area and list the most representative writing style features for the zoomed-in region.<br>
359
+ Use this to compare your mystery text’s position against nearby writing styles and
360
+ investigate which features distinguish it from others.
361
+ </p>
362
+ """
363
+ gr.HTML(styled_html(expl_html))
364
+
365
+ cluster_dropdown = gr.Dropdown(choices=[], label="Select Cluster to Inspect", visible=False)
366
+ style_map_state = gr.State()
367
+ llm_style_feats_analysis = gr.State()
368
+ visible_zoomed_authors = gr.State()
369
+
370
+ if use_cluster_feats:
371
+ # ── Dynamic Cluster Choice dropdown ──────────────────────────────────
372
+ gr.HTML(instruction_callout("Choose a cluster from the dropdown below to inspect whether its features appear in the mystery author’s text."))
373
+ cluster_dropdown.visible = True
374
+ else:
375
+ gr.HTML(instruction_callout("Zoom in on the plot to select a set of background authors and see the presence of the top features from this set in candidate and mystery authors."))
376
+
377
+ with gr.Row():
378
+ # ── LLM Features Column ──────────────────────────────────
379
+ with gr.Column(scale=1, min_width=0):
380
+ # gr.Markdown("**Features from the cluster closest to the Mystery Author**")
381
+ gr.HTML("""
382
+ <div style="
383
+ font-size: 1.3em;
384
+ font-weight: 600;
385
+ margin-bottom: 0.5em;
386
+ ">
387
+ LLM-derived style features prominent in the zoomed-in region
388
+ </div>
389
+ """)
390
+ features_rb = gr.Radio(choices=[], label="LLM-derived style features for this zoomed-in region")#, label="Features from the cluster closest to the Mystery Author", info="LLM-derived style features for this cluster")
391
+ feature_list_state = gr.State()
392
+
393
+ # ── Gram2Vec Features Column ─────────────────────────────
394
+ with gr.Column(scale=1, min_width=0):
395
+ # gr.Markdown("**Top-10 Gram2Vec Features most likely to occur in Mystery Author**")
396
+ gr.HTML("""
397
+ <div style="
398
+ font-size: 1.3em;
399
+ font-weight: 600;
400
+ margin-bottom: 0.5em;
401
+ ">
402
+ Gram2Vec Features prominent in the zoomed-in region
403
+ </div>
404
+ """)
405
+ gram2vec_rb = gr.Radio(choices=[], label="Gram2Vec features for this zoomed-in region")#, label="Top-10 Gram2Vec Features most likely to occur in Mystery Author", info="Most prominent Gram2Vec features in the mystery text")
406
+ gram2vec_state = gr.State()
407
+
408
+ # ── Visualization button click ───────────────────────────────
409
+ run_btn.click(
410
+ fn=lambda iid, model_radio, custom_model_input, task_authors_embeddings_df, background_authors_embeddings_df, predicted_author, ground_truth_author: visualize_clusters_plotly(
411
+ int(iid.replace('Task ','')), cfg, instances, model_radio,
412
+ custom_model_input, task_authors_embeddings_df, background_authors_embeddings_df, predicted_author, ground_truth_author
413
+ ),
414
+ inputs=[task_dropdown, model_radio, custom_model, task_authors_embeddings_df, background_authors_embeddings_df, predicted_author, ground_truth_author],
415
+ outputs=[plot, style_map_state, bg_proj_state, bg_lbls_state, bg_authors_df]
416
+ )
417
+
418
+ # Populate feature list based on selection.
419
+ if use_cluster_feats:
420
+ # Use cluster-based flow
421
+ cluster_dropdown.change(
422
+ fn=on_cluster_change,
423
+ inputs=[cluster_dropdown, style_map_state],
424
+ outputs=[features_rb, gram2vec_rb , feature_list_state]
425
+ #adding feature_list_state to persisit all llm features in the app state
426
+ )
427
+ else:
428
+
429
+ axis_ranges.change(
430
+ fn=handle_zoom_with_retries,
431
+ inputs=[axis_ranges, bg_proj_state, bg_lbls_state, bg_authors_df, task_authors_embeddings_df],
432
+ outputs=[features_rb, gram2vec_rb , llm_style_feats_analysis, feature_list_state, visible_zoomed_authors]
433
+ )
434
+
435
+
436
+ # ── Show combined feature‐span highlights ──
437
+ # combined callout + legend in one HTML block
438
+ gr.HTML(
439
+ instruction_callout(
440
+ "Click \"Show Combined Spans\" to highlight the LLM (yellow) & Gram2Vec (blue) feature spans in the texts"
441
+ )
442
+ + """
443
+ <div style="
444
+ display: flex;
445
+ align-items: center;
446
+ justify-content: center;
447
+ gap: 2em;
448
+ margin-top: 0.5em;
449
+ font-size: 0.9em;
450
+ ">
451
+ <div style="display: flex; align-items: center; gap: 0.5em; font-weight: 600; font-size: 1.5em;">
452
+ <span style="
453
+ display: inline-block;
454
+ width: 1.5em; height: 1.5em;
455
+ background: #FFEB3B; /* bright yellow */
456
+ border: 1px solid #666;
457
+ vertical-align: middle;
458
+ "></span>
459
+ LLM feature
460
+ </div>
461
+ <div style="display: flex; align-items: center; gap: 0.5em; font-weight: 600; font-size: 1.5em;">
462
+ <span style="
463
+ display: inline-block;
464
+ width: 1.5em; height: 1.5em;
465
+ background: #5CB3FF; /* clearer blue */
466
+ border: 1px solid #666;
467
+ vertical-align: middle;
468
+ "></span>
469
+ Gram2Vec feature
470
+ </div>
471
+ </div>
472
+ """
473
+ )
474
+
475
+
476
+ combined_btn = gr.Button("Show Combined Spans")
477
+ combined_html = gr.HTML()
478
+ show_background_checkbox = gr.Checkbox(label="Show spans in background authors", value=False)
479
+ background_html = gr.HTML(visible=False)
480
+ # print(f"in app: all_feats={feature_list_state.value}")
481
+ # print(f"in app: sel_feat_llm={features_rb.value}")
482
+
483
+
484
+ combined_btn.click(
485
+ fn=show_combined_spans_all,
486
+ inputs=[features_rb,
487
+ gram2vec_rb,
488
+ llm_style_feats_analysis,
489
+ background_authors_embeddings_df,
490
+ task_authors_embeddings_df,
491
+ visible_zoomed_authors,
492
+ predicted_author,
493
+ ground_truth_author],
494
+ outputs=[combined_html, background_html]
495
+ )
496
+ # mapping -->
497
+ # iid = task_dropdown.value
498
+ # sel_feat_llm = features_rb.value
499
+ # all_feats = feature_list_state.value
500
+ # sel_feat_g2v = gram2vec_rb.value
501
+ # combined_html -> spans/html for task authors
502
+ # background_html -> spans/html for background authors
503
+
504
+ show_background_checkbox.change(
505
+ fn=lambda show: gr.update(visible=show),
506
+ inputs=[show_background_checkbox],
507
+ outputs=[background_html]
508
+ )
509
+
510
+ demo.launch(share=share)
511
+
512
+ if __name__ == "__main__":
513
+ parser = argparse.ArgumentParser()
514
+ parser.add_argument("--use_cluster_feats", action="store_true", help="Use cluster-based selection for features")
515
+ args = parser.parse_args()
516
+ app(share=True, use_cluster_feats=args.use_cluster_feats)
config/config.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.yaml
2
+ instances_to_explain_path: "./datasets/hrs_explanations.json"
3
+ instances_to_explain_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/raw/main/hrs_explanations.json?download"
4
+ interp_space_path: "./datasets/luar_interp_space_cluster_19/"
5
+ interp_space_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/luar_interp_space_cluster.zip?download=true"
6
+ gram2vec_feats_path: "./datasets/gram2vec_feats.csv"
7
+ gram2vec_feats_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/gram2vec_feats.csv?download=true"
8
+
9
+ style_feat_clm: "llm_tfidf_weights"
10
+ top_k: 10
11
+ only_llm_feats: false
12
+ only_gram2vec_feats: false
datasets/placeholder.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ test
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ nltk
2
+ spacy
3
+ scikit-learn
4
+ openai
5
+ python-dotenv
6
+ gradio==5.30
7
+ pyyaml
8
+ plotly
9
+ sentence_transformers
10
+ git+https://github.com/MiladAlshomary/gram2vec
utils/augmented_human_readable.txt ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adjective:ADJ
2
+ Adposition:ADP
3
+ Adverb:ADV
4
+ Auxiliary verb:AUX
5
+ Coordinating conjunction:CCONJ
6
+ Determiner:DET
7
+ Interjection:INTJ
8
+ Noun:NOUN
9
+ Numeral:NUM
10
+ Particle:PART
11
+ Pronoun:PRON
12
+ Proper noun:PROPN
13
+ Punctuation:PUNCT
14
+ Subordinating conjunction:SCONJ
15
+ Symbol:SYM
16
+ Verb:VERB
17
+ Other:X
18
+ Space:SPACE
19
+ Other (foreign words, typos, abbreviations):X
20
+ ADP: Adposition (preposition or postposition)
21
+
22
+ Perfect aspect:Aspect=Perf
23
+ Progressive aspect:Aspect=Prog
24
+ Accusative case:Case=Acc
25
+ Nominative case:Case=Nom
26
+ Definite article:Definite=Def
27
+ Indefinite article:Definite=Ind
28
+ Comparative degree:Degree=Cmp
29
+ Positive degree:Degree=Pos
30
+ Superlative degree:Degree=Sup
31
+ Feminine gender:Gender=Fem
32
+ Masculine gender:Gender=Masc
33
+ Indicative mood:Mood=Ind
34
+ Plural number:Number=Plur
35
+ Singular number:Number=Sing
36
+ First person:Person=1
37
+ Second person:Person=2
38
+ Third person:Person=3
39
+ Past tense:Tense=Past
40
+ Present tense:Tense=Pres
41
+ Finite verb form:VerbForm=Fin
42
+ Infinitive verb form:VerbForm=Inf
43
+
44
+ Adjective followed by Adjective:ADJ ADJ
45
+ Adjective followed by Adposition:ADJ ADP
46
+ Adjective followed by Adverb:ADJ ADV
47
+ Adjective followed by Auxiliary verb:ADJ AUX
48
+ Adjective followed by Coordinating conjunction:ADJ CCONJ
49
+ Adjective followed by Determiner:ADJ DET
50
+ Adjective followed by Interjection:ADJ INTJ
51
+ Adjective followed by Noun:ADJ NOUN
52
+ Adjective followed by Numeral:ADJ NUM
53
+ Adjective followed by Other:ADJ X
54
+ Adjective followed by Particle:ADJ PART
55
+ Adjective followed by Pronoun:ADJ PRON
56
+ Adjective followed by Proper noun:ADJ PROPN
57
+ Adjective followed by Punctuation:ADJ PUNCT
58
+ Adjective followed by Subordinating conjunction:ADJ SCONJ
59
+ Adjective followed by Symbol:ADJ SYM
60
+ Adjective followed by Verb:ADJ VERB
61
+ Adposition followed by Adjective:ADP ADJ
62
+ Adposition followed by Adposition:ADP ADP
63
+ Adposition followed by Adverb:ADP ADV
64
+ Adposition followed by Auxiliary verb:ADP AUX
65
+ Adposition followed by Coordinating conjunction:ADP CCONJ
66
+ Adposition followed by Determiner:ADP DET
67
+ Adposition followed by Interjection:ADP INTJ
68
+ Adposition followed by Noun:ADP NOUN
69
+ Adposition followed by Numeral:ADP NUM
70
+ Adposition followed by Other:ADP X
71
+ Adposition followed by Particle:ADP PART
72
+ Adposition followed by Pronoun:ADP PRON
73
+ Adposition followed by Proper noun:ADP PROPN
74
+ Adposition followed by Punctuation:ADP PUNCT
75
+ Adposition followed by Subordinating conjunction:ADP SCONJ
76
+ Adposition followed by Symbol:ADP SYM
77
+ Adposition followed by Verb:ADP VERB
78
+ Adverb followed by Adjective:ADV ADJ
79
+ Adverb followed by Adposition:ADV ADP
80
+ Adverb followed by Adverb:ADV ADV
81
+ Adverb followed by Auxiliary verb:ADV AUX
82
+ Adverb followed by Coordinating conjunction:ADV CCONJ
83
+ Adverb followed by Determiner:ADV DET
84
+ Adverb followed by Interjection:ADV INTJ
85
+ Adverb followed by Noun:ADV NOUN
86
+ Adverb followed by Numeral:ADV NUM
87
+ Adverb followed by Other:ADV X
88
+ Adverb followed by Particle:ADV PART
89
+ Adverb followed by Pronoun:ADV PRON
90
+ Adverb followed by Proper noun:ADV PROPN
91
+ Adverb followed by Punctuation:ADV PUNCT
92
+ Adverb followed by Subordinating conjunction:ADV SCONJ
93
+ Adverb followed by Symbol:ADV SYM
94
+ Adverb followed by Verb:ADV VERB
95
+ Auxiliary verb followed by Adjective:AUX ADJ
96
+ Auxiliary verb followed by Adposition:AUX ADP
97
+ Auxiliary verb followed by Adverb:AUX ADV
98
+ Auxiliary verb followed by Auxiliary verb:AUX AUX
99
+ Auxiliary verb followed by Coordinating conjunction:AUX CCONJ
100
+ Auxiliary verb followed by Determiner:AUX DET
101
+ Auxiliary verb followed by Interjection:AUX INTJ
102
+ Auxiliary verb followed by Noun:AUX NOUN
103
+ Auxiliary verb followed by Numeral:AUX NUM
104
+ Auxiliary verb followed by Other:AUX X
105
+ Auxiliary verb followed by Particle:AUX PART
106
+ Auxiliary verb followed by Pronoun:AUX PRON
107
+ Auxiliary verb followed by Proper noun:AUX PROPN
108
+ Auxiliary verb followed by Punctuation:AUX PUNCT
109
+ Auxiliary verb followed by Subordinating conjunction:AUX SCONJ
110
+ Auxiliary verb followed by Symbol:AUX SYM
111
+ Auxiliary verb followed by Verb:AUX VERB
112
+ Coordinating conjunction followed by Adjective:CCONJ ADJ
113
+ Coordinating conjunction followed by Adposition:CCONJ ADP
114
+ Coordinating conjunction followed by Adverb:CCONJ ADV
115
+ Coordinating conjunction followed by Auxiliary verb:CCONJ AUX
116
+ Coordinating conjunction followed by Coordinating conjunction:CCONJ CCONJ
117
+ Coordinating conjunction followed by Determiner:CCONJ DET
118
+ Coordinating conjunction followed by Interjection:CCONJ INTJ
119
+ Coordinating conjunction followed by Noun:CCONJ NOUN
120
+ Coordinating conjunction followed by Numeral:CCONJ NUM
121
+ Coordinating conjunction followed by Other:CCONJ X
122
+ Coordinating conjunction followed by Particle:CCONJ PART
123
+ Coordinating conjunction followed by Pronoun:CCONJ PRON
124
+ Coordinating conjunction followed by Proper noun:CCONJ PROPN
125
+ Coordinating conjunction followed by Punctuation:CCONJ PUNCT
126
+ Coordinating conjunction followed by Subordinating conjunction:CCONJ SCONJ
127
+ Coordinating conjunction followed by Symbol:CCONJ SYM
128
+ Coordinating conjunction followed by Verb:CCONJ VERB
129
+ Determiner followed by Adjective:DET ADJ
130
+ Determiner followed by Adposition:DET ADP
131
+ Determiner followed by Adverb:DET ADV
132
+ Determiner followed by Auxiliary verb:DET AUX
133
+ Determiner followed by Coordinating conjunction:DET CCONJ
134
+ Determiner followed by Determiner:DET DET
135
+ Determiner followed by Interjection:DET INTJ
136
+ Determiner followed by Noun:DET NOUN
137
+ Determiner followed by Numeral:DET NUM
138
+ Determiner followed by Other:DET X
139
+ Determiner followed by Particle:DET PART
140
+ Determiner followed by Pronoun:DET PRON
141
+ Determiner followed by Proper noun:DET PROPN
142
+ Determiner followed by Punctuation:DET PUNCT
143
+ Determiner followed by Subordinating conjunction:DET SCONJ
144
+ Determiner followed by Symbol:DET SYM
145
+ Determiner followed by Verb:DET VERB
146
+ Interjection followed by Adjective:INTJ ADJ
147
+ Interjection followed by Adposition:INTJ ADP
148
+ Interjection followed by Adverb:INTJ ADV
149
+ Interjection followed by Auxiliary verb:INTJ AUX
150
+ Interjection followed by Coordinating conjunction:INTJ CCONJ
151
+ Interjection followed by Determiner:INTJ DET
152
+ Interjection followed by Interjection:INTJ INTJ
153
+ Interjection followed by Noun:INTJ NOUN
154
+ Interjection followed by Numeral:INTJ NUM
155
+ Interjection followed by Other:INTJ X
156
+ Interjection followed by Particle:INTJ PART
157
+ Interjection followed by Pronoun:INTJ PRON
158
+ Interjection followed by Proper noun:INTJ PROPN
159
+ Interjection followed by Punctuation:INTJ PUNCT
160
+ Interjection followed by Subordinating conjunction:INTJ SCONJ
161
+ Interjection followed by Symbol:INTJ SYM
162
+ Interjection followed by Verb:INTJ VERB
163
+ Noun followed by Adjective:NOUN ADJ
164
+ Noun followed by Adposition:NOUN ADP
165
+ Noun followed by Adverb:NOUN ADV
166
+ Noun followed by Auxiliary verb:NOUN AUX
167
+ Noun followed by Coordinating conjunction:NOUN CCONJ
168
+ Noun followed by Determiner:NOUN DET
169
+ Noun followed by Interjection:NOUN INTJ
170
+ Noun followed by Noun:NOUN NOUN
171
+ Noun followed by Numeral:NOUN NUM
172
+ Noun followed by Other:NOUN X
173
+ Noun followed by Particle:NOUN PART
174
+ Noun followed by Pronoun:NOUN PRON
175
+ Noun followed by Proper noun:NOUN PROPN
176
+ Noun followed by Punctuation:NOUN PUNCT
177
+ Noun followed by Subordinating conjunction:NOUN SCONJ
178
+ Noun followed by Symbol:NOUN SYM
179
+ Noun followed by Verb:NOUN VERB
180
+ Numeral followed by Adjective:NUM ADJ
181
+ Numeral followed by Adposition:NUM ADP
182
+ Numeral followed by Adverb:NUM ADV
183
+ Numeral followed by Auxiliary verb:NUM AUX
184
+ Numeral followed by Coordinating conjunction:NUM CCONJ
185
+ Numeral followed by Determiner:NUM DET
186
+ Numeral followed by Interjection:NUM INTJ
187
+ Numeral followed by Noun:NUM NOUN
188
+ Numeral followed by Numeral:NUM NUM
189
+ Numeral followed by Other:NUM X
190
+ Numeral followed by Particle:NUM PART
191
+ Numeral followed by Pronoun:NUM PRON
192
+ Numeral followed by Proper noun:NUM PROPN
193
+ Numeral followed by Punctuation:NUM PUNCT
194
+ Numeral followed by Subordinating conjunction:NUM SCONJ
195
+ Numeral followed by Symbol:NUM SYM
196
+ Numeral followed by Verb:NUM VERB
197
+ Other followed by Adjective:X ADJ
198
+ Other followed by Adposition:X ADP
199
+ Other followed by Adverb:X ADV
200
+ Other followed by Auxiliary verb:X AUX
201
+ Other followed by Coordinating conjunction:X CCONJ
202
+ Other followed by Determiner:X DET
203
+ Other followed by Interjection:X INTJ
204
+ Other followed by Noun:X NOUN
205
+ Other followed by Numeral:X NUM
206
+ Other followed by Other:X X
207
+ Other followed by Particle:X PART
208
+ Other followed by Pronoun:X PRON
209
+ Other followed by Proper noun:X PROPN
210
+ Other followed by Punctuation:X PUNCT
211
+ Other followed by Subordinating conjunction:X SCONJ
212
+ Other followed by Symbol:X SYM
213
+ Other followed by Verb:X VERB
214
+ Particle followed by Adjective:PART ADJ
215
+ Particle followed by Adposition:PART ADP
216
+ Particle followed by Adverb:PART ADV
217
+ Particle followed by Auxiliary verb:PART AUX
218
+ Particle followed by Coordinating conjunction:PART CCONJ
219
+ Particle followed by Determiner:PART DET
220
+ Particle followed by Interjection:PART INTJ
221
+ Particle followed by Noun:PART NOUN
222
+ Particle followed by Numeral:PART NUM
223
+ Particle followed by Other:PART X
224
+ Particle followed by Particle:PART PART
225
+ Particle followed by Pronoun:PART PRON
226
+ Particle followed by Proper noun:PART PROPN
227
+ Particle followed by Punctuation:PART PUNCT
228
+ Particle followed by Subordinating conjunction:PART SCONJ
229
+ Particle followed by Symbol:PART SYM
230
+ Particle followed by Verb:PART VERB
231
+ Pronoun followed by Adjective:PRON ADJ
232
+ Pronoun followed by Adposition:PRON ADP
233
+ Pronoun followed by Adverb:PRON ADV
234
+ Pronoun followed by Auxiliary verb:PRON AUX
235
+ Pronoun followed by Coordinating conjunction:PRON CCONJ
236
+ Pronoun followed by Determiner:PRON DET
237
+ Pronoun followed by Interjection:PRON INTJ
238
+ Pronoun followed by Noun:PRON NOUN
239
+ Pronoun followed by Numeral:PRON NUM
240
+ Pronoun followed by Other:PRON X
241
+ Pronoun followed by Particle:PRON PART
242
+ Pronoun followed by Pronoun:PRON PRON
243
+ Pronoun followed by Proper noun:PRON PROPN
244
+ Pronoun followed by Punctuation:PRON PUNCT
245
+ Pronoun followed by Subordinating conjunction:PRON SCONJ
246
+ Pronoun followed by Symbol:PRON SYM
247
+ Pronoun followed by Verb:PRON VERB
248
+ Proper noun followed by Adjective:PROPN ADJ
249
+ Proper noun followed by Adposition:PROPN ADP
250
+ Proper noun followed by Adverb:PROPN ADV
251
+ Proper noun followed by Auxiliary verb:PROPN AUX
252
+ Proper noun followed by Coordinating conjunction:PROPN CCONJ
253
+ Proper noun followed by Determiner:PROPN DET
254
+ Proper noun followed by Interjection:PROPN INTJ
255
+ Proper noun followed by Noun:PROPN NOUN
256
+ Proper noun followed by Numeral:PROPN NUM
257
+ Proper noun followed by Other:PROPN X
258
+ Proper noun followed by Particle:PROPN PART
259
+ Proper noun followed by Pronoun:PROPN PRON
260
+ Proper noun followed by Proper noun:PROPN PROPN
261
+ Proper noun followed by Punctuation:PROPN PUNCT
262
+ Proper noun followed by Subordinating conjunction:PROPN SCONJ
263
+ Proper noun followed by Symbol:PROPN SYM
264
+ Proper noun followed by Verb:PROPN VERB
265
+ Punctuation followed by Adjective:PUNCT ADJ
266
+ Punctuation followed by Adposition:PUNCT ADP
267
+ Punctuation followed by Adverb:PUNCT ADV
268
+ Punctuation followed by Auxiliary verb:PUNCT AUX
269
+ Punctuation followed by Coordinating conjunction:PUNCT CCONJ
270
+ Punctuation followed by Determiner:PUNCT DET
271
+ Punctuation followed by Interjection:PUNCT INTJ
272
+ Punctuation followed by Noun:PUNCT NOUN
273
+ Punctuation followed by Numeral:PUNCT NUM
274
+ Punctuation followed by Other:PUNCT X
275
+ Punctuation followed by Particle:PUNCT PART
276
+ Punctuation followed by Pronoun:PUNCT PRON
277
+ Punctuation followed by Proper noun:PUNCT PROPN
278
+ Punctuation followed by Punctuation:PUNCT PUNCT
279
+ Punctuation followed by Subordinating conjunction:PUNCT SCONJ
280
+ Punctuation followed by Symbol:PUNCT SYM
281
+ Punctuation followed by Verb:PUNCT VERB
282
+ Subordinating conjunction followed by Adjective:SCONJ ADJ
283
+ Subordinating conjunction followed by Adposition:SCONJ ADP
284
+ Subordinating conjunction followed by Adverb:SCONJ ADV
285
+ Subordinating conjunction followed by Auxiliary verb:SCONJ AUX
286
+ Subordinating conjunction followed by Coordinating conjunction:SCONJ CCONJ
287
+ Subordinating conjunction followed by Determiner:SCONJ DET
288
+ Subordinating conjunction followed by Interjection:SCONJ INTJ
289
+ Subordinating conjunction followed by Noun:SCONJ NOUN
290
+ Subordinating conjunction followed by Numeral:SCONJ NUM
291
+ Subordinating conjunction followed by Other:SCONJ X
292
+ Subordinating conjunction followed by Particle:SCONJ PART
293
+ Subordinating conjunction followed by Pronoun:SCONJ PRON
294
+ Subordinating conjunction followed by Proper noun:SCONJ PROPN
295
+ Subordinating conjunction followed by Punctuation:SCONJ PUNCT
296
+ Subordinating conjunction followed by Subordinating conjunction:SCONJ SCONJ
297
+ Subordinating conjunction followed by Symbol:SCONJ SYM
298
+ Subordinating conjunction followed by Verb:SCONJ VERB
299
+ Symbol followed by Adjective:SYM ADJ
300
+ Symbol followed by Adposition:SYM ADP
301
+ Symbol followed by Adverb:SYM ADV
302
+ Symbol followed by Auxiliary verb:SYM AUX
303
+ Symbol followed by Coordinating conjunction:SYM CCONJ
304
+ Symbol followed by Determiner:SYM DET
305
+ Symbol followed by Interjection:SYM INTJ
306
+ Symbol followed by Noun:SYM NOUN
307
+ Symbol followed by Numeral:SYM NUM
308
+ Symbol followed by Other:SYM X
309
+ Symbol followed by Particle:SYM PART
310
+ Symbol followed by Pronoun:SYM PRON
311
+ Symbol followed by Proper noun:SYM PROPN
312
+ Symbol followed by Punctuation:SYM PUNCT
313
+ Symbol followed by Subordinating conjunction:SYM SCONJ
314
+ Symbol followed by Symbol:SYM SYM
315
+ Symbol followed by Verb:SYM VERB
316
+ Verb followed by Adjective:VERB ADJ
317
+ Verb followed by Adposition:VERB ADP
318
+ Verb followed by Adverb:VERB ADV
319
+ Verb followed by Auxiliary verb:VERB AUX
320
+ Verb followed by Coordinating conjunction:VERB CCONJ
321
+ Verb followed by Determiner:VERB DET
322
+ Verb followed by Interjection:VERB INTJ
323
+ Verb followed by Noun:VERB NOUN
324
+ Verb followed by Numeral:VERB NUM
325
+ Verb followed by Other:VERB X
326
+ Verb followed by Particle:VERB PART
327
+ Verb followed by Pronoun:VERB PRON
328
+ Verb followed by Proper noun:VERB PROPN
329
+ Verb followed by Punctuation:VERB PUNCT
330
+ Verb followed by Subordinating conjunction:VERB SCONJ
331
+ Verb followed by Symbol:VERB SYM
332
+ Verb followed by Verb:VERB VERB
333
+
334
+ Accusative case:Case=Acc
335
+ Comparative degree:Degree=Cmp
336
+ Definite article:Definite=Def
337
+ Feminine gender:Gender=Fem
338
+ Finite verb form:VerbForm=Fin
339
+ First person:Person=1
340
+ Indefinite article:Definite=Ind
341
+ Indicative mood:Mood=Ind
342
+ Infinitive verb form:VerbForm=Inf
343
+ Masculine gender:Gender=Masc
344
+ Nominative case:Case=Nom
345
+ Past tense:Tense=Past
346
+ Perfect aspect:Aspect=Perf
347
+ Plural number:Number=Plur
348
+ Positive degree:Degree=Pos
349
+ Present tense:Tense=Pres
350
+ Progressive aspect:Aspect=Prog
351
+ Second person:Person=2
352
+ Singular number:Number=Sing
353
+ Superlative degree:Degree=Sup
354
+ Third person:Person=3
355
+ Number of Tokens:num_tokens
356
+
357
+ Adjectival clause:acl
358
+ Adjectival complement:acomp
359
+ Adjectival modifier:amod
360
+ Adverbial clause modifier:advcl
361
+ Adverbial modifier:advmod
362
+ Agent (in passive voice):agent
363
+ Appositional modifier:appos
364
+ Attribute:attr
365
+ Case marking:case
366
+ Clausal complement:ccomp
367
+ Clausal subject:csubj
368
+ Clausal subject (passive):csubjpass
369
+ Complement of preposition:pcomp
370
+ Compound word:compound
371
+ Conjunct:conj
372
+ Coordinating conjunction:cc
373
+ Adposition (preposition or postposition):ADP
374
+ Dative:dative
375
+ Direct object:dobj
376
+ Expletive:expl
377
+ Marker (introducing adverbial clause):mark
378
+ Meta modifier:meta
379
+ Negation modifier:neg
380
+ Nominal modifier:nmod
381
+ Nominal subject (passive):nsubjpass
382
+ Noun phrase as adverbial modifier:npadvmod
383
+ Numeric modifier:nummod
384
+ Object of preposition:pobj
385
+ Object predicate:oprd
386
+ Open clausal complement:xcomp
387
+ Parataxis:parataxis
388
+ Passive auxiliary:auxpass
389
+ Possession modifier:poss
390
+ Pre-correlative conjunction:preconj
391
+ Predeterminer:predet
392
+ Prepositional modifier:prep
393
+ Quantifier modifier:quantmod
394
+ Relative clause modifier:relcl
395
+ Root of the sentence:ROOT
396
+ Unspecified dependency:dep
397
+
398
+ Article pronoun type:PronType=Art
399
+ Bracket punctuation type:PunctType=Brck
400
+ Cardinal number:NumType=Card
401
+ Comma punctuation type:PunctType=Comm
402
+ Comparative conjunction type:ConjType=Cmp
403
+ Dash punctuation type:PunctType=Dash
404
+ Demonstrative pronoun type:PronType=Dem
405
+ Final punctuation:PunctSide=Fin
406
+ Foreign word:Foreign=Yes
407
+ Gerund verb form:VerbForm=Ger
408
+ Hyphenated:Hyph=Yes
409
+ Indefinite pronoun type:PronType=Ind
410
+ Initial punctuation:PunctSide=Ini
411
+ Modal verb type:VerbType=Mod
412
+ Multiplicative number:NumType=Mult
413
+ Negative polarity:Polarity=Neg
414
+ Neuter gender:Gender=Neut
415
+ Ordinal number:NumType=Ord
416
+ Participle verb form:VerbForm=Part
417
+ Period punctuation type:PunctType=Peri
418
+ Possessive:Poss=Yes
419
+ Quotation punctuation type:PunctType=Quot
420
+ Reflexive:Reflex=Yes
421
+ Relative pronoun type:PronType=Rel
422
+
423
+ ❀️:❀️
424
+ πŸ‘:πŸ‘
425
+ πŸ˜‚:πŸ˜‚
426
+ 😍:😍
427
+
428
+ !:!
429
+ ":\\"
430
+ %:%
431
+ &:&
432
+ ':'
433
+ (:(
434
+ ):\)
435
+ *:*
436
+ ,:,
437
+ -:-
438
+ .:.
439
+ ;:;
440
+ ?:?
441
+ _:_
442
+ `:`
443
+ –:–
444
+ ':'
445
+ ':'
446
+
447
+ all-cleft:all-cleft
448
+ coordinate-clause:coordinate-clause
449
+ if-because-cleft:if-because-cleft
450
+ it-cleft:it-cleft
451
+ obj-relcl:obj-relcl
452
+ passive:passive
453
+ pseudo-cleft:pseudo-cleft
454
+ subj-relcl:subj-relcl
455
+ tag-question:tag-question
456
+ there-cleft:there-cleft
457
+
458
+ punctuation:punctuation
459
+
460
+ Articles:DET
461
+ Auxiliary Verbs:AUX
462
+ Conjunctions:CCONJ
463
+ Prepositions:ADP
464
+
465
+ Personal Pronouns:category:Personal Pronouns
466
+ Demonstrative Pronouns:category:Demonstrative Pronouns
467
+ Interrogative Pronouns:category:Interrogative Pronouns
468
+ Modal Verbs:category:Modal Verbs
469
+ Contractions:category:Contractions
470
+ Adverbs:category:Adverbs
471
+ Other:category:Other
472
+
473
+ i:i
474
+ me:me
475
+ my:my
476
+ myself:myself
477
+ we:we
478
+ our:our
479
+ ours:ours
480
+ ourselves:ourselves
481
+ you:you
482
+ 're:'re
483
+ 've:'ve
484
+ 'll:'ll
485
+ 'd:'d
486
+ 's:'s
487
+ 't:'t
488
+ your:your
489
+ yours:yours
490
+ yourself:yourself
491
+ yourselves:yourselves
492
+ he:he
493
+ him:him
494
+ his:his
495
+ himself:himself
496
+ she:she
497
+ her:her
498
+ ers:ers
499
+ herself:herself
500
+ it:it
501
+ its:its
502
+ itself:itself
503
+ they:they
504
+ them:them
505
+ their:their
506
+ theirs:theirs
507
+ themselves:themselves
508
+ what:what
509
+ which:which
510
+ who:who
511
+ this:this
512
+ that:that
513
+ these:these
514
+ those:those
515
+ am:am
516
+ is:is
517
+ are:are
518
+ was:was
519
+ were:were
520
+ be:be
521
+ been:been
522
+ being:being
523
+ have:have
524
+ has:has
525
+ had:had
526
+ having:having
527
+ do:do
528
+ does:does
529
+ did:did
530
+ doing:doing
531
+ a:a
532
+ an:an
533
+ the:the
534
+ and:and
535
+ but:but
536
+ if:if
537
+ or:or
538
+ because:because
539
+ as:as
540
+ until:until
541
+ while:while
542
+ of:of
543
+ at:at
544
+ by:by
545
+ for:for
546
+ with:with
547
+ about:about
548
+ against:against
549
+ between:between
550
+ into:into
551
+ through:through
552
+ during:during
553
+ before:before
554
+ after:after
555
+ above:above
556
+ below:below
557
+ to:to
558
+ from:from
559
+ up:up
560
+ down:down
561
+ in:in
562
+ out:out
563
+ on:on
564
+ off:off
565
+ over:over
566
+ under:under
567
+ again:again
568
+ further:further
569
+ then:then
570
+ once:once
571
+ here:here
572
+ there:there
573
+ when:when
574
+ where:where
575
+ why:why
576
+ how:how
577
+ all:all
578
+ any:any
579
+ both:both
580
+ each:each
581
+ few:few
582
+ more:more
583
+ most:most
584
+ other:other
585
+ some:some
586
+ such:such
587
+ no:no
588
+ nor:nor
589
+ not:not
590
+ only:only
591
+ own:own
592
+ same:same
593
+ so:so
594
+ than:than
595
+ too:too
596
+ very:very
597
+ can:can
598
+ will:will
599
+ just:just
600
+ don:don
601
+ should:should
602
+ now:now
603
+ ain:ain
604
+ aren:aren
605
+ couldn:couldn
606
+ didn:didn
607
+ doesn:doesn
608
+ hadn:hadn
609
+ hasn:hasn
610
+ haven:haven
611
+ isn:isn
612
+ ma:ma
613
+ shouldn:shouldn
614
+ wasn:wasn
615
+ weren:weren
616
+ won:won
617
+ wouldn:wouldn
utils/clustering_utils.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Required for clustering_author function:
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.cluster import DBSCAN
5
+ from sklearn.metrics import silhouette_score
6
+ # Required for analyze_space_distance_preservation
7
+ from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
8
+ from scipy.stats import pearsonr
9
+ from typing import List, Dict, Any
10
+
11
+ def _find_best_dbscan_eps(X: np.ndarray,
12
+ eps_values: List[float],
13
+ min_samples: int,
14
+ metric: str) -> tuple[float | None, np.ndarray | None, float]:
15
+ """
16
+ Iterates through eps_values for DBSCAN and returns the parameters
17
+ that yield the highest silhouette score.
18
+
19
+ Args:
20
+ X (np.ndarray): The input data (embeddings).
21
+ eps_values (List[float]): List of eps values to try.
22
+ min_samples (int): DBSCAN min_samples parameter.
23
+ metric (str): Distance metric for DBSCAN and silhouette score.
24
+
25
+ Returns:
26
+ tuple[float | None, np.ndarray | None, float]:
27
+ - best_eps: The eps value that resulted in the best score. None if no suitable clustering.
28
+ - best_labels: The cluster labels from the best DBSCAN run. None if no suitable clustering.
29
+ - best_score: The highest silhouette score achieved.
30
+ """
31
+ best_score = -1.001 # Silhouette score is in [-1, 1]
32
+ best_labels = None
33
+ best_eps = None
34
+
35
+ for eps in eps_values:
36
+ if eps <= 1e-9: # eps must be positive
37
+ continue
38
+ db = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
39
+ labels = db.fit_predict(X)
40
+
41
+ unique_labels_set = set(labels)
42
+ n_clusters_ = len(unique_labels_set) - (1 if -1 in unique_labels_set else 0)
43
+
44
+ if n_clusters_ > 1:
45
+ clustered_mask = (labels != -1)
46
+ if np.sum(clustered_mask) >= 2: # Need at least 2 non-noise points
47
+ X_clustered = X[clustered_mask]
48
+ labels_clustered = labels[clustered_mask]
49
+ try:
50
+ score = silhouette_score(X_clustered, labels_clustered, metric=metric)
51
+ if score > best_score:
52
+ best_score = score
53
+ best_labels = labels.copy()
54
+ best_eps = eps
55
+ except ValueError: # Catch errors from silhouette_score
56
+ pass
57
+ elif n_clusters_ == 1 and best_labels is None: # Fallback for single cluster
58
+ if not all(l == -1 for l in labels):
59
+ current_score_for_single_cluster = -0.5 # Nominal score
60
+ if current_score_for_single_cluster > best_score:
61
+ best_score = current_score_for_single_cluster
62
+ best_labels = labels.copy()
63
+ best_eps = eps
64
+ return best_eps, best_labels, best_score
65
+
66
+ def clustering_author(background_corpus_df: pd.DataFrame,
67
+ embedding_clm: str = 'style_embedding',
68
+ eps_values: List[float] = None,
69
+ min_samples: int = 5,
70
+ metric: str = 'cosine') -> pd.DataFrame:
71
+ """
72
+ Performs DBSCAN clustering on embeddings in a DataFrame.
73
+
74
+ Experiments with different `eps` parameters to find a clustering
75
+ that maximizes the silhouette score, indicating well-separated clusters.
76
+
77
+ Args:
78
+ background_corpus_df (pd.DataFrame): DataFrame with an embedding column.
79
+ embedding_clm (str): Name of the column containing embeddings.
80
+ Each embedding should be a list or NumPy array.
81
+ eps_values (List[float], optional): Specific `eps` values to test.
82
+ If None, a default range is used.
83
+ For 'cosine' metric, eps is typically in [0, 2].
84
+ For 'euclidean', scale depends on embedding magnitudes.
85
+ min_samples (int): DBSCAN `min_samples` parameter. Minimum number of
86
+ samples in a neighborhood for a point to be a core point.
87
+ metric (str): The distance metric to use for DBSCAN and silhouette score
88
+ (e.g., 'cosine', 'euclidean').
89
+
90
+ Returns:
91
+ pd.DataFrame: The input DataFrame with a new 'cluster_label' column.
92
+ Labels are from the DBSCAN run with the highest silhouette score.
93
+ If no suitable clustering is found, labels might be all -1 (noise).
94
+ """
95
+ if embedding_clm not in background_corpus_df.columns:
96
+ raise ValueError(f"Embedding column '{embedding_clm}' not found in DataFrame.")
97
+
98
+ embeddings_list = background_corpus_df[embedding_clm].tolist()
99
+
100
+ X_list = []
101
+ original_indices = [] # To map results back to the original DataFrame's indices
102
+
103
+ for i, emb_val in enumerate(embeddings_list):
104
+ if emb_val is not None:
105
+ try:
106
+ e = np.asarray(emb_val, dtype=float)
107
+ if e.ndim == 1 and e.size > 0: # Standard 1D vector
108
+ X_list.append(e)
109
+ original_indices.append(i)
110
+ elif e.ndim == 0 and e.size == 1: # Scalar value, treat as 1D vector of size 1
111
+ X_list.append(np.array([e.item()]))
112
+ original_indices.append(i)
113
+ # Silently skip empty arrays or improperly shaped arrays
114
+ except (TypeError, ValueError):
115
+ # Silently skip if conversion to float array fails
116
+ pass
117
+
118
+ # Initialize labels for all rows in the original DataFrame to -1 (noise/unprocessed)
119
+ final_labels_for_df = pd.Series(-1, index=background_corpus_df.index, dtype=int)
120
+
121
+ if not X_list:
122
+ print(f"No valid embeddings found in column '{embedding_clm}'. Assigning all 'cluster_label' as -1.")
123
+ background_corpus_df['cluster_label'] = final_labels_for_df
124
+ return background_corpus_df
125
+
126
+ X = np.array(X_list) # Creates a 2D array from the list of 1D arrays
127
+
128
+ if X.shape[0] == 1:
129
+ print("Only one valid embedding found. Assigning cluster label 0 to it.")
130
+ if original_indices: # Should always be true if X.shape[0]==1 from X_list
131
+ final_labels_for_df.iloc[original_indices[0]] = 0
132
+ background_corpus_df['cluster_label'] = final_labels_for_df
133
+ return background_corpus_df
134
+
135
+ if X.shape[0] < min_samples:
136
+ print(f"Number of valid embeddings ({X.shape[0]}) is less than min_samples ({min_samples}). "
137
+ f"All valid embeddings will be marked as noise (-1).")
138
+ for original_idx in original_indices:
139
+ final_labels_for_df.iloc[original_idx] = -1
140
+ background_corpus_df['cluster_label'] = final_labels_for_df
141
+ return background_corpus_df
142
+
143
+ if eps_values is None:
144
+ if metric == 'cosine':
145
+ eps_values = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
146
+ else:
147
+ if X.shape[0] > 1:
148
+ data_spread = np.std(X)
149
+ eps_values = [round(data_spread * f, 2) for f in [0.25, 0.5, 1.0]]
150
+ eps_values = [e for e in eps_values if e > 1e-6]
151
+ if not eps_values or X.shape[0] <=1:
152
+ eps_values = [0.5, 1.0, 1.5]
153
+ print(f"Warning: `eps_values` not provided. Using default range for metric '{metric}': {eps_values}. "
154
+ f"It's recommended to supply `eps_values` tuned to your data.")
155
+
156
+ print(f"Performing DBSCAN clustering (min_samples={min_samples}, metric='{metric}') with eps values: "
157
+ f"{[f'{e:.2f}' for e in eps_values]}")
158
+
159
+ best_eps, best_labels, best_score = _find_best_dbscan_eps(X, eps_values, min_samples, metric)
160
+
161
+ if best_labels is not None:
162
+ num_found_clusters = len(set(best_labels) - {-1})
163
+ print(f"Best clustering found: eps={best_eps:.2f}, Silhouette Score={best_score:.4f} ({num_found_clusters} clusters).")
164
+ for i, label in enumerate(best_labels):
165
+ original_df_idx = original_indices[i]
166
+ final_labels_for_df.iloc[original_df_idx] = label
167
+ else:
168
+ print("No suitable DBSCAN clustering found meeting criteria. All processed embeddings marked as noise (-1).")
169
+
170
+ background_corpus_df['cluster_label'] = final_labels_for_df
171
+ return background_corpus_df
172
+
173
+
174
+ def _safe_embeddings_to_matrix(embeddings_column: pd.Series) -> np.ndarray:
175
+ """
176
+ Converts a pandas Series of embeddings (expected to be lists of floats or 1D np.arrays)
177
+ into a 2D NumPy matrix. Handles None values and attempts to stack consistently.
178
+ Returns an empty 2D array (e.g., shape (0,0) or (0,D)) if conversion fails or no valid data.
179
+ """
180
+ embeddings_list = embeddings_column.tolist()
181
+
182
+ processed_1d_arrays = []
183
+ for emb in embeddings_list:
184
+ if emb is not None:
185
+ if hasattr(emb, '__iter__') and not isinstance(emb, (str, bytes)):
186
+ try:
187
+ arr = np.asarray(emb, dtype=float)
188
+ if arr.ndim == 1 and arr.size > 0:
189
+ processed_1d_arrays.append(arr)
190
+ except (TypeError, ValueError):
191
+ pass # Ignore embeddings that cannot be converted
192
+
193
+ if not processed_1d_arrays:
194
+ return np.empty((0,0))
195
+
196
+ # Check for consistent dimensionality before vstacking
197
+ first_len = processed_1d_arrays[0].shape[0]
198
+ consistent_embeddings = [arr for arr in processed_1d_arrays if arr.shape[0] == first_len]
199
+
200
+ if not consistent_embeddings:
201
+ return np.empty((0, first_len if processed_1d_arrays else 0)) # (0,D) or (0,0)
202
+
203
+ try:
204
+ return np.vstack(consistent_embeddings)
205
+ except ValueError:
206
+ # Should not happen if lengths are consistent
207
+ return np.empty((0, first_len))
208
+
209
+
210
+ def _compute_cluster_centroids(
211
+ df_clustered_items: pd.DataFrame, # DataFrame already filtered for non-noise items
212
+ embedding_clm: str,
213
+ cluster_label_clm: str
214
+ ) -> Dict[Any, np.ndarray]:
215
+ """Computes the centroid for each cluster from a pre-filtered DataFrame."""
216
+ centroids = {}
217
+ if df_clustered_items.empty:
218
+ return centroids
219
+
220
+ for cluster_id, group in df_clustered_items.groupby(cluster_label_clm):
221
+ embeddings_matrix = _safe_embeddings_to_matrix(group[embedding_clm])
222
+
223
+ if embeddings_matrix.ndim == 2 and embeddings_matrix.shape[0] > 0 and embeddings_matrix.shape[1] > 0:
224
+ centroids[cluster_id] = np.mean(embeddings_matrix, axis=0)
225
+ return centroids
226
+
227
+
228
+ def _project_to_centroid_space(
229
+ original_embeddings_matrix: np.ndarray, # (n_items, n_original_features)
230
+ centroids_map: Dict[Any, np.ndarray] # {cluster_id: centroid_vector (n_original_features,)}
231
+ ) -> np.ndarray:
232
+ """Projects embeddings into a new space defined by cluster centroids using cosine similarity."""
233
+ if not centroids_map or original_embeddings_matrix.ndim != 2 or \
234
+ original_embeddings_matrix.shape[0] == 0 or original_embeddings_matrix.shape[1] == 0:
235
+ return np.empty((original_embeddings_matrix.shape[0], 0)) # (n_items, 0_new_features)
236
+
237
+ sorted_cluster_ids = sorted(centroids_map.keys())
238
+
239
+ valid_centroid_vectors = []
240
+ for cid in sorted_cluster_ids:
241
+ centroid_vec = centroids_map[cid]
242
+ if isinstance(centroid_vec, np.ndarray) and centroid_vec.ndim == 1 and \
243
+ centroid_vec.size == original_embeddings_matrix.shape[1]:
244
+ valid_centroid_vectors.append(centroid_vec)
245
+
246
+ if not valid_centroid_vectors:
247
+ return np.empty((original_embeddings_matrix.shape[0], 0))
248
+
249
+ centroid_matrix = np.vstack(valid_centroid_vectors) # (n_valid_centroids, n_original_features)
250
+
251
+ # Result: (n_items, n_valid_centroids)
252
+ projected_matrix = cosine_similarity(original_embeddings_matrix, centroid_matrix)
253
+ return projected_matrix
254
+
255
+
256
+ def _get_pairwise_cosine_distances(embeddings_matrix: np.ndarray) -> np.ndarray:
257
+ """Calculates unique pairwise cosine distances from an embedding matrix."""
258
+ if not isinstance(embeddings_matrix, np.ndarray) or embeddings_matrix.ndim != 2 or \
259
+ embeddings_matrix.shape[0] < 2 or embeddings_matrix.shape[1] == 0:
260
+ return np.array([]) # Not enough samples or features
261
+
262
+ dist_matrix = cosine_distances(embeddings_matrix)
263
+ iu = np.triu_indices(dist_matrix.shape[0], k=1) # Upper triangle, excluding diagonal
264
+ return dist_matrix[iu]
265
+
266
+
267
+ def analyze_space_distance_preservation(
268
+ df: pd.DataFrame,
269
+ embedding_clm: str = 'style_embedding',
270
+ cluster_label_clm: str = 'cluster_label'
271
+ ) -> float | None:
272
+ """
273
+ Analyzes how well a new space, defined by cluster centroids, preserves
274
+ the cosine distance relationships from the original embedding space.
275
+
276
+ Args:
277
+ df (pd.DataFrame): DataFrame with original embeddings and cluster labels.
278
+ embedding_clm (str): Column name for original embeddings.
279
+ cluster_label_clm (str): Column name for cluster labels.
280
+
281
+ Returns:
282
+ float | None: Pearson correlation coefficient. Returns None if analysis
283
+ cannot be performed (e.g., <2 clusters, <2 items), or 0.0
284
+ if correlation is NaN (e.g. due to zero variance in distances).
285
+ """
286
+ df_valid_items = df[df[cluster_label_clm] != -1].copy()
287
+
288
+ if df_valid_items.shape[0] < 2:
289
+ return None # Need at least 2 items for pairwise distances
290
+
291
+ original_embeddings_matrix = _safe_embeddings_to_matrix(df_valid_items[embedding_clm])
292
+ if original_embeddings_matrix.ndim != 2 or original_embeddings_matrix.shape[0] < 2 or \
293
+ original_embeddings_matrix.shape[1] == 0:
294
+ return None # Valid matrix from original embeddings could not be formed
295
+
296
+ centroids = _compute_cluster_centroids(df_valid_items, embedding_clm, cluster_label_clm)
297
+ if len(centroids) < 2: # Need at least 2 centroids for a multi-dimensional new space
298
+ return None
299
+
300
+ projected_embeddings_matrix = _project_to_centroid_space(original_embeddings_matrix, centroids)
301
+ if projected_embeddings_matrix.ndim != 2 or projected_embeddings_matrix.shape[0] < 2 or \
302
+ projected_embeddings_matrix.shape[1] < 2: # New space needs at least 2 dimensions (centroids)
303
+ return None
304
+
305
+ distances_original_space = _get_pairwise_cosine_distances(original_embeddings_matrix)
306
+ distances_new_space = _get_pairwise_cosine_distances(projected_embeddings_matrix)
307
+
308
+ if distances_original_space.size == 0 or distances_new_space.size == 0 or \
309
+ distances_original_space.size != distances_new_space.size:
310
+ return None # Mismatch or empty distances
311
+
312
+ # Handle cases where variance is zero in one of the distance arrays (leads to NaN correlation)
313
+ if np.all(distances_new_space == distances_new_space[0]) or \
314
+ np.all(distances_original_space == distances_original_space[0]):
315
+ return 0.0 # Correlation is undefined or 0 if one variable is constant
316
+
317
+ try:
318
+ correlation, _ = pearsonr(distances_original_space, distances_new_space)
319
+ except ValueError: # Should be caught by variance checks, but as a safeguard
320
+ return None
321
+
322
+ if np.isnan(correlation):
323
+ return 0.0 # Default for NaN correlation
324
+
325
+ return correlation
utils/file_download.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import urllib.request
4
+ import zipfile
5
+ import tempfile
6
+ from tqdm import tqdm
7
+ from urllib.parse import urlparse
8
+
9
+ class TqdmUpTo(tqdm):
10
+ def update_to(self, b=1, bsize=1, tsize=None):
11
+ if tsize is not None:
12
+ self.total = tsize
13
+ self.update(b * bsize - self.n)
14
+
15
+ def download_file_override(url: str, dest_path: str):
16
+ """
17
+ Download a file from a URL and always overwrite the target.
18
+ If it's a zip, extract its contents directly into dest_path (no extra folder level).
19
+ If it's not a zip, save it directly to dest_path.
20
+ """
21
+
22
+ # Ensure parent dir for files
23
+ dest_dir = dest_path if dest_path.endswith(('/', '\\')) else os.path.dirname(dest_path)
24
+ if dest_dir:
25
+ os.makedirs(dest_dir, exist_ok=True)
26
+
27
+ # Temp file for download
28
+ tmp_fd, tmp_path = tempfile.mkstemp()
29
+ os.close(tmp_fd)
30
+
31
+ filename = os.path.basename(urlparse(url).path) or "downloaded.file"
32
+ print(f"Downloading {filename}...")
33
+
34
+ try:
35
+ with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=filename) as t:
36
+ urllib.request.urlretrieve(url, filename=tmp_path, reporthook=t.update_to)
37
+
38
+ if zipfile.is_zipfile(tmp_path):
39
+ # Remove dest_path if exists
40
+ if os.path.exists(dest_path):
41
+ shutil.rmtree(dest_path)
42
+ os.makedirs(dest_path, exist_ok=True)
43
+
44
+ # Extract into temp folder first
45
+ with tempfile.TemporaryDirectory() as tmp_extract_dir:
46
+ with zipfile.ZipFile(tmp_path, 'r') as z:
47
+ z.extractall(tmp_extract_dir)
48
+
49
+ # Move *contents* of extracted folder into dest_path
50
+ for item in os.listdir(tmp_extract_dir):
51
+ src = os.path.join(tmp_extract_dir, item)
52
+ dst = os.path.join(dest_path, item)
53
+ if os.path.isdir(src):
54
+ shutil.move(src, dst)
55
+ else:
56
+ shutil.move(src, dst)
57
+
58
+ print(f"Extracted zip contents into '{dest_path}'.")
59
+ else:
60
+ # Ensure parent dir exists
61
+ os.makedirs(os.path.dirname(dest_path) or ".", exist_ok=True)
62
+ if os.path.exists(dest_path):
63
+ os.remove(dest_path)
64
+ shutil.move(tmp_path, dest_path)
65
+ tmp_path = None
66
+ print(f"Saved file to '{dest_path}'.")
67
+
68
+ finally:
69
+ if tmp_path and os.path.exists(tmp_path):
70
+ os.remove(tmp_path)
utils/generate_augmented_mapping.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import re
3
+
4
+ def load_original_map_and_extract_morph(path="human_readable.txt"):
5
+ human_to_code = {}
6
+ morph_entries = []
7
+
8
+ with open(path, "r", encoding="utf-8") as f:
9
+ for line in f:
10
+ line = line.strip()
11
+ if not line or ":" not in line or line.startswith("#"):
12
+ continue
13
+ key, val = [p.strip() for p in line.split(":", 1)]
14
+
15
+ # If key looks like Aspect=Perf, it's a morphological tag
16
+ if "=" in key:
17
+ morph_entries.append((val, key)) # human:code
18
+ else:
19
+ human_to_code[val] = key # human:code for POS/etc.
20
+
21
+ return human_to_code, morph_entries
22
+
23
+ def extract_bigrams_from_csv(csv_path="../datasets/gram2vec_feats.csv"):
24
+ bigrams = set()
25
+ with open(csv_path, "r", encoding="utf-8") as f:
26
+ reader = csv.DictReader(f)
27
+ for row in reader:
28
+ feat = row["gram2vec_feats"]
29
+ if feat.startswith("Part-of-Speech Bigram:"):
30
+ human_bigram = feat.split(":", 1)[1].strip()
31
+ if "followed by" in human_bigram:
32
+ bigrams.add(human_bigram)
33
+ return bigrams
34
+
35
+ def generate_bigram_code_map(human_to_code, bigrams):
36
+ pattern = re.compile(r"(.+?) followed by (.+)")
37
+ code_map = {}
38
+
39
+ for bigram in bigrams:
40
+ match = pattern.match(bigram)
41
+ if match:
42
+ x = match.group(1).strip()
43
+ y = match.group(2).strip()
44
+ code_x = human_to_code.get(x)
45
+ code_y = human_to_code.get(y)
46
+ if code_x and code_y:
47
+ code_map[bigram] = f"{code_x} {code_y}"
48
+ else:
49
+ print(f"Could not map: {bigram} β†’ {code_x}, {code_y}")
50
+ else:
51
+ print(f"Not matched: {bigram}")
52
+ return code_map
53
+
54
+ def write_augmented_map(pos_bigram_map, morph_entries, original_path="human_readable.txt", output_path="augmented_human_readable.txt"):
55
+ with open(output_path, "w", encoding="utf-8") as f:
56
+ # Flip original lines: write human-readable:code instead of code:human
57
+ with open(original_path, "r", encoding="utf-8") as orig:
58
+ for line in orig:
59
+ line = line.strip()
60
+ if not line or line.startswith("#"):
61
+ f.write(line + "\n")
62
+ continue
63
+ if ":" not in line:
64
+ continue
65
+ key, val = [p.strip() for p in line.split(":", 1)]
66
+ flipped_line = f"{val}:{key}\n"
67
+ f.write(flipped_line)
68
+
69
+
70
+ # Add new section for POS bigrams
71
+ f.write("\n")
72
+ for human, code in sorted(pos_bigram_map.items()):
73
+ f.write(f"{human}:{code}\n")
74
+
75
+ # Re-add morph tag mappings
76
+ f.write("\n")
77
+ for human, code in sorted(morph_entries):
78
+ f.write(f"{human}:{code}\n")
79
+
80
+ print(f"Augmented map written to {output_path}")
81
+
82
+ # Run all
83
+ human_to_code, morph_entries = load_original_map_and_extract_morph()
84
+ bigrams = extract_bigrams_from_csv()
85
+ pos_bigram_map = generate_bigram_code_map(human_to_code, bigrams)
86
+ write_augmented_map(pos_bigram_map, morph_entries)
utils/gram2vec_feat_utils.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import html
3
+
4
+ from collections import namedtuple
5
+ from gram2vec.feature_locator import find_feature_spans
6
+ from functools import lru_cache
7
+
8
+ from utils.llm_feat_utils import generate_feature_spans_cached
9
+ import pandas as pd
10
+ Span = namedtuple('Span', ['start_char', 'end_char'])
11
+
12
+ from gram2vec import vectorizer
13
+
14
+ # ── the FEATURE_HANDLERS & loader ────────────
15
+ FEATURE_HANDLERS = {
16
+ "Part-of-Speech Unigram": "pos_unigrams",
17
+ "Part-of-Speech Bigram": "pos_bigrams",
18
+ "Function Word": "func_words",
19
+ "Punctuation": "punctuation",
20
+ "Letter": "letters",
21
+ "Dependency Label": "dep_labels",
22
+ "Morphology Tag": "morph_tags",
23
+ "Sentence Type": "sentences",
24
+ "Emoji": "emojis",
25
+ "Number of Tokens": "num_tokens"
26
+ }
27
+
28
+ @lru_cache(maxsize=1)
29
+ def load_code_map(txt_path: str = "utils/augmented_human_readable.txt") -> dict:
30
+ code_map = {}
31
+ with open(txt_path, "r", encoding="utf-8") as f:
32
+ for line in f:
33
+ line = line.strip()
34
+ if not line:
35
+ continue
36
+ human, code = [p.strip() for p in line.split(":", 1)]
37
+ code_map[human] = code
38
+ return code_map
39
+
40
+ def get_shorthand(feature_str: str) -> str:
41
+ """
42
+ Expects 'Category:Human-Readable', returns e.g. 'pos_unigrams:ADJ' or None.
43
+ """
44
+ try:
45
+ category, human = [p.strip() for p in feature_str.split(":", 1)]
46
+ # print(f"Category: {category}, Human: {human}")
47
+ except ValueError:
48
+ # print("Invalid format for feature string:", feature_str)
49
+ return None
50
+ if category not in FEATURE_HANDLERS:
51
+ return None
52
+ code = load_code_map().get(human)
53
+ if code is None:
54
+ # print(f"Warning: No code found for human-readable feature '{human}'")
55
+ return None # fallback to the human-readable name
56
+ return f"{FEATURE_HANDLERS[category]}:{code}"
57
+
58
+ def get_fullform(shorthand: str) -> str:
59
+ """
60
+ Expects 'prefix:code' (e.g., 'pos_unigrams:ADJ'), returns 'Category:Human-Readable'
61
+ (e.g., 'Part-of-Speech Unigram:Adjective'), or None if invalid.
62
+ """
63
+ try:
64
+ prefix, code = shorthand.split(":", 1)
65
+ except ValueError:
66
+ return None
67
+
68
+ # Reverse FEATURE_HANDLERS
69
+ reverse_handlers = {v: k for k, v in FEATURE_HANDLERS.items()}
70
+ category = reverse_handlers.get(prefix)
71
+ if category is None:
72
+ return None
73
+
74
+ # Reverse code map
75
+ code_map = load_code_map()
76
+ reverse_code_map = {v: k for k, v in code_map.items()}
77
+ human = reverse_code_map.get(code)
78
+ if human is None:
79
+ return None
80
+
81
+ return f"{category}:{human}"
82
+
83
+ def highlight_both_spans(text, llm_spans, gram_spans):
84
+ """
85
+ Walk the original `text` once, injecting <mark> tags at the correct offsets,
86
+ so that nested or overlapping highlights never stomp on each other.
87
+ """
88
+
89
+ # Inline CSS : mark-llm is in yellow, mark-gram in blue
90
+ style = """
91
+ <style>
92
+ .mark-llm { background-color: #fff176; }
93
+ .mark-gram { background-color: #90caf9; }
94
+ </style>
95
+ """
96
+
97
+ # Turn each span into two β€œevents”: open and close
98
+ events = []
99
+ for s in llm_spans:
100
+ events.append((s.start_char, 'open', 'llm'))
101
+ events.append((s.end_char, 'close', 'llm'))
102
+ for s in gram_spans:
103
+ events.append((s.start_char, 'open', 'gram'))
104
+ events.append((s.end_char, 'close', 'gram'))
105
+
106
+ # Sort by position;
107
+ events.sort(key=lambda e: (e[0], 0 if e[1]=='open' else 1))
108
+
109
+ out = []
110
+ last_idx = 0
111
+ for idx, typ, cls in events:
112
+ # escape the slice between last index and this event
113
+ out.append(html.escape(text[last_idx:idx]))
114
+ if typ == 'open':
115
+ out.append(f'<mark class="mark-{cls}">')
116
+ else:
117
+ out.append('</mark>')
118
+ last_idx = idx
119
+
120
+ out.append(html.escape(text[last_idx:]))
121
+ highlighted = "".join(out)
122
+
123
+ highlighted = highlighted.replace('\n', '<br>')
124
+
125
+ return style + highlighted
126
+
127
+
128
+ def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
129
+ llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=7):
130
+ """
131
+ For mystery + 3 candidates:
132
+ 1. get llm spans via your existing cache+API
133
+ 2. get gram2vec spans via find_feature_spans
134
+ 3. merge and highlight both
135
+ """
136
+ print(f"\n\n\n\n\nShowing combined spans for LLM feature '{selected_feature_llm}' and Gram2Vec feature '{selected_feature_g2v}'")
137
+ print(f"predicted_author: {predicted_author}, ground_truth_author: {ground_truth_author}")
138
+ print(f" keys = {background_authors_embeddings_df.keys()}")
139
+
140
+ # background_and_task_authors = pd.concat([task_authors_embeddings_df, background_authors_embeddings_df])
141
+ # background_and_task_authors = background_and_task_authors[background_and_task_authors.authorID.isin(visible_authors)]
142
+
143
+ #get the visible background authors
144
+ background_authors_embeddings_df = background_authors_embeddings_df[background_authors_embeddings_df.authorID.isin(visible_authors)]
145
+ background_and_task_authors = pd.concat([task_authors_embeddings_df, background_authors_embeddings_df])
146
+
147
+ authors_texts = ['\n\n =========== \n\n'.join(x) if type(x) == list else x for x in background_and_task_authors[:max_num_authors]['fullText'].tolist()]
148
+ authors_names = background_and_task_authors[:max_num_authors]['authorID'].tolist()
149
+ print(f"Number of authors to show: {len(authors_texts)}")
150
+ print(f"Authors names: {authors_names}")
151
+ texts = list(zip(authors_names, authors_texts))
152
+
153
+ if selected_feature_llm and selected_feature_llm != "None":
154
+ # print(llm_style_feats_analysis)
155
+ author_list = list(llm_style_feats_analysis['spans'].values())
156
+ llm_spans_list = []
157
+ for i, (_, txt) in enumerate(texts):
158
+ author_spans_list = []
159
+ for txt_span in author_list[i][selected_feature_llm]:
160
+ author_spans_list.append(Span(txt.find(txt_span), txt.find(txt_span) + len(txt_span)))
161
+ llm_spans_list.append(author_spans_list)
162
+ else:
163
+ print("Skipping LLM span extraction: feature is None")
164
+ llm_spans_list = [[] for _ in texts]
165
+
166
+ if selected_feature_g2v and selected_feature_g2v != "None":
167
+ # get gram2vec spans
168
+ gram_spans_list = []
169
+ print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
170
+ short = get_shorthand(selected_feature_g2v)
171
+ print(f"short hand: {short}")
172
+ for role, txt in texts:
173
+ try:
174
+ print(f"Finding spans for {short} {role}")
175
+ spans = find_feature_spans(txt, short)
176
+ # spans = [Span(fs.start_char, fs.end_char) for fs in raw_spans]
177
+ except:
178
+ print(f"Error finding spans for {short} {role}")
179
+ spans = []
180
+ gram_spans_list.append(spans)
181
+ else:
182
+ print("Skipping Gram2Vec span extraction: feature is None")
183
+ gram_spans_list = [[] for _ in texts]
184
+
185
+ # build HTML blocks
186
+ print(f" ----> Number of authors: {len(texts)}")
187
+
188
+ html_task_authors = create_html(
189
+ texts[:4], #first 4 are task
190
+ llm_spans_list,
191
+ gram_spans_list,
192
+ selected_feature_llm,
193
+ selected_feature_g2v,
194
+ short,
195
+ background = False,
196
+ predicted_author=predicted_author,
197
+ ground_truth_author=ground_truth_author
198
+ )
199
+ combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
200
+
201
+ html_background_authors = create_html(
202
+ texts[4:], #last three are background
203
+ llm_spans_list,
204
+ gram_spans_list,
205
+ selected_feature_llm,
206
+ selected_feature_g2v,
207
+ short,
208
+ background = True,
209
+ predicted_author=predicted_author,
210
+ ground_truth_author=ground_truth_author
211
+ )
212
+ background_html = "<div>" + "\n<hr>\n".join(html_background_authors) + "</div>"
213
+ return combined_html, background_html
214
+
215
+ def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id: int=0) -> str:
216
+ """
217
+ Returns a human-readable label for the author.
218
+ """
219
+ print(f"get_label called with label: {label}, predicted_author: {predicted_author}, ground_truth_author: {ground_truth_author}, bg_id: {bg_id}")
220
+ if label.startswith("Mystery") or label.startswith("Q_author"):
221
+ return "Mystery Author"
222
+ elif label.startswith("a0_author") or label.startswith("a1_author") or label.startswith("a2_author") or label.startswith("Candidate"):
223
+ if label.startswith("Candidate"):
224
+ id = int(label.split(" ")[2]) # Get the number after 'Candidate Author'
225
+ else:
226
+ id = label.split("_")[0][-1] # Get the last character of the first part (a0, a1, a2)
227
+ if predicted_author is not None and ground_truth_author is not None:
228
+ if int(id) == predicted_author and int(id) == ground_truth_author:
229
+ return f"Candidate {int(id)+1} (Predicted & Ground Truth)"
230
+ elif int(id) == predicted_author:
231
+ return f"Candidate {int(id)+1} (Predicted)"
232
+ elif int(id) == ground_truth_author:
233
+ return f"Candidate {int(id)+1} (Ground Truth)"
234
+ else:
235
+ return f"Candidate {int(id)+1}"
236
+ else:
237
+ return f"Candidate {int(id)+1}"
238
+ else:
239
+ return f"Background Author {bg_id+1}"
240
+
241
+ def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
242
+ html = []
243
+ for i, (label, txt) in enumerate(texts):
244
+ label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
245
+ combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
246
+ notice = ""
247
+ if selected_feature_llm == "None":
248
+ notice += f"""
249
+ <div style="padding:8px; background:#eee; border:1px solid #aaa;">
250
+ <em>No LLM feature selected.</em>
251
+ </div>
252
+ """
253
+ elif not llm_spans_list[i]:
254
+ notice += f"""
255
+ <div style="padding:8px; background:#fee; border:1px solid #f00;">
256
+ <em>No spans found for LLM feature "{selected_feature_llm}".</em>
257
+ </div>
258
+ """
259
+ if selected_feature_g2v == "None":
260
+ notice += f"""
261
+ <div style="padding:8px; background:#eee; border:1px solid #aaa;">
262
+ <em>No Gram2Vec feature selected.</em>
263
+ </div>
264
+ """
265
+ elif not short:
266
+ notice += f"""
267
+ <div style="padding:8px; background:#fee; border:1px solid #f00;">
268
+ <em>Invalid or unmapped feature: "{selected_feature_g2v}".</em>
269
+ </div>
270
+ """
271
+ elif not gram_spans_list[i]:
272
+ notice += f"""
273
+ <div style="padding:8px; background:#fee; border:1px solid #f00;">
274
+ <em>No spans found for Gram2Vec feature "{selected_feature_g2v}".</em>
275
+ </div>
276
+ """
277
+ html.append(f"""
278
+ <h3>{label}</h3>
279
+ {notice}
280
+ <div style="border:1px solid #ccc; padding:8px; margin-bottom:1em;">
281
+ {combined}
282
+ </div>
283
+ """)
284
+ return html
utils/human_readable.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ADJ: Adjective
2
+ ADP: Adposition
3
+ ADV: Adverb
4
+ AUX: Auxiliary verb
5
+ CCONJ: Coordinating conjunction
6
+ DET: Determiner
7
+ INTJ: Interjection
8
+ NOUN: Noun
9
+ NUM: Numeral
10
+ PART: Particle
11
+ PRON: Pronoun
12
+ PROPN: Proper noun
13
+ PUNCT: Punctuation
14
+ SCONJ: Subordinating conjunction
15
+ SYM: Symbol
16
+ VERB: Verb
17
+ X: Other
18
+ SPACE: Space
19
+
20
+ Aspect=Perf: Perfect aspect
21
+ Aspect=Prog: Progressive aspect
22
+ Case=Acc: Accusative case
23
+ Case=Nom: Nominative case
24
+ Definite=Def: Definite article
25
+ Definite=Ind: Indefinite article
26
+ Degree=Cmp: Comparative degree
27
+ Degree=Pos: Positive degree
28
+ Degree=Sup: Superlative degree
29
+ Gender=Fem: Feminine gender
30
+ Gender=Masc: Masculine gender
31
+ Mood=Ind: Indicative mood
32
+ Number=Plur: Plural number
33
+ Number=Sing: Singular number
34
+ Person=1: First person
35
+ Person=2: Second person
36
+ Person=3: Third person
37
+ Tense=Past: Past tense
38
+ Tense=Pres: Present tense
39
+ VerbForm=Fin: Finite verb form
40
+ VerbForm=Inf: Infinitive verb form
utils/interp_space_utils.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ import math
6
+ from collections import Counter, defaultdict
7
+ from typing import List, Any
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
+ import os
10
+ import pickle
11
+ import hashlib
12
+ import json
13
+ from gram2vec import vectorizer
14
+ from openai import OpenAI
15
+ from openai.lib._pydantic import to_strict_json_schema
16
+ from pydantic import BaseModel
17
+ from pydantic import ValidationError
18
+ import time
19
+ from utils.llm_feat_utils import generate_feature_spans_cached
20
+ from collections import Counter
21
+ import numpy as np
22
+ from sklearn.metrics.pairwise import cosine_similarity
23
+
24
+ CACHE_DIR = "datasets/embeddings_cache"
25
+ os.makedirs(CACHE_DIR, exist_ok=True)
26
+ # Bump this whenever there is a change etc...
27
+ CACHE_VERSION = 1
28
+
29
+ class style_analysis_schema(BaseModel):
30
+ features: list[str]
31
+ spans: dict[str, dict[str, list[str]]]
32
+
33
+ class FeatureIdentificationSchema(BaseModel):
34
+ features: list[str]
35
+
36
+ class SpanExtractionSchema(BaseModel):
37
+ spans: dict[str, dict[str, list[str]]] # {author_name: {feature: [spans]}}
38
+
39
+
40
+
41
+ def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd.DataFrame=None, text_clm='fullText') -> pd.DataFrame:
42
+ """
43
+ Computes gram2vec feature vectors for each author and adds them to the DataFrame.
44
+ This effectively creates a mapping from each author to their vector.
45
+ """
46
+ if task_authors_df is not None:
47
+ print (f"concatenating task authors and background corpus authors")
48
+ print(f"Number of task authors: {len(task_authors_df)}")
49
+ print(f"task authors author_ids: {task_authors_df.authorID.tolist()}")
50
+ print(f"task authors -->")
51
+ print(task_authors_df)
52
+ print(f"Number of background corpus authors: {len(clustered_authors_df)}")
53
+ clustered_authors_df = pd.concat([task_authors_df, clustered_authors_df])
54
+ print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
55
+
56
+ # Gather the input texts (preserves list-of-strings if any)
57
+ #texts = background_corpus_df[text_clm].fillna("").tolist()
58
+ author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
59
+
60
+ print(f"Number of author_texts: {len(author_texts)}")
61
+
62
+ # Create a reproducible JSON serialization of the texts
63
+ serialized = json.dumps({
64
+ "col": text_clm,
65
+ "texts": author_texts
66
+ }, sort_keys=True, ensure_ascii=False)
67
+
68
+ # Compute MD5 hash
69
+ digest = hashlib.md5(serialized.encode("utf-8")).hexdigest()
70
+ cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
71
+
72
+ # If cache hit, load and return
73
+ if os.path.exists(cache_path):
74
+ print(f"Cache hit...")
75
+ with open(cache_path, "rb") as f:
76
+ clustered_authors_df = pickle.load(f)
77
+
78
+ else: # Else compute and cache
79
+ g2v_feats_df = vectorizer.from_documents(author_texts, batch_size=16)
80
+
81
+ print(f"Number of g2v features: {len(g2v_feats_df)}")
82
+ print(f"Number of clustered_authors_df.authorID.tolist(): {len(clustered_authors_df.authorID.tolist())}")
83
+ print(f"Number of g2v_feats_df.to_numpy().tolist(): {len(g2v_feats_df.to_numpy().tolist())}")
84
+
85
+ ids = clustered_authors_df.authorID.tolist()
86
+ counter = Counter(ids)
87
+ duplicates = [k for k, v in counter.items() if v > 1]
88
+
89
+ print(f"Duplicate authorIDs: {duplicates}")
90
+ print(f"Number of duplicates: {len(ids) - len(set(ids))}")
91
+
92
+ author_to_g2v_feats = {x[0]: x[1] for x in zip(clustered_authors_df.authorID.tolist(), g2v_feats_df.to_numpy().tolist())}
93
+
94
+ print(f"Number of authors with g2v features: {len(author_to_g2v_feats)}")
95
+
96
+ # apply normalization
97
+ vector_std = np.std(list(author_to_g2v_feats.values()), axis=0)
98
+ vector_mean = np.mean(list(author_to_g2v_feats.values()), axis=0)
99
+ vector_std[vector_std == 0] = 1.0
100
+ author_to_g2v_feats_z_normalized = {x[0]: (x[1] - vector_mean) / vector_std for x in author_to_g2v_feats.items()}
101
+
102
+ print(f"Number of authors with g2v features normalized: {len(author_to_g2v_feats_z_normalized)}")
103
+ print(f" len of clustered authors df: {len(clustered_authors_df)}")
104
+
105
+
106
+ # Add the vectors as a new column of the DataFrame.
107
+ clustered_authors_df['g2v_vector'] = [{x[1]: x[0] for x in zip(val, g2v_feats_df.columns.tolist())}
108
+ for val in author_to_g2v_feats_z_normalized.values()]
109
+
110
+ with open(cache_path, "wb") as f:
111
+ pickle.dump(clustered_authors_df, f)
112
+
113
+ if task_authors_df is not None:
114
+ task_authors_df = clustered_authors_df[clustered_authors_df.authorID.isin(task_authors_df.authorID.tolist())]
115
+ clustered_authors_df = clustered_authors_df[~clustered_authors_df.authorID.isin(task_authors_df.authorID.tolist())]
116
+
117
+
118
+ return clustered_authors_df['g2v_vector'].tolist(), task_authors_df['g2v_vector'].tolist()
119
+
120
+
121
+ def get_task_authors_from_background_df(background_df):
122
+ task_authors_df = background_df[background_df.authorID.isin(["Q_author", "a0_author", "a1_author", "a2_author"])]
123
+ return task_authors_df
124
+
125
+ def instance_to_df(instance, predicted_author=None, ground_truth_author=None):
126
+ #create a dataframe of the task authors
127
+ task_authos_df = pd.DataFrame([
128
+ {'authorID': 'Mystery author', 'fullText': instance['Q_fullText'], 'predicted': None, 'ground_truth': None},
129
+ {'authorID': 'Candidate Author 1', 'fullText': instance['a0_fullText'], 'predicted': predicted_author == 0, 'ground_truth': ground_truth_author == 0},
130
+ {'authorID': 'Candidate Author 2', 'fullText': instance['a1_fullText'], 'predicted': predicted_author == 1, 'ground_truth': ground_truth_author == 1},
131
+ {'authorID': 'Candidate Author 3', 'fullText': instance['a2_fullText'], 'predicted': predicted_author == 2, 'ground_truth': ground_truth_author == 2}
132
+
133
+ ])
134
+
135
+ if type(instance['Q_fullText']) == list:
136
+ task_authos_df = task_authos_df.groupby('authorID').agg({'fullText': lambda x: list(x)}).reset_index()
137
+
138
+ return task_authos_df
139
+
140
+
141
+ def generate_style_embedding(background_corpus_df: pd.DataFrame, text_clm: str, model_name: str) -> pd.DataFrame:
142
+ """
143
+ Generates style embeddings for documents in a background corpus using a specified model.
144
+ If a row in `text_clm` contains a list of strings, the final embedding for that row
145
+ is the average of the embeddings of all strings in the list.
146
+
147
+ Args:
148
+ background_corpus_df (pd.DataFrame): DataFrame containing the corpus.
149
+ text_clm (str): Name of the column containing the text data (either string or list of strings).
150
+ model_name (str): Name of the model to use for generating embeddings.
151
+
152
+ Returns:
153
+ pd.DataFrame: The input DataFrame with a new column for style embeddings.
154
+ """
155
+ from sentence_transformers import SentenceTransformer
156
+ import torch
157
+
158
+ if model_name not in [
159
+ 'gabrielloiseau/LUAR-MUD-sentence-transformers',
160
+ 'gabrielloiseau/LUAR-CRUD-sentence-transformers',
161
+ 'miladalsh/light-luar',
162
+ 'AnnaWegmann/Style-Embedding',
163
+
164
+ ]:
165
+ print('Model is not supported')
166
+ return background_corpus_df
167
+
168
+ print(f"Generating style embeddings using {model_name} on column '{text_clm}'...")
169
+
170
+ model = SentenceTransformer(model_name)
171
+ embedding_dim = model.get_sentence_embedding_dimension()
172
+
173
+ # Heuristic to check if the column contains lists of strings by checking the first valid item.
174
+ # This assumes the column is homogenous.
175
+ is_list_column = False
176
+ if not background_corpus_df.empty:
177
+ # Get the first non-NaN value to inspect its type
178
+ series_no_na = background_corpus_df[text_clm].dropna()
179
+ if not series_no_na.empty:
180
+ first_valid_item = series_no_na.iloc[0]
181
+ if isinstance(first_valid_item, list):
182
+ is_list_column = True
183
+
184
+ if is_list_column:
185
+ # Flatten all texts into a single list for batch processing
186
+ texts_to_encode = []
187
+ row_lengths = []
188
+ for text_list in background_corpus_df[text_clm]:
189
+ # Ensure we handle None, empty lists, or other non-list types gracefully
190
+ if isinstance(text_list, list) and text_list:
191
+ texts_to_encode.extend(text_list)
192
+ row_lengths.append(len(text_list))
193
+ else:
194
+ row_lengths.append(0)
195
+
196
+ if texts_to_encode:
197
+ all_embeddings = model.encode(texts_to_encode, convert_to_tensor=True, show_progress_bar=True)
198
+ else:
199
+ all_embeddings = torch.empty((0, embedding_dim), device=model.device)
200
+
201
+ # Reconstruct and average embeddings for each row
202
+ final_embeddings = []
203
+ current_pos = 0
204
+ for length in row_lengths:
205
+ if length > 0:
206
+ row_embeddings = all_embeddings[current_pos:current_pos + length]
207
+ avg_embedding = torch.mean(row_embeddings, dim=0)
208
+ final_embeddings.append(avg_embedding.cpu().numpy())
209
+ current_pos += length
210
+ else:
211
+ final_embeddings.append(np.zeros(embedding_dim))
212
+ else:
213
+ # Column contains single strings
214
+ texts = background_corpus_df[text_clm].fillna("").tolist()
215
+ # convert_to_tensor=False is faster if we just need numpy arrays
216
+ embeddings = model.encode(texts, show_progress_bar=True)
217
+ final_embeddings = list(embeddings)
218
+
219
+ # Create a clean column name from the model name
220
+ col_name = f'{model_name.split("/")[-1]}_style_embedding'
221
+ background_corpus_df[col_name] = final_embeddings
222
+
223
+ return background_corpus_df
224
+
225
+ # ── wrapper with caching ───────────────────────────────────────
226
+ def cached_generate_style_embedding(background_corpus_df: pd.DataFrame,
227
+ text_clm: str,
228
+ model_name: str) -> pd.DataFrame:
229
+ """
230
+ Wraps `generate_style_embedding`, caching its output in pickle files
231
+ keyed by an MD5 of (model_name + text list). If the cache exists,
232
+ loads and returns it instead of recomputing.
233
+ """
234
+
235
+ # Gather the input texts (preserves list-of-strings if any)
236
+ texts = background_corpus_df[text_clm].fillna("").tolist()
237
+
238
+ # Create a reproducible JSON serialization of the texts
239
+ serialized = json.dumps({
240
+ "model": model_name,
241
+ "col": text_clm,
242
+ "texts": texts
243
+ }, sort_keys=True, ensure_ascii=False)
244
+
245
+ # Compute MD5 hash
246
+ digest = hashlib.md5(serialized.encode("utf-8")).hexdigest()
247
+ cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
248
+
249
+ # If cache hit, load and return
250
+ if os.path.exists(cache_path):
251
+ print(f"Cache hit for {model_name} on column '{text_clm}'")
252
+ print(cache_path)
253
+ with open(cache_path, "rb") as f:
254
+ return pickle.load(f)
255
+
256
+ # Otherwise, compute, cache, and return
257
+ df_with_emb = generate_style_embedding(background_corpus_df, text_clm, model_name)
258
+ print(f"Computing embeddings for {model_name} on column '{text_clm}', saving to {cache_path}")
259
+ with open(cache_path, "wb") as f:
260
+ pickle.dump(df_with_emb, f)
261
+ return df_with_emb
262
+
263
+ def get_style_feats_distribution(documentIDs, style_feats_dict):
264
+ style_feats = []
265
+ for documentId in documentIDs:
266
+ if documentId not in document_to_style_feats:
267
+ #print(documentId)
268
+ continue
269
+
270
+ style_feats+= document_to_style_feats[documentId]
271
+
272
+ tfidf = [style_feats.count(key) * val for key, val in style_feats_dict.items()]
273
+
274
+ return tfidf
275
+
276
+ def get_cluster_top_feats(style_feats_distribution, style_feats_list, top_k=5):
277
+ sorted_feats = np.argsort(style_feats_distribution)[::-1]
278
+ top_feats = [style_feats_list[x] for x in sorted_feats[:top_k] if style_feats_distribution[x] > 0]
279
+ return top_feats
280
+
281
+ def compute_clusters_style_representation(
282
+ background_corpus_df: pd.DataFrame,
283
+ cluster_ids: List[Any],
284
+ other_cluster_ids: List[Any],
285
+ features_clm_name: str,
286
+ cluster_label_clm_name: str = 'cluster_label',
287
+ top_n: int = 10
288
+ ) -> List[str]:
289
+ """
290
+ Given a DataFrame with document IDs, cluster IDs, and feature lists,
291
+ return the top N features that are most important in the specified `cluster_ids`
292
+ while having low importance in `other_cluster_ids`.
293
+ Importance is determined by TF-IDF scores. The final score for a feature is
294
+ (summed TF-IDF in `cluster_ids`) - (summed TF-IDF in `other_cluster_ids`).
295
+
296
+ Parameters:
297
+ - background_corpus_df: pd.DataFrame. Must contain the columns specified by
298
+ `cluster_label_clm_name` and `features_clm_name`.
299
+ The column `features_clm_name` should contain lists of strings (features).
300
+ - cluster_ids: List of cluster IDs for which to find representative features (target clusters).
301
+ - other_cluster_ids: List of cluster IDs whose features should be down-weighted.
302
+ Features prominent in these clusters will have their scores reduced.
303
+ Pass an empty list or None if no contrastive clusters are needed.
304
+ - features_clm_name: The name of the column in `background_corpus_df` that
305
+ contains the list of features for each document.
306
+ - cluster_label_clm_name: The name of the column in `background_corpus_df`
307
+ that contains the cluster labels. Defaults to 'cluster_label'.
308
+ - top_n: Number of top features to return.
309
+ Returns:
310
+ - List[str]: A list of feature names. These are up to `top_n` features
311
+ ranked by their adjusted TF-IDF scores (score in `cluster_ids`
312
+ minus score in `other_cluster_ids`). Only features with a final
313
+ adjusted score > 0 are included.
314
+ """
315
+
316
+ assert background_corpus_df[features_clm_name].apply(
317
+ lambda x: isinstance(x, list) and all(isinstance(feat, str) for feat in x)
318
+ ).all(), f"Column '{features_clm_name}' must contain lists of strings."
319
+
320
+ # Compute TF-IDF on the entire corpus
321
+ vectorizer = TfidfVectorizer(
322
+ tokenizer=lambda x: x,
323
+ preprocessor=lambda x: x,
324
+ token_pattern=None # Disable default token pattern, treat items in list as tokens
325
+ )
326
+ tfidf_matrix = vectorizer.fit_transform(background_corpus_df[features_clm_name])
327
+ feature_names = vectorizer.get_feature_names_out()
328
+
329
+ # Get boolean mask for documents in selected clusters
330
+ selected_mask = background_corpus_df[cluster_label_clm_name].isin(cluster_ids).to_numpy()
331
+
332
+ if not selected_mask.any():
333
+ return [] # No documents found for the given cluster_ids
334
+
335
+ # Subset the TF-IDF matrix using the boolean mask
336
+ selected_tfidf = tfidf_matrix[selected_mask]
337
+
338
+ # Sum TF-IDF scores across documents for each feature in the target clusters
339
+ target_feature_scores_sum = selected_tfidf.sum(axis=0).A1 # Convert to 1D array
340
+
341
+ # Initialize adjusted scores with target scores
342
+ adjusted_feature_scores = target_feature_scores_sum.copy()
343
+
344
+ # If other_cluster_ids are provided and not empty, subtract their TF-IDF sums
345
+ if other_cluster_ids: # Checks if the list is not None and not empty
346
+ other_selected_mask = background_corpus_df[cluster_label_clm_name].isin(other_cluster_ids).to_numpy()
347
+
348
+ if other_selected_mask.any():
349
+ other_selected_tfidf = tfidf_matrix[other_selected_mask]
350
+ contrast_feature_scores_sum = other_selected_tfidf.sum(axis=0).A1
351
+
352
+ # Element-wise subtraction; assumes feature_names aligns for both sums
353
+ adjusted_feature_scores -= contrast_feature_scores_sum
354
+
355
+ # Map scores to feature names
356
+ feature_score_dict = dict(zip(feature_names, adjusted_feature_scores))
357
+ # Sort features by score
358
+ sorted_features = sorted(feature_score_dict.items(), key=lambda item: item[1], reverse=True)
359
+
360
+ # Return the names of the top_n features that have a score > 0
361
+ top_features = [feature for feature, score in sorted_features if score > 0][:top_n]
362
+
363
+ return top_features
364
+
365
+ def compute_clusters_style_representation_2(
366
+ background_corpus_df: pd.DataFrame,
367
+ cluster_ids: List[Any],
368
+ cluster_label_clm_name: str = 'cluster_label',
369
+ max_num_feats: int = 5,
370
+ max_num_documents_per_author=3,
371
+ max_num_authors=5):
372
+ """
373
+ Call openAI to analyze the common writing style features of the given list of texts
374
+ """
375
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
376
+
377
+ background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
378
+ background_corpus_df = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
379
+
380
+ author_texts = background_corpus_df['fullText'].tolist()[:max_num_authors]
381
+ author_texts = "\n\n".join(["""Author {}:\n""".format(i+1) + text for i, text in enumerate(author_texts)])
382
+ author_names = background_corpus_df[cluster_label_clm_name].tolist()[:max_num_authors]
383
+ print(f"Number of authors: {len(background_corpus_df)}")
384
+ print(author_names)
385
+ print(author_texts)
386
+ print(f"Number of authors: {len(author_names)}")
387
+ print(f"Number of authors: {len(author_texts)}")
388
+
389
+ prompt = f"""First identify a list of {max_num_feats} writing style features that are common between the given texts. Second for every author text and style feature, extract all spans that represent the feature. Output for every author all style features with their spans.
390
+ Author Texts:
391
+ \"\"\"{author_texts}\"\"\"
392
+ """
393
+
394
+ # Compute MD5 hash
395
+ digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()
396
+ cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
397
+
398
+ # If cache hit, load and return
399
+ if os.path.exists(cache_path):
400
+ print(f"Loading authors writing style from cache ...")
401
+ with open(cache_path, "rb") as f:
402
+ parsed_response = pickle.load(f)
403
+
404
+ else: # Else compute and cache
405
+
406
+ response = client.chat.completions.create(
407
+ model="gpt-4o-mini",
408
+ messages=[
409
+ {"role":"assistant","content":"You are a forensic linguistic who knows how to analyze similarites in writing styles."},
410
+ {"role":"user","content":prompt}],
411
+ response_format={"type": "json_schema", "json_schema": {"name": "style_analysis_schema", "schema": to_strict_json_schema(style_analysis_schema)}}
412
+ )
413
+
414
+ parsed_response = json.loads(response.choices[0].message.content)
415
+
416
+ with open(cache_path, "wb") as f:
417
+ pickle.dump(parsed_response, f)
418
+
419
+ return parsed_response
420
+
421
+ def identify_style_features(author_texts: list[str], max_num_feats: int = 5) -> list[str]:
422
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
423
+ prompt = f"""Identify {max_num_feats} writing style features that are commonly found across the following texts. Do not extract spans. Just return the feature names as a list.
424
+ Author Texts:
425
+ \"\"\"{chr(10).join(author_texts)}\"\"\"
426
+ """
427
+
428
+ def _make_call():
429
+ response = client.chat.completions.create(
430
+ model="gpt-4o-mini",
431
+ messages=[
432
+ {"role": "assistant", "content": "You are a forensic linguist specializing in writing styles."},
433
+ {"role": "user", "content": prompt}
434
+ ],
435
+ response_format={
436
+ "type": "json_schema",
437
+ "json_schema": {
438
+ "name": "FeatureIdentificationSchema",
439
+ "schema": to_strict_json_schema(FeatureIdentificationSchema)
440
+ }
441
+ }
442
+ )
443
+ return json.loads(response.choices[0].message.content)
444
+
445
+ return retry_call(_make_call, FeatureIdentificationSchema).features
446
+
447
+ def retry_call(call_fn, schema_class, max_attempts=3, wait_sec=2):
448
+ for attempt in range(max_attempts):
449
+ try:
450
+ result = call_fn()
451
+ # Validate against schema
452
+ validated = schema_class(**result)
453
+ return validated
454
+ except (ValidationError, KeyError, json.JSONDecodeError) as e:
455
+ print(f"Attempt {attempt + 1} failed with error: {e}")
456
+ time.sleep(wait_sec)
457
+ raise RuntimeError("All retry attempts failed for OpenAI call.")
458
+
459
+ def extract_all_spans(authors_df: pd.DataFrame, features: list[str], cluster_label_clm_name: str = 'authorID') -> dict[str, dict[str, list[str]]]:
460
+ """
461
+ For each author, use `generate_feature_spans_cached` to get feature->span mappings.
462
+ Returns a dict: {author_name: {feature: [spans]}}
463
+ """
464
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
465
+
466
+ spans_by_author = {}
467
+
468
+ for _, row in authors_df.iterrows():
469
+ author_name = str(row[cluster_label_clm_name])
470
+ print(author_name)
471
+ role = f"{author_name}"
472
+ full_text = row['fullText']
473
+ spans = generate_feature_spans_cached(client, full_text, features, role)
474
+ spans_by_author[author_name] = spans
475
+
476
+ return spans_by_author
477
+
478
+ def compute_clusters_style_representation_3(
479
+ background_corpus_df: pd.DataFrame,
480
+ cluster_ids: List[Any],
481
+ cluster_label_clm_name: str = 'authorID',
482
+ max_num_feats: int = 5,
483
+ max_num_documents_per_author=3,
484
+ max_num_authors=5
485
+ ):
486
+
487
+ print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
488
+ # STEP 1: Identify features on 5 visible authors
489
+ background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
490
+ background_corpus_df_feat_id = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
491
+
492
+ author_texts = background_corpus_df_feat_id['fullText'].tolist()[:max_num_authors]
493
+ author_texts = "\n\n".join(["""Author {}:\n""".format(i+1) + text for i, text in enumerate(author_texts)])
494
+ author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
495
+ print(f"Number of authors: {len(background_corpus_df_feat_id)}")
496
+ print(author_names)
497
+ print(author_texts)
498
+ print(f"Number of authors: {len(author_names)}")
499
+ print(f"Number of authors: {len(author_texts)}")
500
+ features = identify_style_features(author_texts, max_num_feats=max_num_feats)
501
+
502
+ # STEP 2: Prepare author pool for span extraction
503
+
504
+ span_df = background_corpus_df.iloc[:7]
505
+ author_names = span_df[cluster_label_clm_name].tolist()[:7]
506
+ print(f"Number of authors for span detection : {len(span_df)}")
507
+ print(author_names)
508
+ spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
509
+
510
+ return {
511
+ "features": features,
512
+ "spans": spans_by_author
513
+ }
514
+
515
+
516
+ def compute_clusters_g2v_representation(
517
+ background_corpus_df: pd.DataFrame,
518
+ author_ids: List[Any],
519
+ other_author_ids: List[Any],
520
+ features_clm_name: str,
521
+ top_n: int = 10
522
+ ) -> List[str]:
523
+
524
+
525
+ # Get boolean mask for documents in selected clusters
526
+ selected_mask = background_corpus_df['authorID'].isin(author_ids).to_numpy()
527
+
528
+ if not selected_mask.any():
529
+ return [] # No documents found for the given cluster_ids
530
+
531
+ selected_feats = background_corpus_df[selected_mask][features_clm_name].tolist()
532
+ all_g2v_feats = list(selected_feats[0].keys())
533
+ all_g2v_values = np.array([list(x.values()) for x in selected_feats]).mean(axis=0)
534
+
535
+
536
+ other_selected_feats = background_corpus_df[~selected_mask][features_clm_name].tolist()
537
+ all_g2v_other_feats = list(other_selected_feats[0].keys())
538
+ all_g2v_other_values = np.array([list(x.values()) for x in other_selected_feats]).mean(axis=0)
539
+
540
+ final_g2v_feats_values = all_g2v_values - all_g2v_other_values
541
+
542
+
543
+ top_g2v_feats = sorted(list(zip(all_g2v_feats, final_g2v_feats_values)), key=lambda x: -x[1])
544
+ print(top_g2v_feats[:top_n])
545
+
546
+ return [x[0] for x in top_g2v_feats[:top_n]]
547
+
548
+
549
+ def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
550
+
551
+ styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]
552
+
553
+ # A dictionary of style features and their IDF
554
+ style_feats_agg_df = styles_df.groupby(feat_clm).agg({'documentID': lambda x : len(list(x))}).reset_index()
555
+ style_feats_agg_df['document_freq'] = style_feats_agg_df.documentID
556
+ style_to_feats_dfreq = {x[0]: math.log(styles_df.documentID.nunique()/x[1]) for x in zip(style_feats_agg_df[feat_clm].tolist(), style_feats_agg_df.document_freq.tolist())}
557
+
558
+ # A list of style features we work with
559
+ style_feats_list = style_feats_agg_df[feat_clm].tolist()
560
+ print('Number of style feats ', len(style_feats_list))
561
+
562
+ # A list of documents and what list of style features each has
563
+ doc_style_agg_df = styles_df.groupby('documentID').agg({feat_clm: lambda x : list(x)}).reset_index()
564
+ document_to_feats_dict = {x[0]: x[1] for x in zip(doc_style_agg_df.documentID.tolist(), doc_style_agg_df[feat_clm].tolist())}
565
+
566
+
567
+
568
+ # Load the clustering information
569
+ df = pd.read_pickle(interp_space_path)
570
+ df = df[df.cluster_label != -1]
571
+ # A cluster to list of documents
572
+ clusterd_df = df.groupby('cluster_label').agg({
573
+ 'documentID': lambda x: [d_id for doc_ids in x for d_id in doc_ids]
574
+ }).reset_index()
575
+
576
+ # Filter-in only documents that has a style description
577
+ clusterd_df['documentID'] = clusterd_df.documentID.apply(lambda documentIDs: [documentID for documentID in documentIDs if documentID in document_to_feats_dict])
578
+ # Map from cluster label to list of features through the document information
579
+ clusterd_df[feat_clm] = clusterd_df.documentID.apply(lambda doc_ids: [f for d_id in doc_ids for f in document_to_feats_dict[d_id]])
580
+
581
+ def compute_tfidf(row):
582
+ style_counts = Counter(row[feat_clm])
583
+ total_num_styles = sum(style_counts.values())
584
+ #print(style_counts, total_num_styles)
585
+ style_distribution = {
586
+ style: math.log(1+count) * style_to_feats_dfreq[style] if style in style_to_feats_dfreq else 0 for style, count in style_counts.items()
587
+ } #TF-IDF
588
+
589
+ return style_distribution
590
+
591
+ def create_tfidf_rep(tfidf_dist, num_feats):
592
+ style_feats = sorted(tfidf_dist.items(), key=lambda x: -x[1])
593
+ top_k_feats = [x[0] for x in style_feats[:num_feats] if str(x[0]) != 'nan']
594
+ return top_k_feats
595
+
596
+ clusterd_df[output_clm +'_dist'] = clusterd_df.apply(lambda row: compute_tfidf(row), axis=1)
597
+ clusterd_df[output_clm] = clusterd_df[output_clm +'_dist'].apply(lambda dist: create_tfidf_rep(dist, num_feats))
598
+
599
+
600
+ return clusterd_df
601
+
602
+ def compute_predicted_author(task_authors_df: pd.DataFrame, col_name: str) -> int:
603
+ """
604
+ Computes the predicted author based on the style features.
605
+ """
606
+ print("Computing predicted author using LUAR-MUD-style embeddings...")
607
+
608
+ # Extract LUAR embeddings from task authors dataframe
609
+ mystery_embedding = np.array(task_authors_df.iloc[0][col_name]).reshape(1, -1)
610
+ candidate_embeddings = np.array([
611
+ task_authors_df.iloc[1][col_name],
612
+ task_authors_df.iloc[2][col_name],
613
+ task_authors_df.iloc[3][col_name]
614
+ ])
615
+
616
+ # Compute cosine similarities
617
+ similarities = cosine_similarity(mystery_embedding, candidate_embeddings)[0]
618
+ predicted_author = int(np.argmax(similarities))
619
+ print(f"Predicted author is Candidate {predicted_author + 1}")
620
+
621
+ return predicted_author
622
+
623
+
624
+ if __name__ == "__main__":
625
+ background_corpus = pd.read_pickle('../datasets/luar_interp_space_cluster_19/train_authors.pkl')
626
+ print(background_corpus.columns)
627
+ print(background_corpus[['authorID', 'fullText', 'cluster_label']].head())
628
+ # # Example: Find features for clusters [2,3,4] that are NOT prominent in cluster [1]
629
+ # feats = compute_clusters_style_representation(
630
+ # background_corpus_df=background_corpus,
631
+ # cluster_ids=['00005a5c-5c06-3a36-37f9-53c6422a31d8',],
632
+ # other_cluster_ids=[], # Pass the contrastive cluster IDs here
633
+ # cluster_label_clm_name='authorID',
634
+ # features_clm_name='final_attribute_name'
635
+ # )
636
+ # print(feats)
637
+ generate_style_embedding(background_corpus, 'fullText', 'AnnaWegmann/Style-Embedding')
638
+ print(background_corpus.columns)
utils/llm_feat_utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import hashlib
4
+ import time
5
+ from json import JSONDecodeError
6
+
7
+ CACHE_DIR = "datasets/feature_spans_cache"
8
+ os.makedirs(CACHE_DIR, exist_ok=True)
9
+ import pandas as pd
10
+
11
+ #read and create the Gram2Vec feature set once
12
+ _g2v_df = pd.read_csv("datasets/gram2vec_feats.csv")
13
+ GRAM2VEC_SET = set(_g2v_df['gram2vec_feats'].unique())
14
+ MAX_ATTEMPTS = 3
15
+ WAIT_SECONDS = 2
16
+
17
+ # Bump this whenever there is a change prompt, feature space, etc...
18
+ CACHE_VERSION = 2
19
+
20
+ def _feat_hash(feature: str, text: str) -> str:
21
+ blob = json.dumps({
22
+ "version": CACHE_VERSION,
23
+ "text": text,
24
+ "features": sorted(feature)
25
+ }, sort_keys=True).encode()
26
+ return hashlib.md5(blob).hexdigest()
27
+
28
+
29
+ def generate_feature_spans(client, text: str, features: list[str]) -> str:
30
+ print("Calling OpenAI to extract spans")
31
+ """
32
+ Call to OpenAI to extract spans. Returns a JSON string.
33
+ """
34
+ prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
35
+
36
+ Important:
37
+ - The headers like "Document 1:" etc are NOT part of the original text β€” ignore them.
38
+ - For each feature, even if there is no match, return an empty list.
39
+ - Only return exact phrases from the text.
40
+
41
+ Respond in JSON format like:
42
+ {{
43
+ "feature1": ["span1", "span2"],
44
+ "feature2": [],
45
+ …
46
+ }}
47
+
48
+ Text:
49
+ \"\"\"{text}\"\"\"
50
+
51
+ Style Features:
52
+ {features}
53
+ """
54
+ response = client.chat.completions.create(
55
+ model="gpt-4",
56
+ messages=[{"role":"user","content":prompt}],
57
+ temperature=0.3,
58
+ )
59
+ return response.choices[0].message.content
60
+
61
+ def generate_feature_spans_with_retries(client, text: str, features: list[str]) -> dict:
62
+ """
63
+ Calls `generate_feature_spans` with retries on failure.
64
+ Returns the parsed JSON dict mapping feature->list[spans].
65
+ """
66
+ for attempt in range(MAX_ATTEMPTS):
67
+ try:
68
+ response_str = generate_feature_spans(client, text, features)
69
+ result = json.loads(response_str)
70
+ return result
71
+ except (JSONDecodeError, ValueError) as e:
72
+ print(f"Attempt {attempt+1} failed: {e}")
73
+ if attempt < MAX_ATTEMPTS - 1:
74
+ wait_sec = WAIT_SECONDS * (2 ** attempt)
75
+ print(f"Retrying after {wait_sec} seconds...")
76
+ time.sleep(wait_sec)
77
+ raise RuntimeError("All retry attempts failed for OpenAI call.")
78
+
79
+
80
+ def generate_feature_spans_cached(client, text: str, features: list[str], role: str = "mystery" ) -> dict:
81
+ """
82
+ Computes a cache key from text + feature list,
83
+ then either loads or calls the API and saves to disk.
84
+ Returns the parsed JSON dict mapping feature->list[spans].
85
+ """
86
+ print(f"Generating spans for ({role})")
87
+ # print(f"feature list {features}")
88
+ role = role.replace(" ", "_").replace("/", "_").replace("-", "_")
89
+ print(f"Cache dir: {CACHE_DIR}")
90
+ os.makedirs(CACHE_DIR, exist_ok=True)
91
+ cache_path = os.path.join(CACHE_DIR, f"{role}.json")
92
+ if os.path.exists(cache_path):
93
+ with open(cache_path) as f:
94
+ cache: dict[str, dict] = json.load(f)
95
+ else:
96
+ cache = {}
97
+ result: dict[str, list[str]] = {}
98
+ missing_feats: list[str] = []
99
+
100
+ for feat in features:
101
+ if feat == "None":
102
+ result[feat] = []
103
+ continue
104
+
105
+ h = _feat_hash(feat, text)
106
+ if h in cache:
107
+ result[feat] = cache[h]["spans"]
108
+ else:
109
+ missing_feats.append(feat)
110
+
111
+ if missing_feats:
112
+
113
+ mapping = generate_feature_spans_with_retries(client, text, missing_feats)
114
+ # 4) update cache & result for each missing feature
115
+ for feat in missing_feats:
116
+ h = _feat_hash(feat, text)
117
+ spans = mapping.get(feat)
118
+ cache[h] = {
119
+ "feature": feat,
120
+ "spans": spans
121
+ }
122
+ result[feat] = spans
123
+
124
+ # 5) write back the combined cache
125
+ with open(cache_path, "w") as f:
126
+ json.dump(cache, f, indent=2)
127
+ return result
128
+
129
+
130
+ def split_features(all_feats):
131
+ """
132
+ Given a list of mixed features, returns two lists:
133
+ - llm_feats: those NOT in the Gram2Vec CSV
134
+ - g2v_feats: those present in the CSV
135
+ """
136
+ g2v_feats = [feat for feat in all_feats if feat in GRAM2VEC_SET]
137
+ llm_feats = [feat for feat in all_feats if feat not in GRAM2VEC_SET]
138
+ return llm_feats, g2v_feats
utils/ui.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from utils.visualizations import load_instance, get_instances, clean_text
4
+ from utils.interp_space_utils import cached_generate_style_embedding, instance_to_df, compute_g2v_features, compute_predicted_author
5
+
6
+
7
+ # ── Global CSS to be prepended to every block ─────────────────────────────────
8
+ GLOBAL_CSS = """
9
+ <style>
10
+ /* Bold only the top‐level field labels (not every label) */
11
+ .gradio-container .input_label {
12
+ font-weight: 600 !important;
13
+ font-size: 1.1em !important;
14
+ }
15
+
16
+ /* Reset radio‐option labels to normal weight/size */
17
+ .gradio-container .radio-container .radio-option-label {
18
+ font-weight: normal !important;
19
+ font-size: 1em !important;
20
+ }
21
+ /* Give HTML output blocks a stronger border and padding */
22
+ .gradio-container .output-html {
23
+ border: 2px solid #888 !important;
24
+ border-radius: 4px !important;
25
+ padding: 0.5em !important;
26
+ margin-bottom: 1em !important;
27
+ font-size: 1em !important;
28
+ line-height: 1.4 !important;
29
+ }
30
+ </style>
31
+ """
32
+
33
+ def styled_block(content: str) -> str:
34
+ """
35
+ Injects GLOBAL_CSS before the provided content.
36
+ Returns a single HTML blob safe to pass into gr.HTML().
37
+ """
38
+ return GLOBAL_CSS + "\n" + content
39
+
40
+ def styled_html(html_content: str) -> str:
41
+ """
42
+ Wraps raw HTML content with global CSS. Pass the result to gr.HTML().
43
+ """
44
+ return styled_block(html_content)
45
+
46
+ def instruction_callout(text: str) -> str:
47
+ """
48
+ Returns a full HTML string (with global CSS) rendering `text`
49
+ as a bold, full-width callout box.
50
+
51
+ Usage:
52
+ gr.HTML(instruction_callout(
53
+ "Run visualization to see which author cluster contains the mystery document."
54
+ ))
55
+ """
56
+ callout = f"""
57
+ <div style="
58
+ background: #e3f2fd; /* light blue background */
59
+ border-left: 5px solid #2196f3; /* bold accent stripe */
60
+ padding: 12px 16px;
61
+ margin-bottom: 12px;
62
+ font-weight: 600;
63
+ font-size: 1.1em;
64
+ ">
65
+ {text}
66
+ </div>
67
+ """
68
+ return styled_html(callout)
69
+
70
+ def read_txt(f):
71
+ if not f:
72
+ return ""
73
+ path = f.name if hasattr(f, 'name') else f
74
+ try:
75
+ with open(path, 'r', encoding='utf-8') as fh:
76
+ return fh.read().strip()
77
+ except Exception:
78
+ return "(Could not read file)"
79
+
80
+ # Toggle which input UI is visible
81
+ def toggle_task(mode):
82
+ print(mode)
83
+ return (
84
+ gr.update(visible=(mode == "Predefined HRS Task")),
85
+ gr.update(visible=(mode == "Upload Your Own Task"))
86
+ )
87
+
88
+ # Update displayed texts based on mode
89
+ def update_task_display(mode, iid, instances, background_df, mystery_file, cand1_file, cand2_file, cand3_file, true_author, model_radio, custom_model_input):
90
+ model_name = model_radio if model_radio != "Other" else custom_model_input
91
+ if mode == "Predefined HRS Task":
92
+ iid = int(iid.replace('Task ', ''))
93
+ data = instances[iid]
94
+ predicted_author = data['latent_rank'][0]
95
+ ground_truth_author = data['gt_idx']
96
+ mystery_txt = data['Q_fullText']
97
+ c1_txt = data['a0_fullText']
98
+ c2_txt = data['a1_fullText']
99
+ c3_txt = data['a2_fullText']
100
+ candidate_texts = [c1_txt, c2_txt, c3_txt]
101
+
102
+ #create a dataframe of the task authors
103
+ task_authors_df = instance_to_df(instances[iid])
104
+ print(f"\n\n\n ----> Loaded task {iid} with {len(task_authors_df)} authors\n\n\n")
105
+ print(task_authors_df)
106
+ else:
107
+ header_html = "<h3>Custom Uploaded Task</h3>"
108
+ mystery_txt = read_txt(mystery_file)
109
+ c1_txt = read_txt(cand1_file)
110
+ c2_txt = read_txt(cand2_file)
111
+ c3_txt = read_txt(cand3_file)
112
+ candidate_texts = [c1_txt, c2_txt, c3_txt]
113
+ ground_truth_author = true_author
114
+ print(f"Ground truth author: {ground_truth_author} ; {true_author}")
115
+ custom_task_instance = {
116
+ 'Q_fullText': mystery_txt,
117
+ 'a0_fullText': c1_txt,
118
+ 'a1_fullText': c2_txt,
119
+ 'a2_fullText': c3_txt
120
+ }
121
+ task_authors_df = instance_to_df(custom_task_instance)
122
+ print(task_authors_df)
123
+
124
+ print(f"Generating embeddings for {model_name} on task authors")
125
+ task_authors_df = cached_generate_style_embedding(task_authors_df, 'fullText', model_name)
126
+ print("Task authors after embedding generation:")
127
+ print(task_authors_df)
128
+ # Generate the new embedding of all the background_df authors
129
+ print(f"Generating embeddings for {model_name} on background corpus")
130
+ background_df = cached_generate_style_embedding(background_df, 'fullText', model_name)
131
+ print(f"Generated embeddings for {len(background_df)} texts using model '{model_name}'")
132
+
133
+ # computing g2v features
134
+ print("Generating g2v features for on background corpus")
135
+ background_g2v, task_authors_g2v = compute_g2v_features(background_df, task_authors_df)
136
+ background_df['g2v_vector'] = background_g2v
137
+ task_authors_df['g2v_vector'] = task_authors_g2v
138
+ print(f"Gram2Vec feature generation complete")
139
+
140
+ print(background_df.columns)
141
+
142
+ # Computing predicted author by checking pairwise cosine similarity over luar embeddings
143
+ col_name = f'{model_name.split("/")[-1]}_style_embedding'
144
+ predicted_author = compute_predicted_author(task_authors_df, col_name)
145
+
146
+ #generating html for the task
147
+ header_html, mystery_html, candidate_htmls = task_HTML(mystery_txt, candidate_texts, predicted_author, ground_truth_author)
148
+
149
+ return [
150
+ header_html,
151
+ mystery_html,
152
+ candidate_htmls[0],
153
+ candidate_htmls[1],
154
+ candidate_htmls[2],
155
+ mystery_txt,
156
+ c1_txt,
157
+ c2_txt,
158
+ c3_txt,
159
+ task_authors_df,
160
+ background_df,
161
+ predicted_author,
162
+ ground_truth_author
163
+ ]
164
+
165
+ def task_HTML(mystery_text, candidate_texts, predicted_author, ground_truth_author):
166
+ header_html = f"""
167
+ <div style="border:1px solid #ccc; padding:10px; margin-bottom:10px;">
168
+ <h3>Here’s the mystery passage alongside three candidate textsβ€”look for the green highlight to see the predicted author.</h3>
169
+ </div>
170
+ """
171
+ # mystery_text = clean_text(mystery_text)
172
+ mystery_html = f"""
173
+ <div style="
174
+ border: 2px solid #ff5722; /* accent border */
175
+ background: #fff3e0; /* very light matching wash */
176
+ border-radius: 6px;
177
+ padding: 1em;
178
+ margin-bottom: 1em;
179
+ ">
180
+ <h3 style="margin-top:0; color:#bf360c;">Mystery Author</h3>
181
+ <p>{clean_text(mystery_text)}</p>
182
+ </div>
183
+ """
184
+
185
+ print(f"Predicted author: {predicted_author}, Ground truth author: {ground_truth_author}")
186
+
187
+ # Candidate boxes
188
+ candidate_htmls = []
189
+ for i in range(3):
190
+ text = candidate_texts[i]
191
+ title = f"Candidate {i+1}"
192
+ extra_style = ""
193
+
194
+ if ground_truth_author == i:
195
+ if ground_truth_author != predicted_author: # highlight the true author only if its different than the predictd one
196
+ title += " (True Author)"
197
+ extra_style = (
198
+ "border: 2px solid #ff5722; "
199
+ "background: #fff3e0; "
200
+ "padding:10px; "
201
+ )
202
+
203
+
204
+ if predicted_author == i:
205
+ if predicted_author == ground_truth_author:
206
+ title += " (Predicted and True Author)"
207
+ else:
208
+ title += " (Predicted Author)"
209
+ extra_style = (
210
+ "border:2px solid #228B22; " # dark green border
211
+ "background-color: #e6ffe6; " # light green fill
212
+ "padding:10px; "
213
+ )
214
+
215
+
216
+ candidate_htmls.append(f"""
217
+ <div style="border:1px solid #ccc; padding:10px; {extra_style}">
218
+ <h4>{title}</h4>
219
+ <p>{clean_text(text)}</p>
220
+ </div>
221
+ """)
222
+ return header_html, mystery_html, candidate_htmls
223
+
224
+ def toggle_custom_model(choice):
225
+ return gr.update(visible=(choice == "Other"))
utils/visualizations.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import numpy as np
4
+ from sklearn.manifold import TSNE
5
+ import pickle as pkl
6
+ import os
7
+ import hashlib
8
+ import pandas as pd
9
+ import plotly.graph_objects as go
10
+ from plotly.colors import sample_colorscale
11
+ from gradio import update
12
+ import re
13
+ from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation
14
+ from utils.llm_feat_utils import split_features
15
+ from utils.gram2vec_feat_utils import get_shorthand, get_fullform
16
+
17
+ import plotly.io as pio
18
+
19
+ def clean_text(text: str) -> str:
20
+ """
21
+ Cleans the text by replacing HTML tags with their escaped versions.
22
+ """
23
+ return text.replace('<','&lt;').replace('>','&gt;').replace('\n', '<br>')
24
+
25
+ def get_instances(instances_to_explain_path: str = 'datasets/instances_to_explain.json'):
26
+ """
27
+ Loads the JSON and returns:
28
+ - instances_to_explain: the raw dict/list of instances
29
+ - instance_ids: list of keys (if dict) or indices (if list)
30
+ """
31
+ instances_to_explain = json.load(open(instances_to_explain_path))
32
+ if isinstance(instances_to_explain, dict):
33
+ instance_ids = list(instances_to_explain.keys())
34
+ else:
35
+ instance_ids = list(range(len(instances_to_explain)))
36
+ return instances_to_explain, instance_ids
37
+
38
+ def load_instance(instance_id, instances_to_explain: dict):
39
+ """
40
+ Given a selected instance_id and the loaded data,
41
+ returns (mystery_html, c0_html, c1_html, c2_html).
42
+ """
43
+ # normalize instance_id
44
+ try:
45
+ iid = int(instance_id)
46
+ except ValueError:
47
+ iid = instance_id
48
+ data = instances_to_explain[iid]
49
+
50
+ predicted_author = data['latent_rank'][0]
51
+ ground_truth_author = data['gt_idx']
52
+
53
+ header_html = f"""
54
+ <div style="border:1px solid #ccc; padding:10px; margin-bottom:10px;">
55
+ <h3>Here’s the mystery passage alongside three candidate textsβ€”look for the green highlight to see the predicted author.</h3>
56
+ </div>
57
+ """
58
+ mystery_text = clean_text(data['Q_fullText'])
59
+ mystery_html = f"""
60
+ <div style="
61
+ border: 2px solid #ff5722; /* accent border */
62
+ background: #fff3e0; /* very light matching wash */
63
+ border-radius: 6px;
64
+ padding: 1em;
65
+ margin-bottom: 1em;
66
+ ">
67
+ <h3 style="margin-top:0; color:#bf360c;">Mystery Author</h3>
68
+ <p>{clean_text(mystery_text)}</p>
69
+ </div>
70
+ """
71
+
72
+ # Candidate boxes
73
+ candidate_htmls = []
74
+ for i in range(3):
75
+ text = data[f'a{i}_fullText']
76
+ title = f"Candidate {i+1}"
77
+ extra_style = ""
78
+
79
+ if ground_truth_author == i:
80
+ if ground_truth_author != predicted_author: # highlight the true author only if its different than the predictd one
81
+ title += " (True Author)"
82
+ extra_style = (
83
+ "border: 2px solid #ff5722; "
84
+ "background: #fff3e0; "
85
+ "padding:10px; "
86
+ )
87
+
88
+
89
+ if predicted_author == i:
90
+ if predicted_author == ground_truth_author:
91
+ title += " (Predicted and True Author)"
92
+ else:
93
+ title += " (Predicted Author)"
94
+ extra_style = (
95
+ "border:2px solid #228B22; " # dark green border
96
+ "background-color: #e6ffe6; " # light green fill
97
+ "padding:10px; "
98
+ )
99
+
100
+
101
+ candidate_htmls.append(f"""
102
+ <div style="border:1px solid #ccc; padding:10px; {extra_style}">
103
+ <h4>{title}</h4>
104
+ <p>{clean_text(text)}</p>
105
+ </div>
106
+ """)
107
+
108
+ return header_html, mystery_html, candidate_htmls[0], candidate_htmls[1], candidate_htmls[2]
109
+
110
+ def compute_tsne_with_cache(embeddings: np.ndarray, cache_path: str = 'datasets/tsne_cache.pkl') -> np.ndarray:
111
+ """
112
+ Compute t-SNE with caching to avoid recomputation for the same input.
113
+
114
+ Args:
115
+ embeddings (np.ndarray): The input embeddings to compute t-SNE on.
116
+ cache_path (str): Path to the cache file.
117
+
118
+ Returns:
119
+ np.ndarray: The t-SNE transformed embeddings.
120
+ """
121
+ # Create a hash of the input embeddings to use as a key
122
+ hash_key = hashlib.md5(embeddings.tobytes()).hexdigest()
123
+
124
+ if os.path.exists(cache_path):
125
+ with open(cache_path, 'rb') as f:
126
+ cache = pkl.load(f)
127
+ else:
128
+ cache = {}
129
+
130
+ if hash_key in cache:
131
+ return cache[hash_key]
132
+ else:
133
+ print("Computing t-SNE")
134
+ tsne_result = TSNE(n_components=2, learning_rate='auto',
135
+ init='random', perplexity=3).fit_transform(embeddings)
136
+ cache[hash_key] = tsne_result
137
+ with open(cache_path, 'wb') as f:
138
+ pkl.dump(cache, f)
139
+ return tsne_result
140
+
141
+ def load_interp_space(cfg):
142
+ interp_space_path = cfg['interp_space_path'] + 'interpretable_space.pkl'
143
+ interp_space_rep_path = cfg['interp_space_path'] + 'interpretable_space_representations.json'
144
+ gram2vec_feats_path = cfg['interp_space_path'] + '/../gram2vec_feats.csv'
145
+ clustered_authors_path = cfg['interp_space_path'] + 'train_authors.pkl'
146
+
147
+ # Load authors embeddings and their cluster labels
148
+ clustered_authors_df = pd.read_pickle(clustered_authors_path)
149
+ clustered_authors_df = clustered_authors_df[clustered_authors_df.cluster_label != -1]
150
+ author_embedding = clustered_authors_df.author_embedding.tolist()
151
+ author_labels = clustered_authors_df.cluster_label.tolist()
152
+ author_ids = clustered_authors_df.authorID.tolist()
153
+
154
+ # filter out gram2vec features that doesn't have representation
155
+ clustered_authors_df['gram2vec_feats'] = clustered_authors_df.gram2vec_feats.apply(lambda feats: [feat for feat in feats if get_shorthand(feat) is not None])
156
+
157
+ # Load a list of gram2vec features --> we use it to distinguish the cluster representations whether they come from gram2vec or llms
158
+ gram2vec_df = pd.read_csv(gram2vec_feats_path)
159
+ gram2vec_feats = gram2vec_df.gram2vec_feats.unique().tolist()
160
+
161
+ # Load interpretable space embeddings and the representation of each dimension
162
+ interpretable_space = pkl.load(open(interp_space_path, 'rb'))
163
+ del interpretable_space[-1] #DBSCAN generate a cluster -1 of all outliers. We don't want this cluster
164
+ dimension_to_latent = {key: interpretable_space[key][0] for key in interpretable_space}
165
+
166
+ interpretable_space_rep_df = pd.read_json(interp_space_rep_path)
167
+ #dimension_to_style = {x[0]: x[1] for x in zip(interpretable_space_rep_df.cluster_label.tolist(), interpretable_space_rep_df[style_feat_clm].tolist())}
168
+ dimension_to_style = {x[0]: [feat[0] for feat in sorted(x[1].items(), key=lambda feat_w:-feat_w[1])] for x in zip(interpretable_space_rep_df.cluster_label.tolist(), interpretable_space_rep_df[cfg['style_feat_clm']].tolist())}
169
+
170
+ if cfg['only_llm_feats']:
171
+ #print('only llm feats')
172
+ dimension_to_style = {dim[0]:[feat for feat in dim[1] if feat not in gram2vec_feats] for dim in dimension_to_style.items()}
173
+
174
+ if cfg['only_gram2vec_feats']:
175
+ #print('only gra2vec feats')
176
+ dimension_to_style = {dim[0]:[feat for feat in dim[1] if feat in gram2vec_feats] for dim in dimension_to_style.items()}
177
+
178
+ # Take top features from g2v and llm
179
+ def take_to_k_llm_and_g2v_feats(feats_list, top_k):
180
+ g2v_feats = [x for x in feats_list if x in gram2vec_feats][:top_k]
181
+ llm_feats = [x for x in feats_list if x not in gram2vec_feats][:top_k]
182
+ return g2v_feats + llm_feats
183
+ dimension_to_style = {dim[0]: take_to_k_llm_and_g2v_feats(dim[1], cfg['top_k']) for dim in dimension_to_style.items()}
184
+
185
+
186
+ return {
187
+ 'dimension_to_latent': dimension_to_latent,
188
+ 'dimension_to_style' : dimension_to_style,
189
+ 'author_embedding' : author_embedding,
190
+ 'author_labels' : author_labels,
191
+ 'author_ids' : author_ids,
192
+ 'clustered_authors_df' : clustered_authors_df
193
+
194
+ }
195
+
196
+ #function to handle zoom events
197
+ def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df):
198
+ """
199
+ event_json – stringified JSON from JS listener
200
+ bg_proj – (N,2) numpy array with 2D coordinates
201
+ bg_lbls – list of N author IDs
202
+ clustered_authors_df – pd.DataFrame containing authorID and final_attribute_name
203
+ """
204
+ print("[INFO] Handling zoom event")
205
+
206
+ if not event_json:
207
+ return gr.update(value=""), gr.update(value=""), None, None, None
208
+
209
+ try:
210
+ ranges = json.loads(event_json)
211
+ (x_min, x_max) = ranges["xaxis"]
212
+ (y_min, y_max) = ranges["yaxis"]
213
+ except (json.JSONDecodeError, KeyError, ValueError):
214
+ return gr.update(value=""), gr.update(value=""), None, None, None
215
+
216
+ # Find points within the zoomed region
217
+ mask = (
218
+ (bg_proj[:, 0] >= x_min) & (bg_proj[:, 0] <= x_max) &
219
+ (bg_proj[:, 1] >= y_min) & (bg_proj[:, 1] <= y_max)
220
+ )
221
+
222
+ visible_authors = [lbl for lbl, keep in zip(bg_lbls, mask) if keep]
223
+
224
+ print(f"[INFO] Zoomed region includes {len(visible_authors)} authors:{visible_authors}")
225
+
226
+ # Example: Find features for clusters [2,3,4] that are NOT prominent in cluster [1]
227
+ # llm_feats = compute_clusters_style_representation(
228
+ # background_corpus_df=clustered_authors_df,
229
+ # cluster_ids=visible_authors,
230
+ # cluster_label_clm_name='authorID',
231
+ # other_cluster_ids=[],
232
+ # features_clm_name='final_attribute_name_manually_processed'
233
+ # )
234
+ print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}")
235
+ merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
236
+ print(f"Merged authors DataFrame:\n{len(merged_authors_df)}")
237
+ style_analysis_response = compute_clusters_style_representation_3(
238
+ background_corpus_df=merged_authors_df,
239
+ cluster_ids=visible_authors,
240
+ cluster_label_clm_name='authorID',
241
+ )
242
+
243
+ llm_feats = ['None'] + style_analysis_response['features']
244
+
245
+
246
+ merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
247
+ g2v_feats = compute_clusters_g2v_representation(
248
+ background_corpus_df=merged_authors_df,
249
+ author_ids=visible_authors,
250
+ other_author_ids=[],
251
+ features_clm_name='g2v_vector'
252
+ )
253
+
254
+ # Gram2vec features are already in shorthand. convert to human readable for display
255
+ HR_g2v_list = []
256
+ for feat in g2v_feats:
257
+ HR_g2v = get_fullform(feat)
258
+ print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
259
+ if HR_g2v is None:
260
+ print(f"Skipping Gram2Vec feature without human readable form: {feat}")
261
+ else:
262
+ HR_g2v_list.append(HR_g2v)
263
+
264
+ HR_g2v_list = ["None"] + HR_g2v_list
265
+
266
+ print(f"[INFO] Found {len(llm_feats)} LLM features and {len(g2v_feats)} Gram2Vec features in the zoomed region.")
267
+ print(f"[INFO] unfiltered g2v features: {g2v_feats}")
268
+
269
+ print(f"[INFO] LLM features: {llm_feats}")
270
+ print(f"[INFO] Gram2Vec features: {HR_g2v_list}")
271
+
272
+ return (
273
+ gr.update(choices=llm_feats, value=llm_feats[0]),
274
+ gr.update(choices=HR_g2v_list, value=HR_g2v_list[0]),
275
+ style_analysis_response,
276
+ llm_feats,
277
+ visible_authors
278
+ )
279
+ # return gr.update(value="\n".join(llm_feats).join("\n").join(g2v_feats)), llm_feats, g2v_feats
280
+
281
+ def handle_zoom_with_retries(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df):
282
+ """
283
+ event_json – stringified JSON from JS listener
284
+ bg_proj – (N,2) numpy array with 2D coordinates
285
+ bg_lbls – list of N author IDs
286
+ clustered_authors_df – pd.DataFrame containing authorID and final_attribute_name
287
+ task_authors_df – pd.DataFrame containing authorID and final_attribute_name
288
+ """
289
+ print("[INFO] Handling zoom event with retries")
290
+
291
+ for attempt in range(3):
292
+ try:
293
+ return handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df)
294
+ except Exception as e:
295
+ print(f"[ERROR] Attempt {attempt + 1} failed: {e}")
296
+ if attempt < 2:
297
+ print("[INFO] Retrying...")
298
+ return (
299
+ None,
300
+ None,
301
+ None,
302
+ None,
303
+ None
304
+ )
305
+
306
+
307
+ def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_input, task_authors_df, background_authors_embeddings_df, pred_idx=None, gt_idx=None):
308
+ model_name = model_radio if model_radio != "Other" else custom_model_input
309
+ embedding_col_name = f'{model_name.split("/")[-1]}_style_embedding'
310
+ print(background_authors_embeddings_df.columns)
311
+ print("Generating cluster visualization")
312
+ iid = int(iid)
313
+ interp = load_interp_space(cfg)
314
+ # dim2lat = interp['dimension_to_latent']
315
+ style_names = interp['dimension_to_style']
316
+ # bg_emb = np.array(interp['author_embedding'])
317
+ # print(f"bg_emb shape: {bg_emb.shape}")
318
+ #replace with cached embedddings
319
+ bg_emb = np.array(background_authors_embeddings_df[embedding_col_name].tolist()) #placeholder for background embeddings
320
+ print(f"bg_emb shape: {bg_emb.shape}")
321
+ # print("interp.keys():", interp.keys())
322
+ #bg_lbls = interp['author_labels']
323
+ #bg_ids = interp['author_ids']
324
+ bg_ids = task_authors_df['authorID'].tolist() + background_authors_embeddings_df['authorID'].tolist()
325
+ # inst = instances[iid]
326
+ # print("inst.keys():", inst.keys())
327
+ # q_lat = np.array(inst['author_latents'][:1])
328
+ # print(f"q_lat shape: {q_lat.shape}")
329
+ # c_lat = np.array(inst['author_latents'][1:])
330
+ # print(f"c_lat shape: {c_lat.shape}")
331
+ # pred_idx = inst['latent_rank'][0]
332
+ # gt_idx = inst['gt_idx']
333
+ q_lat = np.array(task_authors_df[embedding_col_name].iloc[0]).reshape(1, -1) # Mystery author latent
334
+ print(f"q_lat shape: {q_lat.shape}")
335
+ c_lat = np.array(task_authors_df[embedding_col_name].iloc[1:].tolist()) # Candidate authors latents
336
+ print(f"c_lat shape: {c_lat.shape}")
337
+
338
+ # cent_emb = np.array([v for _,v in dim2lat.items()])
339
+ # cent_lbl = np.array([k for k,_ in dim2lat.items()])
340
+
341
+ # all_emb = np.vstack([q_lat, c_lat, bg_emb, cent_emb])
342
+ all_emb = np.vstack([q_lat, c_lat, bg_emb])
343
+ proj = compute_tsne_with_cache(all_emb)
344
+
345
+ # split
346
+ q_proj = proj[0]
347
+ c_proj = proj[1:4]
348
+ #bg_proj = proj[4:4+len(bg_lbls)]
349
+ bg_proj = proj
350
+
351
+ # cent_proj = proj[4+len(bg_lbls):]
352
+
353
+
354
+ # find nearest centroid
355
+ # dists = np.linalg.norm(cent_proj - q_proj, axis=1)
356
+ # idx = int(np.argmin(dists))
357
+ # cluster_label_query = cent_lbl[idx]
358
+ # features of the nearest centroid to display
359
+ # feature_list = style_names[cluster_label_query]
360
+
361
+ # cluster_labels_per_candidate = [
362
+ # cent_lbl[int(np.argmin(np.linalg.norm(cent_proj - c_proj[i], axis=1)))]
363
+ # for i in range(c_proj.shape[0])
364
+ # ]
365
+
366
+ # prepare colorscale
367
+ # n_cent = len(cent_lbl)
368
+ # cent_colors = sample_colorscale("algae", [i/(n_cent-1) for i in range(n_cent)])
369
+ # map each cluster label to its color
370
+ # color_map = { label: cent_colors[i] for i, label in enumerate(cent_lbl) }
371
+
372
+ # uncomment the following line to show background authors
373
+ ## background author colors pulled from their cluster label
374
+ # bg_colors = [ color_map[label] for label in bg_lbls ]
375
+
376
+ # 2) build Plotly figure
377
+ fig = go.Figure()
378
+
379
+ fig.update_layout(
380
+ template='plotly_white',
381
+ margin=dict(l=40,r=40,t=60,b=40),
382
+ autosize=True,
383
+ hovermode='closest',
384
+ # Enable zoom events
385
+ dragmode='zoom'
386
+ )
387
+
388
+ # fig.update_layout(
389
+ # template='plotly_white',
390
+ # margin=dict(l=40,r=40,t=60,b=40),
391
+ # autosize=True,
392
+ # hovermode='closest')
393
+
394
+
395
+ # uncomment the following line to show background authors
396
+ ## background authors (light grey dots)
397
+ fig.add_trace(go.Scattergl(
398
+ x=bg_proj[:,0], y=bg_proj[:,1],
399
+ mode='markers',
400
+ marker=dict(size=6, color="#d3d3d3"),# color=bg_colors
401
+ name='Background authors',
402
+ hoverinfo='skip'
403
+ ))
404
+
405
+ # centroids (rainbow colors + hovertext of your top-k features)
406
+ # hover_texts = [
407
+ # f"Cluster {lbl}<br>" + "<br>".join(style_names[lbl])
408
+ # for lbl in cent_lbl
409
+ # ]
410
+ # fig.add_trace(go.Scattergl(
411
+ # x=cent_proj[:,0], y=cent_proj[:,1],
412
+ # mode='markers',
413
+ # marker=dict(symbol='triangle-up', size=10, color="#d3d3d3"),#color=cent_colors
414
+ # name='Cluster centroids',
415
+ # hovertext=hover_texts,
416
+ # hoverinfo='text'
417
+ # ))
418
+
419
+ # three candidates
420
+ marker_syms = ['diamond','pentagon','x']
421
+ for i in range(3):
422
+ # label = f"Candidate {i+1}" + (" (predicted)" if i==pred_idx else "")
423
+ base = f"Candidate {i+1}"
424
+ # pick the right suffix
425
+ if i == pred_idx and i == gt_idx:
426
+ suffix = " (Predicted & Ground Truth)"
427
+ elif i == pred_idx:
428
+ suffix = " (Predicted)"
429
+ elif i == gt_idx:
430
+ suffix = "(Ground Truth)"
431
+ else:
432
+ suffix = ""
433
+
434
+ label = base + suffix
435
+ fig.add_trace(go.Scattergl(
436
+ x=[c_proj[i,0]], y=[c_proj[i,1]],
437
+ mode='markers',
438
+ marker=dict(symbol=marker_syms[i], size=12, color='darkblue'),
439
+ name=label,
440
+ hoverinfo='skip'
441
+ ))
442
+
443
+ # query author
444
+ fig.add_trace(go.Scattergl(
445
+ x=[q_proj[0]], y=[q_proj[1]],
446
+ mode='markers',
447
+ marker=dict(symbol='star', size=14, color='red'),
448
+ name='Mystery author',
449
+ hoverinfo='skip'
450
+ ))
451
+
452
+ # ── Arrowed annotations for mystery + candidates ──────────────────────────
453
+ # Mystery author (red star)
454
+ fig.add_annotation(
455
+ x=q_proj[0], y=q_proj[1],
456
+ xref='x', yref='y',
457
+ text="Mystery",
458
+ showarrow=True,
459
+ arrowhead=2,
460
+ arrowsize=1,
461
+ arrowwidth=1.5,
462
+ ax=40, # tail offset in pixels: moves the label 40px to the right
463
+ ay=-40, # moves the label 40px up
464
+ font=dict(color='red', size=12)
465
+ )
466
+
467
+ # Candidate authors (dark blue β—†)
468
+ offsets = [(-40, -30), (40, -30), (0, 40)] # [(ax,ay) for Cand1, Cand2, Cand3]
469
+ for i in range(3):
470
+ # build the right label
471
+ if i == pred_idx and i == gt_idx:
472
+ label = f"Candidate {i+1} (Predicted & Ground Truth)"
473
+ elif i == pred_idx:
474
+ label = f"Candidate {i+1} (Predicted)"
475
+ elif i == gt_idx:
476
+ label = f"Candidate {i+1} (Ground Truth)"
477
+ else:
478
+ label = f"Candidate {i+1}"
479
+
480
+ fig.add_annotation(
481
+ x=c_proj[i,0], y=c_proj[i,1],
482
+ xref='x', yref='y',
483
+ text= label,
484
+ showarrow=True,
485
+ arrowhead=2,
486
+ arrowsize=1,
487
+ arrowwidth=1.5,
488
+ ax=offsets[i][0],
489
+ ay=offsets[i][1],
490
+ font=dict(color='darkblue', size=12)
491
+ )
492
+
493
+ print('Done processing....')
494
+ # Prepare outputs for the new cluster‐dropdown UI
495
+ # all_clusters = sorted(style_names.keys())
496
+ # --- build display names for the dropdown ---
497
+ # sorted_labels = sorted([int(lbl) for lbl in cent_lbl])
498
+ # display_clusters = []
499
+ # for lbl in sorted_labels:
500
+ # name = f"Cluster {lbl}"
501
+ # if lbl == cluster_label_query:
502
+ # name += " (closest to mystery author)"
503
+ # matching_indices = [i + 1 for i, val in enumerate(cluster_labels_per_candidate) if int(val) == lbl]
504
+ # if matching_indices:
505
+ # if len(matching_indices) == 1:
506
+ # name += f" (closest to Candidate {matching_indices[0]} author)"
507
+ # else:
508
+ # candidate_str = ", ".join(f"Candidate {i}" for i in matching_indices)
509
+ # name += f" (closest to {candidate_str} authors)"
510
+ # display_clusters.append(name)
511
+ # print(f"All clusters: {all_clusters}")
512
+ # return: figure, dropdown payload, full style_map
513
+ return (
514
+ fig,
515
+ # update(choices=display_clusters, value=display_clusters[cluster_label_query]),
516
+ style_names,
517
+ bg_proj, # Return background points
518
+ bg_ids, # Return background labels
519
+ background_authors_embeddings_df, # Return the DataFrame for zoom handling
520
+
521
+ )
522
+ # return fig, update(choices=feature_list, value=feature_list[0]),feature_list
523
+
524
+
525
+ def extract_cluster_key(display_label: str) -> int:
526
+ """
527
+ Given a dropdown label like
528
+ "Cluster 5 (closest to mystery author; closest to Candidate 1 author)"
529
+ returns the integer 5.
530
+ """
531
+ m = re.match(r"Cluster\s+(\d+)", display_label)
532
+ if not m:
533
+ raise ValueError(f"Unrecognized cluster label: {display_label}")
534
+ return int(m.group(1))
535
+
536
+
537
+
538
+ # When a cluster is selected, split features and populate radio buttons
539
+ def on_cluster_change(selected_cluster, style_map):
540
+ cluster_key = extract_cluster_key(selected_cluster)
541
+ all_feats = style_map[cluster_key]
542
+ llm_feats, g2v_feats = split_features(all_feats)
543
+ # print(f"Selected cluster: {selected_cluster} ({cluster_key})")
544
+ # print(f"LLM features: {llm_feats}")
545
+
546
+ # Add "None" as a default selectable option
547
+ llm_feats = ["None"] + llm_feats
548
+
549
+ # filter out any g2v feature without a shorthand
550
+ filtered_g2v = []
551
+ for feat in g2v_feats:
552
+ if get_shorthand(feat) is None:
553
+ print(f"Skipping Gram2Vec feature without shorthand: {feat}")
554
+ else:
555
+ filtered_g2v.append(feat)
556
+
557
+ # Add "None" as a default selectable option
558
+ filtered_g2v = ["None"] + filtered_g2v
559
+
560
+ return (
561
+ gr.update(choices=llm_feats, value=llm_feats[0]),
562
+ gr.update(choices=filtered_g2v, value=filtered_g2v[0]),
563
+ llm_feats
564
+ )