Spaces:
Running
Running
from pathlib import Path | |
import re | |
from datasets import load_dataset | |
import json | |
import gradio as gr | |
from matplotlib import pyplot as plt | |
import pandas as pd | |
HEAD_HTML = """ | |
<link href='https://fonts.googleapis.com/css?family=PT Mono' rel='stylesheet'> | |
""" | |
def normalize_spaces(text): | |
return '\n'.join(re.sub(r" {2,}", " ", line) for line in text.split('\n')) | |
def load_json(file_path): | |
with open(file_path, "r") as file: | |
return json.load(file) | |
def on_select(evt: gr.SelectData, current_split): | |
item_id = evt.row_value[0] | |
filename = evt.row_value[1] | |
output_methods = [] | |
for method in METHOD_LIST: | |
output_methods.extend( | |
[ | |
item_by_id_dict[current_split][filename][method], | |
evaluation_dict[current_split][method][filename]["pred"], | |
evaluation_dict[current_split][method][filename]["score"] == 1, | |
] | |
) | |
return output_methods + [ | |
item_by_id_dict[current_split][filename]["image"], | |
input_dataframe[current_split]["questions"][item_id], | |
input_dataframe[current_split]["answers"][item_id], | |
] | |
def on_dataset_change(current_split): | |
# update dataframe, plot based on the selected dataset | |
plot = generate_plot( | |
providers=METHOD_LIST, | |
scores=[ | |
method_scores[current_split][method] | |
for method in METHOD_LIST | |
], | |
) | |
dataframe = pd.DataFrame(input_dataframe[current_split]) | |
return plot, dataframe | |
def generate_plot(providers, scores): | |
fig, ax = plt.subplots(figsize=(4, 3)) | |
bars = ax.barh(providers[::-1], scores[::-1]) | |
min_score = min(scores) | |
max_score = max(scores) | |
# Customize plot | |
ax.set_title("Methods Average Scores") | |
ax.set_ylabel("Methods") | |
ax.set_xlabel("Scores") | |
ax.set_xlim(min_score - 0.1, min(max_score + 0.1, 1.0)) | |
for bar in bars: | |
width = bar.get_width() | |
ax.text( | |
width, | |
bar.get_y() + bar.get_height() / 2.0, | |
f"{width:.3f}", | |
ha="left", | |
va="center", | |
) | |
plt.tight_layout() | |
return fig | |
evaluation_json_dir = Path("eval_output") | |
dataset = load_dataset(path="terryoo/TableVQA-Bench") | |
SPLIT_NAMES = ["fintabnetqa", "vwtq_syn"] | |
DEFAULT_SPLIT_NAME = "fintabnetqa" | |
METHOD_LIST = ["text_2d", "text_1d", "html"] | |
item_by_id_dict = {} | |
input_dataframe = {} | |
evaluation_dict = {} | |
method_scores = {} | |
for split_name in SPLIT_NAMES: | |
input_text_path = Path( | |
f"dataset_tablevqa_{split_name}_2d_text" | |
) | |
item_by_id_dict[split_name] = {} | |
input_dataframe[split_name] = { | |
"ids": [], | |
"filenames": [], | |
"questions": [], | |
"answers": [], | |
} | |
evaluation_dict[split_name] = {} | |
method_scores[split_name] = {} | |
for idx, sample in enumerate(dataset[split_name]): | |
sample_id = sample["qa_id"] | |
text_path = input_text_path / f"{sample_id}.txt" | |
with open(text_path, "r") as f: | |
text_2d = f.read() | |
item_by_id_dict[split_name][sample_id] = { | |
"text_2d": text_2d, | |
"text_1d": normalize_spaces(text_2d), | |
"image": sample["image"], | |
"html": sample["text_html_table"], | |
} | |
input_dataframe[split_name]["ids"].append(idx) | |
input_dataframe[split_name]["filenames"].append(sample_id) | |
input_dataframe[split_name]["questions"].append(sample["question"]) | |
input_dataframe[split_name]["answers"].append(sample["gt"]) | |
for method in METHOD_LIST: | |
evaluation_json_path = evaluation_json_dir / f"{split_name}_{method}.json" | |
evaluation_data = load_json(evaluation_json_path) | |
evaluation_dict[split_name][method] = { | |
item["qa_id"]: { | |
"pred": item["pred"], | |
"score": item["scores"]["a"], | |
} | |
for item in evaluation_data["instances"] | |
} | |
method_scores[split_name][method] = round( | |
evaluation_data["evaluation_meta"]["average_scores"][0] / 100, | |
2, | |
) | |
with gr.Blocks( | |
theme=gr.themes.Ocean( | |
font_mono="PT Mono", | |
), | |
head=HEAD_HTML, | |
) as demo: | |
gr.Markdown( | |
"# 2D Layout-Preserving Text Benchmark\n" | |
"Dataset: [TableVQA-Bench](https://huggingface.co/datasets/terryoo/TableVQA-Bench)\n" | |
) | |
dataset_name = gr.Dropdown( | |
label="Dataset split", | |
value=DEFAULT_SPLIT_NAME, | |
choices=["fintabnetqa", "vwtq_syn"], | |
) | |
gr.Markdown("### File List") | |
plot_avg = gr.Plot( | |
label="Average scores", | |
value=generate_plot( | |
providers=METHOD_LIST, | |
scores=[ | |
method_scores[DEFAULT_SPLIT_NAME][method] | |
for method in METHOD_LIST | |
], | |
), | |
container=False, | |
) | |
file_list = gr.Dataframe( | |
value=pd.DataFrame(input_dataframe[DEFAULT_SPLIT_NAME]), | |
max_height=300, | |
show_row_numbers=False, | |
show_search=True, | |
column_widths=["10%", "30%", "30%", "30%"], | |
) | |
with gr.Row(): | |
with gr.Column(): | |
demo_image = gr.Image( | |
label="Input Image", | |
interactive=False, | |
height=400, | |
width=600, | |
) | |
with gr.Column(): | |
question = gr.Textbox( | |
label="Question", | |
interactive=False, | |
) | |
answer_gt = gr.Textbox( | |
label="GT Answer", | |
interactive=False, | |
) | |
output_elements = [] | |
with gr.Tabs(): | |
for method in METHOD_LIST: | |
with gr.Tab(method): | |
if "html" in method: | |
output = gr.HTML( | |
container=False, | |
show_label=False, | |
) | |
else: | |
output = gr.Code( | |
container=False, | |
language="markdown", | |
show_line_numbers=False, | |
) | |
pred = gr.Textbox( | |
label="Predicted Answer", | |
interactive=False, | |
) | |
score = gr.Textbox( | |
label="Score", | |
interactive=False, | |
) | |
output_elements.extend([output, pred, score]) | |
file_list.select( | |
fn=on_select, | |
inputs=[dataset_name], | |
outputs=output_elements + | |
[ | |
demo_image, | |
question, | |
answer_gt | |
], | |
) | |
dataset_name.change( | |
fn=on_dataset_change, | |
inputs=dataset_name, | |
outputs=[plot_avg, file_list], | |
) | |
demo.launch() | |