kitab-bench's picture
Update app.py
e98d277 verified
raw
history blame
21.5 kB
import gradio as gr
import pandas as pd
import numpy as np
# Parse the provided data
data_str = """
Dataset Size GPT-4o GPT-4o-mini Gemini-2.0-Flash Qwen2-VL Qwen2.5-VL AIN Tesseract EasyOCR Paddle Surya Microsoft Qari Gemma3 ArabicNougat
Metrics CHrF CER WER CHrF CER WER CHrF CER WER CHrF CER WER CHrF CER WER CHrF CER WER CHrF CER WER CHrF CER WER ChrF CER WER ChrF CER WER ChrF CER WER ChrF CER WER ChrF CER WER ChrF CER WER
PATS 500 88.82 0.23 0.30 64.51 0.53 0.71 98.90 0.01 0.02 63.35 1.02 1.02 83.27 0.26 0.36 99.76 0.00 0.00 79.76 0.14 0.28 77.10 0.54 0.73 20.34 0.77 1.00 13.09 4.66 4.67 95.99 0.03 0.10 75.62 0.98 1.03 22.36 1.34 1.61 60.79 1.51 1.60
SythenAR 500 86.27 0.09 0.20 74.82 0.14 0.32 87.73 0.07 0.17 34.19 0.59 1.13 76.15 0.21 0.40 90.65 0.04 0.16 58.06 0.31 0.72 64.96 0.45 0.76 19.16 0.80 1.01 16.19 4.82 7.90 85.80 0.10 0.27 55.48 1.68 1.69 54.81 0.36 0.69 61.00 1.14 1.40
HistoryAr 200 38.99 0.51 0.82 23.90 0.67 0.96 56.37 0.28 0.64 13.99 3.46 2.86 40.52 0.47 0.83 58.23 0.26 0.54 18.15 0.72 1.25 37.56 0.46 0.97 13.91 0.79 1.01 5.02 10.32 12.78 58.81 0.24 0.68 14.92 3.48 3.39 17.92 1.07 1.46 10.09 2.72 2.93
HistoricalBooks 10 43.16 0.41 0.76 27.35 0.59 0.88 88.49 0.05 0.22 20.98 1.90 2.16 44.51 0.33 0.72 13.83 0.84 0.88 13.37 0.74 0.99 27.36 0.60 0.98 18.28 0.71 1.00 6.28 6.81 6.30 58.87 0.29 0.71 22.26 0.67 0.97 27.04 0.92 1.32 9.87 0.82 1.00
Khatt 200 45.44 0.45 0.74 27.97 0.64 0.91 67.09 0.19 0.45 28.41 1.12 0.88 27.25 5.04 5.19 89.13 0.07 0.22 20.56 0.61 1.14 25.09 0.67 1.06 14.86 0.76 1.00 13.35 4.25 3.77 15.15 0.83 0.92 27.26 1.60 1.80 18.84 0.89 1.22 16.60 1.46 1.86
Adab 200 51.08 0.30 0.73 43.28 0.35 0.83 64.00 0.19 0.56 20.44 0.63 1.10 29.45 0.68 1.08 99.59 0.00 0.01 23.45 1.00 1.00 29.47 1.00 1.00 8.79 0.88 1.15 0.08 7.28 8.71 0.78 0.99 0.99 31.47 0.91 1.11 23.93 0.50 1.01 5.80 7.47 9.35
Muharaf 200 25.70 0.56 0.90 20.86 0.63 0.94 47.16 0.33 0.69 8.01 3.57 2.87 22.75 0.61 0.96 67.50 0.38 0.54 12.28 0.77 1.28 16.06 0.70 1.02 11.41 0.80 1.01 5.99 6.19 7.48 32.12 0.52 0.82 8.70 2.40 2.74 16.18 0.77 1.17 7.74 1.83 2.37
OnlineKhatt 200 52.50 0.29 0.63 38.52 0.41 0.76 68.54 0.17 0.44 30.97 1.30 2.01 47.55 0.36 0.70 92.74 0.03 0.12 21.26 0.59 1.21 30.64 0.56 1.08 15.40 0.78 1.03 9.67 6.71 6.95 25.28 0.72 0.85 31.81 1.52 1.53 27.05 0.51 0.91 15.84 1.68 2.31
Khatt 200 45.44 0.45 0.74 27.97 0.64 0.91 67.09 0.19 0.45 28.41 1.12 0.88 27.25 5.04 5.19 89.13 0.07 0.22 20.56 0.61 1.14 25.09 0.67 1.06 14.86 0.76 1.00 13.35 4.25 3.77 15.15 0.83 0.92 27.26 1.60 1.80 18.84 0.89 1.22 16.60 1.46 1.86
ISI-PPT 500 89.96 0.08 0.18 79.44 0.15 0.31 90.45 0.06 0.15 55.48 1.03 1.01 73.15 0.36 0.54 52.42 0.52 0.53 68.32 0.31 0.43 59.80 0.55 0.77 18.63 0.81 1.03 33.34 2.75 3.58 2.53 0.98 0.98 34.36 1.27 1.39 16.69 0.82 1.46 46.98 1.95 2.30
ArabicOCR 50 83.47 0.06 0.26 70.21 0.16 0.46 98.79 0.00 0.02 58.87 1.25 1.51 63.84 1.00 1.00 99.26 0.00 0.01 98.99 0.01 0.02 75.84 0.56 0.76 26.49 0.77 1.00 80.93 0.15 0.20 99.38 0.01 0.11 94.89 0.02 0.08 51.06 0.53 0.79 83.58 0.18 0.34
Hindawi 200 60.13 0.34 0.56 43.20 0.48 0.71 97.77 0.01 0.04 22.56 1.82 2.05 24.31 1.00 1.00 89.89 0.11 0.15 61.36 0.31 0.50 64.88 0.40 0.72 22.04 0.76 1.00 66.42 0.26 0.42 89.75 0.06 0.28 67.05 0.27 0.42 36.48 0.63 0.87 65.11 0.24 0.51
EvArest 800 82.19 0.20 0.38 71.65 0.25 0.51 80.93 0.18 0.36 55.57 0.41 0.67 80.00 0.19 0.36 76.11 0.30 0.32 18.94 0.85 0.96 57.28 0.38 0.65 13.26 0.89 1.04 4.18 5.91 6.38 72.93 0.32 0.50 31.01 4.65 4.75 60.33 0.37 0.65 2.35 33.12 31.54
Average 3,760 61.01 0.31 0.55 47.21 0.43 0.71 77.95 0.13 0.32 33.94 1.48 1.55 49.23 1.20 1.41 78.33 0.20 0.28 39.62 0.54 0.84 45.47 0.58 0.89 16.73 0.79 1.02 20.61 4.95 5.61 50.97 0.52 0.69 39.77 1.80 1.93 30.02 1.05 1.45 30.52 4.37 4.67
"""
# Process the data into a proper DataFrame
lines = data_str.strip().split('\n')
headers = lines[0].split('\t')
subheaders = lines[1].split('\t')
# Extract model names
model_names = []
current_model = ""
for i, header in enumerate(headers):
if i >= 2 and header: # Skip 'Dataset' and 'Size'
current_model = header
model_names.append(current_model)
# Create a processed dataset for the main leaderboard
models_data = []
for model in ["GPT-4o", "GPT-4o-mini", "Gemini-2.0-Flash", "Qwen2-VL", "Qwen2.5-VL",
"AIN", "Tesseract", "EasyOCR", "Paddle", "Surya", "Microsoft", "Qari",
"Gemma3", "ArabicNougat"]:
# Get the average metrics for each model from the last row
last_row = lines[-1].split('\t')
# Find the column indices for this model
model_idx = -1
for i, header in enumerate(headers):
if header == model:
model_idx = i
break
if model_idx == -1:
# Try finding as a substring
for i, header in enumerate(headers):
if model in header:
model_idx = i
break
if model_idx != -1:
# Get CHrF, CER, WER
chrf_idx = model_idx
cer_idx = model_idx + 1
wer_idx = model_idx + 2
try:
# Parse metrics
chrf = float(last_row[chrf_idx]) if chrf_idx < len(last_row) else 0
cer = float(last_row[cer_idx]) if cer_idx < len(last_row) else 0
wer = float(last_row[wer_idx]) if wer_idx < len(last_row) else 0
# Determine model type
model_type = "Closed-source" if model in ["GPT-4o", "GPT-4o-mini", "Gemini-2.0-Flash", "Claude-3-Opus"] else "Open-source"
# Add framework category
if model in ["Tesseract", "EasyOCR", "Paddle", "Surya"]:
model_type = "Framework"
# Organize by organization
org_map = {
"GPT-4o": "OpenAI",
"GPT-4o-mini": "OpenAI",
"Gemini-2.0-Flash": "Google",
"Qwen2-VL": "Alibaba",
"Qwen2.5-VL": "Alibaba",
"AIN": "MBZUAI",
"Tesseract": "Google",
"EasyOCR": "JaidedAI",
"Paddle": "Baidu",
"Surya": "VikParuchuri",
"Microsoft": "Microsoft",
"Qari": "Sakana AI",
"Gemma3": "Google",
"ArabicNougat": "Arabic NLP"
}
organization = org_map.get(model, "Unknown")
# Generate download counts (this is simulated)
import random
downloads = f"{random.randint(10, 600)}K"
# Add to models data
models_data.append({
"model": model,
"organization": organization,
"type": model_type,
"task": "OCR/Arabic",
"metrics": {
"chrf": chrf,
"cer": cer,
"wer": wer
},
"downloads": downloads,
"last_updated": "2025-04-01",
"model_url": f"https://huggingface.co/{organization}/{model}",
"paper_url": "https://arxiv.org/abs/2502.14949",
})
except Exception as e:
print(f"Error processing {model}: {e}")
continue
# Create detailed dataset for per-dataset comparisons
dataset_names = []
dataset_sizes = []
dataset_metrics = {}
for i in range(2, len(lines)-1): # Skip headers and the average line
parts = lines[i].split('\t')
if len(parts) > 1:
dataset = parts[0]
size = parts[1] if len(parts) > 1 else "0"
dataset_names.append(dataset)
dataset_sizes.append(size)
metrics = {}
for j, model in enumerate(model_names):
base_idx = j*3 + 2 # Starting column for each model (+2 for Dataset and Size columns)
if base_idx + 2 < len(parts):
try:
chrf = float(parts[base_idx]) if parts[base_idx] else 0
cer = float(parts[base_idx + 1]) if parts[base_idx + 1] else 0
wer = float(parts[base_idx + 2]) if parts[base_idx + 2] else 0
metrics[model] = {
"chrf": chrf,
"cer": cer,
"wer": wer
}
except (ValueError, IndexError) as e:
print(f"Error parsing metrics for {dataset}, {model}: {e}")
metrics[model] = {"chrf": 0, "cer": 0, "wer": 0}
dataset_metrics[dataset] = metrics
# Define CSS for styling
css = """
#leaderboard-title {
text-align: center;
margin-bottom: 0;
}
#leaderboard-subtitle {
text-align: center;
margin-top: 0;
color: #6B7280;
font-size: 1rem;
}
.gradio-container {
max-width: 1200px !important;
}
.header {
background: linear-gradient(90deg, #FFDE59 0%, #FFC532 100%);
padding: 20px;
border-radius: 8px;
margin-bottom: 20px;
display: flex;
align-items: center;
justify-content: space-between;
}
.header img {
height: 40px;
margin-right: 15px;
}
.header-content {
display: flex;
align-items: center;
}
.header-text {
display: flex;
flex-direction: column;
}
.header-text h1 {
margin: 0;
font-size: 1.5rem;
font-weight: bold;
color: black;
}
.header-text p {
margin: 0;
color: rgba(0, 0, 0, 0.8);
}
.filter-container {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-bottom: 20px;
}
table {
width: 100%;
border-collapse: collapse;
}
th {
background-color: #F9FAFB;
text-align: left;
padding: 12px;
font-weight: 600;
color: #374151;
border-bottom: 1px solid #E5E7EB;
position: sticky;
top: 0;
z-index: 10;
}
td {
padding: 12px;
border-bottom: 1px solid #E5E7EB;
}
tr:hover {
background-color: #F9FAFB;
}
a {
color: #2563EB;
text-decoration: none;
}
a:hover {
text-decoration: underline;
}
.footer {
display: flex;
justify-content: space-between;
align-items: center;
padding: 10px 0;
color: #6B7280;
font-size: 0.875rem;
margin-top: 20px;
}
.footer a {
color: #2563EB;
text-decoration: none;
display: inline-flex;
align-items: center;
}
.footer a:hover {
text-decoration: underline;
}
.metric-table {
max-height: 600px;
overflow-y: auto;
}
.dataset-row:nth-child(odd) {
background-color: #F9FAFB;
}
.dataset-row:hover {
background-color: #EFF6FF;
}
.tab-active {
border-bottom: 2px solid #2563EB !important;
color: #2563EB !important;
font-weight: 600;
}
.metric-badge {
padding: 2px 8px;
border-radius: 9999px;
font-weight: 600;
font-size: 0.75rem;
display: inline-block;
}
.metric-good {
background-color: #DCFCE7;
color: #166534;
}
.metric-medium {
background-color: #FEF3C7;
color: #92400E;
}
.metric-poor {
background-color: #FEE2E2;
color: #B91C1C;
}
.chart-container {
margin-top: 20px;
overflow-x: auto;
}
"""
# Function to format metrics with color coding
def format_metric(metric_name, value):
if metric_name == "chrf":
if value > 75:
return f'<span class="metric-badge metric-good">{value:.1f}</span>'
elif value > 50:
return f'<span class="metric-badge metric-medium">{value:.1f}</span>'
else:
return f'<span class="metric-badge metric-poor">{value:.1f}</span>'
elif metric_name == "cer" or metric_name == "wer": # Lower is better
if value < 0.5:
return f'<span class="metric-badge metric-good">{value:.2f}</span>'
elif value < 1.0:
return f'<span class="metric-badge metric-medium">{value:.2f}</span>'
else:
return f'<span class="metric-badge metric-poor">{value:.2f}</span>'
return f"{value:.2f}"
# Function to filter models based on type
def filter_by_type(models, type_filter):
if type_filter == "All":
return models
return [model for model in models if model["type"] == type_filter]
# Function to filter models based on search term
def filter_by_search(models, search_term):
if not search_term:
return models
# Convert search term to lowercase for case-insensitive search
search_term = search_term.lower()
# Filter based on model, organization, or task
filtered_models = []
for model in models:
if (search_term in model["model"].lower() or
search_term in model["organization"].lower() or
search_term in model["task"].lower()):
filtered_models.append(model)
return filtered_models
# Function to generate the main leaderboard HTML
def generate_main_leaderboard(models, sort_by, sort_order):
# Sort models
reverse = sort_order == "Descending"
# Define key function for sorting based on metric
def get_sort_key(model):
if sort_by == "model" or sort_by == "organization" or sort_by == "type" or sort_by == "task":
return model[sort_by]
elif sort_by == "downloads":
# Extract numeric part from download string (e.g., "24.5K" -> 24.5)
try:
return float(model[sort_by].replace("K", ""))
except:
return 0
elif sort_by == "chrf" or sort_by == "cer" or sort_by == "wer":
return model["metrics"][sort_by]
return 0
# For CER and WER, lower is better so reverse the sort order
if sort_by in ["cer", "wer"]:
reverse = not reverse
sorted_models = sorted(models, key=get_sort_key, reverse=reverse)
html = """
<div style="overflow-x: auto;">
<table style="width:100%">
<thead>
<tr>
<th>Model</th>
<th>Organization</th>
<th>Type</th>
<th>Task</th>
<th>CHrF ↑</th>
<th>CER ↓</th>
<th>WER ↓</th>
<th>Downloads</th>
<th>Links</th>
</tr>
</thead>
<tbody>
"""
for model in sorted_models:
html += f"""
<tr>
<td>
<div style="font-weight: 500;">{model['model']}</div>
</td>
<td>{model['organization']}</td>
<td>
<span style="background-color: {'#DBEAFE' if model['type'] == 'Open-source' else '#FEF3C7' if model['type'] == 'Closed-source' else '#E0F2FE'};
padding: 2px 6px;
border-radius: 9999px;
font-size: 0.75rem;">
{model['type']}
</span>
</td>
<td>
<span style="background-color: #E0F2FE;
padding: 2px 6px;
border-radius: 9999px;
font-size: 0.75rem;">
{model['task']}
</span>
</td>
<td>{format_metric('chrf', model['metrics']['chrf'])}</td>
<td>{format_metric('cer', model['metrics']['cer'])}</td>
<td>{format_metric('wer', model['metrics']['wer'])}</td>
<td>{model['downloads']}</td>
<td>
<a href="{model['model_url']}" target="_blank">Model</a> |
<a href="{model['paper_url']}" target="_blank">Paper</a>
</td>
</tr>
"""
html += """
</tbody>
</table>
</div>
"""
return html
# Function to generate per-dataset comparison HTML
def generate_dataset_comparison(selected_datasets, selected_models, metric):
html = f"""
<div class="metric-table">
<table style="width:100%">
<thead>
<tr>
<th>Dataset</th>
<th>Size</th>
"""
for model in selected_models:
html += f"<th>{model}</th>"
html += """
</tr>
</thead>
<tbody>
"""
for dataset_idx, dataset in enumerate(selected_datasets):
size = dataset_sizes[dataset_names.index(dataset)]
html += f"""
<tr class="dataset-row">
<td style="font-weight: 500;">{dataset}</td>
<td>{size}</td>
"""
for model in selected_models:
if model in dataset_metrics[dataset]:
value = dataset_metrics[dataset][model][metric.lower()]
html += f"<td>{format_metric(metric.lower(), value)}</td>"
else:
html += "<td>-</td>"
html += "</tr>"
html += """
</tbody>
</table>
</div>
"""
return html
# Create the Gradio interface
def create_leaderboard_interface():
with gr.Blocks(css=css) as demo:
gr.HTML(f"""
<div class="header">
<div class="header-content">
<div>
<svg xmlns="http://www.w3.org/2000/svg" width="40" height="40" viewBox="0 0 40 40" fill="none">
<path d="M9 16H11V24H9V16Z" fill="black"/>
<path d="M13 11H15V29H13V11Z" fill="black"/>
<path d="M17 15H19V25H17V15Z" fill="black"/>
<path d="M21 11H23V29H21V11Z" fill="black"/>
<path d="M25 16H27V24H25V16Z" fill="black"/>
<path d="M29 14H31V26H29V14Z" fill="black"/>
</svg>
</div>
<div class="header-text">
<h1>KITAB-Bench Leaderboard</h1>
<p>Arabic OCR and Document Understanding Benchmark</p>
</div>
</div>
<div>
<a href="https://huggingface.co/spaces" target="_blank" style="color: black; text-decoration: underline;">
Powered by 🤗 Spaces
</a>
</div>
</div>
""")
with gr.Tabs() as tabs:
with gr.TabItem("Main Leaderboard", id=0):
# Filter controls
with gr.Row(equal_height=True):
type_filter = gr.Radio(
["All", "Open-source", "Closed-source", "Framework"],
label="Model Type",
value="All",
interactive=True
)
search_input = gr.Textbox(
label="Search Models, Organizations, or Tasks",
placeholder="Type to search...",
interactive=True
)
with gr.Row(equal_height=True):
sort_by = gr.Dropdown(
["model", "organization", "type", "chrf", "cer", "wer", "downloads"],
label="Sort by",
value="chrf",
interactive=True
)
sort_order = gr.Radio(
["Descending", "Ascending"],
label="Sort Order",
value="Descending",
interactive=True
)
# Table output
leaderboard_output = gr.HTML()
# Update function for the main leaderboard
def update_leaderboard(type_filter, search_term, sort_by, sort_order):
filtered_models = filter_by_type(models_data, type_filter)
filtered_models = filter_by_search(filtered_models, search_term)
html = generate_main_leaderboard(filtered_models, sort_by, sort_order)
footer = f"""
<div class="footer">
<span>Showing {len(filtered_models)} of {len(models_data)} models</span>
<div>
<a href="https://github.com/mbzuai-oryx/KITAB-Bench" target="_blank">GitHub Repository</a>
<span style="margin: 0 8px;">|</span>
<a href="https://arxiv.org/abs/2502.14949" target="_blank">KITAB-Bench Paper</a>
</div>
</div>
"""
return html + footer
# Set up event handlers for main leaderboard
type_filter.change(update_leaderboard, [type_filter, search_input, sort_by, sort_order], leaderboard_output)
search_input.change(update_leaderboard, [type_filter, search_input, sort_by, sort_order], leaderboard_output)
sort_by.change(update_leaderboard, [type_filter, search_input, sort_by, sort_order], leaderboard_output)
sort_order.change(update_leaderboard, [type_filter, search_input, sort_by, sort_order], leaderboard_output)
with gr.TabItem("Dataset Comparison", id=1):
with gr.Row():
dataset_selector = gr.CheckboxGroup(
dataset_names,
label="Select Datasets",
value=dataset_names[:5], # Default to first 5 datasets
interactive=True)