import gradio as gr import json import numpy as np from sklearn.manifold import TSNE import pickle as pkl import os import hashlib import pandas as pd import plotly.graph_objects as go from plotly.colors import sample_colorscale from gradio import update import re from utils.interp_space_utils import compute_clusters_style_representation_3, compute_clusters_g2v_representation from utils.llm_feat_utils import split_features from utils.gram2vec_feat_utils import get_shorthand, get_fullform import plotly.io as pio def clean_text(text: str) -> str: """ Cleans the text by replacing HTML tags with their escaped versions. """ return text.replace('<','<').replace('>','>').replace('\n', '
') def get_instances(instances_to_explain_path: str = 'datasets/instances_to_explain.json'): """ Loads the JSON and returns: - instances_to_explain: the raw dict/list of instances - instance_ids: list of keys (if dict) or indices (if list) """ instances_to_explain = json.load(open(instances_to_explain_path)) if isinstance(instances_to_explain, dict): instance_ids = list(instances_to_explain.keys()) else: instance_ids = list(range(len(instances_to_explain))) return instances_to_explain, instance_ids def load_instance(instance_id, instances_to_explain: dict): """ Given a selected instance_id and the loaded data, returns (mystery_html, c0_html, c1_html, c2_html). """ # normalize instance_id try: iid = int(instance_id) except ValueError: iid = instance_id data = instances_to_explain[iid] predicted_author = data['latent_rank'][0] ground_truth_author = data['gt_idx'] header_html = f"""

Here’s the mystery passage alongside three candidate texts—look for the green highlight to see the predicted author.

""" mystery_text = clean_text(data['Q_fullText']) mystery_html = f"""

Mystery Author

{clean_text(mystery_text)}

""" # Candidate boxes candidate_htmls = [] for i in range(3): text = data[f'a{i}_fullText'] title = f"Candidate {i+1}" extra_style = "" if ground_truth_author == i: if ground_truth_author != predicted_author: # highlight the true author only if its different than the predictd one title += " (True Author)" extra_style = ( "border: 2px solid #ff5722; " "background: #fff3e0; " "padding:10px; " ) if predicted_author == i: if predicted_author == ground_truth_author: title += " (Predicted and True Author)" else: title += " (Predicted Author)" extra_style = ( "border:2px solid #228B22; " # dark green border "background-color: #e6ffe6; " # light green fill "padding:10px; " ) candidate_htmls.append(f"""

{title}

{clean_text(text)}

""") return header_html, mystery_html, candidate_htmls[0], candidate_htmls[1], candidate_htmls[2] def compute_tsne_with_cache(embeddings: np.ndarray, cache_path: str = 'datasets/tsne_cache.pkl') -> np.ndarray: """ Compute t-SNE with caching to avoid recomputation for the same input. Args: embeddings (np.ndarray): The input embeddings to compute t-SNE on. cache_path (str): Path to the cache file. Returns: np.ndarray: The t-SNE transformed embeddings. """ # Create a hash of the input embeddings to use as a key hash_key = hashlib.md5(embeddings.tobytes()).hexdigest() if os.path.exists(cache_path): with open(cache_path, 'rb') as f: cache = pkl.load(f) else: cache = {} if hash_key in cache: return cache[hash_key] else: print("Computing t-SNE") tsne_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(embeddings) cache[hash_key] = tsne_result with open(cache_path, 'wb') as f: pkl.dump(cache, f) return tsne_result def load_interp_space(cfg): interp_space_path = cfg['interp_space_path'] + 'interpretable_space.pkl' interp_space_rep_path = cfg['interp_space_path'] + 'interpretable_space_representations.json' gram2vec_feats_path = cfg['interp_space_path'] + '/../gram2vec_feats.csv' clustered_authors_path = cfg['interp_space_path'] + 'train_authors.pkl' # Load authors embeddings and their cluster labels clustered_authors_df = pd.read_pickle(clustered_authors_path) clustered_authors_df = clustered_authors_df[clustered_authors_df.cluster_label != -1] author_embedding = clustered_authors_df.author_embedding.tolist() author_labels = clustered_authors_df.cluster_label.tolist() author_ids = clustered_authors_df.authorID.tolist() # filter out gram2vec features that doesn't have representation clustered_authors_df['gram2vec_feats'] = clustered_authors_df.gram2vec_feats.apply(lambda feats: [feat for feat in feats if get_shorthand(feat) is not None]) # Load a list of gram2vec features --> we use it to distinguish the cluster representations whether they come from gram2vec or llms gram2vec_df = pd.read_csv(gram2vec_feats_path) gram2vec_feats = gram2vec_df.gram2vec_feats.unique().tolist() # Load interpretable space embeddings and the representation of each dimension interpretable_space = pkl.load(open(interp_space_path, 'rb')) del interpretable_space[-1] #DBSCAN generate a cluster -1 of all outliers. We don't want this cluster dimension_to_latent = {key: interpretable_space[key][0] for key in interpretable_space} interpretable_space_rep_df = pd.read_json(interp_space_rep_path) #dimension_to_style = {x[0]: x[1] for x in zip(interpretable_space_rep_df.cluster_label.tolist(), interpretable_space_rep_df[style_feat_clm].tolist())} dimension_to_style = {x[0]: [feat[0] for feat in sorted(x[1].items(), key=lambda feat_w:-feat_w[1])] for x in zip(interpretable_space_rep_df.cluster_label.tolist(), interpretable_space_rep_df[cfg['style_feat_clm']].tolist())} if cfg['only_llm_feats']: #print('only llm feats') dimension_to_style = {dim[0]:[feat for feat in dim[1] if feat not in gram2vec_feats] for dim in dimension_to_style.items()} if cfg['only_gram2vec_feats']: #print('only gra2vec feats') dimension_to_style = {dim[0]:[feat for feat in dim[1] if feat in gram2vec_feats] for dim in dimension_to_style.items()} # Take top features from g2v and llm def take_to_k_llm_and_g2v_feats(feats_list, top_k): g2v_feats = [x for x in feats_list if x in gram2vec_feats][:top_k] llm_feats = [x for x in feats_list if x not in gram2vec_feats][:top_k] return g2v_feats + llm_feats dimension_to_style = {dim[0]: take_to_k_llm_and_g2v_feats(dim[1], cfg['top_k']) for dim in dimension_to_style.items()} return { 'dimension_to_latent': dimension_to_latent, 'dimension_to_style' : dimension_to_style, 'author_embedding' : author_embedding, 'author_labels' : author_labels, 'author_ids' : author_ids, 'clustered_authors_df' : clustered_authors_df } #function to handle zoom events def handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df): """ event_json – stringified JSON from JS listener bg_proj – (N,2) numpy array with 2D coordinates bg_lbls – list of N author IDs clustered_authors_df – pd.DataFrame containing authorID and final_attribute_name """ print("[INFO] Handling zoom event") if not event_json: return gr.update(value=""), gr.update(value=""), None, None, None try: ranges = json.loads(event_json) (x_min, x_max) = ranges["xaxis"] (y_min, y_max) = ranges["yaxis"] except (json.JSONDecodeError, KeyError, ValueError): return gr.update(value=""), gr.update(value=""), None, None, None # Find points within the zoomed region mask = ( (bg_proj[:, 0] >= x_min) & (bg_proj[:, 0] <= x_max) & (bg_proj[:, 1] >= y_min) & (bg_proj[:, 1] <= y_max) ) visible_authors = [lbl for lbl, keep in zip(bg_lbls, mask) if keep] print(f"[INFO] Zoomed region includes {len(visible_authors)} authors:{visible_authors}") # Example: Find features for clusters [2,3,4] that are NOT prominent in cluster [1] # llm_feats = compute_clusters_style_representation( # background_corpus_df=clustered_authors_df, # cluster_ids=visible_authors, # cluster_label_clm_name='authorID', # other_cluster_ids=[], # features_clm_name='final_attribute_name_manually_processed' # ) print(f"Task authors: {len(task_authors_df)}, Clustered authors: {len(clustered_authors_df)}") merged_authors_df = pd.concat([task_authors_df, clustered_authors_df]) print(f"Merged authors DataFrame:\n{len(merged_authors_df)}") style_analysis_response = compute_clusters_style_representation_3( background_corpus_df=merged_authors_df, cluster_ids=visible_authors, cluster_label_clm_name='authorID', ) llm_feats = ['None'] + style_analysis_response['features'] merged_authors_df = pd.concat([task_authors_df, clustered_authors_df]) g2v_feats = compute_clusters_g2v_representation( background_corpus_df=merged_authors_df, author_ids=visible_authors, other_author_ids=[], features_clm_name='g2v_vector' ) # Gram2vec features are already in shorthand. convert to human readable for display HR_g2v_list = [] for feat in g2v_feats: HR_g2v = get_fullform(feat) print(f"\n\n feat: {feat} ---> Human Readable: {HR_g2v}") if HR_g2v is None: print(f"Skipping Gram2Vec feature without human readable form: {feat}") else: HR_g2v_list.append(HR_g2v) HR_g2v_list = ["None"] + HR_g2v_list print(f"[INFO] Found {len(llm_feats)} LLM features and {len(g2v_feats)} Gram2Vec features in the zoomed region.") print(f"[INFO] unfiltered g2v features: {g2v_feats}") print(f"[INFO] LLM features: {llm_feats}") print(f"[INFO] Gram2Vec features: {HR_g2v_list}") return ( gr.update(choices=llm_feats, value=llm_feats[0]), gr.update(choices=HR_g2v_list, value=HR_g2v_list[0]), style_analysis_response, llm_feats, visible_authors ) # return gr.update(value="\n".join(llm_feats).join("\n").join(g2v_feats)), llm_feats, g2v_feats def handle_zoom_with_retries(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df): """ event_json – stringified JSON from JS listener bg_proj – (N,2) numpy array with 2D coordinates bg_lbls – list of N author IDs clustered_authors_df – pd.DataFrame containing authorID and final_attribute_name task_authors_df – pd.DataFrame containing authorID and final_attribute_name """ print("[INFO] Handling zoom event with retries") for attempt in range(3): try: return handle_zoom(event_json, bg_proj, bg_lbls, clustered_authors_df, task_authors_df) except Exception as e: print(f"[ERROR] Attempt {attempt + 1} failed: {e}") if attempt < 2: print("[INFO] Retrying...") return ( None, None, None, None, None ) def visualize_clusters_plotly(iid, cfg, instances, model_radio, custom_model_input, task_authors_df, background_authors_embeddings_df, pred_idx=None, gt_idx=None): model_name = model_radio if model_radio != "Other" else custom_model_input embedding_col_name = f'{model_name.split("/")[-1]}_style_embedding' print(background_authors_embeddings_df.columns) print("Generating cluster visualization") iid = int(iid) interp = load_interp_space(cfg) # dim2lat = interp['dimension_to_latent'] style_names = interp['dimension_to_style'] # bg_emb = np.array(interp['author_embedding']) # print(f"bg_emb shape: {bg_emb.shape}") #replace with cached embedddings bg_emb = np.array(background_authors_embeddings_df[embedding_col_name].tolist()) #placeholder for background embeddings print(f"bg_emb shape: {bg_emb.shape}") # print("interp.keys():", interp.keys()) #bg_lbls = interp['author_labels'] #bg_ids = interp['author_ids'] bg_ids = task_authors_df['authorID'].tolist() + background_authors_embeddings_df['authorID'].tolist() # inst = instances[iid] # print("inst.keys():", inst.keys()) # q_lat = np.array(inst['author_latents'][:1]) # print(f"q_lat shape: {q_lat.shape}") # c_lat = np.array(inst['author_latents'][1:]) # print(f"c_lat shape: {c_lat.shape}") # pred_idx = inst['latent_rank'][0] # gt_idx = inst['gt_idx'] q_lat = np.array(task_authors_df[embedding_col_name].iloc[0]).reshape(1, -1) # Mystery author latent print(f"q_lat shape: {q_lat.shape}") c_lat = np.array(task_authors_df[embedding_col_name].iloc[1:].tolist()) # Candidate authors latents print(f"c_lat shape: {c_lat.shape}") # cent_emb = np.array([v for _,v in dim2lat.items()]) # cent_lbl = np.array([k for k,_ in dim2lat.items()]) # all_emb = np.vstack([q_lat, c_lat, bg_emb, cent_emb]) all_emb = np.vstack([q_lat, c_lat, bg_emb]) proj = compute_tsne_with_cache(all_emb) # split q_proj = proj[0] c_proj = proj[1:4] #bg_proj = proj[4:4+len(bg_lbls)] bg_proj = proj # cent_proj = proj[4+len(bg_lbls):] # find nearest centroid # dists = np.linalg.norm(cent_proj - q_proj, axis=1) # idx = int(np.argmin(dists)) # cluster_label_query = cent_lbl[idx] # features of the nearest centroid to display # feature_list = style_names[cluster_label_query] # cluster_labels_per_candidate = [ # cent_lbl[int(np.argmin(np.linalg.norm(cent_proj - c_proj[i], axis=1)))] # for i in range(c_proj.shape[0]) # ] # prepare colorscale # n_cent = len(cent_lbl) # cent_colors = sample_colorscale("algae", [i/(n_cent-1) for i in range(n_cent)]) # map each cluster label to its color # color_map = { label: cent_colors[i] for i, label in enumerate(cent_lbl) } # uncomment the following line to show background authors ## background author colors pulled from their cluster label # bg_colors = [ color_map[label] for label in bg_lbls ] # 2) build Plotly figure fig = go.Figure() fig.update_layout( template='plotly_white', margin=dict(l=40,r=40,t=60,b=40), autosize=True, hovermode='closest', # Enable zoom events dragmode='zoom' ) # fig.update_layout( # template='plotly_white', # margin=dict(l=40,r=40,t=60,b=40), # autosize=True, # hovermode='closest') # uncomment the following line to show background authors ## background authors (light grey dots) fig.add_trace(go.Scattergl( x=bg_proj[:,0], y=bg_proj[:,1], mode='markers', marker=dict(size=6, color="#d3d3d3"),# color=bg_colors name='Background authors', hoverinfo='skip' )) # centroids (rainbow colors + hovertext of your top-k features) # hover_texts = [ # f"Cluster {lbl}
" + "
".join(style_names[lbl]) # for lbl in cent_lbl # ] # fig.add_trace(go.Scattergl( # x=cent_proj[:,0], y=cent_proj[:,1], # mode='markers', # marker=dict(symbol='triangle-up', size=10, color="#d3d3d3"),#color=cent_colors # name='Cluster centroids', # hovertext=hover_texts, # hoverinfo='text' # )) # three candidates marker_syms = ['diamond','pentagon','x'] for i in range(3): # label = f"Candidate {i+1}" + (" (predicted)" if i==pred_idx else "") base = f"Candidate {i+1}" # pick the right suffix if i == pred_idx and i == gt_idx: suffix = " (Predicted & Ground Truth)" elif i == pred_idx: suffix = " (Predicted)" elif i == gt_idx: suffix = "(Ground Truth)" else: suffix = "" label = base + suffix fig.add_trace(go.Scattergl( x=[c_proj[i,0]], y=[c_proj[i,1]], mode='markers', marker=dict(symbol=marker_syms[i], size=12, color='darkblue'), name=label, hoverinfo='skip' )) # query author fig.add_trace(go.Scattergl( x=[q_proj[0]], y=[q_proj[1]], mode='markers', marker=dict(symbol='star', size=14, color='red'), name='Mystery author', hoverinfo='skip' )) # ── Arrowed annotations for mystery + candidates ────────────────────────── # Mystery author (red star) fig.add_annotation( x=q_proj[0], y=q_proj[1], xref='x', yref='y', text="Mystery", showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=1.5, ax=40, # tail offset in pixels: moves the label 40px to the right ay=-40, # moves the label 40px up font=dict(color='red', size=12) ) # Candidate authors (dark blue ◆) offsets = [(-40, -30), (40, -30), (0, 40)] # [(ax,ay) for Cand1, Cand2, Cand3] for i in range(3): # build the right label if i == pred_idx and i == gt_idx: label = f"Candidate {i+1} (Predicted & Ground Truth)" elif i == pred_idx: label = f"Candidate {i+1} (Predicted)" elif i == gt_idx: label = f"Candidate {i+1} (Ground Truth)" else: label = f"Candidate {i+1}" fig.add_annotation( x=c_proj[i,0], y=c_proj[i,1], xref='x', yref='y', text= label, showarrow=True, arrowhead=2, arrowsize=1, arrowwidth=1.5, ax=offsets[i][0], ay=offsets[i][1], font=dict(color='darkblue', size=12) ) print('Done processing....') # Prepare outputs for the new cluster‐dropdown UI # all_clusters = sorted(style_names.keys()) # --- build display names for the dropdown --- # sorted_labels = sorted([int(lbl) for lbl in cent_lbl]) # display_clusters = [] # for lbl in sorted_labels: # name = f"Cluster {lbl}" # if lbl == cluster_label_query: # name += " (closest to mystery author)" # matching_indices = [i + 1 for i, val in enumerate(cluster_labels_per_candidate) if int(val) == lbl] # if matching_indices: # if len(matching_indices) == 1: # name += f" (closest to Candidate {matching_indices[0]} author)" # else: # candidate_str = ", ".join(f"Candidate {i}" for i in matching_indices) # name += f" (closest to {candidate_str} authors)" # display_clusters.append(name) # print(f"All clusters: {all_clusters}") # return: figure, dropdown payload, full style_map return ( fig, # update(choices=display_clusters, value=display_clusters[cluster_label_query]), style_names, bg_proj, # Return background points bg_ids, # Return background labels background_authors_embeddings_df, # Return the DataFrame for zoom handling ) # return fig, update(choices=feature_list, value=feature_list[0]),feature_list def extract_cluster_key(display_label: str) -> int: """ Given a dropdown label like "Cluster 5 (closest to mystery author; closest to Candidate 1 author)" returns the integer 5. """ m = re.match(r"Cluster\s+(\d+)", display_label) if not m: raise ValueError(f"Unrecognized cluster label: {display_label}") return int(m.group(1)) # When a cluster is selected, split features and populate radio buttons def on_cluster_change(selected_cluster, style_map): cluster_key = extract_cluster_key(selected_cluster) all_feats = style_map[cluster_key] llm_feats, g2v_feats = split_features(all_feats) # print(f"Selected cluster: {selected_cluster} ({cluster_key})") # print(f"LLM features: {llm_feats}") # Add "None" as a default selectable option llm_feats = ["None"] + llm_feats # filter out any g2v feature without a shorthand filtered_g2v = [] for feat in g2v_feats: if get_shorthand(feat) is None: print(f"Skipping Gram2Vec feature without shorthand: {feat}") else: filtered_g2v.append(feat) # Add "None" as a default selectable option filtered_g2v = ["None"] + filtered_g2v return ( gr.update(choices=llm_feats, value=llm_feats[0]), gr.update(choices=filtered_g2v, value=filtered_g2v[0]), llm_feats )