taprosoft's picture
feat: add data split selection
d89ff19 unverified
from pathlib import Path
import re
from datasets import load_dataset
import json
import gradio as gr
from matplotlib import pyplot as plt
import pandas as pd
HEAD_HTML = """
<link href='https://fonts.googleapis.com/css?family=PT Mono' rel='stylesheet'>
"""
def normalize_spaces(text):
return '\n'.join(re.sub(r" {2,}", " ", line) for line in text.split('\n'))
def load_json(file_path):
with open(file_path, "r") as file:
return json.load(file)
def on_select(evt: gr.SelectData, current_split):
item_id = evt.row_value[0]
filename = evt.row_value[1]
output_methods = []
for method in METHOD_LIST:
output_methods.extend(
[
item_by_id_dict[current_split][filename][method],
evaluation_dict[current_split][method][filename]["pred"],
evaluation_dict[current_split][method][filename]["score"] == 1,
]
)
return output_methods + [
item_by_id_dict[current_split][filename]["image"],
input_dataframe[current_split]["questions"][item_id],
input_dataframe[current_split]["answers"][item_id],
]
def on_dataset_change(current_split):
# update dataframe, plot based on the selected dataset
plot = generate_plot(
providers=METHOD_LIST,
scores=[
method_scores[current_split][method]
for method in METHOD_LIST
],
)
dataframe = pd.DataFrame(input_dataframe[current_split])
return plot, dataframe
def generate_plot(providers, scores):
fig, ax = plt.subplots(figsize=(4, 3))
bars = ax.barh(providers[::-1], scores[::-1])
min_score = min(scores)
max_score = max(scores)
# Customize plot
ax.set_title("Methods Average Scores")
ax.set_ylabel("Methods")
ax.set_xlabel("Scores")
ax.set_xlim(min_score - 0.1, min(max_score + 0.1, 1.0))
for bar in bars:
width = bar.get_width()
ax.text(
width,
bar.get_y() + bar.get_height() / 2.0,
f"{width:.3f}",
ha="left",
va="center",
)
plt.tight_layout()
return fig
evaluation_json_dir = Path("eval_output")
dataset = load_dataset(path="terryoo/TableVQA-Bench")
SPLIT_NAMES = ["fintabnetqa", "vwtq_syn"]
DEFAULT_SPLIT_NAME = "fintabnetqa"
METHOD_LIST = ["text_2d", "text_1d", "html"]
item_by_id_dict = {}
input_dataframe = {}
evaluation_dict = {}
method_scores = {}
for split_name in SPLIT_NAMES:
input_text_path = Path(
f"dataset_tablevqa_{split_name}_2d_text"
)
item_by_id_dict[split_name] = {}
input_dataframe[split_name] = {
"ids": [],
"filenames": [],
"questions": [],
"answers": [],
}
evaluation_dict[split_name] = {}
method_scores[split_name] = {}
for idx, sample in enumerate(dataset[split_name]):
sample_id = sample["qa_id"]
text_path = input_text_path / f"{sample_id}.txt"
with open(text_path, "r") as f:
text_2d = f.read()
item_by_id_dict[split_name][sample_id] = {
"text_2d": text_2d,
"text_1d": normalize_spaces(text_2d),
"image": sample["image"],
"html": sample["text_html_table"],
}
input_dataframe[split_name]["ids"].append(idx)
input_dataframe[split_name]["filenames"].append(sample_id)
input_dataframe[split_name]["questions"].append(sample["question"])
input_dataframe[split_name]["answers"].append(sample["gt"])
for method in METHOD_LIST:
evaluation_json_path = evaluation_json_dir / f"{split_name}_{method}.json"
evaluation_data = load_json(evaluation_json_path)
evaluation_dict[split_name][method] = {
item["qa_id"]: {
"pred": item["pred"],
"score": item["scores"]["a"],
}
for item in evaluation_data["instances"]
}
method_scores[split_name][method] = round(
evaluation_data["evaluation_meta"]["average_scores"][0] / 100,
2,
)
with gr.Blocks(
theme=gr.themes.Ocean(
font_mono="PT Mono",
),
head=HEAD_HTML,
) as demo:
gr.Markdown(
"# 2D Layout-Preserving Text Benchmark\n"
"Dataset: [TableVQA-Bench](https://huggingface.co/datasets/terryoo/TableVQA-Bench)\n"
)
dataset_name = gr.Dropdown(
label="Dataset split",
value=DEFAULT_SPLIT_NAME,
choices=["fintabnetqa", "vwtq_syn"],
)
gr.Markdown("### File List")
plot_avg = gr.Plot(
label="Average scores",
value=generate_plot(
providers=METHOD_LIST,
scores=[
method_scores[DEFAULT_SPLIT_NAME][method]
for method in METHOD_LIST
],
),
container=False,
)
file_list = gr.Dataframe(
value=pd.DataFrame(input_dataframe[DEFAULT_SPLIT_NAME]),
max_height=300,
show_row_numbers=False,
show_search=True,
column_widths=["10%", "30%", "30%", "30%"],
)
with gr.Row():
with gr.Column():
demo_image = gr.Image(
label="Input Image",
interactive=False,
height=400,
width=600,
)
with gr.Column():
question = gr.Textbox(
label="Question",
interactive=False,
)
answer_gt = gr.Textbox(
label="GT Answer",
interactive=False,
)
output_elements = []
with gr.Tabs():
for method in METHOD_LIST:
with gr.Tab(method):
if "html" in method:
output = gr.HTML(
container=False,
show_label=False,
)
else:
output = gr.Code(
container=False,
language="markdown",
show_line_numbers=False,
)
pred = gr.Textbox(
label="Predicted Answer",
interactive=False,
)
score = gr.Textbox(
label="Score",
interactive=False,
)
output_elements.extend([output, pred, score])
file_list.select(
fn=on_select,
inputs=[dataset_name],
outputs=output_elements +
[
demo_image,
question,
answer_gt
],
)
dataset_name.change(
fn=on_dataset_change,
inputs=dataset_name,
outputs=[plot_avg, file_list],
)
demo.launch()