|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
|
|
|
|
MODEL_OPTIONS = { |
|
"waleko/roberta-arxiv-tags": "RoBERTa Arxiv Tags" |
|
} |
|
|
|
def load_model(model_name): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
return model, tokenizer |
|
|
|
current_model = None |
|
current_tokenizer = None |
|
|
|
def get_model_and_tokenizer(model_name): |
|
global current_model, current_tokenizer |
|
if current_model is None or current_tokenizer is None: |
|
current_model, current_tokenizer = load_model(model_name) |
|
return current_model, current_tokenizer |
|
|
|
def create_visualization(probs, labels): |
|
return go.Figure(data=[go.Pie( |
|
labels=labels + ['Others'] if sum(probs) < 1 else labels, |
|
values=list(probs) + [1 - sum(probs)] if sum(probs) < 1 else list(probs), |
|
textinfo='percent', |
|
textposition='inside', |
|
hole=.3, |
|
showlegend=True |
|
)]) |
|
|
|
def classify_text(title, abstract, model_name): |
|
if not title and not abstract: |
|
return "Error: At least one of title or abstract must be provided.", None |
|
|
|
model, tokenizer = get_model_and_tokenizer(model_name) |
|
text = 'Title: ' + (title or '') + '\n\nAbstract: ' + (abstract or '') |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probs = torch.nn.functional.softmax(logits[0], dim=0) |
|
probs = probs.numpy() |
|
sorted_idx = np.argsort(probs)[::-1] |
|
sorted_probs = probs[sorted_idx] |
|
cumsum = np.cumsum(sorted_probs) |
|
k = 1 |
|
if sorted_probs[0] < 0.95: |
|
k = np.argmax(cumsum >= 0.95) + 1 |
|
id2label = model.config.id2label |
|
tags = [id2label[idx] for idx in sorted_idx[:k]] |
|
compact_pred = f'<span style="font-weight: 800;">{tags[0]}</span>' + (f" {' '.join(tags[1:])}" if len(tags) > 1 else "") |
|
viz_data = create_visualization( |
|
sorted_probs[:k], |
|
[id2label[idx] for idx in sorted_idx[:k]] |
|
) |
|
html_output = f""" |
|
<div> |
|
<h3>Predicted Tags</h3> |
|
<p>{compact_pred}</p> |
|
</div> |
|
""" |
|
return html_output, viz_data |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# Arxiv Tags Classification |
|
Classify academic papers into arXiv categories using state-of-the-art language models. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_dropdown = gr.Dropdown( |
|
choices=list(MODEL_OPTIONS.keys()), |
|
value=list(MODEL_OPTIONS.keys())[0], |
|
label="Select Model", |
|
info="Choose the model for classification" |
|
) |
|
title_input = gr.Textbox( |
|
lines=1, |
|
label="Title", |
|
placeholder="Enter paper title (optional if abstract is provided)" |
|
) |
|
abstract_input = gr.Textbox( |
|
lines=5, |
|
label="Abstract", |
|
placeholder="Enter paper abstract (optional if title is provided)" |
|
) |
|
with gr.Column(scale=1): |
|
output_html = gr.HTML( |
|
label="Predicted Tags" |
|
) |
|
output_plot = gr.Plot( |
|
label="Probability Distribution", |
|
show_label=True |
|
) |
|
inputs = [title_input, abstract_input, model_dropdown] |
|
btn = gr.Button("Classify", variant="primary") |
|
btn.click(fn=classify_text, inputs=inputs, outputs=[output_html, output_plot]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |