Milad Alshomary
commited on
Commit
Β·
3d73c8d
1
Parent(s):
b96061f
updates
Browse files- Dockerfile +16 -0
- add_hf_env_to_hf_space.py +6 -0
- app.py +516 -0
- config/config.yaml +12 -0
- datasets/placeholder.txt +1 -0
- requirements.txt +10 -0
- utils/augmented_human_readable.txt +617 -0
- utils/clustering_utils.py +325 -0
- utils/file_download.py +70 -0
- utils/generate_augmented_mapping.py +86 -0
- utils/gram2vec_feat_utils.py +284 -0
- utils/human_readable.txt +40 -0
- utils/interp_space_utils.py +638 -0
- utils/llm_feat_utils.py +138 -0
- utils/ui.py +225 -0
- utils/visualizations.py +564 -0
Dockerfile
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.9
|
5 |
+
|
6 |
+
RUN useradd -m -u 1000 user
|
7 |
+
USER user
|
8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
9 |
+
|
10 |
+
WORKDIR /app
|
11 |
+
|
12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
13 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
14 |
+
|
15 |
+
COPY --chown=user . /app
|
16 |
+
CMD ["python", "app.py"]
|
add_hf_env_to_hf_space.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi
|
2 |
+
import os
|
3 |
+
repo_id = "miladalsh/explaining_authorship_attribution_models"
|
4 |
+
api = HfApi()
|
5 |
+
|
6 |
+
api.add_space_variable(repo_id=repo_id, key="OPENAI_API_KEY", value=os.environ["OPENAI_API_KEY"])
|
app.py
ADDED
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
import os
|
6 |
+
os.environ["GRADIO_TEMP_DIR"] = "./datasets/temp" # Set a custom temp directory for Gradio
|
7 |
+
os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
|
8 |
+
|
9 |
+
import yaml
|
10 |
+
import argparse
|
11 |
+
import os
|
12 |
+
import urllib.request
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from dotenv import load_dotenv
|
16 |
+
from openai import OpenAI
|
17 |
+
from utils.file_download import download_file_override
|
18 |
+
|
19 |
+
|
20 |
+
def load_config(path="config/config.yaml"):
|
21 |
+
with open(path, "r") as f:
|
22 |
+
return yaml.safe_load(f)
|
23 |
+
|
24 |
+
cfg = load_config()
|
25 |
+
|
26 |
+
|
27 |
+
download_file_override(cfg.get('interp_space_url'), cfg.get('interp_space_path'))
|
28 |
+
download_file_override(cfg.get('instances_to_explain_url'), cfg.get('instances_to_explain_path'))
|
29 |
+
download_file_override(cfg.get('gram2vec_feats_url'), cfg.get('gram2vec_feats_path'))
|
30 |
+
|
31 |
+
from utils.visualizations import *
|
32 |
+
from utils.llm_feat_utils import *
|
33 |
+
from utils.gram2vec_feat_utils import *
|
34 |
+
from utils.interp_space_utils import *
|
35 |
+
from utils.ui import *
|
36 |
+
|
37 |
+
load_dotenv()
|
38 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
39 |
+
|
40 |
+
|
41 |
+
# ββ load once at startup ββββββββββββββββββββββββββββββββββββββββ
|
42 |
+
GRAM2VEC_SHORTHAND = load_code_map()
|
43 |
+
|
44 |
+
def validate_ground_truth(gt1, gt2, gt3):
|
45 |
+
selected = [gt1, gt2, gt3]
|
46 |
+
selected_count = sum(selected)
|
47 |
+
|
48 |
+
if selected_count > 1:
|
49 |
+
return None, "Please select only one ground truth author."
|
50 |
+
elif selected_count == 0:
|
51 |
+
return None, "No ground truth author selected."
|
52 |
+
|
53 |
+
index = selected.index(True)
|
54 |
+
return index, f"Candidate {index+1} is marked as the ground truth author."
|
55 |
+
|
56 |
+
|
57 |
+
def app(share=False, use_cluster_feats=False):
|
58 |
+
instances, instance_ids = get_instances(cfg['instances_to_explain_path'])
|
59 |
+
|
60 |
+
interp = load_interp_space(cfg)
|
61 |
+
clustered_authors_df = interp['clustered_authors_df'][:1000]
|
62 |
+
clustered_authors_df['fullText'] = clustered_authors_df['fullText'].map(lambda l: l[:3]) # Take at most 3 texts per author
|
63 |
+
|
64 |
+
with gr.Blocks(title="Author Attribution Explainability Tool") as demo:
|
65 |
+
# ββ Big Centered Title ββββββββββββββββββββββββββββββββββββββββββ
|
66 |
+
gr.HTML(styled_block("""
|
67 |
+
<h1 style="
|
68 |
+
text-align:center;
|
69 |
+
font-size:3em; /* About 48px */
|
70 |
+
margin-bottom:0.3em;
|
71 |
+
font-weight:700;
|
72 |
+
">
|
73 |
+
Author Attribution (AA) Explainability Tool
|
74 |
+
</h1>
|
75 |
+
"""))
|
76 |
+
|
77 |
+
gr.HTML(styled_block("""
|
78 |
+
<div style="
|
79 |
+
text-align:center;
|
80 |
+
margin: 1em auto 2em auto;
|
81 |
+
max-width:900px;
|
82 |
+
">
|
83 |
+
<p style="font-size:1.3em; line-height:1.4;">
|
84 |
+
This demo helps you <strong>see inside</strong> a deep AA modelβs latent style space.
|
85 |
+
</p>
|
86 |
+
<p style="font-size:0.9em; line-height:1.4;">
|
87 |
+
Currently you are inspecting <a href="https://huggingface.co/rrivera1849/LUAR-MUD">LUAR</a> with pre-defined AA tasks from the <a href="https://www.iarpa.gov/images/research-programs/HIATUS/IARPA_HIATUS_Phase_1_HRS_Data.to_DAC_20240610.pdf">HRS dataset </a>
|
88 |
+
</p>
|
89 |
+
<div style="
|
90 |
+
display:flex;
|
91 |
+
justify-content:center;
|
92 |
+
gap:3em;
|
93 |
+
margin-top:1em;
|
94 |
+
">
|
95 |
+
<!-- Visualize -->
|
96 |
+
<div style="max-width:200px;">
|
97 |
+
<div style="font-size:2em;">π</div>
|
98 |
+
<h4 style="margin:0.2em 0;">Visualize</h4>
|
99 |
+
<p style="margin:0; font-size:1em; line-height:1.3;">
|
100 |
+
Place your AA task with respect to other background authors.
|
101 |
+
</p>
|
102 |
+
</div>
|
103 |
+
<!-- GENERATE -->
|
104 |
+
<div style="max-width:200px;">
|
105 |
+
<div style="font-size:2em;">βοΈ</div>
|
106 |
+
<h4 style="margin:0.2em 0;">Generate</h4>
|
107 |
+
<p style="margin:0; font-size:1em; line-height:1.3;">
|
108 |
+
Describe your investigated authors' writing style via human-readable LLM-generated style features.
|
109 |
+
</p>
|
110 |
+
</div>
|
111 |
+
<!-- COMPARE -->
|
112 |
+
<div style="max-width:200px;">
|
113 |
+
<div style="font-size:2em;">βοΈ</div>
|
114 |
+
<h4 style="margin:0.2em 0;">Compare</h4>
|
115 |
+
<p style="margin:0; font-size:1em; line-height:1.3;">
|
116 |
+
Contrast with <a href=""https://github.com/eric-sclafani/gram2vec>Gram2Vec</a> stylometric features.
|
117 |
+
</p>
|
118 |
+
</div>
|
119 |
+
</div>
|
120 |
+
</div>
|
121 |
+
"""))
|
122 |
+
|
123 |
+
|
124 |
+
# ββ Step-by-Step Guided Panel ββ
|
125 |
+
with gr.Accordion("π How to Use", open=True):
|
126 |
+
gr.Markdown("""
|
127 |
+
1. **Select** a model and a task source (pre-defined or custom)
|
128 |
+
2. Click **Load Task & Generate Embeddings** to load the task and generate embeddings
|
129 |
+
3. **Run Visualization** to see the mystery author and candidates in the AA model's latent space
|
130 |
+
4. **Zoom** into the visualization to select a cluster of background authors
|
131 |
+
5. Pick an **LLM feature** to highlight in yellow
|
132 |
+
6. Pick a **Gram2Vec feature** to highlight in blue
|
133 |
+
7. Click **Show Combined Spans** to compare side-by-side
|
134 |
+
"""
|
135 |
+
)
|
136 |
+
|
137 |
+
# ββ Model Selection βββββββββββββββββββββββββββββββββ
|
138 |
+
model_radio = gr.Radio(
|
139 |
+
choices=[
|
140 |
+
'gabrielloiseau/LUAR-MUD-sentence-transformers',
|
141 |
+
'gabrielloiseau/LUAR-CRUD-sentence-transformers',
|
142 |
+
'miladalsh/light-luar',
|
143 |
+
'AnnaWegmann/Style-Embedding',
|
144 |
+
'Other'
|
145 |
+
],
|
146 |
+
value='gabrielloiseau/LUAR-MUD-sentence-transformers',
|
147 |
+
label='Choose a Model to inspect'
|
148 |
+
)
|
149 |
+
print(f"Model choices: {model_radio.choices}")
|
150 |
+
print(f"Model default: {model_radio.value}")
|
151 |
+
custom_model = gr.Textbox(
|
152 |
+
label='Custom Model ID',
|
153 |
+
placeholder='Enter your Hugging Face Model ID here',
|
154 |
+
visible=False,
|
155 |
+
interactive=True
|
156 |
+
)
|
157 |
+
# Show the textbox when 'Other' is selected
|
158 |
+
model_radio.change(
|
159 |
+
fn=toggle_custom_model,
|
160 |
+
inputs=[model_radio],
|
161 |
+
outputs=[custom_model]
|
162 |
+
)
|
163 |
+
|
164 |
+
# ββ Task Source Selection βββββββββββββββββββββββββββββββββ
|
165 |
+
task_mode = gr.Radio(
|
166 |
+
choices=["Predefined HRS Task", "Upload Your Own Task"],
|
167 |
+
value="Predefined HRS Task",
|
168 |
+
label="Select Task Source"
|
169 |
+
)
|
170 |
+
|
171 |
+
ground_truth_author = gr.State() # To store the index of the ground truth author
|
172 |
+
|
173 |
+
with gr.Column():
|
174 |
+
with gr.Column(visible=True) as predefined_container:
|
175 |
+
gr.HTML("""
|
176 |
+
<div style="
|
177 |
+
font-size: 1.3em;
|
178 |
+
font-weight: 600;
|
179 |
+
margin-bottom: 0.5em;
|
180 |
+
">
|
181 |
+
Pick a pre-defined task to investigate (a mystery text and its three candidate authors)
|
182 |
+
</div>
|
183 |
+
""")
|
184 |
+
task_dropdown = gr.Dropdown(
|
185 |
+
choices=[f"Task {i}" for i in instance_ids],
|
186 |
+
value=f"Task {instance_ids[0]}",
|
187 |
+
label="Choose which mystery document to explain",
|
188 |
+
)
|
189 |
+
with gr.Column(visible=False) as custom_container:
|
190 |
+
gr.HTML("""
|
191 |
+
<div style="
|
192 |
+
font-size: 1.3em;
|
193 |
+
font-weight: 600;
|
194 |
+
margin-bottom: 0.5em;
|
195 |
+
">
|
196 |
+
Upload your own task
|
197 |
+
</div>
|
198 |
+
""")
|
199 |
+
mystery_input = gr.File(label="Mystery (.txt)", file_types=['.txt'])
|
200 |
+
with gr.Row():
|
201 |
+
candidate1 = gr.File(label="Candidate 1 (.txt)", file_types=['.txt'])
|
202 |
+
gt1_checkbox = gr.Checkbox(label="Ground Truth?", value=False)
|
203 |
+
|
204 |
+
with gr.Row():
|
205 |
+
candidate2 = gr.File(label="Candidate 2 (.txt)", file_types=['.txt'])
|
206 |
+
gt2_checkbox = gr.Checkbox(label="Ground Truth?", value=False)
|
207 |
+
|
208 |
+
with gr.Row():
|
209 |
+
candidate3 = gr.File(label="Candidate 3 (.txt)", file_types=['.txt'])
|
210 |
+
gt3_checkbox = gr.Checkbox(label="Ground Truth?", value=False)
|
211 |
+
|
212 |
+
validation_msg = gr.Textbox(label="Validation Result", interactive=False)
|
213 |
+
|
214 |
+
for checkbox in [gt1_checkbox, gt2_checkbox, gt3_checkbox]:
|
215 |
+
checkbox.change(
|
216 |
+
fn=validate_ground_truth,
|
217 |
+
inputs=[gt1_checkbox, gt2_checkbox, gt3_checkbox],
|
218 |
+
outputs=[ground_truth_author, validation_msg]
|
219 |
+
)
|
220 |
+
|
221 |
+
|
222 |
+
# ββ Load Task Button βββββββββββββββββββββββββββββββββββββ
|
223 |
+
gr.HTML(instruction_callout("Click the button below to load the tasks and generate embeddings using selected model."))
|
224 |
+
load_button = gr.Button("Load Task & Generate Embeddings")
|
225 |
+
|
226 |
+
# ββ HTML outputs for author texts βββββββββββββββββββββββββββ
|
227 |
+
default_outputs = load_instance(0, instances)
|
228 |
+
#dont need defaults since they are loaded only on click of the load button
|
229 |
+
header = gr.HTML()
|
230 |
+
mystery = gr.HTML()
|
231 |
+
mystery_state = gr.State() # Store unformatted mystery text for later use
|
232 |
+
with gr.Row():
|
233 |
+
c0 = gr.HTML()
|
234 |
+
c1 = gr.HTML()
|
235 |
+
c2 = gr.HTML()
|
236 |
+
c0_state = gr.State() # Store unformatted candidate 1 text for later use
|
237 |
+
c1_state = gr.State() # Store unformatted candidate 2 text for later use
|
238 |
+
c2_state = gr.State() # Store unformatted candidate 3 text for later use
|
239 |
+
# ββ State to hold embeddings DataFrame βββββββββββββββββββββ
|
240 |
+
task_authors_embeddings_df = gr.State() # Store embeddings of task authors
|
241 |
+
background_authors_embeddings_df = gr.State() # Store background authors DataFrame
|
242 |
+
task_mode.change(
|
243 |
+
fn=toggle_task,
|
244 |
+
inputs=[task_mode],
|
245 |
+
outputs=[predefined_container, custom_container]
|
246 |
+
)
|
247 |
+
# ββ Wire call to load task and generate embeddings once load button is clicked βββββββββββββββββββ
|
248 |
+
predicted_author = gr.State() # Store predicted author from the embeddings
|
249 |
+
load_button.click(
|
250 |
+
fn=lambda: gr.update(value="β³ Loading... Please wait", interactive=False),
|
251 |
+
inputs=[],
|
252 |
+
outputs=[load_button]
|
253 |
+
).then(
|
254 |
+
fn=lambda mode, dropdown, mystery, c1, c2, c3, ground_truth_author, model_radio, custom_model_input:
|
255 |
+
update_task_display(
|
256 |
+
mode,
|
257 |
+
dropdown,
|
258 |
+
instances, # closed over
|
259 |
+
clustered_authors_df,
|
260 |
+
mystery,
|
261 |
+
c1,
|
262 |
+
c2,
|
263 |
+
c3,
|
264 |
+
ground_truth_author, # true_author placeholder
|
265 |
+
model_radio,
|
266 |
+
custom_model_input
|
267 |
+
),
|
268 |
+
inputs=[task_mode, task_dropdown, mystery_input, candidate1, candidate2, candidate3, ground_truth_author, model_radio, custom_model],
|
269 |
+
outputs=[header, mystery, c0, c1, c2, mystery_state, c0_state, c1_state, c2_state, task_authors_embeddings_df, background_authors_embeddings_df, predicted_author, ground_truth_author] # embeddings_df is a placeholder for now
|
270 |
+
).then(
|
271 |
+
fn=lambda: gr.update(value="Load Task & Generate Embeddings", interactive=True),
|
272 |
+
inputs=[],
|
273 |
+
outputs=[load_button]
|
274 |
+
)
|
275 |
+
|
276 |
+
# ββ Visualization for features βββββββββββββββββββββββββββββ
|
277 |
+
gr.HTML(instruction_callout("Run visualization to see which author is similar to the mystery document."))
|
278 |
+
run_btn = gr.Button("Run visualization")
|
279 |
+
bg_proj_state = gr.State()
|
280 |
+
bg_lbls_state = gr.State()
|
281 |
+
bg_authors_df = gr.State() # Holds the background authors DataFrame
|
282 |
+
with gr.Row():
|
283 |
+
with gr.Column(scale=3):
|
284 |
+
axis_ranges = gr.Textbox(visible=False, elem_id="axis-ranges")
|
285 |
+
plot = gr.Plot(
|
286 |
+
label="Visualization",
|
287 |
+
elem_id="feature-plot",
|
288 |
+
)
|
289 |
+
plot.change(
|
290 |
+
fn=None,
|
291 |
+
inputs=[plot],
|
292 |
+
outputs=[axis_ranges],
|
293 |
+
js="""
|
294 |
+
function(){
|
295 |
+
console.log("------------>[JS] plot.change() triggered<------------");
|
296 |
+
|
297 |
+
let attempts = 0;
|
298 |
+
const maxAttempts = 50;
|
299 |
+
|
300 |
+
const tryAttach = () => {
|
301 |
+
const gd = document.querySelector('#feature-plot .js-plotly-plot');
|
302 |
+
if (!gd) {
|
303 |
+
if (++attempts < maxAttempts) {
|
304 |
+
requestAnimationFrame(tryAttach);
|
305 |
+
} else {
|
306 |
+
console.error(" ------------>Could not find .js-plotly-plot after multiple attempts.<------------");
|
307 |
+
}
|
308 |
+
return;
|
309 |
+
}
|
310 |
+
|
311 |
+
if (gd.__zoomListenerAttached) {
|
312 |
+
console.log("------------>Zoom listener already attached.<------------");
|
313 |
+
return;
|
314 |
+
}
|
315 |
+
|
316 |
+
gd.__zoomListenerAttached = true;
|
317 |
+
console.log("------------>Zoom listener attached!<------------");
|
318 |
+
|
319 |
+
gd.on('plotly_relayout', (ev) => {
|
320 |
+
if (
|
321 |
+
ev['xaxis.range[0]'] === undefined ||
|
322 |
+
ev['xaxis.range[1]'] === undefined ||
|
323 |
+
ev['yaxis.range[0]'] === undefined ||
|
324 |
+
ev['yaxis.range[1]'] === undefined
|
325 |
+
) return;
|
326 |
+
|
327 |
+
const payload = {
|
328 |
+
xaxis: [ev['xaxis.range[0]'], ev['xaxis.range[1]']],
|
329 |
+
yaxis: [ev['yaxis.range[0]'], ev['yaxis.range[1]']]
|
330 |
+
};
|
331 |
+
|
332 |
+
const txtbox = document.querySelector('#axis-ranges textarea');
|
333 |
+
if (txtbox) {
|
334 |
+
txtbox.value = JSON.stringify(payload);
|
335 |
+
txtbox.dispatchEvent(new Event('input', { bubbles: true }));
|
336 |
+
console.log("------------> Zoom payload dispatched:<------------", payload);
|
337 |
+
} else {
|
338 |
+
console.warn("------------> No hidden textbox found to write zoom payload.<------------");
|
339 |
+
}
|
340 |
+
});
|
341 |
+
};
|
342 |
+
|
343 |
+
requestAnimationFrame(tryAttach);
|
344 |
+
return '';
|
345 |
+
}
|
346 |
+
"""
|
347 |
+
)
|
348 |
+
|
349 |
+
|
350 |
+
with gr.Column(scale=1):
|
351 |
+
expl_html = """
|
352 |
+
<h4>What am I looking at?</h4>
|
353 |
+
<p>
|
354 |
+
This plot shows the mystery author (β
) and three candidate authors (β)
|
355 |
+
in the AA modelβs latent space.<br>
|
356 |
+
The grey β symbols represent the background corpusβreal authors with diverse writing styles.
|
357 |
+
You can zoom in on any region of the plot. The system will analyze the visible authors
|
358 |
+
in that area and list the most representative writing style features for the zoomed-in region.<br>
|
359 |
+
Use this to compare your mystery textβs position against nearby writing styles and
|
360 |
+
investigate which features distinguish it from others.
|
361 |
+
</p>
|
362 |
+
"""
|
363 |
+
gr.HTML(styled_html(expl_html))
|
364 |
+
|
365 |
+
cluster_dropdown = gr.Dropdown(choices=[], label="Select Cluster to Inspect", visible=False)
|
366 |
+
style_map_state = gr.State()
|
367 |
+
llm_style_feats_analysis = gr.State()
|
368 |
+
visible_zoomed_authors = gr.State()
|
369 |
+
|
370 |
+
if use_cluster_feats:
|
371 |
+
# ββ Dynamic Cluster Choice dropdown ββββββββββββββββββββββββββββββββββ
|
372 |
+
gr.HTML(instruction_callout("Choose a cluster from the dropdown below to inspect whether its features appear in the mystery authorβs text."))
|
373 |
+
cluster_dropdown.visible = True
|
374 |
+
else:
|
375 |
+
gr.HTML(instruction_callout("Zoom in on the plot to select a set of background authors and see the presence of the top features from this set in candidate and mystery authors."))
|
376 |
+
|
377 |
+
with gr.Row():
|
378 |
+
# ββ LLM Features Column ββββββββββββββββββββββββββββββββββ
|
379 |
+
with gr.Column(scale=1, min_width=0):
|
380 |
+
# gr.Markdown("**Features from the cluster closest to the Mystery Author**")
|
381 |
+
gr.HTML("""
|
382 |
+
<div style="
|
383 |
+
font-size: 1.3em;
|
384 |
+
font-weight: 600;
|
385 |
+
margin-bottom: 0.5em;
|
386 |
+
">
|
387 |
+
LLM-derived style features prominent in the zoomed-in region
|
388 |
+
</div>
|
389 |
+
""")
|
390 |
+
features_rb = gr.Radio(choices=[], label="LLM-derived style features for this zoomed-in region")#, label="Features from the cluster closest to the Mystery Author", info="LLM-derived style features for this cluster")
|
391 |
+
feature_list_state = gr.State()
|
392 |
+
|
393 |
+
# ββ Gram2Vec Features Column βββββββββββββββββββββββββββββ
|
394 |
+
with gr.Column(scale=1, min_width=0):
|
395 |
+
# gr.Markdown("**Top-10 Gram2Vec Features most likely to occur in Mystery Author**")
|
396 |
+
gr.HTML("""
|
397 |
+
<div style="
|
398 |
+
font-size: 1.3em;
|
399 |
+
font-weight: 600;
|
400 |
+
margin-bottom: 0.5em;
|
401 |
+
">
|
402 |
+
Gram2Vec Features prominent in the zoomed-in region
|
403 |
+
</div>
|
404 |
+
""")
|
405 |
+
gram2vec_rb = gr.Radio(choices=[], label="Gram2Vec features for this zoomed-in region")#, label="Top-10 Gram2Vec Features most likely to occur in Mystery Author", info="Most prominent Gram2Vec features in the mystery text")
|
406 |
+
gram2vec_state = gr.State()
|
407 |
+
|
408 |
+
# ββ Visualization button click βββββββββββββββββββββββββββββββ
|
409 |
+
run_btn.click(
|
410 |
+
fn=lambda iid, model_radio, custom_model_input, task_authors_embeddings_df, background_authors_embeddings_df, predicted_author, ground_truth_author: visualize_clusters_plotly(
|
411 |
+
int(iid.replace('Task ','')), cfg, instances, model_radio,
|
412 |
+
custom_model_input, task_authors_embeddings_df, background_authors_embeddings_df, predicted_author, ground_truth_author
|
413 |
+
),
|
414 |
+
inputs=[task_dropdown, model_radio, custom_model, task_authors_embeddings_df, background_authors_embeddings_df, predicted_author, ground_truth_author],
|
415 |
+
outputs=[plot, style_map_state, bg_proj_state, bg_lbls_state, bg_authors_df]
|
416 |
+
)
|
417 |
+
|
418 |
+
# Populate feature list based on selection.
|
419 |
+
if use_cluster_feats:
|
420 |
+
# Use cluster-based flow
|
421 |
+
cluster_dropdown.change(
|
422 |
+
fn=on_cluster_change,
|
423 |
+
inputs=[cluster_dropdown, style_map_state],
|
424 |
+
outputs=[features_rb, gram2vec_rb , feature_list_state]
|
425 |
+
#adding feature_list_state to persisit all llm features in the app state
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
|
429 |
+
axis_ranges.change(
|
430 |
+
fn=handle_zoom_with_retries,
|
431 |
+
inputs=[axis_ranges, bg_proj_state, bg_lbls_state, bg_authors_df, task_authors_embeddings_df],
|
432 |
+
outputs=[features_rb, gram2vec_rb , llm_style_feats_analysis, feature_list_state, visible_zoomed_authors]
|
433 |
+
)
|
434 |
+
|
435 |
+
|
436 |
+
# ββ Show combined featureβspan highlights ββ
|
437 |
+
# combined callout + legend in one HTML block
|
438 |
+
gr.HTML(
|
439 |
+
instruction_callout(
|
440 |
+
"Click \"Show Combined Spans\" to highlight the LLM (yellow) & Gram2Vec (blue) feature spans in the texts"
|
441 |
+
)
|
442 |
+
+ """
|
443 |
+
<div style="
|
444 |
+
display: flex;
|
445 |
+
align-items: center;
|
446 |
+
justify-content: center;
|
447 |
+
gap: 2em;
|
448 |
+
margin-top: 0.5em;
|
449 |
+
font-size: 0.9em;
|
450 |
+
">
|
451 |
+
<div style="display: flex; align-items: center; gap: 0.5em; font-weight: 600; font-size: 1.5em;">
|
452 |
+
<span style="
|
453 |
+
display: inline-block;
|
454 |
+
width: 1.5em; height: 1.5em;
|
455 |
+
background: #FFEB3B; /* bright yellow */
|
456 |
+
border: 1px solid #666;
|
457 |
+
vertical-align: middle;
|
458 |
+
"></span>
|
459 |
+
LLM feature
|
460 |
+
</div>
|
461 |
+
<div style="display: flex; align-items: center; gap: 0.5em; font-weight: 600; font-size: 1.5em;">
|
462 |
+
<span style="
|
463 |
+
display: inline-block;
|
464 |
+
width: 1.5em; height: 1.5em;
|
465 |
+
background: #5CB3FF; /* clearer blue */
|
466 |
+
border: 1px solid #666;
|
467 |
+
vertical-align: middle;
|
468 |
+
"></span>
|
469 |
+
Gram2Vec feature
|
470 |
+
</div>
|
471 |
+
</div>
|
472 |
+
"""
|
473 |
+
)
|
474 |
+
|
475 |
+
|
476 |
+
combined_btn = gr.Button("Show Combined Spans")
|
477 |
+
combined_html = gr.HTML()
|
478 |
+
show_background_checkbox = gr.Checkbox(label="Show spans in background authors", value=False)
|
479 |
+
background_html = gr.HTML(visible=False)
|
480 |
+
# print(f"in app: all_feats={feature_list_state.value}")
|
481 |
+
# print(f"in app: sel_feat_llm={features_rb.value}")
|
482 |
+
|
483 |
+
|
484 |
+
combined_btn.click(
|
485 |
+
fn=show_combined_spans_all,
|
486 |
+
inputs=[features_rb,
|
487 |
+
gram2vec_rb,
|
488 |
+
llm_style_feats_analysis,
|
489 |
+
background_authors_embeddings_df,
|
490 |
+
task_authors_embeddings_df,
|
491 |
+
visible_zoomed_authors,
|
492 |
+
predicted_author,
|
493 |
+
ground_truth_author],
|
494 |
+
outputs=[combined_html, background_html]
|
495 |
+
)
|
496 |
+
# mapping -->
|
497 |
+
# iid = task_dropdown.value
|
498 |
+
# sel_feat_llm = features_rb.value
|
499 |
+
# all_feats = feature_list_state.value
|
500 |
+
# sel_feat_g2v = gram2vec_rb.value
|
501 |
+
# combined_html -> spans/html for task authors
|
502 |
+
# background_html -> spans/html for background authors
|
503 |
+
|
504 |
+
show_background_checkbox.change(
|
505 |
+
fn=lambda show: gr.update(visible=show),
|
506 |
+
inputs=[show_background_checkbox],
|
507 |
+
outputs=[background_html]
|
508 |
+
)
|
509 |
+
|
510 |
+
demo.launch(share=share)
|
511 |
+
|
512 |
+
if __name__ == "__main__":
|
513 |
+
parser = argparse.ArgumentParser()
|
514 |
+
parser.add_argument("--use_cluster_feats", action="store_true", help="Use cluster-based selection for features")
|
515 |
+
args = parser.parse_args()
|
516 |
+
app(share=True, use_cluster_feats=args.use_cluster_feats)
|
config/config.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# config.yaml
|
2 |
+
instances_to_explain_path: "./datasets/hrs_explanations.json"
|
3 |
+
instances_to_explain_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/raw/main/hrs_explanations.json?download"
|
4 |
+
interp_space_path: "./datasets/luar_interp_space_cluster_19/"
|
5 |
+
interp_space_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/luar_interp_space_cluster.zip?download=true"
|
6 |
+
gram2vec_feats_path: "./datasets/gram2vec_feats.csv"
|
7 |
+
gram2vec_feats_url: "https://huggingface.co/datasets/miladalsh/explanation_tool_files/resolve/main/gram2vec_feats.csv?download=true"
|
8 |
+
|
9 |
+
style_feat_clm: "llm_tfidf_weights"
|
10 |
+
top_k: 10
|
11 |
+
only_llm_feats: false
|
12 |
+
only_gram2vec_feats: false
|
datasets/placeholder.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
test
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nltk
|
2 |
+
spacy
|
3 |
+
scikit-learn
|
4 |
+
openai
|
5 |
+
python-dotenv
|
6 |
+
gradio==5.30
|
7 |
+
pyyaml
|
8 |
+
plotly
|
9 |
+
sentence_transformers
|
10 |
+
git+https://github.com/MiladAlshomary/gram2vec
|
utils/augmented_human_readable.txt
ADDED
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Adjective:ADJ
|
2 |
+
Adposition:ADP
|
3 |
+
Adverb:ADV
|
4 |
+
Auxiliary verb:AUX
|
5 |
+
Coordinating conjunction:CCONJ
|
6 |
+
Determiner:DET
|
7 |
+
Interjection:INTJ
|
8 |
+
Noun:NOUN
|
9 |
+
Numeral:NUM
|
10 |
+
Particle:PART
|
11 |
+
Pronoun:PRON
|
12 |
+
Proper noun:PROPN
|
13 |
+
Punctuation:PUNCT
|
14 |
+
Subordinating conjunction:SCONJ
|
15 |
+
Symbol:SYM
|
16 |
+
Verb:VERB
|
17 |
+
Other:X
|
18 |
+
Space:SPACE
|
19 |
+
Other (foreign words, typos, abbreviations):X
|
20 |
+
ADP: Adposition (preposition or postposition)
|
21 |
+
|
22 |
+
Perfect aspect:Aspect=Perf
|
23 |
+
Progressive aspect:Aspect=Prog
|
24 |
+
Accusative case:Case=Acc
|
25 |
+
Nominative case:Case=Nom
|
26 |
+
Definite article:Definite=Def
|
27 |
+
Indefinite article:Definite=Ind
|
28 |
+
Comparative degree:Degree=Cmp
|
29 |
+
Positive degree:Degree=Pos
|
30 |
+
Superlative degree:Degree=Sup
|
31 |
+
Feminine gender:Gender=Fem
|
32 |
+
Masculine gender:Gender=Masc
|
33 |
+
Indicative mood:Mood=Ind
|
34 |
+
Plural number:Number=Plur
|
35 |
+
Singular number:Number=Sing
|
36 |
+
First person:Person=1
|
37 |
+
Second person:Person=2
|
38 |
+
Third person:Person=3
|
39 |
+
Past tense:Tense=Past
|
40 |
+
Present tense:Tense=Pres
|
41 |
+
Finite verb form:VerbForm=Fin
|
42 |
+
Infinitive verb form:VerbForm=Inf
|
43 |
+
|
44 |
+
Adjective followed by Adjective:ADJ ADJ
|
45 |
+
Adjective followed by Adposition:ADJ ADP
|
46 |
+
Adjective followed by Adverb:ADJ ADV
|
47 |
+
Adjective followed by Auxiliary verb:ADJ AUX
|
48 |
+
Adjective followed by Coordinating conjunction:ADJ CCONJ
|
49 |
+
Adjective followed by Determiner:ADJ DET
|
50 |
+
Adjective followed by Interjection:ADJ INTJ
|
51 |
+
Adjective followed by Noun:ADJ NOUN
|
52 |
+
Adjective followed by Numeral:ADJ NUM
|
53 |
+
Adjective followed by Other:ADJ X
|
54 |
+
Adjective followed by Particle:ADJ PART
|
55 |
+
Adjective followed by Pronoun:ADJ PRON
|
56 |
+
Adjective followed by Proper noun:ADJ PROPN
|
57 |
+
Adjective followed by Punctuation:ADJ PUNCT
|
58 |
+
Adjective followed by Subordinating conjunction:ADJ SCONJ
|
59 |
+
Adjective followed by Symbol:ADJ SYM
|
60 |
+
Adjective followed by Verb:ADJ VERB
|
61 |
+
Adposition followed by Adjective:ADP ADJ
|
62 |
+
Adposition followed by Adposition:ADP ADP
|
63 |
+
Adposition followed by Adverb:ADP ADV
|
64 |
+
Adposition followed by Auxiliary verb:ADP AUX
|
65 |
+
Adposition followed by Coordinating conjunction:ADP CCONJ
|
66 |
+
Adposition followed by Determiner:ADP DET
|
67 |
+
Adposition followed by Interjection:ADP INTJ
|
68 |
+
Adposition followed by Noun:ADP NOUN
|
69 |
+
Adposition followed by Numeral:ADP NUM
|
70 |
+
Adposition followed by Other:ADP X
|
71 |
+
Adposition followed by Particle:ADP PART
|
72 |
+
Adposition followed by Pronoun:ADP PRON
|
73 |
+
Adposition followed by Proper noun:ADP PROPN
|
74 |
+
Adposition followed by Punctuation:ADP PUNCT
|
75 |
+
Adposition followed by Subordinating conjunction:ADP SCONJ
|
76 |
+
Adposition followed by Symbol:ADP SYM
|
77 |
+
Adposition followed by Verb:ADP VERB
|
78 |
+
Adverb followed by Adjective:ADV ADJ
|
79 |
+
Adverb followed by Adposition:ADV ADP
|
80 |
+
Adverb followed by Adverb:ADV ADV
|
81 |
+
Adverb followed by Auxiliary verb:ADV AUX
|
82 |
+
Adverb followed by Coordinating conjunction:ADV CCONJ
|
83 |
+
Adverb followed by Determiner:ADV DET
|
84 |
+
Adverb followed by Interjection:ADV INTJ
|
85 |
+
Adverb followed by Noun:ADV NOUN
|
86 |
+
Adverb followed by Numeral:ADV NUM
|
87 |
+
Adverb followed by Other:ADV X
|
88 |
+
Adverb followed by Particle:ADV PART
|
89 |
+
Adverb followed by Pronoun:ADV PRON
|
90 |
+
Adverb followed by Proper noun:ADV PROPN
|
91 |
+
Adverb followed by Punctuation:ADV PUNCT
|
92 |
+
Adverb followed by Subordinating conjunction:ADV SCONJ
|
93 |
+
Adverb followed by Symbol:ADV SYM
|
94 |
+
Adverb followed by Verb:ADV VERB
|
95 |
+
Auxiliary verb followed by Adjective:AUX ADJ
|
96 |
+
Auxiliary verb followed by Adposition:AUX ADP
|
97 |
+
Auxiliary verb followed by Adverb:AUX ADV
|
98 |
+
Auxiliary verb followed by Auxiliary verb:AUX AUX
|
99 |
+
Auxiliary verb followed by Coordinating conjunction:AUX CCONJ
|
100 |
+
Auxiliary verb followed by Determiner:AUX DET
|
101 |
+
Auxiliary verb followed by Interjection:AUX INTJ
|
102 |
+
Auxiliary verb followed by Noun:AUX NOUN
|
103 |
+
Auxiliary verb followed by Numeral:AUX NUM
|
104 |
+
Auxiliary verb followed by Other:AUX X
|
105 |
+
Auxiliary verb followed by Particle:AUX PART
|
106 |
+
Auxiliary verb followed by Pronoun:AUX PRON
|
107 |
+
Auxiliary verb followed by Proper noun:AUX PROPN
|
108 |
+
Auxiliary verb followed by Punctuation:AUX PUNCT
|
109 |
+
Auxiliary verb followed by Subordinating conjunction:AUX SCONJ
|
110 |
+
Auxiliary verb followed by Symbol:AUX SYM
|
111 |
+
Auxiliary verb followed by Verb:AUX VERB
|
112 |
+
Coordinating conjunction followed by Adjective:CCONJ ADJ
|
113 |
+
Coordinating conjunction followed by Adposition:CCONJ ADP
|
114 |
+
Coordinating conjunction followed by Adverb:CCONJ ADV
|
115 |
+
Coordinating conjunction followed by Auxiliary verb:CCONJ AUX
|
116 |
+
Coordinating conjunction followed by Coordinating conjunction:CCONJ CCONJ
|
117 |
+
Coordinating conjunction followed by Determiner:CCONJ DET
|
118 |
+
Coordinating conjunction followed by Interjection:CCONJ INTJ
|
119 |
+
Coordinating conjunction followed by Noun:CCONJ NOUN
|
120 |
+
Coordinating conjunction followed by Numeral:CCONJ NUM
|
121 |
+
Coordinating conjunction followed by Other:CCONJ X
|
122 |
+
Coordinating conjunction followed by Particle:CCONJ PART
|
123 |
+
Coordinating conjunction followed by Pronoun:CCONJ PRON
|
124 |
+
Coordinating conjunction followed by Proper noun:CCONJ PROPN
|
125 |
+
Coordinating conjunction followed by Punctuation:CCONJ PUNCT
|
126 |
+
Coordinating conjunction followed by Subordinating conjunction:CCONJ SCONJ
|
127 |
+
Coordinating conjunction followed by Symbol:CCONJ SYM
|
128 |
+
Coordinating conjunction followed by Verb:CCONJ VERB
|
129 |
+
Determiner followed by Adjective:DET ADJ
|
130 |
+
Determiner followed by Adposition:DET ADP
|
131 |
+
Determiner followed by Adverb:DET ADV
|
132 |
+
Determiner followed by Auxiliary verb:DET AUX
|
133 |
+
Determiner followed by Coordinating conjunction:DET CCONJ
|
134 |
+
Determiner followed by Determiner:DET DET
|
135 |
+
Determiner followed by Interjection:DET INTJ
|
136 |
+
Determiner followed by Noun:DET NOUN
|
137 |
+
Determiner followed by Numeral:DET NUM
|
138 |
+
Determiner followed by Other:DET X
|
139 |
+
Determiner followed by Particle:DET PART
|
140 |
+
Determiner followed by Pronoun:DET PRON
|
141 |
+
Determiner followed by Proper noun:DET PROPN
|
142 |
+
Determiner followed by Punctuation:DET PUNCT
|
143 |
+
Determiner followed by Subordinating conjunction:DET SCONJ
|
144 |
+
Determiner followed by Symbol:DET SYM
|
145 |
+
Determiner followed by Verb:DET VERB
|
146 |
+
Interjection followed by Adjective:INTJ ADJ
|
147 |
+
Interjection followed by Adposition:INTJ ADP
|
148 |
+
Interjection followed by Adverb:INTJ ADV
|
149 |
+
Interjection followed by Auxiliary verb:INTJ AUX
|
150 |
+
Interjection followed by Coordinating conjunction:INTJ CCONJ
|
151 |
+
Interjection followed by Determiner:INTJ DET
|
152 |
+
Interjection followed by Interjection:INTJ INTJ
|
153 |
+
Interjection followed by Noun:INTJ NOUN
|
154 |
+
Interjection followed by Numeral:INTJ NUM
|
155 |
+
Interjection followed by Other:INTJ X
|
156 |
+
Interjection followed by Particle:INTJ PART
|
157 |
+
Interjection followed by Pronoun:INTJ PRON
|
158 |
+
Interjection followed by Proper noun:INTJ PROPN
|
159 |
+
Interjection followed by Punctuation:INTJ PUNCT
|
160 |
+
Interjection followed by Subordinating conjunction:INTJ SCONJ
|
161 |
+
Interjection followed by Symbol:INTJ SYM
|
162 |
+
Interjection followed by Verb:INTJ VERB
|
163 |
+
Noun followed by Adjective:NOUN ADJ
|
164 |
+
Noun followed by Adposition:NOUN ADP
|
165 |
+
Noun followed by Adverb:NOUN ADV
|
166 |
+
Noun followed by Auxiliary verb:NOUN AUX
|
167 |
+
Noun followed by Coordinating conjunction:NOUN CCONJ
|
168 |
+
Noun followed by Determiner:NOUN DET
|
169 |
+
Noun followed by Interjection:NOUN INTJ
|
170 |
+
Noun followed by Noun:NOUN NOUN
|
171 |
+
Noun followed by Numeral:NOUN NUM
|
172 |
+
Noun followed by Other:NOUN X
|
173 |
+
Noun followed by Particle:NOUN PART
|
174 |
+
Noun followed by Pronoun:NOUN PRON
|
175 |
+
Noun followed by Proper noun:NOUN PROPN
|
176 |
+
Noun followed by Punctuation:NOUN PUNCT
|
177 |
+
Noun followed by Subordinating conjunction:NOUN SCONJ
|
178 |
+
Noun followed by Symbol:NOUN SYM
|
179 |
+
Noun followed by Verb:NOUN VERB
|
180 |
+
Numeral followed by Adjective:NUM ADJ
|
181 |
+
Numeral followed by Adposition:NUM ADP
|
182 |
+
Numeral followed by Adverb:NUM ADV
|
183 |
+
Numeral followed by Auxiliary verb:NUM AUX
|
184 |
+
Numeral followed by Coordinating conjunction:NUM CCONJ
|
185 |
+
Numeral followed by Determiner:NUM DET
|
186 |
+
Numeral followed by Interjection:NUM INTJ
|
187 |
+
Numeral followed by Noun:NUM NOUN
|
188 |
+
Numeral followed by Numeral:NUM NUM
|
189 |
+
Numeral followed by Other:NUM X
|
190 |
+
Numeral followed by Particle:NUM PART
|
191 |
+
Numeral followed by Pronoun:NUM PRON
|
192 |
+
Numeral followed by Proper noun:NUM PROPN
|
193 |
+
Numeral followed by Punctuation:NUM PUNCT
|
194 |
+
Numeral followed by Subordinating conjunction:NUM SCONJ
|
195 |
+
Numeral followed by Symbol:NUM SYM
|
196 |
+
Numeral followed by Verb:NUM VERB
|
197 |
+
Other followed by Adjective:X ADJ
|
198 |
+
Other followed by Adposition:X ADP
|
199 |
+
Other followed by Adverb:X ADV
|
200 |
+
Other followed by Auxiliary verb:X AUX
|
201 |
+
Other followed by Coordinating conjunction:X CCONJ
|
202 |
+
Other followed by Determiner:X DET
|
203 |
+
Other followed by Interjection:X INTJ
|
204 |
+
Other followed by Noun:X NOUN
|
205 |
+
Other followed by Numeral:X NUM
|
206 |
+
Other followed by Other:X X
|
207 |
+
Other followed by Particle:X PART
|
208 |
+
Other followed by Pronoun:X PRON
|
209 |
+
Other followed by Proper noun:X PROPN
|
210 |
+
Other followed by Punctuation:X PUNCT
|
211 |
+
Other followed by Subordinating conjunction:X SCONJ
|
212 |
+
Other followed by Symbol:X SYM
|
213 |
+
Other followed by Verb:X VERB
|
214 |
+
Particle followed by Adjective:PART ADJ
|
215 |
+
Particle followed by Adposition:PART ADP
|
216 |
+
Particle followed by Adverb:PART ADV
|
217 |
+
Particle followed by Auxiliary verb:PART AUX
|
218 |
+
Particle followed by Coordinating conjunction:PART CCONJ
|
219 |
+
Particle followed by Determiner:PART DET
|
220 |
+
Particle followed by Interjection:PART INTJ
|
221 |
+
Particle followed by Noun:PART NOUN
|
222 |
+
Particle followed by Numeral:PART NUM
|
223 |
+
Particle followed by Other:PART X
|
224 |
+
Particle followed by Particle:PART PART
|
225 |
+
Particle followed by Pronoun:PART PRON
|
226 |
+
Particle followed by Proper noun:PART PROPN
|
227 |
+
Particle followed by Punctuation:PART PUNCT
|
228 |
+
Particle followed by Subordinating conjunction:PART SCONJ
|
229 |
+
Particle followed by Symbol:PART SYM
|
230 |
+
Particle followed by Verb:PART VERB
|
231 |
+
Pronoun followed by Adjective:PRON ADJ
|
232 |
+
Pronoun followed by Adposition:PRON ADP
|
233 |
+
Pronoun followed by Adverb:PRON ADV
|
234 |
+
Pronoun followed by Auxiliary verb:PRON AUX
|
235 |
+
Pronoun followed by Coordinating conjunction:PRON CCONJ
|
236 |
+
Pronoun followed by Determiner:PRON DET
|
237 |
+
Pronoun followed by Interjection:PRON INTJ
|
238 |
+
Pronoun followed by Noun:PRON NOUN
|
239 |
+
Pronoun followed by Numeral:PRON NUM
|
240 |
+
Pronoun followed by Other:PRON X
|
241 |
+
Pronoun followed by Particle:PRON PART
|
242 |
+
Pronoun followed by Pronoun:PRON PRON
|
243 |
+
Pronoun followed by Proper noun:PRON PROPN
|
244 |
+
Pronoun followed by Punctuation:PRON PUNCT
|
245 |
+
Pronoun followed by Subordinating conjunction:PRON SCONJ
|
246 |
+
Pronoun followed by Symbol:PRON SYM
|
247 |
+
Pronoun followed by Verb:PRON VERB
|
248 |
+
Proper noun followed by Adjective:PROPN ADJ
|
249 |
+
Proper noun followed by Adposition:PROPN ADP
|
250 |
+
Proper noun followed by Adverb:PROPN ADV
|
251 |
+
Proper noun followed by Auxiliary verb:PROPN AUX
|
252 |
+
Proper noun followed by Coordinating conjunction:PROPN CCONJ
|
253 |
+
Proper noun followed by Determiner:PROPN DET
|
254 |
+
Proper noun followed by Interjection:PROPN INTJ
|
255 |
+
Proper noun followed by Noun:PROPN NOUN
|
256 |
+
Proper noun followed by Numeral:PROPN NUM
|
257 |
+
Proper noun followed by Other:PROPN X
|
258 |
+
Proper noun followed by Particle:PROPN PART
|
259 |
+
Proper noun followed by Pronoun:PROPN PRON
|
260 |
+
Proper noun followed by Proper noun:PROPN PROPN
|
261 |
+
Proper noun followed by Punctuation:PROPN PUNCT
|
262 |
+
Proper noun followed by Subordinating conjunction:PROPN SCONJ
|
263 |
+
Proper noun followed by Symbol:PROPN SYM
|
264 |
+
Proper noun followed by Verb:PROPN VERB
|
265 |
+
Punctuation followed by Adjective:PUNCT ADJ
|
266 |
+
Punctuation followed by Adposition:PUNCT ADP
|
267 |
+
Punctuation followed by Adverb:PUNCT ADV
|
268 |
+
Punctuation followed by Auxiliary verb:PUNCT AUX
|
269 |
+
Punctuation followed by Coordinating conjunction:PUNCT CCONJ
|
270 |
+
Punctuation followed by Determiner:PUNCT DET
|
271 |
+
Punctuation followed by Interjection:PUNCT INTJ
|
272 |
+
Punctuation followed by Noun:PUNCT NOUN
|
273 |
+
Punctuation followed by Numeral:PUNCT NUM
|
274 |
+
Punctuation followed by Other:PUNCT X
|
275 |
+
Punctuation followed by Particle:PUNCT PART
|
276 |
+
Punctuation followed by Pronoun:PUNCT PRON
|
277 |
+
Punctuation followed by Proper noun:PUNCT PROPN
|
278 |
+
Punctuation followed by Punctuation:PUNCT PUNCT
|
279 |
+
Punctuation followed by Subordinating conjunction:PUNCT SCONJ
|
280 |
+
Punctuation followed by Symbol:PUNCT SYM
|
281 |
+
Punctuation followed by Verb:PUNCT VERB
|
282 |
+
Subordinating conjunction followed by Adjective:SCONJ ADJ
|
283 |
+
Subordinating conjunction followed by Adposition:SCONJ ADP
|
284 |
+
Subordinating conjunction followed by Adverb:SCONJ ADV
|
285 |
+
Subordinating conjunction followed by Auxiliary verb:SCONJ AUX
|
286 |
+
Subordinating conjunction followed by Coordinating conjunction:SCONJ CCONJ
|
287 |
+
Subordinating conjunction followed by Determiner:SCONJ DET
|
288 |
+
Subordinating conjunction followed by Interjection:SCONJ INTJ
|
289 |
+
Subordinating conjunction followed by Noun:SCONJ NOUN
|
290 |
+
Subordinating conjunction followed by Numeral:SCONJ NUM
|
291 |
+
Subordinating conjunction followed by Other:SCONJ X
|
292 |
+
Subordinating conjunction followed by Particle:SCONJ PART
|
293 |
+
Subordinating conjunction followed by Pronoun:SCONJ PRON
|
294 |
+
Subordinating conjunction followed by Proper noun:SCONJ PROPN
|
295 |
+
Subordinating conjunction followed by Punctuation:SCONJ PUNCT
|
296 |
+
Subordinating conjunction followed by Subordinating conjunction:SCONJ SCONJ
|
297 |
+
Subordinating conjunction followed by Symbol:SCONJ SYM
|
298 |
+
Subordinating conjunction followed by Verb:SCONJ VERB
|
299 |
+
Symbol followed by Adjective:SYM ADJ
|
300 |
+
Symbol followed by Adposition:SYM ADP
|
301 |
+
Symbol followed by Adverb:SYM ADV
|
302 |
+
Symbol followed by Auxiliary verb:SYM AUX
|
303 |
+
Symbol followed by Coordinating conjunction:SYM CCONJ
|
304 |
+
Symbol followed by Determiner:SYM DET
|
305 |
+
Symbol followed by Interjection:SYM INTJ
|
306 |
+
Symbol followed by Noun:SYM NOUN
|
307 |
+
Symbol followed by Numeral:SYM NUM
|
308 |
+
Symbol followed by Other:SYM X
|
309 |
+
Symbol followed by Particle:SYM PART
|
310 |
+
Symbol followed by Pronoun:SYM PRON
|
311 |
+
Symbol followed by Proper noun:SYM PROPN
|
312 |
+
Symbol followed by Punctuation:SYM PUNCT
|
313 |
+
Symbol followed by Subordinating conjunction:SYM SCONJ
|
314 |
+
Symbol followed by Symbol:SYM SYM
|
315 |
+
Symbol followed by Verb:SYM VERB
|
316 |
+
Verb followed by Adjective:VERB ADJ
|
317 |
+
Verb followed by Adposition:VERB ADP
|
318 |
+
Verb followed by Adverb:VERB ADV
|
319 |
+
Verb followed by Auxiliary verb:VERB AUX
|
320 |
+
Verb followed by Coordinating conjunction:VERB CCONJ
|
321 |
+
Verb followed by Determiner:VERB DET
|
322 |
+
Verb followed by Interjection:VERB INTJ
|
323 |
+
Verb followed by Noun:VERB NOUN
|
324 |
+
Verb followed by Numeral:VERB NUM
|
325 |
+
Verb followed by Other:VERB X
|
326 |
+
Verb followed by Particle:VERB PART
|
327 |
+
Verb followed by Pronoun:VERB PRON
|
328 |
+
Verb followed by Proper noun:VERB PROPN
|
329 |
+
Verb followed by Punctuation:VERB PUNCT
|
330 |
+
Verb followed by Subordinating conjunction:VERB SCONJ
|
331 |
+
Verb followed by Symbol:VERB SYM
|
332 |
+
Verb followed by Verb:VERB VERB
|
333 |
+
|
334 |
+
Accusative case:Case=Acc
|
335 |
+
Comparative degree:Degree=Cmp
|
336 |
+
Definite article:Definite=Def
|
337 |
+
Feminine gender:Gender=Fem
|
338 |
+
Finite verb form:VerbForm=Fin
|
339 |
+
First person:Person=1
|
340 |
+
Indefinite article:Definite=Ind
|
341 |
+
Indicative mood:Mood=Ind
|
342 |
+
Infinitive verb form:VerbForm=Inf
|
343 |
+
Masculine gender:Gender=Masc
|
344 |
+
Nominative case:Case=Nom
|
345 |
+
Past tense:Tense=Past
|
346 |
+
Perfect aspect:Aspect=Perf
|
347 |
+
Plural number:Number=Plur
|
348 |
+
Positive degree:Degree=Pos
|
349 |
+
Present tense:Tense=Pres
|
350 |
+
Progressive aspect:Aspect=Prog
|
351 |
+
Second person:Person=2
|
352 |
+
Singular number:Number=Sing
|
353 |
+
Superlative degree:Degree=Sup
|
354 |
+
Third person:Person=3
|
355 |
+
Number of Tokens:num_tokens
|
356 |
+
|
357 |
+
Adjectival clause:acl
|
358 |
+
Adjectival complement:acomp
|
359 |
+
Adjectival modifier:amod
|
360 |
+
Adverbial clause modifier:advcl
|
361 |
+
Adverbial modifier:advmod
|
362 |
+
Agent (in passive voice):agent
|
363 |
+
Appositional modifier:appos
|
364 |
+
Attribute:attr
|
365 |
+
Case marking:case
|
366 |
+
Clausal complement:ccomp
|
367 |
+
Clausal subject:csubj
|
368 |
+
Clausal subject (passive):csubjpass
|
369 |
+
Complement of preposition:pcomp
|
370 |
+
Compound word:compound
|
371 |
+
Conjunct:conj
|
372 |
+
Coordinating conjunction:cc
|
373 |
+
Adposition (preposition or postposition):ADP
|
374 |
+
Dative:dative
|
375 |
+
Direct object:dobj
|
376 |
+
Expletive:expl
|
377 |
+
Marker (introducing adverbial clause):mark
|
378 |
+
Meta modifier:meta
|
379 |
+
Negation modifier:neg
|
380 |
+
Nominal modifier:nmod
|
381 |
+
Nominal subject (passive):nsubjpass
|
382 |
+
Noun phrase as adverbial modifier:npadvmod
|
383 |
+
Numeric modifier:nummod
|
384 |
+
Object of preposition:pobj
|
385 |
+
Object predicate:oprd
|
386 |
+
Open clausal complement:xcomp
|
387 |
+
Parataxis:parataxis
|
388 |
+
Passive auxiliary:auxpass
|
389 |
+
Possession modifier:poss
|
390 |
+
Pre-correlative conjunction:preconj
|
391 |
+
Predeterminer:predet
|
392 |
+
Prepositional modifier:prep
|
393 |
+
Quantifier modifier:quantmod
|
394 |
+
Relative clause modifier:relcl
|
395 |
+
Root of the sentence:ROOT
|
396 |
+
Unspecified dependency:dep
|
397 |
+
|
398 |
+
Article pronoun type:PronType=Art
|
399 |
+
Bracket punctuation type:PunctType=Brck
|
400 |
+
Cardinal number:NumType=Card
|
401 |
+
Comma punctuation type:PunctType=Comm
|
402 |
+
Comparative conjunction type:ConjType=Cmp
|
403 |
+
Dash punctuation type:PunctType=Dash
|
404 |
+
Demonstrative pronoun type:PronType=Dem
|
405 |
+
Final punctuation:PunctSide=Fin
|
406 |
+
Foreign word:Foreign=Yes
|
407 |
+
Gerund verb form:VerbForm=Ger
|
408 |
+
Hyphenated:Hyph=Yes
|
409 |
+
Indefinite pronoun type:PronType=Ind
|
410 |
+
Initial punctuation:PunctSide=Ini
|
411 |
+
Modal verb type:VerbType=Mod
|
412 |
+
Multiplicative number:NumType=Mult
|
413 |
+
Negative polarity:Polarity=Neg
|
414 |
+
Neuter gender:Gender=Neut
|
415 |
+
Ordinal number:NumType=Ord
|
416 |
+
Participle verb form:VerbForm=Part
|
417 |
+
Period punctuation type:PunctType=Peri
|
418 |
+
Possessive:Poss=Yes
|
419 |
+
Quotation punctuation type:PunctType=Quot
|
420 |
+
Reflexive:Reflex=Yes
|
421 |
+
Relative pronoun type:PronType=Rel
|
422 |
+
|
423 |
+
β€οΈ:β€οΈ
|
424 |
+
π:π
|
425 |
+
π:π
|
426 |
+
π:π
|
427 |
+
|
428 |
+
!:!
|
429 |
+
":\\"
|
430 |
+
%:%
|
431 |
+
&:&
|
432 |
+
':'
|
433 |
+
(:(
|
434 |
+
):\)
|
435 |
+
*:*
|
436 |
+
,:,
|
437 |
+
-:-
|
438 |
+
.:.
|
439 |
+
;:;
|
440 |
+
?:?
|
441 |
+
_:_
|
442 |
+
`:`
|
443 |
+
β:β
|
444 |
+
':'
|
445 |
+
':'
|
446 |
+
|
447 |
+
all-cleft:all-cleft
|
448 |
+
coordinate-clause:coordinate-clause
|
449 |
+
if-because-cleft:if-because-cleft
|
450 |
+
it-cleft:it-cleft
|
451 |
+
obj-relcl:obj-relcl
|
452 |
+
passive:passive
|
453 |
+
pseudo-cleft:pseudo-cleft
|
454 |
+
subj-relcl:subj-relcl
|
455 |
+
tag-question:tag-question
|
456 |
+
there-cleft:there-cleft
|
457 |
+
|
458 |
+
punctuation:punctuation
|
459 |
+
|
460 |
+
Articles:DET
|
461 |
+
Auxiliary Verbs:AUX
|
462 |
+
Conjunctions:CCONJ
|
463 |
+
Prepositions:ADP
|
464 |
+
|
465 |
+
Personal Pronouns:category:Personal Pronouns
|
466 |
+
Demonstrative Pronouns:category:Demonstrative Pronouns
|
467 |
+
Interrogative Pronouns:category:Interrogative Pronouns
|
468 |
+
Modal Verbs:category:Modal Verbs
|
469 |
+
Contractions:category:Contractions
|
470 |
+
Adverbs:category:Adverbs
|
471 |
+
Other:category:Other
|
472 |
+
|
473 |
+
i:i
|
474 |
+
me:me
|
475 |
+
my:my
|
476 |
+
myself:myself
|
477 |
+
we:we
|
478 |
+
our:our
|
479 |
+
ours:ours
|
480 |
+
ourselves:ourselves
|
481 |
+
you:you
|
482 |
+
're:'re
|
483 |
+
've:'ve
|
484 |
+
'll:'ll
|
485 |
+
'd:'d
|
486 |
+
's:'s
|
487 |
+
't:'t
|
488 |
+
your:your
|
489 |
+
yours:yours
|
490 |
+
yourself:yourself
|
491 |
+
yourselves:yourselves
|
492 |
+
he:he
|
493 |
+
him:him
|
494 |
+
his:his
|
495 |
+
himself:himself
|
496 |
+
she:she
|
497 |
+
her:her
|
498 |
+
ers:ers
|
499 |
+
herself:herself
|
500 |
+
it:it
|
501 |
+
its:its
|
502 |
+
itself:itself
|
503 |
+
they:they
|
504 |
+
them:them
|
505 |
+
their:their
|
506 |
+
theirs:theirs
|
507 |
+
themselves:themselves
|
508 |
+
what:what
|
509 |
+
which:which
|
510 |
+
who:who
|
511 |
+
this:this
|
512 |
+
that:that
|
513 |
+
these:these
|
514 |
+
those:those
|
515 |
+
am:am
|
516 |
+
is:is
|
517 |
+
are:are
|
518 |
+
was:was
|
519 |
+
were:were
|
520 |
+
be:be
|
521 |
+
been:been
|
522 |
+
being:being
|
523 |
+
have:have
|
524 |
+
has:has
|
525 |
+
had:had
|
526 |
+
having:having
|
527 |
+
do:do
|
528 |
+
does:does
|
529 |
+
did:did
|
530 |
+
doing:doing
|
531 |
+
a:a
|
532 |
+
an:an
|
533 |
+
the:the
|
534 |
+
and:and
|
535 |
+
but:but
|
536 |
+
if:if
|
537 |
+
or:or
|
538 |
+
because:because
|
539 |
+
as:as
|
540 |
+
until:until
|
541 |
+
while:while
|
542 |
+
of:of
|
543 |
+
at:at
|
544 |
+
by:by
|
545 |
+
for:for
|
546 |
+
with:with
|
547 |
+
about:about
|
548 |
+
against:against
|
549 |
+
between:between
|
550 |
+
into:into
|
551 |
+
through:through
|
552 |
+
during:during
|
553 |
+
before:before
|
554 |
+
after:after
|
555 |
+
above:above
|
556 |
+
below:below
|
557 |
+
to:to
|
558 |
+
from:from
|
559 |
+
up:up
|
560 |
+
down:down
|
561 |
+
in:in
|
562 |
+
out:out
|
563 |
+
on:on
|
564 |
+
off:off
|
565 |
+
over:over
|
566 |
+
under:under
|
567 |
+
again:again
|
568 |
+
further:further
|
569 |
+
then:then
|
570 |
+
once:once
|
571 |
+
here:here
|
572 |
+
there:there
|
573 |
+
when:when
|
574 |
+
where:where
|
575 |
+
why:why
|
576 |
+
how:how
|
577 |
+
all:all
|
578 |
+
any:any
|
579 |
+
both:both
|
580 |
+
each:each
|
581 |
+
few:few
|
582 |
+
more:more
|
583 |
+
most:most
|
584 |
+
other:other
|
585 |
+
some:some
|
586 |
+
such:such
|
587 |
+
no:no
|
588 |
+
nor:nor
|
589 |
+
not:not
|
590 |
+
only:only
|
591 |
+
own:own
|
592 |
+
same:same
|
593 |
+
so:so
|
594 |
+
than:than
|
595 |
+
too:too
|
596 |
+
very:very
|
597 |
+
can:can
|
598 |
+
will:will
|
599 |
+
just:just
|
600 |
+
don:don
|
601 |
+
should:should
|
602 |
+
now:now
|
603 |
+
ain:ain
|
604 |
+
aren:aren
|
605 |
+
couldn:couldn
|
606 |
+
didn:didn
|
607 |
+
doesn:doesn
|
608 |
+
hadn:hadn
|
609 |
+
hasn:hasn
|
610 |
+
haven:haven
|
611 |
+
isn:isn
|
612 |
+
ma:ma
|
613 |
+
shouldn:shouldn
|
614 |
+
wasn:wasn
|
615 |
+
weren:weren
|
616 |
+
won:won
|
617 |
+
wouldn:wouldn
|
utils/clustering_utils.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Required for clustering_author function:
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.cluster import DBSCAN
|
5 |
+
from sklearn.metrics import silhouette_score
|
6 |
+
# Required for analyze_space_distance_preservation
|
7 |
+
from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
|
8 |
+
from scipy.stats import pearsonr
|
9 |
+
from typing import List, Dict, Any
|
10 |
+
|
11 |
+
def _find_best_dbscan_eps(X: np.ndarray,
|
12 |
+
eps_values: List[float],
|
13 |
+
min_samples: int,
|
14 |
+
metric: str) -> tuple[float | None, np.ndarray | None, float]:
|
15 |
+
"""
|
16 |
+
Iterates through eps_values for DBSCAN and returns the parameters
|
17 |
+
that yield the highest silhouette score.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
X (np.ndarray): The input data (embeddings).
|
21 |
+
eps_values (List[float]): List of eps values to try.
|
22 |
+
min_samples (int): DBSCAN min_samples parameter.
|
23 |
+
metric (str): Distance metric for DBSCAN and silhouette score.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
tuple[float | None, np.ndarray | None, float]:
|
27 |
+
- best_eps: The eps value that resulted in the best score. None if no suitable clustering.
|
28 |
+
- best_labels: The cluster labels from the best DBSCAN run. None if no suitable clustering.
|
29 |
+
- best_score: The highest silhouette score achieved.
|
30 |
+
"""
|
31 |
+
best_score = -1.001 # Silhouette score is in [-1, 1]
|
32 |
+
best_labels = None
|
33 |
+
best_eps = None
|
34 |
+
|
35 |
+
for eps in eps_values:
|
36 |
+
if eps <= 1e-9: # eps must be positive
|
37 |
+
continue
|
38 |
+
db = DBSCAN(eps=eps, min_samples=min_samples, metric=metric)
|
39 |
+
labels = db.fit_predict(X)
|
40 |
+
|
41 |
+
unique_labels_set = set(labels)
|
42 |
+
n_clusters_ = len(unique_labels_set) - (1 if -1 in unique_labels_set else 0)
|
43 |
+
|
44 |
+
if n_clusters_ > 1:
|
45 |
+
clustered_mask = (labels != -1)
|
46 |
+
if np.sum(clustered_mask) >= 2: # Need at least 2 non-noise points
|
47 |
+
X_clustered = X[clustered_mask]
|
48 |
+
labels_clustered = labels[clustered_mask]
|
49 |
+
try:
|
50 |
+
score = silhouette_score(X_clustered, labels_clustered, metric=metric)
|
51 |
+
if score > best_score:
|
52 |
+
best_score = score
|
53 |
+
best_labels = labels.copy()
|
54 |
+
best_eps = eps
|
55 |
+
except ValueError: # Catch errors from silhouette_score
|
56 |
+
pass
|
57 |
+
elif n_clusters_ == 1 and best_labels is None: # Fallback for single cluster
|
58 |
+
if not all(l == -1 for l in labels):
|
59 |
+
current_score_for_single_cluster = -0.5 # Nominal score
|
60 |
+
if current_score_for_single_cluster > best_score:
|
61 |
+
best_score = current_score_for_single_cluster
|
62 |
+
best_labels = labels.copy()
|
63 |
+
best_eps = eps
|
64 |
+
return best_eps, best_labels, best_score
|
65 |
+
|
66 |
+
def clustering_author(background_corpus_df: pd.DataFrame,
|
67 |
+
embedding_clm: str = 'style_embedding',
|
68 |
+
eps_values: List[float] = None,
|
69 |
+
min_samples: int = 5,
|
70 |
+
metric: str = 'cosine') -> pd.DataFrame:
|
71 |
+
"""
|
72 |
+
Performs DBSCAN clustering on embeddings in a DataFrame.
|
73 |
+
|
74 |
+
Experiments with different `eps` parameters to find a clustering
|
75 |
+
that maximizes the silhouette score, indicating well-separated clusters.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
background_corpus_df (pd.DataFrame): DataFrame with an embedding column.
|
79 |
+
embedding_clm (str): Name of the column containing embeddings.
|
80 |
+
Each embedding should be a list or NumPy array.
|
81 |
+
eps_values (List[float], optional): Specific `eps` values to test.
|
82 |
+
If None, a default range is used.
|
83 |
+
For 'cosine' metric, eps is typically in [0, 2].
|
84 |
+
For 'euclidean', scale depends on embedding magnitudes.
|
85 |
+
min_samples (int): DBSCAN `min_samples` parameter. Minimum number of
|
86 |
+
samples in a neighborhood for a point to be a core point.
|
87 |
+
metric (str): The distance metric to use for DBSCAN and silhouette score
|
88 |
+
(e.g., 'cosine', 'euclidean').
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
pd.DataFrame: The input DataFrame with a new 'cluster_label' column.
|
92 |
+
Labels are from the DBSCAN run with the highest silhouette score.
|
93 |
+
If no suitable clustering is found, labels might be all -1 (noise).
|
94 |
+
"""
|
95 |
+
if embedding_clm not in background_corpus_df.columns:
|
96 |
+
raise ValueError(f"Embedding column '{embedding_clm}' not found in DataFrame.")
|
97 |
+
|
98 |
+
embeddings_list = background_corpus_df[embedding_clm].tolist()
|
99 |
+
|
100 |
+
X_list = []
|
101 |
+
original_indices = [] # To map results back to the original DataFrame's indices
|
102 |
+
|
103 |
+
for i, emb_val in enumerate(embeddings_list):
|
104 |
+
if emb_val is not None:
|
105 |
+
try:
|
106 |
+
e = np.asarray(emb_val, dtype=float)
|
107 |
+
if e.ndim == 1 and e.size > 0: # Standard 1D vector
|
108 |
+
X_list.append(e)
|
109 |
+
original_indices.append(i)
|
110 |
+
elif e.ndim == 0 and e.size == 1: # Scalar value, treat as 1D vector of size 1
|
111 |
+
X_list.append(np.array([e.item()]))
|
112 |
+
original_indices.append(i)
|
113 |
+
# Silently skip empty arrays or improperly shaped arrays
|
114 |
+
except (TypeError, ValueError):
|
115 |
+
# Silently skip if conversion to float array fails
|
116 |
+
pass
|
117 |
+
|
118 |
+
# Initialize labels for all rows in the original DataFrame to -1 (noise/unprocessed)
|
119 |
+
final_labels_for_df = pd.Series(-1, index=background_corpus_df.index, dtype=int)
|
120 |
+
|
121 |
+
if not X_list:
|
122 |
+
print(f"No valid embeddings found in column '{embedding_clm}'. Assigning all 'cluster_label' as -1.")
|
123 |
+
background_corpus_df['cluster_label'] = final_labels_for_df
|
124 |
+
return background_corpus_df
|
125 |
+
|
126 |
+
X = np.array(X_list) # Creates a 2D array from the list of 1D arrays
|
127 |
+
|
128 |
+
if X.shape[0] == 1:
|
129 |
+
print("Only one valid embedding found. Assigning cluster label 0 to it.")
|
130 |
+
if original_indices: # Should always be true if X.shape[0]==1 from X_list
|
131 |
+
final_labels_for_df.iloc[original_indices[0]] = 0
|
132 |
+
background_corpus_df['cluster_label'] = final_labels_for_df
|
133 |
+
return background_corpus_df
|
134 |
+
|
135 |
+
if X.shape[0] < min_samples:
|
136 |
+
print(f"Number of valid embeddings ({X.shape[0]}) is less than min_samples ({min_samples}). "
|
137 |
+
f"All valid embeddings will be marked as noise (-1).")
|
138 |
+
for original_idx in original_indices:
|
139 |
+
final_labels_for_df.iloc[original_idx] = -1
|
140 |
+
background_corpus_df['cluster_label'] = final_labels_for_df
|
141 |
+
return background_corpus_df
|
142 |
+
|
143 |
+
if eps_values is None:
|
144 |
+
if metric == 'cosine':
|
145 |
+
eps_values = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
146 |
+
else:
|
147 |
+
if X.shape[0] > 1:
|
148 |
+
data_spread = np.std(X)
|
149 |
+
eps_values = [round(data_spread * f, 2) for f in [0.25, 0.5, 1.0]]
|
150 |
+
eps_values = [e for e in eps_values if e > 1e-6]
|
151 |
+
if not eps_values or X.shape[0] <=1:
|
152 |
+
eps_values = [0.5, 1.0, 1.5]
|
153 |
+
print(f"Warning: `eps_values` not provided. Using default range for metric '{metric}': {eps_values}. "
|
154 |
+
f"It's recommended to supply `eps_values` tuned to your data.")
|
155 |
+
|
156 |
+
print(f"Performing DBSCAN clustering (min_samples={min_samples}, metric='{metric}') with eps values: "
|
157 |
+
f"{[f'{e:.2f}' for e in eps_values]}")
|
158 |
+
|
159 |
+
best_eps, best_labels, best_score = _find_best_dbscan_eps(X, eps_values, min_samples, metric)
|
160 |
+
|
161 |
+
if best_labels is not None:
|
162 |
+
num_found_clusters = len(set(best_labels) - {-1})
|
163 |
+
print(f"Best clustering found: eps={best_eps:.2f}, Silhouette Score={best_score:.4f} ({num_found_clusters} clusters).")
|
164 |
+
for i, label in enumerate(best_labels):
|
165 |
+
original_df_idx = original_indices[i]
|
166 |
+
final_labels_for_df.iloc[original_df_idx] = label
|
167 |
+
else:
|
168 |
+
print("No suitable DBSCAN clustering found meeting criteria. All processed embeddings marked as noise (-1).")
|
169 |
+
|
170 |
+
background_corpus_df['cluster_label'] = final_labels_for_df
|
171 |
+
return background_corpus_df
|
172 |
+
|
173 |
+
|
174 |
+
def _safe_embeddings_to_matrix(embeddings_column: pd.Series) -> np.ndarray:
|
175 |
+
"""
|
176 |
+
Converts a pandas Series of embeddings (expected to be lists of floats or 1D np.arrays)
|
177 |
+
into a 2D NumPy matrix. Handles None values and attempts to stack consistently.
|
178 |
+
Returns an empty 2D array (e.g., shape (0,0) or (0,D)) if conversion fails or no valid data.
|
179 |
+
"""
|
180 |
+
embeddings_list = embeddings_column.tolist()
|
181 |
+
|
182 |
+
processed_1d_arrays = []
|
183 |
+
for emb in embeddings_list:
|
184 |
+
if emb is not None:
|
185 |
+
if hasattr(emb, '__iter__') and not isinstance(emb, (str, bytes)):
|
186 |
+
try:
|
187 |
+
arr = np.asarray(emb, dtype=float)
|
188 |
+
if arr.ndim == 1 and arr.size > 0:
|
189 |
+
processed_1d_arrays.append(arr)
|
190 |
+
except (TypeError, ValueError):
|
191 |
+
pass # Ignore embeddings that cannot be converted
|
192 |
+
|
193 |
+
if not processed_1d_arrays:
|
194 |
+
return np.empty((0,0))
|
195 |
+
|
196 |
+
# Check for consistent dimensionality before vstacking
|
197 |
+
first_len = processed_1d_arrays[0].shape[0]
|
198 |
+
consistent_embeddings = [arr for arr in processed_1d_arrays if arr.shape[0] == first_len]
|
199 |
+
|
200 |
+
if not consistent_embeddings:
|
201 |
+
return np.empty((0, first_len if processed_1d_arrays else 0)) # (0,D) or (0,0)
|
202 |
+
|
203 |
+
try:
|
204 |
+
return np.vstack(consistent_embeddings)
|
205 |
+
except ValueError:
|
206 |
+
# Should not happen if lengths are consistent
|
207 |
+
return np.empty((0, first_len))
|
208 |
+
|
209 |
+
|
210 |
+
def _compute_cluster_centroids(
|
211 |
+
df_clustered_items: pd.DataFrame, # DataFrame already filtered for non-noise items
|
212 |
+
embedding_clm: str,
|
213 |
+
cluster_label_clm: str
|
214 |
+
) -> Dict[Any, np.ndarray]:
|
215 |
+
"""Computes the centroid for each cluster from a pre-filtered DataFrame."""
|
216 |
+
centroids = {}
|
217 |
+
if df_clustered_items.empty:
|
218 |
+
return centroids
|
219 |
+
|
220 |
+
for cluster_id, group in df_clustered_items.groupby(cluster_label_clm):
|
221 |
+
embeddings_matrix = _safe_embeddings_to_matrix(group[embedding_clm])
|
222 |
+
|
223 |
+
if embeddings_matrix.ndim == 2 and embeddings_matrix.shape[0] > 0 and embeddings_matrix.shape[1] > 0:
|
224 |
+
centroids[cluster_id] = np.mean(embeddings_matrix, axis=0)
|
225 |
+
return centroids
|
226 |
+
|
227 |
+
|
228 |
+
def _project_to_centroid_space(
|
229 |
+
original_embeddings_matrix: np.ndarray, # (n_items, n_original_features)
|
230 |
+
centroids_map: Dict[Any, np.ndarray] # {cluster_id: centroid_vector (n_original_features,)}
|
231 |
+
) -> np.ndarray:
|
232 |
+
"""Projects embeddings into a new space defined by cluster centroids using cosine similarity."""
|
233 |
+
if not centroids_map or original_embeddings_matrix.ndim != 2 or \
|
234 |
+
original_embeddings_matrix.shape[0] == 0 or original_embeddings_matrix.shape[1] == 0:
|
235 |
+
return np.empty((original_embeddings_matrix.shape[0], 0)) # (n_items, 0_new_features)
|
236 |
+
|
237 |
+
sorted_cluster_ids = sorted(centroids_map.keys())
|
238 |
+
|
239 |
+
valid_centroid_vectors = []
|
240 |
+
for cid in sorted_cluster_ids:
|
241 |
+
centroid_vec = centroids_map[cid]
|
242 |
+
if isinstance(centroid_vec, np.ndarray) and centroid_vec.ndim == 1 and \
|
243 |
+
centroid_vec.size == original_embeddings_matrix.shape[1]:
|
244 |
+
valid_centroid_vectors.append(centroid_vec)
|
245 |
+
|
246 |
+
if not valid_centroid_vectors:
|
247 |
+
return np.empty((original_embeddings_matrix.shape[0], 0))
|
248 |
+
|
249 |
+
centroid_matrix = np.vstack(valid_centroid_vectors) # (n_valid_centroids, n_original_features)
|
250 |
+
|
251 |
+
# Result: (n_items, n_valid_centroids)
|
252 |
+
projected_matrix = cosine_similarity(original_embeddings_matrix, centroid_matrix)
|
253 |
+
return projected_matrix
|
254 |
+
|
255 |
+
|
256 |
+
def _get_pairwise_cosine_distances(embeddings_matrix: np.ndarray) -> np.ndarray:
|
257 |
+
"""Calculates unique pairwise cosine distances from an embedding matrix."""
|
258 |
+
if not isinstance(embeddings_matrix, np.ndarray) or embeddings_matrix.ndim != 2 or \
|
259 |
+
embeddings_matrix.shape[0] < 2 or embeddings_matrix.shape[1] == 0:
|
260 |
+
return np.array([]) # Not enough samples or features
|
261 |
+
|
262 |
+
dist_matrix = cosine_distances(embeddings_matrix)
|
263 |
+
iu = np.triu_indices(dist_matrix.shape[0], k=1) # Upper triangle, excluding diagonal
|
264 |
+
return dist_matrix[iu]
|
265 |
+
|
266 |
+
|
267 |
+
def analyze_space_distance_preservation(
|
268 |
+
df: pd.DataFrame,
|
269 |
+
embedding_clm: str = 'style_embedding',
|
270 |
+
cluster_label_clm: str = 'cluster_label'
|
271 |
+
) -> float | None:
|
272 |
+
"""
|
273 |
+
Analyzes how well a new space, defined by cluster centroids, preserves
|
274 |
+
the cosine distance relationships from the original embedding space.
|
275 |
+
|
276 |
+
Args:
|
277 |
+
df (pd.DataFrame): DataFrame with original embeddings and cluster labels.
|
278 |
+
embedding_clm (str): Column name for original embeddings.
|
279 |
+
cluster_label_clm (str): Column name for cluster labels.
|
280 |
+
|
281 |
+
Returns:
|
282 |
+
float | None: Pearson correlation coefficient. Returns None if analysis
|
283 |
+
cannot be performed (e.g., <2 clusters, <2 items), or 0.0
|
284 |
+
if correlation is NaN (e.g. due to zero variance in distances).
|
285 |
+
"""
|
286 |
+
df_valid_items = df[df[cluster_label_clm] != -1].copy()
|
287 |
+
|
288 |
+
if df_valid_items.shape[0] < 2:
|
289 |
+
return None # Need at least 2 items for pairwise distances
|
290 |
+
|
291 |
+
original_embeddings_matrix = _safe_embeddings_to_matrix(df_valid_items[embedding_clm])
|
292 |
+
if original_embeddings_matrix.ndim != 2 or original_embeddings_matrix.shape[0] < 2 or \
|
293 |
+
original_embeddings_matrix.shape[1] == 0:
|
294 |
+
return None # Valid matrix from original embeddings could not be formed
|
295 |
+
|
296 |
+
centroids = _compute_cluster_centroids(df_valid_items, embedding_clm, cluster_label_clm)
|
297 |
+
if len(centroids) < 2: # Need at least 2 centroids for a multi-dimensional new space
|
298 |
+
return None
|
299 |
+
|
300 |
+
projected_embeddings_matrix = _project_to_centroid_space(original_embeddings_matrix, centroids)
|
301 |
+
if projected_embeddings_matrix.ndim != 2 or projected_embeddings_matrix.shape[0] < 2 or \
|
302 |
+
projected_embeddings_matrix.shape[1] < 2: # New space needs at least 2 dimensions (centroids)
|
303 |
+
return None
|
304 |
+
|
305 |
+
distances_original_space = _get_pairwise_cosine_distances(original_embeddings_matrix)
|
306 |
+
distances_new_space = _get_pairwise_cosine_distances(projected_embeddings_matrix)
|
307 |
+
|
308 |
+
if distances_original_space.size == 0 or distances_new_space.size == 0 or \
|
309 |
+
distances_original_space.size != distances_new_space.size:
|
310 |
+
return None # Mismatch or empty distances
|
311 |
+
|
312 |
+
# Handle cases where variance is zero in one of the distance arrays (leads to NaN correlation)
|
313 |
+
if np.all(distances_new_space == distances_new_space[0]) or \
|
314 |
+
np.all(distances_original_space == distances_original_space[0]):
|
315 |
+
return 0.0 # Correlation is undefined or 0 if one variable is constant
|
316 |
+
|
317 |
+
try:
|
318 |
+
correlation, _ = pearsonr(distances_original_space, distances_new_space)
|
319 |
+
except ValueError: # Should be caught by variance checks, but as a safeguard
|
320 |
+
return None
|
321 |
+
|
322 |
+
if np.isnan(correlation):
|
323 |
+
return 0.0 # Default for NaN correlation
|
324 |
+
|
325 |
+
return correlation
|
utils/file_download.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import urllib.request
|
4 |
+
import zipfile
|
5 |
+
import tempfile
|
6 |
+
from tqdm import tqdm
|
7 |
+
from urllib.parse import urlparse
|
8 |
+
|
9 |
+
class TqdmUpTo(tqdm):
|
10 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
11 |
+
if tsize is not None:
|
12 |
+
self.total = tsize
|
13 |
+
self.update(b * bsize - self.n)
|
14 |
+
|
15 |
+
def download_file_override(url: str, dest_path: str):
|
16 |
+
"""
|
17 |
+
Download a file from a URL and always overwrite the target.
|
18 |
+
If it's a zip, extract its contents directly into dest_path (no extra folder level).
|
19 |
+
If it's not a zip, save it directly to dest_path.
|
20 |
+
"""
|
21 |
+
|
22 |
+
# Ensure parent dir for files
|
23 |
+
dest_dir = dest_path if dest_path.endswith(('/', '\\')) else os.path.dirname(dest_path)
|
24 |
+
if dest_dir:
|
25 |
+
os.makedirs(dest_dir, exist_ok=True)
|
26 |
+
|
27 |
+
# Temp file for download
|
28 |
+
tmp_fd, tmp_path = tempfile.mkstemp()
|
29 |
+
os.close(tmp_fd)
|
30 |
+
|
31 |
+
filename = os.path.basename(urlparse(url).path) or "downloaded.file"
|
32 |
+
print(f"Downloading {filename}...")
|
33 |
+
|
34 |
+
try:
|
35 |
+
with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=filename) as t:
|
36 |
+
urllib.request.urlretrieve(url, filename=tmp_path, reporthook=t.update_to)
|
37 |
+
|
38 |
+
if zipfile.is_zipfile(tmp_path):
|
39 |
+
# Remove dest_path if exists
|
40 |
+
if os.path.exists(dest_path):
|
41 |
+
shutil.rmtree(dest_path)
|
42 |
+
os.makedirs(dest_path, exist_ok=True)
|
43 |
+
|
44 |
+
# Extract into temp folder first
|
45 |
+
with tempfile.TemporaryDirectory() as tmp_extract_dir:
|
46 |
+
with zipfile.ZipFile(tmp_path, 'r') as z:
|
47 |
+
z.extractall(tmp_extract_dir)
|
48 |
+
|
49 |
+
# Move *contents* of extracted folder into dest_path
|
50 |
+
for item in os.listdir(tmp_extract_dir):
|
51 |
+
src = os.path.join(tmp_extract_dir, item)
|
52 |
+
dst = os.path.join(dest_path, item)
|
53 |
+
if os.path.isdir(src):
|
54 |
+
shutil.move(src, dst)
|
55 |
+
else:
|
56 |
+
shutil.move(src, dst)
|
57 |
+
|
58 |
+
print(f"Extracted zip contents into '{dest_path}'.")
|
59 |
+
else:
|
60 |
+
# Ensure parent dir exists
|
61 |
+
os.makedirs(os.path.dirname(dest_path) or ".", exist_ok=True)
|
62 |
+
if os.path.exists(dest_path):
|
63 |
+
os.remove(dest_path)
|
64 |
+
shutil.move(tmp_path, dest_path)
|
65 |
+
tmp_path = None
|
66 |
+
print(f"Saved file to '{dest_path}'.")
|
67 |
+
|
68 |
+
finally:
|
69 |
+
if tmp_path and os.path.exists(tmp_path):
|
70 |
+
os.remove(tmp_path)
|
utils/generate_augmented_mapping.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import re
|
3 |
+
|
4 |
+
def load_original_map_and_extract_morph(path="human_readable.txt"):
|
5 |
+
human_to_code = {}
|
6 |
+
morph_entries = []
|
7 |
+
|
8 |
+
with open(path, "r", encoding="utf-8") as f:
|
9 |
+
for line in f:
|
10 |
+
line = line.strip()
|
11 |
+
if not line or ":" not in line or line.startswith("#"):
|
12 |
+
continue
|
13 |
+
key, val = [p.strip() for p in line.split(":", 1)]
|
14 |
+
|
15 |
+
# If key looks like Aspect=Perf, it's a morphological tag
|
16 |
+
if "=" in key:
|
17 |
+
morph_entries.append((val, key)) # human:code
|
18 |
+
else:
|
19 |
+
human_to_code[val] = key # human:code for POS/etc.
|
20 |
+
|
21 |
+
return human_to_code, morph_entries
|
22 |
+
|
23 |
+
def extract_bigrams_from_csv(csv_path="../datasets/gram2vec_feats.csv"):
|
24 |
+
bigrams = set()
|
25 |
+
with open(csv_path, "r", encoding="utf-8") as f:
|
26 |
+
reader = csv.DictReader(f)
|
27 |
+
for row in reader:
|
28 |
+
feat = row["gram2vec_feats"]
|
29 |
+
if feat.startswith("Part-of-Speech Bigram:"):
|
30 |
+
human_bigram = feat.split(":", 1)[1].strip()
|
31 |
+
if "followed by" in human_bigram:
|
32 |
+
bigrams.add(human_bigram)
|
33 |
+
return bigrams
|
34 |
+
|
35 |
+
def generate_bigram_code_map(human_to_code, bigrams):
|
36 |
+
pattern = re.compile(r"(.+?) followed by (.+)")
|
37 |
+
code_map = {}
|
38 |
+
|
39 |
+
for bigram in bigrams:
|
40 |
+
match = pattern.match(bigram)
|
41 |
+
if match:
|
42 |
+
x = match.group(1).strip()
|
43 |
+
y = match.group(2).strip()
|
44 |
+
code_x = human_to_code.get(x)
|
45 |
+
code_y = human_to_code.get(y)
|
46 |
+
if code_x and code_y:
|
47 |
+
code_map[bigram] = f"{code_x} {code_y}"
|
48 |
+
else:
|
49 |
+
print(f"Could not map: {bigram} β {code_x}, {code_y}")
|
50 |
+
else:
|
51 |
+
print(f"Not matched: {bigram}")
|
52 |
+
return code_map
|
53 |
+
|
54 |
+
def write_augmented_map(pos_bigram_map, morph_entries, original_path="human_readable.txt", output_path="augmented_human_readable.txt"):
|
55 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
56 |
+
# Flip original lines: write human-readable:code instead of code:human
|
57 |
+
with open(original_path, "r", encoding="utf-8") as orig:
|
58 |
+
for line in orig:
|
59 |
+
line = line.strip()
|
60 |
+
if not line or line.startswith("#"):
|
61 |
+
f.write(line + "\n")
|
62 |
+
continue
|
63 |
+
if ":" not in line:
|
64 |
+
continue
|
65 |
+
key, val = [p.strip() for p in line.split(":", 1)]
|
66 |
+
flipped_line = f"{val}:{key}\n"
|
67 |
+
f.write(flipped_line)
|
68 |
+
|
69 |
+
|
70 |
+
# Add new section for POS bigrams
|
71 |
+
f.write("\n")
|
72 |
+
for human, code in sorted(pos_bigram_map.items()):
|
73 |
+
f.write(f"{human}:{code}\n")
|
74 |
+
|
75 |
+
# Re-add morph tag mappings
|
76 |
+
f.write("\n")
|
77 |
+
for human, code in sorted(morph_entries):
|
78 |
+
f.write(f"{human}:{code}\n")
|
79 |
+
|
80 |
+
print(f"Augmented map written to {output_path}")
|
81 |
+
|
82 |
+
# Run all
|
83 |
+
human_to_code, morph_entries = load_original_map_and_extract_morph()
|
84 |
+
bigrams = extract_bigrams_from_csv()
|
85 |
+
pos_bigram_map = generate_bigram_code_map(human_to_code, bigrams)
|
86 |
+
write_augmented_map(pos_bigram_map, morph_entries)
|
utils/gram2vec_feat_utils.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import html
|
3 |
+
|
4 |
+
from collections import namedtuple
|
5 |
+
from gram2vec.feature_locator import find_feature_spans
|
6 |
+
from functools import lru_cache
|
7 |
+
|
8 |
+
from utils.llm_feat_utils import generate_feature_spans_cached
|
9 |
+
import pandas as pd
|
10 |
+
Span = namedtuple('Span', ['start_char', 'end_char'])
|
11 |
+
|
12 |
+
from gram2vec import vectorizer
|
13 |
+
|
14 |
+
# ββ the FEATURE_HANDLERS & loader ββββββββββββ
|
15 |
+
FEATURE_HANDLERS = {
|
16 |
+
"Part-of-Speech Unigram": "pos_unigrams",
|
17 |
+
"Part-of-Speech Bigram": "pos_bigrams",
|
18 |
+
"Function Word": "func_words",
|
19 |
+
"Punctuation": "punctuation",
|
20 |
+
"Letter": "letters",
|
21 |
+
"Dependency Label": "dep_labels",
|
22 |
+
"Morphology Tag": "morph_tags",
|
23 |
+
"Sentence Type": "sentences",
|
24 |
+
"Emoji": "emojis",
|
25 |
+
"Number of Tokens": "num_tokens"
|
26 |
+
}
|
27 |
+
|
28 |
+
@lru_cache(maxsize=1)
|
29 |
+
def load_code_map(txt_path: str = "utils/augmented_human_readable.txt") -> dict:
|
30 |
+
code_map = {}
|
31 |
+
with open(txt_path, "r", encoding="utf-8") as f:
|
32 |
+
for line in f:
|
33 |
+
line = line.strip()
|
34 |
+
if not line:
|
35 |
+
continue
|
36 |
+
human, code = [p.strip() for p in line.split(":", 1)]
|
37 |
+
code_map[human] = code
|
38 |
+
return code_map
|
39 |
+
|
40 |
+
def get_shorthand(feature_str: str) -> str:
|
41 |
+
"""
|
42 |
+
Expects 'Category:Human-Readable', returns e.g. 'pos_unigrams:ADJ' or None.
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
category, human = [p.strip() for p in feature_str.split(":", 1)]
|
46 |
+
# print(f"Category: {category}, Human: {human}")
|
47 |
+
except ValueError:
|
48 |
+
# print("Invalid format for feature string:", feature_str)
|
49 |
+
return None
|
50 |
+
if category not in FEATURE_HANDLERS:
|
51 |
+
return None
|
52 |
+
code = load_code_map().get(human)
|
53 |
+
if code is None:
|
54 |
+
# print(f"Warning: No code found for human-readable feature '{human}'")
|
55 |
+
return None # fallback to the human-readable name
|
56 |
+
return f"{FEATURE_HANDLERS[category]}:{code}"
|
57 |
+
|
58 |
+
def get_fullform(shorthand: str) -> str:
|
59 |
+
"""
|
60 |
+
Expects 'prefix:code' (e.g., 'pos_unigrams:ADJ'), returns 'Category:Human-Readable'
|
61 |
+
(e.g., 'Part-of-Speech Unigram:Adjective'), or None if invalid.
|
62 |
+
"""
|
63 |
+
try:
|
64 |
+
prefix, code = shorthand.split(":", 1)
|
65 |
+
except ValueError:
|
66 |
+
return None
|
67 |
+
|
68 |
+
# Reverse FEATURE_HANDLERS
|
69 |
+
reverse_handlers = {v: k for k, v in FEATURE_HANDLERS.items()}
|
70 |
+
category = reverse_handlers.get(prefix)
|
71 |
+
if category is None:
|
72 |
+
return None
|
73 |
+
|
74 |
+
# Reverse code map
|
75 |
+
code_map = load_code_map()
|
76 |
+
reverse_code_map = {v: k for k, v in code_map.items()}
|
77 |
+
human = reverse_code_map.get(code)
|
78 |
+
if human is None:
|
79 |
+
return None
|
80 |
+
|
81 |
+
return f"{category}:{human}"
|
82 |
+
|
83 |
+
def highlight_both_spans(text, llm_spans, gram_spans):
|
84 |
+
"""
|
85 |
+
Walk the original `text` once, injecting <mark> tags at the correct offsets,
|
86 |
+
so that nested or overlapping highlights never stomp on each other.
|
87 |
+
"""
|
88 |
+
|
89 |
+
# Inline CSS : mark-llm is in yellow, mark-gram in blue
|
90 |
+
style = """
|
91 |
+
<style>
|
92 |
+
.mark-llm { background-color: #fff176; }
|
93 |
+
.mark-gram { background-color: #90caf9; }
|
94 |
+
</style>
|
95 |
+
"""
|
96 |
+
|
97 |
+
# Turn each span into two βeventsβ: open and close
|
98 |
+
events = []
|
99 |
+
for s in llm_spans:
|
100 |
+
events.append((s.start_char, 'open', 'llm'))
|
101 |
+
events.append((s.end_char, 'close', 'llm'))
|
102 |
+
for s in gram_spans:
|
103 |
+
events.append((s.start_char, 'open', 'gram'))
|
104 |
+
events.append((s.end_char, 'close', 'gram'))
|
105 |
+
|
106 |
+
# Sort by position;
|
107 |
+
events.sort(key=lambda e: (e[0], 0 if e[1]=='open' else 1))
|
108 |
+
|
109 |
+
out = []
|
110 |
+
last_idx = 0
|
111 |
+
for idx, typ, cls in events:
|
112 |
+
# escape the slice between last index and this event
|
113 |
+
out.append(html.escape(text[last_idx:idx]))
|
114 |
+
if typ == 'open':
|
115 |
+
out.append(f'<mark class="mark-{cls}">')
|
116 |
+
else:
|
117 |
+
out.append('</mark>')
|
118 |
+
last_idx = idx
|
119 |
+
|
120 |
+
out.append(html.escape(text[last_idx:]))
|
121 |
+
highlighted = "".join(out)
|
122 |
+
|
123 |
+
highlighted = highlighted.replace('\n', '<br>')
|
124 |
+
|
125 |
+
return style + highlighted
|
126 |
+
|
127 |
+
|
128 |
+
def show_combined_spans_all(selected_feature_llm, selected_feature_g2v,
|
129 |
+
llm_style_feats_analysis, background_authors_embeddings_df, task_authors_embeddings_df, visible_authors, predicted_author=None, ground_truth_author=None, max_num_authors=7):
|
130 |
+
"""
|
131 |
+
For mystery + 3 candidates:
|
132 |
+
1. get llm spans via your existing cache+API
|
133 |
+
2. get gram2vec spans via find_feature_spans
|
134 |
+
3. merge and highlight both
|
135 |
+
"""
|
136 |
+
print(f"\n\n\n\n\nShowing combined spans for LLM feature '{selected_feature_llm}' and Gram2Vec feature '{selected_feature_g2v}'")
|
137 |
+
print(f"predicted_author: {predicted_author}, ground_truth_author: {ground_truth_author}")
|
138 |
+
print(f" keys = {background_authors_embeddings_df.keys()}")
|
139 |
+
|
140 |
+
# background_and_task_authors = pd.concat([task_authors_embeddings_df, background_authors_embeddings_df])
|
141 |
+
# background_and_task_authors = background_and_task_authors[background_and_task_authors.authorID.isin(visible_authors)]
|
142 |
+
|
143 |
+
#get the visible background authors
|
144 |
+
background_authors_embeddings_df = background_authors_embeddings_df[background_authors_embeddings_df.authorID.isin(visible_authors)]
|
145 |
+
background_and_task_authors = pd.concat([task_authors_embeddings_df, background_authors_embeddings_df])
|
146 |
+
|
147 |
+
authors_texts = ['\n\n =========== \n\n'.join(x) if type(x) == list else x for x in background_and_task_authors[:max_num_authors]['fullText'].tolist()]
|
148 |
+
authors_names = background_and_task_authors[:max_num_authors]['authorID'].tolist()
|
149 |
+
print(f"Number of authors to show: {len(authors_texts)}")
|
150 |
+
print(f"Authors names: {authors_names}")
|
151 |
+
texts = list(zip(authors_names, authors_texts))
|
152 |
+
|
153 |
+
if selected_feature_llm and selected_feature_llm != "None":
|
154 |
+
# print(llm_style_feats_analysis)
|
155 |
+
author_list = list(llm_style_feats_analysis['spans'].values())
|
156 |
+
llm_spans_list = []
|
157 |
+
for i, (_, txt) in enumerate(texts):
|
158 |
+
author_spans_list = []
|
159 |
+
for txt_span in author_list[i][selected_feature_llm]:
|
160 |
+
author_spans_list.append(Span(txt.find(txt_span), txt.find(txt_span) + len(txt_span)))
|
161 |
+
llm_spans_list.append(author_spans_list)
|
162 |
+
else:
|
163 |
+
print("Skipping LLM span extraction: feature is None")
|
164 |
+
llm_spans_list = [[] for _ in texts]
|
165 |
+
|
166 |
+
if selected_feature_g2v and selected_feature_g2v != "None":
|
167 |
+
# get gram2vec spans
|
168 |
+
gram_spans_list = []
|
169 |
+
print(f"Selected Gram2Vec feature: {selected_feature_g2v}")
|
170 |
+
short = get_shorthand(selected_feature_g2v)
|
171 |
+
print(f"short hand: {short}")
|
172 |
+
for role, txt in texts:
|
173 |
+
try:
|
174 |
+
print(f"Finding spans for {short} {role}")
|
175 |
+
spans = find_feature_spans(txt, short)
|
176 |
+
# spans = [Span(fs.start_char, fs.end_char) for fs in raw_spans]
|
177 |
+
except:
|
178 |
+
print(f"Error finding spans for {short} {role}")
|
179 |
+
spans = []
|
180 |
+
gram_spans_list.append(spans)
|
181 |
+
else:
|
182 |
+
print("Skipping Gram2Vec span extraction: feature is None")
|
183 |
+
gram_spans_list = [[] for _ in texts]
|
184 |
+
|
185 |
+
# build HTML blocks
|
186 |
+
print(f" ----> Number of authors: {len(texts)}")
|
187 |
+
|
188 |
+
html_task_authors = create_html(
|
189 |
+
texts[:4], #first 4 are task
|
190 |
+
llm_spans_list,
|
191 |
+
gram_spans_list,
|
192 |
+
selected_feature_llm,
|
193 |
+
selected_feature_g2v,
|
194 |
+
short,
|
195 |
+
background = False,
|
196 |
+
predicted_author=predicted_author,
|
197 |
+
ground_truth_author=ground_truth_author
|
198 |
+
)
|
199 |
+
combined_html = "<div>" + "\n<hr>\n".join(html_task_authors) + "</div>"
|
200 |
+
|
201 |
+
html_background_authors = create_html(
|
202 |
+
texts[4:], #last three are background
|
203 |
+
llm_spans_list,
|
204 |
+
gram_spans_list,
|
205 |
+
selected_feature_llm,
|
206 |
+
selected_feature_g2v,
|
207 |
+
short,
|
208 |
+
background = True,
|
209 |
+
predicted_author=predicted_author,
|
210 |
+
ground_truth_author=ground_truth_author
|
211 |
+
)
|
212 |
+
background_html = "<div>" + "\n<hr>\n".join(html_background_authors) + "</div>"
|
213 |
+
return combined_html, background_html
|
214 |
+
|
215 |
+
def get_label(label: str, predicted_author=None, ground_truth_author=None, bg_id: int=0) -> str:
|
216 |
+
"""
|
217 |
+
Returns a human-readable label for the author.
|
218 |
+
"""
|
219 |
+
print(f"get_label called with label: {label}, predicted_author: {predicted_author}, ground_truth_author: {ground_truth_author}, bg_id: {bg_id}")
|
220 |
+
if label.startswith("Mystery") or label.startswith("Q_author"):
|
221 |
+
return "Mystery Author"
|
222 |
+
elif label.startswith("a0_author") or label.startswith("a1_author") or label.startswith("a2_author") or label.startswith("Candidate"):
|
223 |
+
if label.startswith("Candidate"):
|
224 |
+
id = int(label.split(" ")[2]) # Get the number after 'Candidate Author'
|
225 |
+
else:
|
226 |
+
id = label.split("_")[0][-1] # Get the last character of the first part (a0, a1, a2)
|
227 |
+
if predicted_author is not None and ground_truth_author is not None:
|
228 |
+
if int(id) == predicted_author and int(id) == ground_truth_author:
|
229 |
+
return f"Candidate {int(id)+1} (Predicted & Ground Truth)"
|
230 |
+
elif int(id) == predicted_author:
|
231 |
+
return f"Candidate {int(id)+1} (Predicted)"
|
232 |
+
elif int(id) == ground_truth_author:
|
233 |
+
return f"Candidate {int(id)+1} (Ground Truth)"
|
234 |
+
else:
|
235 |
+
return f"Candidate {int(id)+1}"
|
236 |
+
else:
|
237 |
+
return f"Candidate {int(id)+1}"
|
238 |
+
else:
|
239 |
+
return f"Background Author {bg_id+1}"
|
240 |
+
|
241 |
+
def create_html(texts, llm_spans_list, gram_spans_list, selected_feature_llm, selected_feature_g2v, short=None, background = False, predicted_author=None, ground_truth_author=None):
|
242 |
+
html = []
|
243 |
+
for i, (label, txt) in enumerate(texts):
|
244 |
+
label = get_label(label, predicted_author, ground_truth_author, i) if background else get_label(label, predicted_author, ground_truth_author)
|
245 |
+
combined = highlight_both_spans(txt, llm_spans_list[i], gram_spans_list[i])
|
246 |
+
notice = ""
|
247 |
+
if selected_feature_llm == "None":
|
248 |
+
notice += f"""
|
249 |
+
<div style="padding:8px; background:#eee; border:1px solid #aaa;">
|
250 |
+
<em>No LLM feature selected.</em>
|
251 |
+
</div>
|
252 |
+
"""
|
253 |
+
elif not llm_spans_list[i]:
|
254 |
+
notice += f"""
|
255 |
+
<div style="padding:8px; background:#fee; border:1px solid #f00;">
|
256 |
+
<em>No spans found for LLM feature "{selected_feature_llm}".</em>
|
257 |
+
</div>
|
258 |
+
"""
|
259 |
+
if selected_feature_g2v == "None":
|
260 |
+
notice += f"""
|
261 |
+
<div style="padding:8px; background:#eee; border:1px solid #aaa;">
|
262 |
+
<em>No Gram2Vec feature selected.</em>
|
263 |
+
</div>
|
264 |
+
"""
|
265 |
+
elif not short:
|
266 |
+
notice += f"""
|
267 |
+
<div style="padding:8px; background:#fee; border:1px solid #f00;">
|
268 |
+
<em>Invalid or unmapped feature: "{selected_feature_g2v}".</em>
|
269 |
+
</div>
|
270 |
+
"""
|
271 |
+
elif not gram_spans_list[i]:
|
272 |
+
notice += f"""
|
273 |
+
<div style="padding:8px; background:#fee; border:1px solid #f00;">
|
274 |
+
<em>No spans found for Gram2Vec feature "{selected_feature_g2v}".</em>
|
275 |
+
</div>
|
276 |
+
"""
|
277 |
+
html.append(f"""
|
278 |
+
<h3>{label}</h3>
|
279 |
+
{notice}
|
280 |
+
<div style="border:1px solid #ccc; padding:8px; margin-bottom:1em;">
|
281 |
+
{combined}
|
282 |
+
</div>
|
283 |
+
""")
|
284 |
+
return html
|
utils/human_readable.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ADJ: Adjective
|
2 |
+
ADP: Adposition
|
3 |
+
ADV: Adverb
|
4 |
+
AUX: Auxiliary verb
|
5 |
+
CCONJ: Coordinating conjunction
|
6 |
+
DET: Determiner
|
7 |
+
INTJ: Interjection
|
8 |
+
NOUN: Noun
|
9 |
+
NUM: Numeral
|
10 |
+
PART: Particle
|
11 |
+
PRON: Pronoun
|
12 |
+
PROPN: Proper noun
|
13 |
+
PUNCT: Punctuation
|
14 |
+
SCONJ: Subordinating conjunction
|
15 |
+
SYM: Symbol
|
16 |
+
VERB: Verb
|
17 |
+
X: Other
|
18 |
+
SPACE: Space
|
19 |
+
|
20 |
+
Aspect=Perf: Perfect aspect
|
21 |
+
Aspect=Prog: Progressive aspect
|
22 |
+
Case=Acc: Accusative case
|
23 |
+
Case=Nom: Nominative case
|
24 |
+
Definite=Def: Definite article
|
25 |
+
Definite=Ind: Indefinite article
|
26 |
+
Degree=Cmp: Comparative degree
|
27 |
+
Degree=Pos: Positive degree
|
28 |
+
Degree=Sup: Superlative degree
|
29 |
+
Gender=Fem: Feminine gender
|
30 |
+
Gender=Masc: Masculine gender
|
31 |
+
Mood=Ind: Indicative mood
|
32 |
+
Number=Plur: Plural number
|
33 |
+
Number=Sing: Singular number
|
34 |
+
Person=1: First person
|
35 |
+
Person=2: Second person
|
36 |
+
Person=3: Third person
|
37 |
+
Tense=Past: Past tense
|
38 |
+
Tense=Pres: Present tense
|
39 |
+
VerbForm=Fin: Finite verb form
|
40 |
+
VerbForm=Inf: Infinitive verb form
|
utils/interp_space_utils.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
from collections import Counter, defaultdict
|
7 |
+
from typing import List, Any
|
8 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
9 |
+
import os
|
10 |
+
import pickle
|
11 |
+
import hashlib
|
12 |
+
import json
|
13 |
+
from gram2vec import vectorizer
|
14 |
+
from openai import OpenAI
|
15 |
+
from openai.lib._pydantic import to_strict_json_schema
|
16 |
+
from pydantic import BaseModel
|
17 |
+
from pydantic import ValidationError
|
18 |
+
import time
|
19 |
+
from utils.llm_feat_utils import generate_feature_spans_cached
|
20 |
+
from collections import Counter
|
21 |
+
import numpy as np
|
22 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
23 |
+
|
24 |
+
CACHE_DIR = "datasets/embeddings_cache"
|
25 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
26 |
+
# Bump this whenever there is a change etc...
|
27 |
+
CACHE_VERSION = 1
|
28 |
+
|
29 |
+
class style_analysis_schema(BaseModel):
|
30 |
+
features: list[str]
|
31 |
+
spans: dict[str, dict[str, list[str]]]
|
32 |
+
|
33 |
+
class FeatureIdentificationSchema(BaseModel):
|
34 |
+
features: list[str]
|
35 |
+
|
36 |
+
class SpanExtractionSchema(BaseModel):
|
37 |
+
spans: dict[str, dict[str, list[str]]] # {author_name: {feature: [spans]}}
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def compute_g2v_features(clustered_authors_df: pd.DataFrame, task_authors_df: pd.DataFrame=None, text_clm='fullText') -> pd.DataFrame:
|
42 |
+
"""
|
43 |
+
Computes gram2vec feature vectors for each author and adds them to the DataFrame.
|
44 |
+
This effectively creates a mapping from each author to their vector.
|
45 |
+
"""
|
46 |
+
if task_authors_df is not None:
|
47 |
+
print (f"concatenating task authors and background corpus authors")
|
48 |
+
print(f"Number of task authors: {len(task_authors_df)}")
|
49 |
+
print(f"task authors author_ids: {task_authors_df.authorID.tolist()}")
|
50 |
+
print(f"task authors -->")
|
51 |
+
print(task_authors_df)
|
52 |
+
print(f"Number of background corpus authors: {len(clustered_authors_df)}")
|
53 |
+
clustered_authors_df = pd.concat([task_authors_df, clustered_authors_df])
|
54 |
+
print(f"Number of authors after concatenation: {len(clustered_authors_df)}")
|
55 |
+
|
56 |
+
# Gather the input texts (preserves list-of-strings if any)
|
57 |
+
#texts = background_corpus_df[text_clm].fillna("").tolist()
|
58 |
+
author_texts = ['\n\n'.join(x) for x in clustered_authors_df.fullText.tolist()]
|
59 |
+
|
60 |
+
print(f"Number of author_texts: {len(author_texts)}")
|
61 |
+
|
62 |
+
# Create a reproducible JSON serialization of the texts
|
63 |
+
serialized = json.dumps({
|
64 |
+
"col": text_clm,
|
65 |
+
"texts": author_texts
|
66 |
+
}, sort_keys=True, ensure_ascii=False)
|
67 |
+
|
68 |
+
# Compute MD5 hash
|
69 |
+
digest = hashlib.md5(serialized.encode("utf-8")).hexdigest()
|
70 |
+
cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
|
71 |
+
|
72 |
+
# If cache hit, load and return
|
73 |
+
if os.path.exists(cache_path):
|
74 |
+
print(f"Cache hit...")
|
75 |
+
with open(cache_path, "rb") as f:
|
76 |
+
clustered_authors_df = pickle.load(f)
|
77 |
+
|
78 |
+
else: # Else compute and cache
|
79 |
+
g2v_feats_df = vectorizer.from_documents(author_texts, batch_size=16)
|
80 |
+
|
81 |
+
print(f"Number of g2v features: {len(g2v_feats_df)}")
|
82 |
+
print(f"Number of clustered_authors_df.authorID.tolist(): {len(clustered_authors_df.authorID.tolist())}")
|
83 |
+
print(f"Number of g2v_feats_df.to_numpy().tolist(): {len(g2v_feats_df.to_numpy().tolist())}")
|
84 |
+
|
85 |
+
ids = clustered_authors_df.authorID.tolist()
|
86 |
+
counter = Counter(ids)
|
87 |
+
duplicates = [k for k, v in counter.items() if v > 1]
|
88 |
+
|
89 |
+
print(f"Duplicate authorIDs: {duplicates}")
|
90 |
+
print(f"Number of duplicates: {len(ids) - len(set(ids))}")
|
91 |
+
|
92 |
+
author_to_g2v_feats = {x[0]: x[1] for x in zip(clustered_authors_df.authorID.tolist(), g2v_feats_df.to_numpy().tolist())}
|
93 |
+
|
94 |
+
print(f"Number of authors with g2v features: {len(author_to_g2v_feats)}")
|
95 |
+
|
96 |
+
# apply normalization
|
97 |
+
vector_std = np.std(list(author_to_g2v_feats.values()), axis=0)
|
98 |
+
vector_mean = np.mean(list(author_to_g2v_feats.values()), axis=0)
|
99 |
+
vector_std[vector_std == 0] = 1.0
|
100 |
+
author_to_g2v_feats_z_normalized = {x[0]: (x[1] - vector_mean) / vector_std for x in author_to_g2v_feats.items()}
|
101 |
+
|
102 |
+
print(f"Number of authors with g2v features normalized: {len(author_to_g2v_feats_z_normalized)}")
|
103 |
+
print(f" len of clustered authors df: {len(clustered_authors_df)}")
|
104 |
+
|
105 |
+
|
106 |
+
# Add the vectors as a new column of the DataFrame.
|
107 |
+
clustered_authors_df['g2v_vector'] = [{x[1]: x[0] for x in zip(val, g2v_feats_df.columns.tolist())}
|
108 |
+
for val in author_to_g2v_feats_z_normalized.values()]
|
109 |
+
|
110 |
+
with open(cache_path, "wb") as f:
|
111 |
+
pickle.dump(clustered_authors_df, f)
|
112 |
+
|
113 |
+
if task_authors_df is not None:
|
114 |
+
task_authors_df = clustered_authors_df[clustered_authors_df.authorID.isin(task_authors_df.authorID.tolist())]
|
115 |
+
clustered_authors_df = clustered_authors_df[~clustered_authors_df.authorID.isin(task_authors_df.authorID.tolist())]
|
116 |
+
|
117 |
+
|
118 |
+
return clustered_authors_df['g2v_vector'].tolist(), task_authors_df['g2v_vector'].tolist()
|
119 |
+
|
120 |
+
|
121 |
+
def get_task_authors_from_background_df(background_df):
|
122 |
+
task_authors_df = background_df[background_df.authorID.isin(["Q_author", "a0_author", "a1_author", "a2_author"])]
|
123 |
+
return task_authors_df
|
124 |
+
|
125 |
+
def instance_to_df(instance, predicted_author=None, ground_truth_author=None):
|
126 |
+
#create a dataframe of the task authors
|
127 |
+
task_authos_df = pd.DataFrame([
|
128 |
+
{'authorID': 'Mystery author', 'fullText': instance['Q_fullText'], 'predicted': None, 'ground_truth': None},
|
129 |
+
{'authorID': 'Candidate Author 1', 'fullText': instance['a0_fullText'], 'predicted': predicted_author == 0, 'ground_truth': ground_truth_author == 0},
|
130 |
+
{'authorID': 'Candidate Author 2', 'fullText': instance['a1_fullText'], 'predicted': predicted_author == 1, 'ground_truth': ground_truth_author == 1},
|
131 |
+
{'authorID': 'Candidate Author 3', 'fullText': instance['a2_fullText'], 'predicted': predicted_author == 2, 'ground_truth': ground_truth_author == 2}
|
132 |
+
|
133 |
+
])
|
134 |
+
|
135 |
+
if type(instance['Q_fullText']) == list:
|
136 |
+
task_authos_df = task_authos_df.groupby('authorID').agg({'fullText': lambda x: list(x)}).reset_index()
|
137 |
+
|
138 |
+
return task_authos_df
|
139 |
+
|
140 |
+
|
141 |
+
def generate_style_embedding(background_corpus_df: pd.DataFrame, text_clm: str, model_name: str) -> pd.DataFrame:
|
142 |
+
"""
|
143 |
+
Generates style embeddings for documents in a background corpus using a specified model.
|
144 |
+
If a row in `text_clm` contains a list of strings, the final embedding for that row
|
145 |
+
is the average of the embeddings of all strings in the list.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
background_corpus_df (pd.DataFrame): DataFrame containing the corpus.
|
149 |
+
text_clm (str): Name of the column containing the text data (either string or list of strings).
|
150 |
+
model_name (str): Name of the model to use for generating embeddings.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
pd.DataFrame: The input DataFrame with a new column for style embeddings.
|
154 |
+
"""
|
155 |
+
from sentence_transformers import SentenceTransformer
|
156 |
+
import torch
|
157 |
+
|
158 |
+
if model_name not in [
|
159 |
+
'gabrielloiseau/LUAR-MUD-sentence-transformers',
|
160 |
+
'gabrielloiseau/LUAR-CRUD-sentence-transformers',
|
161 |
+
'miladalsh/light-luar',
|
162 |
+
'AnnaWegmann/Style-Embedding',
|
163 |
+
|
164 |
+
]:
|
165 |
+
print('Model is not supported')
|
166 |
+
return background_corpus_df
|
167 |
+
|
168 |
+
print(f"Generating style embeddings using {model_name} on column '{text_clm}'...")
|
169 |
+
|
170 |
+
model = SentenceTransformer(model_name)
|
171 |
+
embedding_dim = model.get_sentence_embedding_dimension()
|
172 |
+
|
173 |
+
# Heuristic to check if the column contains lists of strings by checking the first valid item.
|
174 |
+
# This assumes the column is homogenous.
|
175 |
+
is_list_column = False
|
176 |
+
if not background_corpus_df.empty:
|
177 |
+
# Get the first non-NaN value to inspect its type
|
178 |
+
series_no_na = background_corpus_df[text_clm].dropna()
|
179 |
+
if not series_no_na.empty:
|
180 |
+
first_valid_item = series_no_na.iloc[0]
|
181 |
+
if isinstance(first_valid_item, list):
|
182 |
+
is_list_column = True
|
183 |
+
|
184 |
+
if is_list_column:
|
185 |
+
# Flatten all texts into a single list for batch processing
|
186 |
+
texts_to_encode = []
|
187 |
+
row_lengths = []
|
188 |
+
for text_list in background_corpus_df[text_clm]:
|
189 |
+
# Ensure we handle None, empty lists, or other non-list types gracefully
|
190 |
+
if isinstance(text_list, list) and text_list:
|
191 |
+
texts_to_encode.extend(text_list)
|
192 |
+
row_lengths.append(len(text_list))
|
193 |
+
else:
|
194 |
+
row_lengths.append(0)
|
195 |
+
|
196 |
+
if texts_to_encode:
|
197 |
+
all_embeddings = model.encode(texts_to_encode, convert_to_tensor=True, show_progress_bar=True)
|
198 |
+
else:
|
199 |
+
all_embeddings = torch.empty((0, embedding_dim), device=model.device)
|
200 |
+
|
201 |
+
# Reconstruct and average embeddings for each row
|
202 |
+
final_embeddings = []
|
203 |
+
current_pos = 0
|
204 |
+
for length in row_lengths:
|
205 |
+
if length > 0:
|
206 |
+
row_embeddings = all_embeddings[current_pos:current_pos + length]
|
207 |
+
avg_embedding = torch.mean(row_embeddings, dim=0)
|
208 |
+
final_embeddings.append(avg_embedding.cpu().numpy())
|
209 |
+
current_pos += length
|
210 |
+
else:
|
211 |
+
final_embeddings.append(np.zeros(embedding_dim))
|
212 |
+
else:
|
213 |
+
# Column contains single strings
|
214 |
+
texts = background_corpus_df[text_clm].fillna("").tolist()
|
215 |
+
# convert_to_tensor=False is faster if we just need numpy arrays
|
216 |
+
embeddings = model.encode(texts, show_progress_bar=True)
|
217 |
+
final_embeddings = list(embeddings)
|
218 |
+
|
219 |
+
# Create a clean column name from the model name
|
220 |
+
col_name = f'{model_name.split("/")[-1]}_style_embedding'
|
221 |
+
background_corpus_df[col_name] = final_embeddings
|
222 |
+
|
223 |
+
return background_corpus_df
|
224 |
+
|
225 |
+
# ββ wrapper with caching βββββββββββββββββββββββββββββββββββββββ
|
226 |
+
def cached_generate_style_embedding(background_corpus_df: pd.DataFrame,
|
227 |
+
text_clm: str,
|
228 |
+
model_name: str) -> pd.DataFrame:
|
229 |
+
"""
|
230 |
+
Wraps `generate_style_embedding`, caching its output in pickle files
|
231 |
+
keyed by an MD5 of (model_name + text list). If the cache exists,
|
232 |
+
loads and returns it instead of recomputing.
|
233 |
+
"""
|
234 |
+
|
235 |
+
# Gather the input texts (preserves list-of-strings if any)
|
236 |
+
texts = background_corpus_df[text_clm].fillna("").tolist()
|
237 |
+
|
238 |
+
# Create a reproducible JSON serialization of the texts
|
239 |
+
serialized = json.dumps({
|
240 |
+
"model": model_name,
|
241 |
+
"col": text_clm,
|
242 |
+
"texts": texts
|
243 |
+
}, sort_keys=True, ensure_ascii=False)
|
244 |
+
|
245 |
+
# Compute MD5 hash
|
246 |
+
digest = hashlib.md5(serialized.encode("utf-8")).hexdigest()
|
247 |
+
cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
|
248 |
+
|
249 |
+
# If cache hit, load and return
|
250 |
+
if os.path.exists(cache_path):
|
251 |
+
print(f"Cache hit for {model_name} on column '{text_clm}'")
|
252 |
+
print(cache_path)
|
253 |
+
with open(cache_path, "rb") as f:
|
254 |
+
return pickle.load(f)
|
255 |
+
|
256 |
+
# Otherwise, compute, cache, and return
|
257 |
+
df_with_emb = generate_style_embedding(background_corpus_df, text_clm, model_name)
|
258 |
+
print(f"Computing embeddings for {model_name} on column '{text_clm}', saving to {cache_path}")
|
259 |
+
with open(cache_path, "wb") as f:
|
260 |
+
pickle.dump(df_with_emb, f)
|
261 |
+
return df_with_emb
|
262 |
+
|
263 |
+
def get_style_feats_distribution(documentIDs, style_feats_dict):
|
264 |
+
style_feats = []
|
265 |
+
for documentId in documentIDs:
|
266 |
+
if documentId not in document_to_style_feats:
|
267 |
+
#print(documentId)
|
268 |
+
continue
|
269 |
+
|
270 |
+
style_feats+= document_to_style_feats[documentId]
|
271 |
+
|
272 |
+
tfidf = [style_feats.count(key) * val for key, val in style_feats_dict.items()]
|
273 |
+
|
274 |
+
return tfidf
|
275 |
+
|
276 |
+
def get_cluster_top_feats(style_feats_distribution, style_feats_list, top_k=5):
|
277 |
+
sorted_feats = np.argsort(style_feats_distribution)[::-1]
|
278 |
+
top_feats = [style_feats_list[x] for x in sorted_feats[:top_k] if style_feats_distribution[x] > 0]
|
279 |
+
return top_feats
|
280 |
+
|
281 |
+
def compute_clusters_style_representation(
|
282 |
+
background_corpus_df: pd.DataFrame,
|
283 |
+
cluster_ids: List[Any],
|
284 |
+
other_cluster_ids: List[Any],
|
285 |
+
features_clm_name: str,
|
286 |
+
cluster_label_clm_name: str = 'cluster_label',
|
287 |
+
top_n: int = 10
|
288 |
+
) -> List[str]:
|
289 |
+
"""
|
290 |
+
Given a DataFrame with document IDs, cluster IDs, and feature lists,
|
291 |
+
return the top N features that are most important in the specified `cluster_ids`
|
292 |
+
while having low importance in `other_cluster_ids`.
|
293 |
+
Importance is determined by TF-IDF scores. The final score for a feature is
|
294 |
+
(summed TF-IDF in `cluster_ids`) - (summed TF-IDF in `other_cluster_ids`).
|
295 |
+
|
296 |
+
Parameters:
|
297 |
+
- background_corpus_df: pd.DataFrame. Must contain the columns specified by
|
298 |
+
`cluster_label_clm_name` and `features_clm_name`.
|
299 |
+
The column `features_clm_name` should contain lists of strings (features).
|
300 |
+
- cluster_ids: List of cluster IDs for which to find representative features (target clusters).
|
301 |
+
- other_cluster_ids: List of cluster IDs whose features should be down-weighted.
|
302 |
+
Features prominent in these clusters will have their scores reduced.
|
303 |
+
Pass an empty list or None if no contrastive clusters are needed.
|
304 |
+
- features_clm_name: The name of the column in `background_corpus_df` that
|
305 |
+
contains the list of features for each document.
|
306 |
+
- cluster_label_clm_name: The name of the column in `background_corpus_df`
|
307 |
+
that contains the cluster labels. Defaults to 'cluster_label'.
|
308 |
+
- top_n: Number of top features to return.
|
309 |
+
Returns:
|
310 |
+
- List[str]: A list of feature names. These are up to `top_n` features
|
311 |
+
ranked by their adjusted TF-IDF scores (score in `cluster_ids`
|
312 |
+
minus score in `other_cluster_ids`). Only features with a final
|
313 |
+
adjusted score > 0 are included.
|
314 |
+
"""
|
315 |
+
|
316 |
+
assert background_corpus_df[features_clm_name].apply(
|
317 |
+
lambda x: isinstance(x, list) and all(isinstance(feat, str) for feat in x)
|
318 |
+
).all(), f"Column '{features_clm_name}' must contain lists of strings."
|
319 |
+
|
320 |
+
# Compute TF-IDF on the entire corpus
|
321 |
+
vectorizer = TfidfVectorizer(
|
322 |
+
tokenizer=lambda x: x,
|
323 |
+
preprocessor=lambda x: x,
|
324 |
+
token_pattern=None # Disable default token pattern, treat items in list as tokens
|
325 |
+
)
|
326 |
+
tfidf_matrix = vectorizer.fit_transform(background_corpus_df[features_clm_name])
|
327 |
+
feature_names = vectorizer.get_feature_names_out()
|
328 |
+
|
329 |
+
# Get boolean mask for documents in selected clusters
|
330 |
+
selected_mask = background_corpus_df[cluster_label_clm_name].isin(cluster_ids).to_numpy()
|
331 |
+
|
332 |
+
if not selected_mask.any():
|
333 |
+
return [] # No documents found for the given cluster_ids
|
334 |
+
|
335 |
+
# Subset the TF-IDF matrix using the boolean mask
|
336 |
+
selected_tfidf = tfidf_matrix[selected_mask]
|
337 |
+
|
338 |
+
# Sum TF-IDF scores across documents for each feature in the target clusters
|
339 |
+
target_feature_scores_sum = selected_tfidf.sum(axis=0).A1 # Convert to 1D array
|
340 |
+
|
341 |
+
# Initialize adjusted scores with target scores
|
342 |
+
adjusted_feature_scores = target_feature_scores_sum.copy()
|
343 |
+
|
344 |
+
# If other_cluster_ids are provided and not empty, subtract their TF-IDF sums
|
345 |
+
if other_cluster_ids: # Checks if the list is not None and not empty
|
346 |
+
other_selected_mask = background_corpus_df[cluster_label_clm_name].isin(other_cluster_ids).to_numpy()
|
347 |
+
|
348 |
+
if other_selected_mask.any():
|
349 |
+
other_selected_tfidf = tfidf_matrix[other_selected_mask]
|
350 |
+
contrast_feature_scores_sum = other_selected_tfidf.sum(axis=0).A1
|
351 |
+
|
352 |
+
# Element-wise subtraction; assumes feature_names aligns for both sums
|
353 |
+
adjusted_feature_scores -= contrast_feature_scores_sum
|
354 |
+
|
355 |
+
# Map scores to feature names
|
356 |
+
feature_score_dict = dict(zip(feature_names, adjusted_feature_scores))
|
357 |
+
# Sort features by score
|
358 |
+
sorted_features = sorted(feature_score_dict.items(), key=lambda item: item[1], reverse=True)
|
359 |
+
|
360 |
+
# Return the names of the top_n features that have a score > 0
|
361 |
+
top_features = [feature for feature, score in sorted_features if score > 0][:top_n]
|
362 |
+
|
363 |
+
return top_features
|
364 |
+
|
365 |
+
def compute_clusters_style_representation_2(
|
366 |
+
background_corpus_df: pd.DataFrame,
|
367 |
+
cluster_ids: List[Any],
|
368 |
+
cluster_label_clm_name: str = 'cluster_label',
|
369 |
+
max_num_feats: int = 5,
|
370 |
+
max_num_documents_per_author=3,
|
371 |
+
max_num_authors=5):
|
372 |
+
"""
|
373 |
+
Call openAI to analyze the common writing style features of the given list of texts
|
374 |
+
"""
|
375 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
376 |
+
|
377 |
+
background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
|
378 |
+
background_corpus_df = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
|
379 |
+
|
380 |
+
author_texts = background_corpus_df['fullText'].tolist()[:max_num_authors]
|
381 |
+
author_texts = "\n\n".join(["""Author {}:\n""".format(i+1) + text for i, text in enumerate(author_texts)])
|
382 |
+
author_names = background_corpus_df[cluster_label_clm_name].tolist()[:max_num_authors]
|
383 |
+
print(f"Number of authors: {len(background_corpus_df)}")
|
384 |
+
print(author_names)
|
385 |
+
print(author_texts)
|
386 |
+
print(f"Number of authors: {len(author_names)}")
|
387 |
+
print(f"Number of authors: {len(author_texts)}")
|
388 |
+
|
389 |
+
prompt = f"""First identify a list of {max_num_feats} writing style features that are common between the given texts. Second for every author text and style feature, extract all spans that represent the feature. Output for every author all style features with their spans.
|
390 |
+
Author Texts:
|
391 |
+
\"\"\"{author_texts}\"\"\"
|
392 |
+
"""
|
393 |
+
|
394 |
+
# Compute MD5 hash
|
395 |
+
digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()
|
396 |
+
cache_path = os.path.join(CACHE_DIR, f"{digest}.pkl")
|
397 |
+
|
398 |
+
# If cache hit, load and return
|
399 |
+
if os.path.exists(cache_path):
|
400 |
+
print(f"Loading authors writing style from cache ...")
|
401 |
+
with open(cache_path, "rb") as f:
|
402 |
+
parsed_response = pickle.load(f)
|
403 |
+
|
404 |
+
else: # Else compute and cache
|
405 |
+
|
406 |
+
response = client.chat.completions.create(
|
407 |
+
model="gpt-4o-mini",
|
408 |
+
messages=[
|
409 |
+
{"role":"assistant","content":"You are a forensic linguistic who knows how to analyze similarites in writing styles."},
|
410 |
+
{"role":"user","content":prompt}],
|
411 |
+
response_format={"type": "json_schema", "json_schema": {"name": "style_analysis_schema", "schema": to_strict_json_schema(style_analysis_schema)}}
|
412 |
+
)
|
413 |
+
|
414 |
+
parsed_response = json.loads(response.choices[0].message.content)
|
415 |
+
|
416 |
+
with open(cache_path, "wb") as f:
|
417 |
+
pickle.dump(parsed_response, f)
|
418 |
+
|
419 |
+
return parsed_response
|
420 |
+
|
421 |
+
def identify_style_features(author_texts: list[str], max_num_feats: int = 5) -> list[str]:
|
422 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
423 |
+
prompt = f"""Identify {max_num_feats} writing style features that are commonly found across the following texts. Do not extract spans. Just return the feature names as a list.
|
424 |
+
Author Texts:
|
425 |
+
\"\"\"{chr(10).join(author_texts)}\"\"\"
|
426 |
+
"""
|
427 |
+
|
428 |
+
def _make_call():
|
429 |
+
response = client.chat.completions.create(
|
430 |
+
model="gpt-4o-mini",
|
431 |
+
messages=[
|
432 |
+
{"role": "assistant", "content": "You are a forensic linguist specializing in writing styles."},
|
433 |
+
{"role": "user", "content": prompt}
|
434 |
+
],
|
435 |
+
response_format={
|
436 |
+
"type": "json_schema",
|
437 |
+
"json_schema": {
|
438 |
+
"name": "FeatureIdentificationSchema",
|
439 |
+
"schema": to_strict_json_schema(FeatureIdentificationSchema)
|
440 |
+
}
|
441 |
+
}
|
442 |
+
)
|
443 |
+
return json.loads(response.choices[0].message.content)
|
444 |
+
|
445 |
+
return retry_call(_make_call, FeatureIdentificationSchema).features
|
446 |
+
|
447 |
+
def retry_call(call_fn, schema_class, max_attempts=3, wait_sec=2):
|
448 |
+
for attempt in range(max_attempts):
|
449 |
+
try:
|
450 |
+
result = call_fn()
|
451 |
+
# Validate against schema
|
452 |
+
validated = schema_class(**result)
|
453 |
+
return validated
|
454 |
+
except (ValidationError, KeyError, json.JSONDecodeError) as e:
|
455 |
+
print(f"Attempt {attempt + 1} failed with error: {e}")
|
456 |
+
time.sleep(wait_sec)
|
457 |
+
raise RuntimeError("All retry attempts failed for OpenAI call.")
|
458 |
+
|
459 |
+
def extract_all_spans(authors_df: pd.DataFrame, features: list[str], cluster_label_clm_name: str = 'authorID') -> dict[str, dict[str, list[str]]]:
|
460 |
+
"""
|
461 |
+
For each author, use `generate_feature_spans_cached` to get feature->span mappings.
|
462 |
+
Returns a dict: {author_name: {feature: [spans]}}
|
463 |
+
"""
|
464 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
465 |
+
|
466 |
+
spans_by_author = {}
|
467 |
+
|
468 |
+
for _, row in authors_df.iterrows():
|
469 |
+
author_name = str(row[cluster_label_clm_name])
|
470 |
+
print(author_name)
|
471 |
+
role = f"{author_name}"
|
472 |
+
full_text = row['fullText']
|
473 |
+
spans = generate_feature_spans_cached(client, full_text, features, role)
|
474 |
+
spans_by_author[author_name] = spans
|
475 |
+
|
476 |
+
return spans_by_author
|
477 |
+
|
478 |
+
def compute_clusters_style_representation_3(
|
479 |
+
background_corpus_df: pd.DataFrame,
|
480 |
+
cluster_ids: List[Any],
|
481 |
+
cluster_label_clm_name: str = 'authorID',
|
482 |
+
max_num_feats: int = 5,
|
483 |
+
max_num_documents_per_author=3,
|
484 |
+
max_num_authors=5
|
485 |
+
):
|
486 |
+
|
487 |
+
print(f"Computing style representation for visible clusters: {len(cluster_ids)}")
|
488 |
+
# STEP 1: Identify features on 5 visible authors
|
489 |
+
background_corpus_df['fullText'] = background_corpus_df['fullText'].map(lambda x: '\n\n'.join(x[:max_num_documents_per_author]) if isinstance(x, list) else x)
|
490 |
+
background_corpus_df_feat_id = background_corpus_df[background_corpus_df[cluster_label_clm_name].isin(cluster_ids)]
|
491 |
+
|
492 |
+
author_texts = background_corpus_df_feat_id['fullText'].tolist()[:max_num_authors]
|
493 |
+
author_texts = "\n\n".join(["""Author {}:\n""".format(i+1) + text for i, text in enumerate(author_texts)])
|
494 |
+
author_names = background_corpus_df_feat_id[cluster_label_clm_name].tolist()[:max_num_authors]
|
495 |
+
print(f"Number of authors: {len(background_corpus_df_feat_id)}")
|
496 |
+
print(author_names)
|
497 |
+
print(author_texts)
|
498 |
+
print(f"Number of authors: {len(author_names)}")
|
499 |
+
print(f"Number of authors: {len(author_texts)}")
|
500 |
+
features = identify_style_features(author_texts, max_num_feats=max_num_feats)
|
501 |
+
|
502 |
+
# STEP 2: Prepare author pool for span extraction
|
503 |
+
|
504 |
+
span_df = background_corpus_df.iloc[:7]
|
505 |
+
author_names = span_df[cluster_label_clm_name].tolist()[:7]
|
506 |
+
print(f"Number of authors for span detection : {len(span_df)}")
|
507 |
+
print(author_names)
|
508 |
+
spans_by_author = extract_all_spans(span_df, features, cluster_label_clm_name)
|
509 |
+
|
510 |
+
return {
|
511 |
+
"features": features,
|
512 |
+
"spans": spans_by_author
|
513 |
+
}
|
514 |
+
|
515 |
+
|
516 |
+
def compute_clusters_g2v_representation(
|
517 |
+
background_corpus_df: pd.DataFrame,
|
518 |
+
author_ids: List[Any],
|
519 |
+
other_author_ids: List[Any],
|
520 |
+
features_clm_name: str,
|
521 |
+
top_n: int = 10
|
522 |
+
) -> List[str]:
|
523 |
+
|
524 |
+
|
525 |
+
# Get boolean mask for documents in selected clusters
|
526 |
+
selected_mask = background_corpus_df['authorID'].isin(author_ids).to_numpy()
|
527 |
+
|
528 |
+
if not selected_mask.any():
|
529 |
+
return [] # No documents found for the given cluster_ids
|
530 |
+
|
531 |
+
selected_feats = background_corpus_df[selected_mask][features_clm_name].tolist()
|
532 |
+
all_g2v_feats = list(selected_feats[0].keys())
|
533 |
+
all_g2v_values = np.array([list(x.values()) for x in selected_feats]).mean(axis=0)
|
534 |
+
|
535 |
+
|
536 |
+
other_selected_feats = background_corpus_df[~selected_mask][features_clm_name].tolist()
|
537 |
+
all_g2v_other_feats = list(other_selected_feats[0].keys())
|
538 |
+
all_g2v_other_values = np.array([list(x.values()) for x in other_selected_feats]).mean(axis=0)
|
539 |
+
|
540 |
+
final_g2v_feats_values = all_g2v_values - all_g2v_other_values
|
541 |
+
|
542 |
+
|
543 |
+
top_g2v_feats = sorted(list(zip(all_g2v_feats, final_g2v_feats_values)), key=lambda x: -x[1])
|
544 |
+
print(top_g2v_feats[:top_n])
|
545 |
+
|
546 |
+
return [x[0] for x in top_g2v_feats[:top_n]]
|
547 |
+
|
548 |
+
|
549 |
+
def generate_interpretable_space_representation(interp_space_path, styles_df_path, feat_clm, output_clm, num_feats=5):
|
550 |
+
|
551 |
+
styles_df = pd.read_csv(styles_df_path)[[feat_clm, "documentID"]]
|
552 |
+
|
553 |
+
# A dictionary of style features and their IDF
|
554 |
+
style_feats_agg_df = styles_df.groupby(feat_clm).agg({'documentID': lambda x : len(list(x))}).reset_index()
|
555 |
+
style_feats_agg_df['document_freq'] = style_feats_agg_df.documentID
|
556 |
+
style_to_feats_dfreq = {x[0]: math.log(styles_df.documentID.nunique()/x[1]) for x in zip(style_feats_agg_df[feat_clm].tolist(), style_feats_agg_df.document_freq.tolist())}
|
557 |
+
|
558 |
+
# A list of style features we work with
|
559 |
+
style_feats_list = style_feats_agg_df[feat_clm].tolist()
|
560 |
+
print('Number of style feats ', len(style_feats_list))
|
561 |
+
|
562 |
+
# A list of documents and what list of style features each has
|
563 |
+
doc_style_agg_df = styles_df.groupby('documentID').agg({feat_clm: lambda x : list(x)}).reset_index()
|
564 |
+
document_to_feats_dict = {x[0]: x[1] for x in zip(doc_style_agg_df.documentID.tolist(), doc_style_agg_df[feat_clm].tolist())}
|
565 |
+
|
566 |
+
|
567 |
+
|
568 |
+
# Load the clustering information
|
569 |
+
df = pd.read_pickle(interp_space_path)
|
570 |
+
df = df[df.cluster_label != -1]
|
571 |
+
# A cluster to list of documents
|
572 |
+
clusterd_df = df.groupby('cluster_label').agg({
|
573 |
+
'documentID': lambda x: [d_id for doc_ids in x for d_id in doc_ids]
|
574 |
+
}).reset_index()
|
575 |
+
|
576 |
+
# Filter-in only documents that has a style description
|
577 |
+
clusterd_df['documentID'] = clusterd_df.documentID.apply(lambda documentIDs: [documentID for documentID in documentIDs if documentID in document_to_feats_dict])
|
578 |
+
# Map from cluster label to list of features through the document information
|
579 |
+
clusterd_df[feat_clm] = clusterd_df.documentID.apply(lambda doc_ids: [f for d_id in doc_ids for f in document_to_feats_dict[d_id]])
|
580 |
+
|
581 |
+
def compute_tfidf(row):
|
582 |
+
style_counts = Counter(row[feat_clm])
|
583 |
+
total_num_styles = sum(style_counts.values())
|
584 |
+
#print(style_counts, total_num_styles)
|
585 |
+
style_distribution = {
|
586 |
+
style: math.log(1+count) * style_to_feats_dfreq[style] if style in style_to_feats_dfreq else 0 for style, count in style_counts.items()
|
587 |
+
} #TF-IDF
|
588 |
+
|
589 |
+
return style_distribution
|
590 |
+
|
591 |
+
def create_tfidf_rep(tfidf_dist, num_feats):
|
592 |
+
style_feats = sorted(tfidf_dist.items(), key=lambda x: -x[1])
|
593 |
+
top_k_feats = [x[0] for x in style_feats[:num_feats] if str(x[0]) != 'nan']
|
594 |
+
return top_k_feats
|
595 |
+
|
596 |
+
clusterd_df[output_clm +'_dist'] = clusterd_df.apply(lambda row: compute_tfidf(row), axis=1)
|
597 |
+
clusterd_df[output_clm] = clusterd_df[output_clm +'_dist'].apply(lambda dist: create_tfidf_rep(dist, num_feats))
|
598 |
+
|
599 |
+
|
600 |
+
return clusterd_df
|
601 |
+
|
602 |
+
def compute_predicted_author(task_authors_df: pd.DataFrame, col_name: str) -> int:
|
603 |
+
"""
|
604 |
+
Computes the predicted author based on the style features.
|
605 |
+
"""
|
606 |
+
print("Computing predicted author using LUAR-MUD-style embeddings...")
|
607 |
+
|
608 |
+
# Extract LUAR embeddings from task authors dataframe
|
609 |
+
mystery_embedding = np.array(task_authors_df.iloc[0][col_name]).reshape(1, -1)
|
610 |
+
candidate_embeddings = np.array([
|
611 |
+
task_authors_df.iloc[1][col_name],
|
612 |
+
task_authors_df.iloc[2][col_name],
|
613 |
+
task_authors_df.iloc[3][col_name]
|
614 |
+
])
|
615 |
+
|
616 |
+
# Compute cosine similarities
|
617 |
+
similarities = cosine_similarity(mystery_embedding, candidate_embeddings)[0]
|
618 |
+
predicted_author = int(np.argmax(similarities))
|
619 |
+
print(f"Predicted author is Candidate {predicted_author + 1}")
|
620 |
+
|
621 |
+
return predicted_author
|
622 |
+
|
623 |
+
|
624 |
+
if __name__ == "__main__":
|
625 |
+
background_corpus = pd.read_pickle('../datasets/luar_interp_space_cluster_19/train_authors.pkl')
|
626 |
+
print(background_corpus.columns)
|
627 |
+
print(background_corpus[['authorID', 'fullText', 'cluster_label']].head())
|
628 |
+
# # Example: Find features for clusters [2,3,4] that are NOT prominent in cluster [1]
|
629 |
+
# feats = compute_clusters_style_representation(
|
630 |
+
# background_corpus_df=background_corpus,
|
631 |
+
# cluster_ids=['00005a5c-5c06-3a36-37f9-53c6422a31d8',],
|
632 |
+
# other_cluster_ids=[], # Pass the contrastive cluster IDs here
|
633 |
+
# cluster_label_clm_name='authorID',
|
634 |
+
# features_clm_name='final_attribute_name'
|
635 |
+
# )
|
636 |
+
# print(feats)
|
637 |
+
generate_style_embedding(background_corpus, 'fullText', 'AnnaWegmann/Style-Embedding')
|
638 |
+
print(background_corpus.columns)
|
utils/llm_feat_utils.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import hashlib
|
4 |
+
import time
|
5 |
+
from json import JSONDecodeError
|
6 |
+
|
7 |
+
CACHE_DIR = "datasets/feature_spans_cache"
|
8 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
#read and create the Gram2Vec feature set once
|
12 |
+
_g2v_df = pd.read_csv("datasets/gram2vec_feats.csv")
|
13 |
+
GRAM2VEC_SET = set(_g2v_df['gram2vec_feats'].unique())
|
14 |
+
MAX_ATTEMPTS = 3
|
15 |
+
WAIT_SECONDS = 2
|
16 |
+
|
17 |
+
# Bump this whenever there is a change prompt, feature space, etc...
|
18 |
+
CACHE_VERSION = 2
|
19 |
+
|
20 |
+
def _feat_hash(feature: str, text: str) -> str:
|
21 |
+
blob = json.dumps({
|
22 |
+
"version": CACHE_VERSION,
|
23 |
+
"text": text,
|
24 |
+
"features": sorted(feature)
|
25 |
+
}, sort_keys=True).encode()
|
26 |
+
return hashlib.md5(blob).hexdigest()
|
27 |
+
|
28 |
+
|
29 |
+
def generate_feature_spans(client, text: str, features: list[str]) -> str:
|
30 |
+
print("Calling OpenAI to extract spans")
|
31 |
+
"""
|
32 |
+
Call to OpenAI to extract spans. Returns a JSON string.
|
33 |
+
"""
|
34 |
+
prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
|
35 |
+
|
36 |
+
Important:
|
37 |
+
- The headers like "Document 1:" etc are NOT part of the original text β ignore them.
|
38 |
+
- For each feature, even if there is no match, return an empty list.
|
39 |
+
- Only return exact phrases from the text.
|
40 |
+
|
41 |
+
Respond in JSON format like:
|
42 |
+
{{
|
43 |
+
"feature1": ["span1", "span2"],
|
44 |
+
"feature2": [],
|
45 |
+
β¦
|
46 |
+
}}
|
47 |
+
|
48 |
+
Text:
|
49 |
+
\"\"\"{text}\"\"\"
|
50 |
+
|
51 |
+
Style Features:
|
52 |
+
{features}
|
53 |
+
"""
|
54 |
+
response = client.chat.completions.create(
|
55 |
+
model="gpt-4",
|
56 |
+
messages=[{"role":"user","content":prompt}],
|
57 |
+
temperature=0.3,
|
58 |
+
)
|
59 |
+
return response.choices[0].message.content
|
60 |
+
|
61 |
+
def generate_feature_spans_with_retries(client, text: str, features: list[str]) -> dict:
|
62 |
+
"""
|
63 |
+
Calls `generate_feature_spans` with retries on failure.
|
64 |
+
Returns the parsed JSON dict mapping feature->list[spans].
|
65 |
+
"""
|
66 |
+
for attempt in range(MAX_ATTEMPTS):
|
67 |
+
try:
|
68 |
+
response_str = generate_feature_spans(client, text, features)
|
69 |
+
result = json.loads(response_str)
|
70 |
+
return result
|
71 |
+
except (JSONDecodeError, ValueError) as e:
|
72 |
+
print(f"Attempt {attempt+1} failed: {e}")
|
73 |
+
if attempt < MAX_ATTEMPTS - 1:
|
74 |
+
wait_sec = WAIT_SECONDS * (2 ** attempt)
|
75 |
+
print(f"Retrying after {wait_sec} seconds...")
|
76 |
+
time.sleep(wait_sec)
|
77 |
+
raise RuntimeError("All retry attempts failed for OpenAI call.")
|
78 |
+
|
79 |
+
|
80 |
+
def generate_feature_spans_cached(client, text: str, features: list[str], role: str = "mystery" ) -> dict:
|
81 |
+
"""
|
82 |
+
Computes a cache key from text + feature list,
|
83 |
+
then either loads or calls the API and saves to disk.
|
84 |
+
Returns the parsed JSON dict mapping feature->list[spans].
|
85 |
+
"""
|
86 |
+
print(f"Generating spans for ({role})")
|
87 |
+
# print(f"feature list {features}")
|
88 |
+
role = role.replace(" ", "_").replace("/", "_").replace("-", "_")
|
89 |
+
print(f"Cache dir: {CACHE_DIR}")
|
90 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
91 |
+
cache_path = os.path.join(CACHE_DIR, f"{role}.json")
|
92 |
+
if os.path.exists(cache_path):
|
93 |
+
with open(cache_path) as f:
|
94 |
+
cache: dict[str, dict] = json.load(f)
|
95 |
+
else:
|
96 |
+
cache = {}
|
97 |
+
result: dict[str, list[str]] = {}
|
98 |
+
missing_feats: list[str] = []
|
99 |
+
|
100 |
+
for feat in features:
|
101 |
+
if feat == "None":
|
102 |
+
result[feat] = []
|
103 |
+
continue
|
104 |
+
|
105 |
+
h = _feat_hash(feat, text)
|
106 |
+
if h in cache:
|
107 |
+
result[feat] = cache[h]["spans"]
|
108 |
+
else:
|
109 |
+
missing_feats.append(feat)
|
110 |
+
|
111 |
+
if missing_feats:
|
112 |
+
|
113 |
+
mapping = generate_feature_spans_with_retries(client, text, missing_feats)
|
114 |
+
# 4) update cache & result for each missing feature
|
115 |
+
for feat in missing_feats:
|
116 |
+
h = _feat_hash(feat, text)
|
117 |
+
spans = mapping.get(feat)
|
118 |
+
cache[h] = {
|
119 |
+
"feature": feat,
|
120 |
+
"spans": spans
|
121 |
+
}
|
122 |
+
result[feat] = spans
|
123 |
+
|
124 |
+
# 5) write back the combined cache
|
125 |
+
with open(cache_path, "w") as f:
|
126 |
+
json.dump(cache, f, indent=2)
|
127 |
+
return result
|
128 |
+
|
129 |
+
|
130 |
+
def split_features(all_feats):
|
131 |
+
"""
|
132 |
+
Given a list of mixed features, returns two lists:
|
133 |
+
- llm_feats: those NOT in the Gram2Vec CSV
|
134 |
+
- g2v_feats: those present in the CSV
|
135 |
+
"""
|
136 |
+
g2v_feats = [feat for feat in all_feats if feat in GRAM2VEC_SET]
|
137 |
+
llm_feats = [feat for feat in all_feats if feat not in GRAM2VEC_SET]
|
138 |
+
return llm_feats, g2v_feats
|
utils/ui.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
from utils.visualizations import load_instance, get_instances, clean_text
|
4 |
+
from utils.interp_space_utils import cached_generate_style_embedding, instance_to_df, compute_g2v_features, compute_predicted_author
|
5 |
+
|
6 |
+
|
7 |
+
# ββ Global CSS to be prepended to every block βββββββββββββββββββββββββββββββββ
|
8 |
+
GLOBAL_CSS = """
|
9 |
+
<style>
|
10 |
+
/* Bold only the topβlevel field labels (not every label) */
|
11 |
+
.gradio-container .input_label {
|
12 |
+
font-weight: 600 !important;
|
13 |
+
font-size: 1.1em !important;
|
14 |
+
}
|
15 |
+
|
16 |
+
/* Reset radioβoption labels to normal weight/size */
|
17 |
+
.gradio-container .radio-container .radio-option-label {
|
18 |
+
font-weight: normal !important;
|
19 |
+
font-size: 1em !important;
|
20 |
+
}
|
21 |
+
/* Give HTML output blocks a stronger border and padding */
|
22 |
+
.gradio-container .output-html {
|
23 |
+
border: 2px solid #888 !important;
|
24 |
+
border-radius: 4px !important;
|
25 |
+
padding: 0.5em !important;
|
26 |
+
margin-bottom: 1em !important;
|
27 |
+
font-size: 1em !important;
|
28 |
+
line-height: 1.4 !important;
|
29 |
+
}
|
30 |
+
</style>
|
31 |
+
"""
|
32 |
+
|
33 |
+
def styled_block(content: str) -> str:
|
34 |
+
"""
|
35 |
+
Injects GLOBAL_CSS before the provided content.
|
36 |
+
Returns a single HTML blob safe to pass into gr.HTML().
|
37 |
+
"""
|
38 |
+
return GLOBAL_CSS + "\n" + content
|
39 |
+
|
40 |
+
def styled_html(html_content: str) -> str:
|
41 |
+
"""
|
42 |
+
Wraps raw HTML content with global CSS. Pass the result to gr.HTML().
|
43 |
+
"""
|
44 |
+
return styled_block(html_content)
|
45 |
+
|
46 |
+
def instruction_callout(text: str) -> str:
|
47 |
+
"""
|
48 |
+
Returns a full HTML string (with global CSS) rendering `text`
|
49 |
+
as a bold, full-width callout box.
|
50 |
+
|
51 |
+
Usage:
|
52 |
+
gr.HTML(instruction_callout(
|
53 |
+
"Run visualization to see which author cluster contains the mystery document."
|
54 |
+
))
|
55 |
+
"""
|
56 |
+
callout = f"""
|
57 |
+
<div style="
|
58 |
+
background: #e3f2fd; /* light blue background */
|
59 |
+
border-left: 5px solid #2196f3; /* bold accent stripe */
|
60 |
+
padding: 12px 16px;
|
61 |
+
margin-bottom: 12px;
|
62 |
+
font-weight: 600;
|
63 |
+
font-size: 1.1em;
|
64 |
+
">
|
65 |
+
{text}
|
66 |
+
</div>
|
67 |
+
"""
|
68 |
+
return styled_html(callout)
|
69 |
+
|
70 |
+
def read_txt(f):
|
71 |
+
if not f:
|
72 |
+
return ""
|
73 |
+
path = f.name if hasattr(f, 'name') else f
|
74 |
+
try:
|
75 |
+
with open(path, 'r', encoding='utf-8') as fh:
|
76 |
+
return fh.read().strip()
|
77 |
+
except Exception:
|
78 |
+
return "(Could not read file)"
|
79 |
+
|
80 |
+
# Toggle which input UI is visible
|
81 |
+
def toggle_task(mode):
|
82 |
+
print(mode)
|
83 |
+
return (
|
84 |
+
gr.update(visible=(mode == "Predefined HRS Task")),
|
85 |
+
gr.update(visible=(mode == "Upload Your Own Task"))
|
86 |
+
)
|
87 |
+
|
88 |
+
# Update displayed texts based on mode
|
89 |
+
def update_task_display(mode, iid, instances, background_df, mystery_file, cand1_file, cand2_file, cand3_file, true_author, model_radio, custom_model_input):
|
90 |
+
model_name = model_radio if model_radio != "Other" else custom_model_input
|
91 |
+
if mode == "Predefined HRS Task":
|
92 |
+
iid = int(iid.replace('Task ', ''))
|
93 |
+
data = instances[iid]
|
94 |
+
predicted_author = data['latent_rank'][0]
|
95 |
+
ground_truth_author = data['gt_idx']
|
96 |
+
mystery_txt = data['Q_fullText']
|
97 |
+
c1_txt = data['a0_fullText']
|
98 |
+
c2_txt = data['a1_fullText']
|
99 |
+
c3_txt = data['a2_fullText']
|
100 |
+
candidate_texts = [c1_txt, c2_txt, c3_txt]
|
101 |
+
|
102 |
+
#create a dataframe of the task authors
|
103 |
+
task_authors_df = instance_to_df(instances[iid])
|
104 |
+
print(f"\n\n\n ----> Loaded task {iid} with {len(task_authors_df)} authors\n\n\n")
|
105 |
+
print(task_authors_df)
|
106 |
+
else:
|
107 |
+
header_html = "<h3>Custom Uploaded Task</h3>"
|
108 |
+
mystery_txt = read_txt(mystery_file)
|
109 |
+
c1_txt = read_txt(cand1_file)
|
110 |
+
c2_txt = read_txt(cand2_file)
|
111 |
+
c3_txt = read_txt(cand3_file)
|
112 |
+
candidate_texts = [c1_txt, c2_txt, c3_txt]
|
113 |
+
ground_truth_author = true_author
|
114 |
+
print(f"Ground truth author: {ground_truth_author} ; {true_author}")
|
115 |
+
custom_task_instance = {
|
116 |
+
'Q_fullText': mystery_txt,
|
117 |
+
'a0_fullText': c1_txt,
|
118 |
+
'a1_fullText': c2_txt,
|
119 |
+
'a2_fullText': c3_txt
|
120 |
+
}
|
121 |
+
task_authors_df = instance_to_df(custom_task_instance)
|
122 |
+
print(task_authors_df)
|
123 |
+
|
124 |
+
print(f"Generating embeddings for {model_name} on task authors")
|
125 |
+
task_authors_df = cached_generate_style_embedding(task_authors_df, 'fullText', model_name)
|
126 |
+
print("Task authors after embedding generation:")
|
127 |
+
print(task_authors_df)
|
128 |
+
# Generate the new embedding of all the background_df authors
|
129 |
+
print(f"Generating embeddings for {model_name} on background corpus")
|
130 |
+
background_df = cached_generate_style_embedding(background_df, 'fullText', model_name)
|
131 |
+
print(f"Generated embeddings for {len(background_df)} texts using model '{model_name}'")
|
132 |
+
|
133 |
+
# computing g2v features
|
134 |
+
print("Generating g2v features for on background corpus")
|
135 |
+
background_g2v, task_authors_g2v = compute_g2v_features(background_df, task_authors_df)
|
136 |
+
background_df['g2v_vector'] = background_g2v
|
137 |
+
task_authors_df['g2v_vector'] = task_authors_g2v
|
138 |
+
print(f"Gram2Vec feature generation complete")
|
139 |
+
|
140 |
+
print(background_df.columns)
|
141 |
+
|
142 |
+
# Computing predicted author by checking pairwise cosine similarity over luar embeddings
|
143 |
+
col_name = f'{model_name.split("/")[-1]}_style_embedding'
|
144 |
+
predicted_author = compute_predicted_author(task_authors_df, col_name)
|
145 |
+
|
146 |
+
#generating html for the task
|
147 |
+
header_html, mystery_html, candidate_htmls = task_HTML(mystery_txt, candidate_texts, predicted_author, ground_truth_author)
|
148 |
+
|
149 |
+
return [
|
150 |
+
header_html,
|
151 |
+
mystery_html,
|
152 |
+
candidate_htmls[0],
|
153 |
+
candidate_htmls[1],
|
154 |
+
candidate_htmls[2],
|
155 |
+
mystery_txt,
|
156 |
+
c1_txt,
|
157 |
+
c2_txt,
|
158 |
+
c3_txt,
|
159 |
+
task_authors_df,
|
160 |
+
background_df,
|
161 |
+
predicted_author,
|
162 |
+
ground_truth_author
|
163 |
+
]
|
164 |
+
|
165 |
+
def task_HTML(mystery_text, candidate_texts, predicted_author, ground_truth_author):
|
166 |
+
header_html = f"""
|
167 |
+
<div style="border:1px solid #ccc; padding:10px; margin-bottom:10px;">
|
168 |
+
<h3>Hereβs the mystery passage alongside three candidate textsβlook for the green highlight to see the predicted author.</h3>
|
169 |
+
</div>
|
170 |
+
"""
|
171 |
+
# mystery_text = clean_text(mystery_text)
|
172 |
+
mystery_html = f"""
|
173 |
+
<div style="
|
174 |
+
border: 2px solid #ff5722; /* accent border */
|
175 |
+
background: #fff3e0; /* very light matching wash */
|
176 |
+
border-radius: 6px;
|
177 |
+
padding: 1em;
|
178 |
+
margin-bottom: 1em;
|
179 |
+
">
|
180 |
+
<h3 style="margin-top:0; color:#bf360c;">Mystery Author</h3>
|
181 |
+
<p>{clean_text(mystery_text)}</p>
|
182 |
+
</div>
|
183 |
+
"""
|
184 |
+
|
185 |
+
print(f"Predicted author: {predicted_author}, Ground truth author: {ground_truth_author}")
|
186 |
+
|
187 |
+
# Candidate boxes
|
188 |
+
candidate_htmls = []
|
189 |
+
for i in range(3):
|
190 |
+
text = candidate_texts[i]
|
191 |
+
title = f"Candidate {i+1}"
|
192 |
+
extra_style = ""
|
193 |
+
|
194 |
+
if ground_truth_author == i:
|
195 |
+
if ground_truth_author != predicted_author: # highlight the true author only if its different than the predictd one
|
196 |
+
title += " (True Author)"
|
197 |
+
extra_style = (
|
198 |
+
"border: 2px solid #ff5722; "
|
199 |
+
"background: #fff3e0; "
|
200 |
+
"padding:10px; "
|
201 |
+
)
|
202 |
+
|
203 |
+
|
204 |
+
if predicted_author == i:
|
205 |
+
if predicted_author == ground_truth_author:
|
206 |
+
title += " (Predicted and True Author)"
|
207 |
+
else:
|
208 |
+
title += " (Predicted Author)"
|
209 |
+
extra_style = (
|
210 |
+
"border:2px solid #228B22; " # dark green border
|
211 |
+
"background-color: #e6ffe6; " # light green fill
|
212 |
+
"padding:10px; "
|
213 |
+
)
|
214 |
+
|
215 |
+
|
216 |
+
candidate_htmls.append(f"""
|
217 |
+
<div style="border:1px solid #ccc; padding:10px; {extra_style}">
|
218 |
+
<h4>{title}</h4>
|
219 |
+
<p>{clean_text(text)}</p>
|
220 |
+
</div>
|
221 |
+
""")
|
222 |
+
return header_html, mystery_html, candidate_htmls
|
223 |
+
|
224 |
+
def toggle_custom_model(choice):
|
225 |
+
return gr.update(visible=(choice == "Other"))
|
utils/visualizations.py
ADDED
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.manifold import TSNE
|
5 |
+
import pickle as pkl
|
6 |
+
import os
|
7 |
+
import hashlib
|
8 |
+
import pandas as pd
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
from plotly.colors import sample_colorscale
|
11 |
+
from gradio import update
|
12 |
+
import re
|
13 |
+
from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation
|
14 |
+
from utils.llm_feat_utils import split_features
|
15 |
+
from utils.gram2vec_feat_utils import get_shorthand, get_fullform
|
16 |
+
|
17 |
+
import plotly.io as pio
|
18 |
+
|
19 |
+
def clean_text(text: str) -> str:
|
20 |
+
"""
|
21 |
+
Cleans the text by replacing HTML tags with their escaped versions.
|
22 |
+
"""
|
23 |
+
return text.replace('<','<').replace('>','>').replace('\n', '<br>')
|
24 |
+
|
25 |
+
def get_instances(instances_to_explain_path: str = 'datasets/instances_to_explain.json'):
|
26 |
+
"""
|
27 |
+
Loads the JSON and returns:
|
28 |
+
- instances_to_explain: the raw dict/list of instances
|
29 |
+
- instance_ids: list of keys (if dict) or indices (if list)
|
30 |
+
"""
|
31 |
+
instances_to_explain = json.load(open(instances_to_explain_path))
|
32 |
+
if isinstance(instances_to_explain, dict):
|
33 |
+
instance_ids = list(instances_to_explain.keys())
|
34 |
+
else:
|
35 |
+
instance_ids = list(range(len(instances_to_explain)))
|
36 |
+
return instances_to_explain, instance_ids
|
37 |
+
|
38 |
+
def load_instance(instance_id, instances_to_explain: dict):
|
39 |
+
"""
|
40 |
+
Given a selected instance_id and the loaded data,
|
41 |
+
returns (mystery_html, c0_html, c1_html, c2_html).
|
42 |
+
"""
|
43 |
+
# normalize instance_id
|
44 |
+
try:
|
45 |
+
iid = int(instance_id)
|
46 |
+
except ValueError:
|
47 |
+
iid = instance_id
|
48 |
+
data = instances_to_explain[iid]
|
49 |
+
|
50 |
+
predicted_author = data['latent_rank'][0]
|
51 |
+
ground_truth_author = data['gt_idx']
|
52 |
+
|
53 |
+
header_html = f"""
|
54 |
+
<div style="border:1px solid #ccc; padding:10px; margin-bottom:10px;">
|
55 |
+
<h3>Hereβs the mystery passage alongside three candidate textsβlook for the green highlight to see the predicted author.</h3>
|
56 |
+
</div>
|
57 |
+
"""
|
58 |
+
mystery_text = clean_text(data['Q_fullText'])
|
59 |
+
mystery_html = f"""
|
60 |
+
<div style="
|
61 |
+
border: 2px solid #ff5722; /* accent border */
|
62 |
+
background: #fff3e0; /* very light matching wash */
|
63 |
+
border-radius: 6px;
|
64 |
+
padding: 1em;
|
65 |
+
margin-bottom: 1em;
|
66 |
+
">
|
67 |
+
<h3 style="margin-top:0; color:#bf360c;">Mystery Author</h3>
|
68 |
+
<p>{clean_text(mystery_text)}</p>
|
69 |
+
</div>
|
70 |
+
"""
|
71 |
+
|
72 |
+
# Candidate boxes
|
73 |
+
candidate_htmls = []
|
74 |
+
for i in range(3):
|
75 |
+
text = data[f'a{i}_fullText']
|
76 |
+
title = f"Candidate {i+1}"
|
77 |
+
extra_style = ""
|
78 |
+
|
79 |
+
if ground_truth_author == i:
|
80 |
+
if ground_truth_author != predicted_author: # highlight the true author only if its different than the predictd one
|
81 |
+
title += " (True Author)"
|
82 |
+
extra_style = (
|
83 |
+
"border: 2px solid #ff5722; "
|
84 |
+
"background: #fff3e0; "
|
85 |
+
"padding:10px; "
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
if predicted_author == i:
|
90 |
+
if predicted_author == ground_truth_author:
|
91 |
+
title += " (Predicted and True Author)"
|
92 |
+
else:
|
93 |
+
title += " (Predicted Author)"
|
94 |
+
extra_style = (
|
95 |
+
"border:2px solid #228B22; " # dark green border
|
96 |
+
"background-color: #e6ffe6; " # light green fill
|
97 |
+
"padding:10px; "
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
candidate_htmls.append(f"""
|
102 |
+
<div style="border:1px solid #ccc; padding:10px; {extra_style}">
|
103 |
+
<h4>{title}</h4>
|
104 |
+
<p>{clean_text(text)}</p>
|
105 |
+
</div>
|
106 |
+
""")
|
107 |
+
|
108 |
+
return header_html, mystery_html, candidate_htmls[0], candidate_htmls[1], candidate_htmls[2]
|
109 |
+
|
110 |
+
def compute_tsne_with_cache(embeddings: np.ndarray, cache_path: str = 'datasets/tsne_cache.pkl') -> np.ndarray:
|
111 |
+
"""
|
112 |
+
Compute t-SNE with caching to avoid recomputation for the same input.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
embeddings (np.ndarray): The input embeddings to compute t-SNE on.
|
116 |
+
cache_path (str): Path to the cache file.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
np.ndarray: The t-SNE transformed embeddings.
|
120 |
+
"""
|
121 |
+
# Create a hash of the input embeddings to use as a key
|
122 |
+
hash_key = hashlib.md5(embeddings.tobytes()).hexdigest()
|
123 |
+
|
124 |
+
if os.path.exists(cache_path):
|
125 |
+
with open(cache_path, 'rb') as f:
|
126 |
+
cache = pkl.load(f)
|
127 |
+
else:
|
128 |
+
cache = {}
|
129 |
+
|
130 |
+
if hash_key in cache:
|
131 |
+
return cache[hash_key]
|
132 |
+
else:
|
133 |
+
print("Computing t-SNE")
|
134 |
+
tsne_result = TSNE(n_components=2, learning_rate='auto',
|
135 |
+
init='random', perplexity=3).fit_transform(embeddings)
|
136 |
+
cache[hash_key] = tsne_result
|
137 |
+
with open(cache_path, 'wb') as f:
|
138 |
+
pkl.dump(cache, f)
|
139 |
+
return tsne_result
|
140 |
+
|
141 |
+
def load_interp_space(cfg):
|
142 |
+
interp_space_path = cfg['interp_space_path'] + 'interpretable_space.pkl'
|
143 |
+
interp_space_rep_path = cfg['interp_space_path'] + 'interpretable_space_representations.json'
|
144 |
+
gram2vec_feats_path = cfg['interp_space_path'] + '/../gram2vec_feats.csv'
|
145 |
+
clustered_authors_path = cfg['interp_space_path'] + 'train_authors.pkl'
|
146 |
+
|
147 |
+
# Load authors embeddings and their cluster labels
|
148 |
+
clustered_authors_df = pd.read_pickle(clustered_authors_path)
|
149 |
+
clustered_authors_df = clustered_authors_df[clustered_authors_df.cluster_label != -1]
|
150 |
+
author_embedding = clustered_authors_df.author_embedding.tolist()
|
151 |
+
author_labels = clustered_authors_df.cluster_label.tolist()
|
152 |
+
author_ids = clustered_authors_df.authorID.tolist()
|
153 |
+
|
154 |
+
# filter out gram2vec features that doesn't have representation
|
155 |
+
clustered_authors_df['gram2vec_feats'] = clustered_authors_df.gram2vec_feats.apply(lambda feats: [feat for feat in feats if get_shorthand(feat) is not None])
|
156 |
+
|
157 |
+
# Load a list of gram2vec features --> we use it to distinguish the cluster representations whether they come from gram2vec or llms
|
158 |
+
gram2vec_df = pd.read_csv(gram2vec_feats_path)
|
159 |
+
gram2vec_feats = gram2vec_df.gram2vec_feats.unique().tolist()
|
160 |
+
|
161 |
+
# Load interpretable space embeddings and the representation of each dimension
|
162 |
+
interpretable_space = pkl.load(open(interp_space_path, 'rb'))
|
163 |
+
del interpretable_space[-1] #DBSCAN generate a cluster -1 of all outliers. We don't want this cluster
|
164 |
+
dimension_to_latent = {key: interpretable_space[key][0] for key in interpretable_space}
|
165 |
+
|
166 |
+
interpretable_space_rep_df = pd.read_json(interp_space_rep_path)
|
167 |
+
#dimension_to_style = {x[0]: x[1] for x in zip(interpretable_space_rep_df.cluster_label.tolist(), interpretable_space_rep_df[style_feat_clm].tolist())}
|
168 |
+
dimension_to_style = {x[0]: [feat[0] for feat in sorted(x[1].items(), key=lambda feat_w:-feat_w[1])] for x in zip(interpretable_space_rep_df.cluster_label.tolist(), interpretable_space_rep_df[cfg['style_feat_clm']].tolist())}
|
169 |
+
|
170 |
+
if cfg['only_llm_feats']:
|
171 |
+
#print('only llm feats')
|
172 |
+
dimension_to_style = {dim[0]:[feat for feat in dim[1] if feat not in gram2vec_feats] for dim in dimension_to_style.items()}
|
173 |
+
|
174 |
+
if cfg['only_gram2vec_feats']:
|
175 |
+
#print('only gra2vec feats')
|
176 |
+
dimension_to_style = {dim[0]:[feat for feat in dim[1] if feat in gram2vec_feats] for dim in dimension_to_style.items()}
|
177 |
+
|
178 |
+
# Take top features from g2v and llm
|
179 |
+
def take_to_k_llm_and_g2v_feats(feats_list, top_k):
|
180 |
+
g2v_feats = [x for x in feats_list if x in gram2vec_feats][:top_k]
|
181 |
+
llm_feats = [x for x in feats_list if x not in gram2vec_feats][:top_k]
|
182 |
+
return g2v_feats + llm_feats
|
183 |
+
dimension_to_style = {dim[0]: take_to_k_llm_and_g2v_feats(dim[1], cfg['top_k']) for dim in dimension_to_style.items()}
|
184 |
+
|
185 |
+
|
186 |
+
return {
|
187 |
+
'dimension_to_latent': dimension_to_latent,
|
188 |
+
'dimension_to_style' : dimension_to_style,
|
189 |
+
'author_embedding' : author_embedding,
|
190 |
+
'author_labels' : author_labels,
|
191 |
+
'author_ids' : author_ids,
|
192 |
+
'clustered_authors_df' : clustered_authors_df
|
193 |
+
|
194 |
+
}
|
195 |
+
|
196 |
+
#function to handle zoom events
|
197 |
+
def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df):
|
198 |
+
"""
|
199 |
+
event_json β stringified JSON from JS listener
|
200 |
+
bg_proj β (N,2) numpy array with 2D coordinates
|
201 |
+
bg_lbls β list of N author IDs
|
202 |
+
clustered_authors_df β pd.DataFrame containing authorID and final_attribute_name
|
203 |
+
"""
|
204 |
+
print("[INFO] Handling zoom event")
|
205 |
+
|
206 |
+
if not event_json:
|
207 |
+
return gr.update(value=""), gr.update(value=""), None, None, None
|
208 |
+
|
209 |
+
try:
|
210 |
+
ranges = json.loads(event_json)
|
211 |
+
(x_min, x_max) = ranges["xaxis"]
|
212 |
+
(y_min, y_max) = ranges["yaxis"]
|
213 |
+
except (json.JSONDecodeError, KeyError, ValueError):
|
214 |
+
return gr.update(value=""), gr.update(value=""), None, None, None
|
215 |
+
|
216 |
+
# Find points within the zoomed region
|
217 |
+
mask = (
|
218 |
+
(bg_proj[:, 0] >= x_min) & (bg_proj[:, 0] <= x_max) &
|
219 |
+
(bg_proj[:, 1] >= y_min) & (bg_proj[:, 1] <= y_max)
|
220 |
+
)
|
221 |
+
|
222 |
+
visible_authors = [lbl for lbl, keep in zip(bg_lbls, mask) if keep]
|
223 |
+
|
224 |
+
print(f"[INFO] Zoomed region includes {len(visible_authors)} authors:{visible_authors}")
|
225 |
+
|
226 |
+
# Example: Find features for clusters [2,3,4] that are NOT prominent in cluster [1]
|
227 |
+
# llm_feats = compute_clusters_style_representation(
|
228 |
+
# background_corpus_df=clustered_authors_df,
|
229 |
+
# cluster_ids=visible_authors,
|
230 |
+
# cluster_label_clm_name='authorID',
|
231 |
+
# other_cluster_ids=[],
|
232 |
+
# features_clm_name='final_attribute_name_manually_processed'
|
233 |
+
# )
|
234 |
+
print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}")
|
235 |
+
merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
|
236 |
+
print(f"Merged authors DataFrame:\n{len(merged_authors_df)}")
|
237 |
+
style_analysis_response = compute_clusters_style_representation_3(
|
238 |
+
background_corpus_df=merged_authors_df,
|
239 |
+
cluster_ids=visible_authors,
|
240 |
+
cluster_label_clm_name='authorID',
|
241 |
+
)
|
242 |
+
|
243 |
+
llm_feats = ['None'] + style_analysis_response['features']
|
244 |
+
|
245 |
+
|
246 |
+
merged_authors_df = pd.concat([task_authors_df, clustered_authors_df])
|
247 |
+
g2v_feats = compute_clusters_g2v_representation(
|
248 |
+
background_corpus_df=merged_authors_df,
|
249 |
+
author_ids=visible_authors,
|
250 |
+
other_author_ids=[],
|
251 |
+
features_clm_name='g2v_vector'
|
252 |
+
)
|
253 |
+
|
254 |
+
# Gram2vec features are already in shorthand. convert to human readable for display
|
255 |
+
HR_g2v_list = []
|
256 |
+
for feat in g2v_feats:
|
257 |
+
HR_g2v = get_fullform(feat)
|
258 |
+
print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}")
|
259 |
+
if HR_g2v is None:
|
260 |
+
print(f"Skipping Gram2Vec feature without human readable form: {feat}")
|
261 |
+
else:
|
262 |
+
HR_g2v_list.append(HR_g2v)
|
263 |
+
|
264 |
+
HR_g2v_list = ["None"] + HR_g2v_list
|
265 |
+
|
266 |
+
print(f"[INFO] Found {len(llm_feats)} LLM features and {len(g2v_feats)} Gram2Vec features in the zoomed region.")
|
267 |
+
print(f"[INFO] unfiltered g2v features: {g2v_feats}")
|
268 |
+
|
269 |
+
print(f"[INFO] LLM features: {llm_feats}")
|
270 |
+
print(f"[INFO] Gram2Vec features: {HR_g2v_list}")
|
271 |
+
|
272 |
+
return (
|
273 |
+
gr.update(choices=llm_feats, value=llm_feats[0]),
|
274 |
+
gr.update(choices=HR_g2v_list, value=HR_g2v_list[0]),
|
275 |
+
style_analysis_response,
|
276 |
+
llm_feats,
|
277 |
+
visible_authors
|
278 |
+
)
|
279 |
+
# return gr.update(value="\n".join(llm_feats).join("\n").join(g2v_feats)), llm_feats, g2v_feats
|
280 |
+
|
281 |
+
def handle_zoom_with_retries(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df):
|
282 |
+
"""
|
283 |
+
event_json β stringified JSON from JS listener
|
284 |
+
bg_proj β (N,2) numpy array with 2D coordinates
|
285 |
+
bg_lbls β list of N author IDs
|
286 |
+
clustered_authors_df β pd.DataFrame containing authorID and final_attribute_name
|
287 |
+
task_authors_df β pd.DataFrame containing authorID and final_attribute_name
|
288 |
+
"""
|
289 |
+
print("[INFO] Handling zoom event with retries")
|
290 |
+
|
291 |
+
for attempt in range(3):
|
292 |
+
try:
|
293 |
+
return handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df)
|
294 |
+
except Exception as e:
|
295 |
+
print(f"[ERROR] Attempt {attempt + 1} failed: {e}")
|
296 |
+
if attempt < 2:
|
297 |
+
print("[INFO] Retrying...")
|
298 |
+
return (
|
299 |
+
None,
|
300 |
+
None,
|
301 |
+
None,
|
302 |
+
None,
|
303 |
+
None
|
304 |
+
)
|
305 |
+
|
306 |
+
|
307 |
+
def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_input, task_authors_df, background_authors_embeddings_df, pred_idx=None, gt_idx=None):
|
308 |
+
model_name = model_radio if model_radio != "Other" else custom_model_input
|
309 |
+
embedding_col_name = f'{model_name.split("/")[-1]}_style_embedding'
|
310 |
+
print(background_authors_embeddings_df.columns)
|
311 |
+
print("Generating cluster visualization")
|
312 |
+
iid = int(iid)
|
313 |
+
interp = load_interp_space(cfg)
|
314 |
+
# dim2lat = interp['dimension_to_latent']
|
315 |
+
style_names = interp['dimension_to_style']
|
316 |
+
# bg_emb = np.array(interp['author_embedding'])
|
317 |
+
# print(f"bg_emb shape: {bg_emb.shape}")
|
318 |
+
#replace with cached embedddings
|
319 |
+
bg_emb = np.array(background_authors_embeddings_df[embedding_col_name].tolist()) #placeholder for background embeddings
|
320 |
+
print(f"bg_emb shape: {bg_emb.shape}")
|
321 |
+
# print("interp.keys():", interp.keys())
|
322 |
+
#bg_lbls = interp['author_labels']
|
323 |
+
#bg_ids = interp['author_ids']
|
324 |
+
bg_ids = task_authors_df['authorID'].tolist() + background_authors_embeddings_df['authorID'].tolist()
|
325 |
+
# inst = instances[iid]
|
326 |
+
# print("inst.keys():", inst.keys())
|
327 |
+
# q_lat = np.array(inst['author_latents'][:1])
|
328 |
+
# print(f"q_lat shape: {q_lat.shape}")
|
329 |
+
# c_lat = np.array(inst['author_latents'][1:])
|
330 |
+
# print(f"c_lat shape: {c_lat.shape}")
|
331 |
+
# pred_idx = inst['latent_rank'][0]
|
332 |
+
# gt_idx = inst['gt_idx']
|
333 |
+
q_lat = np.array(task_authors_df[embedding_col_name].iloc[0]).reshape(1, -1) # Mystery author latent
|
334 |
+
print(f"q_lat shape: {q_lat.shape}")
|
335 |
+
c_lat = np.array(task_authors_df[embedding_col_name].iloc[1:].tolist()) # Candidate authors latents
|
336 |
+
print(f"c_lat shape: {c_lat.shape}")
|
337 |
+
|
338 |
+
# cent_emb = np.array([v for _,v in dim2lat.items()])
|
339 |
+
# cent_lbl = np.array([k for k,_ in dim2lat.items()])
|
340 |
+
|
341 |
+
# all_emb = np.vstack([q_lat, c_lat, bg_emb, cent_emb])
|
342 |
+
all_emb = np.vstack([q_lat, c_lat, bg_emb])
|
343 |
+
proj = compute_tsne_with_cache(all_emb)
|
344 |
+
|
345 |
+
# split
|
346 |
+
q_proj = proj[0]
|
347 |
+
c_proj = proj[1:4]
|
348 |
+
#bg_proj = proj[4:4+len(bg_lbls)]
|
349 |
+
bg_proj = proj
|
350 |
+
|
351 |
+
# cent_proj = proj[4+len(bg_lbls):]
|
352 |
+
|
353 |
+
|
354 |
+
# find nearest centroid
|
355 |
+
# dists = np.linalg.norm(cent_proj - q_proj, axis=1)
|
356 |
+
# idx = int(np.argmin(dists))
|
357 |
+
# cluster_label_query = cent_lbl[idx]
|
358 |
+
# features of the nearest centroid to display
|
359 |
+
# feature_list = style_names[cluster_label_query]
|
360 |
+
|
361 |
+
# cluster_labels_per_candidate = [
|
362 |
+
# cent_lbl[int(np.argmin(np.linalg.norm(cent_proj - c_proj[i], axis=1)))]
|
363 |
+
# for i in range(c_proj.shape[0])
|
364 |
+
# ]
|
365 |
+
|
366 |
+
# prepare colorscale
|
367 |
+
# n_cent = len(cent_lbl)
|
368 |
+
# cent_colors = sample_colorscale("algae", [i/(n_cent-1) for i in range(n_cent)])
|
369 |
+
# map each cluster label to its color
|
370 |
+
# color_map = { label: cent_colors[i] for i, label in enumerate(cent_lbl) }
|
371 |
+
|
372 |
+
# uncomment the following line to show background authors
|
373 |
+
## background author colors pulled from their cluster label
|
374 |
+
# bg_colors = [ color_map[label] for label in bg_lbls ]
|
375 |
+
|
376 |
+
# 2) build Plotly figure
|
377 |
+
fig = go.Figure()
|
378 |
+
|
379 |
+
fig.update_layout(
|
380 |
+
template='plotly_white',
|
381 |
+
margin=dict(l=40,r=40,t=60,b=40),
|
382 |
+
autosize=True,
|
383 |
+
hovermode='closest',
|
384 |
+
# Enable zoom events
|
385 |
+
dragmode='zoom'
|
386 |
+
)
|
387 |
+
|
388 |
+
# fig.update_layout(
|
389 |
+
# template='plotly_white',
|
390 |
+
# margin=dict(l=40,r=40,t=60,b=40),
|
391 |
+
# autosize=True,
|
392 |
+
# hovermode='closest')
|
393 |
+
|
394 |
+
|
395 |
+
# uncomment the following line to show background authors
|
396 |
+
## background authors (light grey dots)
|
397 |
+
fig.add_trace(go.Scattergl(
|
398 |
+
x=bg_proj[:,0], y=bg_proj[:,1],
|
399 |
+
mode='markers',
|
400 |
+
marker=dict(size=6, color="#d3d3d3"),# color=bg_colors
|
401 |
+
name='Background authors',
|
402 |
+
hoverinfo='skip'
|
403 |
+
))
|
404 |
+
|
405 |
+
# centroids (rainbow colors + hovertext of your top-k features)
|
406 |
+
# hover_texts = [
|
407 |
+
# f"Cluster {lbl}<br>" + "<br>".join(style_names[lbl])
|
408 |
+
# for lbl in cent_lbl
|
409 |
+
# ]
|
410 |
+
# fig.add_trace(go.Scattergl(
|
411 |
+
# x=cent_proj[:,0], y=cent_proj[:,1],
|
412 |
+
# mode='markers',
|
413 |
+
# marker=dict(symbol='triangle-up', size=10, color="#d3d3d3"),#color=cent_colors
|
414 |
+
# name='Cluster centroids',
|
415 |
+
# hovertext=hover_texts,
|
416 |
+
# hoverinfo='text'
|
417 |
+
# ))
|
418 |
+
|
419 |
+
# three candidates
|
420 |
+
marker_syms = ['diamond','pentagon','x']
|
421 |
+
for i in range(3):
|
422 |
+
# label = f"Candidate {i+1}" + (" (predicted)" if i==pred_idx else "")
|
423 |
+
base = f"Candidate {i+1}"
|
424 |
+
# pick the right suffix
|
425 |
+
if i == pred_idx and i == gt_idx:
|
426 |
+
suffix = " (Predicted & Ground Truth)"
|
427 |
+
elif i == pred_idx:
|
428 |
+
suffix = " (Predicted)"
|
429 |
+
elif i == gt_idx:
|
430 |
+
suffix = "(Ground Truth)"
|
431 |
+
else:
|
432 |
+
suffix = ""
|
433 |
+
|
434 |
+
label = base + suffix
|
435 |
+
fig.add_trace(go.Scattergl(
|
436 |
+
x=[c_proj[i,0]], y=[c_proj[i,1]],
|
437 |
+
mode='markers',
|
438 |
+
marker=dict(symbol=marker_syms[i], size=12, color='darkblue'),
|
439 |
+
name=label,
|
440 |
+
hoverinfo='skip'
|
441 |
+
))
|
442 |
+
|
443 |
+
# query author
|
444 |
+
fig.add_trace(go.Scattergl(
|
445 |
+
x=[q_proj[0]], y=[q_proj[1]],
|
446 |
+
mode='markers',
|
447 |
+
marker=dict(symbol='star', size=14, color='red'),
|
448 |
+
name='Mystery author',
|
449 |
+
hoverinfo='skip'
|
450 |
+
))
|
451 |
+
|
452 |
+
# ββ Arrowed annotations for mystery + candidates ββββββββββββββββββββββββββ
|
453 |
+
# Mystery author (red star)
|
454 |
+
fig.add_annotation(
|
455 |
+
x=q_proj[0], y=q_proj[1],
|
456 |
+
xref='x', yref='y',
|
457 |
+
text="Mystery",
|
458 |
+
showarrow=True,
|
459 |
+
arrowhead=2,
|
460 |
+
arrowsize=1,
|
461 |
+
arrowwidth=1.5,
|
462 |
+
ax=40, # tail offset in pixels: moves the label 40px to the right
|
463 |
+
ay=-40, # moves the label 40px up
|
464 |
+
font=dict(color='red', size=12)
|
465 |
+
)
|
466 |
+
|
467 |
+
# Candidate authors (dark blue β)
|
468 |
+
offsets = [(-40, -30), (40, -30), (0, 40)] # [(ax,ay) for Cand1, Cand2, Cand3]
|
469 |
+
for i in range(3):
|
470 |
+
# build the right label
|
471 |
+
if i == pred_idx and i == gt_idx:
|
472 |
+
label = f"Candidate {i+1} (Predicted & Ground Truth)"
|
473 |
+
elif i == pred_idx:
|
474 |
+
label = f"Candidate {i+1} (Predicted)"
|
475 |
+
elif i == gt_idx:
|
476 |
+
label = f"Candidate {i+1} (Ground Truth)"
|
477 |
+
else:
|
478 |
+
label = f"Candidate {i+1}"
|
479 |
+
|
480 |
+
fig.add_annotation(
|
481 |
+
x=c_proj[i,0], y=c_proj[i,1],
|
482 |
+
xref='x', yref='y',
|
483 |
+
text= label,
|
484 |
+
showarrow=True,
|
485 |
+
arrowhead=2,
|
486 |
+
arrowsize=1,
|
487 |
+
arrowwidth=1.5,
|
488 |
+
ax=offsets[i][0],
|
489 |
+
ay=offsets[i][1],
|
490 |
+
font=dict(color='darkblue', size=12)
|
491 |
+
)
|
492 |
+
|
493 |
+
print('Done processing....')
|
494 |
+
# Prepare outputs for the new clusterβdropdown UI
|
495 |
+
# all_clusters = sorted(style_names.keys())
|
496 |
+
# --- build display names for the dropdown ---
|
497 |
+
# sorted_labels = sorted([int(lbl) for lbl in cent_lbl])
|
498 |
+
# display_clusters = []
|
499 |
+
# for lbl in sorted_labels:
|
500 |
+
# name = f"Cluster {lbl}"
|
501 |
+
# if lbl == cluster_label_query:
|
502 |
+
# name += " (closest to mystery author)"
|
503 |
+
# matching_indices = [i + 1 for i, val in enumerate(cluster_labels_per_candidate) if int(val) == lbl]
|
504 |
+
# if matching_indices:
|
505 |
+
# if len(matching_indices) == 1:
|
506 |
+
# name += f" (closest to Candidate {matching_indices[0]} author)"
|
507 |
+
# else:
|
508 |
+
# candidate_str = ", ".join(f"Candidate {i}" for i in matching_indices)
|
509 |
+
# name += f" (closest to {candidate_str} authors)"
|
510 |
+
# display_clusters.append(name)
|
511 |
+
# print(f"All clusters: {all_clusters}")
|
512 |
+
# return: figure, dropdown payload, full style_map
|
513 |
+
return (
|
514 |
+
fig,
|
515 |
+
# update(choices=display_clusters, value=display_clusters[cluster_label_query]),
|
516 |
+
style_names,
|
517 |
+
bg_proj, # Return background points
|
518 |
+
bg_ids, # Return background labels
|
519 |
+
background_authors_embeddings_df, # Return the DataFrame for zoom handling
|
520 |
+
|
521 |
+
)
|
522 |
+
# return fig, update(choices=feature_list, value=feature_list[0]),feature_list
|
523 |
+
|
524 |
+
|
525 |
+
def extract_cluster_key(display_label: str) -> int:
|
526 |
+
"""
|
527 |
+
Given a dropdown label like
|
528 |
+
"Cluster 5 (closest to mystery author; closest to Candidate 1 author)"
|
529 |
+
returns the integer 5.
|
530 |
+
"""
|
531 |
+
m = re.match(r"Cluster\s+(\d+)", display_label)
|
532 |
+
if not m:
|
533 |
+
raise ValueError(f"Unrecognized cluster label: {display_label}")
|
534 |
+
return int(m.group(1))
|
535 |
+
|
536 |
+
|
537 |
+
|
538 |
+
# When a cluster is selected, split features and populate radio buttons
|
539 |
+
def on_cluster_change(selected_cluster, style_map):
|
540 |
+
cluster_key = extract_cluster_key(selected_cluster)
|
541 |
+
all_feats = style_map[cluster_key]
|
542 |
+
llm_feats, g2v_feats = split_features(all_feats)
|
543 |
+
# print(f"Selected cluster: {selected_cluster} ({cluster_key})")
|
544 |
+
# print(f"LLM features: {llm_feats}")
|
545 |
+
|
546 |
+
# Add "None" as a default selectable option
|
547 |
+
llm_feats = ["None"] + llm_feats
|
548 |
+
|
549 |
+
# filter out any g2v feature without a shorthand
|
550 |
+
filtered_g2v = []
|
551 |
+
for feat in g2v_feats:
|
552 |
+
if get_shorthand(feat) is None:
|
553 |
+
print(f"Skipping Gram2Vec feature without shorthand: {feat}")
|
554 |
+
else:
|
555 |
+
filtered_g2v.append(feat)
|
556 |
+
|
557 |
+
# Add "None" as a default selectable option
|
558 |
+
filtered_g2v = ["None"] + filtered_g2v
|
559 |
+
|
560 |
+
return (
|
561 |
+
gr.update(choices=llm_feats, value=llm_feats[0]),
|
562 |
+
gr.update(choices=filtered_g2v, value=filtered_g2v[0]),
|
563 |
+
llm_feats
|
564 |
+
)
|