File size: 5,865 Bytes
8a1304d
 
 
 
 
 
 
 
 
 
 
 
 
12faaae
8a1304d
12faaae
 
8a1304d
12faaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a1304d
 
 
 
 
 
12faaae
8a1304d
 
 
 
 
 
 
12faaae
 
 
 
 
 
 
 
 
 
 
8a1304d
12faaae
8a1304d
 
12faaae
 
8a1304d
12faaae
 
 
 
8a1304d
 
12faaae
 
 
 
 
 
 
 
 
 
8a1304d
12faaae
8a1304d
12faaae
8a1304d
 
 
 
 
 
 
 
 
12faaae
 
8a1304d
12faaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a1304d
 
12faaae
 
8a1304d
12faaae
 
 
 
 
 
 
 
 
8a1304d
 
 
12faaae
 
 
 
 
 
 
 
 
8a1304d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import streamlit as st

# Set up the Streamlit page - this must be the first st command
st.set_page_config(
    page_title="Paper Classification Service",
    page_icon="πŸ“š",
    layout="wide"
)

import PyPDF2
import io
from model import PaperClassifier

# Initialize the classifier with model selection
@st.cache_resource
def load_classifier(model_type):
    return PaperClassifier(model_type)

# Cache the PDF text extraction
@st.cache_data
def extract_pdf_text(pdf_bytes):
    """Extract text from PDF and try to separate title and abstract"""
    pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
    text = ""
    for page in pdf_reader.pages:
        text += page.extract_text() + "\n"
    
    # Try to extract title and abstract
    lines = text.split('\n')
    title = lines[0] if lines else ""
    abstract = "\n".join(lines[1:]) if len(lines) > 1 else ""
    
    return title.strip(), abstract.strip()

# Get available models for selection
available_models = list(PaperClassifier.AVAILABLE_MODELS.keys())

# Add model selection to sidebar
st.sidebar.title("Model Settings")
selected_model = st.sidebar.selectbox(
    "Select Model",
    available_models,
    index=0,
    help="Choose the model to use for classification"
)

# Display model information
model_info = PaperClassifier.AVAILABLE_MODELS[selected_model]
st.sidebar.markdown(f"""
### Selected Model
**Name:** {model_info['name']}  
**Description:** {model_info['description']}
""")

# Initialize the classifier with selected model
classifier = load_classifier(selected_model)

# Title and description
st.title("πŸ“š Academic Paper Classification")
st.markdown("""
This service helps you classify academic papers into different categories.
You can either:
- Enter the paper's title and abstract separately
- Upload a PDF file
""")

# Create two columns for input methods
col1, col2 = st.columns(2)

with col1:
    st.subheader("Option 1: Manual Input")
    
    # Title input
    title_input = st.text_input(
        "Paper Title:",
        placeholder="Enter the paper title..."
    )
    
    # Abstract input
    abstract_input = st.text_area(
        "Paper Abstract (optional):",
        height=200,
        placeholder="Enter the paper abstract (optional)..."
    )
    
    if st.button("Classify Paper"):
        if title_input.strip():
            with st.spinner("Classifying..."):
                result = classifier.classify_paper(
                    title=title_input,
                    abstract=abstract_input if abstract_input.strip() else None
                )
                
                st.success("Classification Complete!")
                st.write(f"**Input Type:** {result['input_type'].replace('_', ' ').title()}")
                st.write(f"**Model Used:** {result['model_used']}")
                
                # Show top categories
                st.subheader("Top Categories (95% Confidence)")
                total_prob = 0
                for cat_info in result['top_categories']:
                    prob = cat_info['probability']
                    total_prob += prob
                    st.progress(prob, text=f"{cat_info['category']} ({cat_info['arxiv_category']}): {prob:.1%}")
                
                st.info(f"Total probability of shown categories: {total_prob:.1%}")
        else:
            st.warning("Please enter at least the paper title.")

with col2:
    st.subheader("Option 2: PDF Upload")
    uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
    
    if uploaded_file is not None:
        if st.button("Classify PDF"):
            try:
                with st.spinner("Processing PDF..."):
                    # Extract title and abstract from PDF
                    title, abstract = extract_pdf_text(uploaded_file.read())
                    
                    if not title:
                        st.error("Could not extract title from PDF.")
                        st.stop()
                    
                    # Show extracted text
                    with st.expander("Show extracted text"):
                        st.write("**Extracted Title:**")
                        st.write(title)
                        if abstract:
                            st.write("**Extracted Abstract:**")
                            st.write(abstract)
                    
                    # Classify the paper
                    result = classifier.classify_paper(
                        title=title,
                        abstract=abstract if abstract else None
                    )
                    
                    st.success("Classification Complete!")
                    st.write(f"**Input Type:** {result['input_type'].replace('_', ' ').title()}")
                    st.write(f"**Model Used:** {result['model_used']}")
                    
                    # Show top categories
                    st.subheader("Top Categories (95% Confidence)")
                    total_prob = 0
                    for cat_info in result['top_categories']:
                        prob = cat_info['probability']
                        total_prob += prob
                        st.progress(prob, text=f"{cat_info['category']} ({cat_info['arxiv_category']}): {prob:.1%}")
                    
                    st.info(f"Total probability of shown categories: {total_prob:.1%}")
            except Exception as e:
                st.error(f"Error processing PDF: {str(e)}")

# Add information about the models
st.sidebar.markdown("---")
st.sidebar.title("Available Models")
st.sidebar.markdown("""
- **DistilBERT**: Fast and lightweight
- **DeBERTa v3**: Advanced performance
- **T5**: Versatile text-to-text
- **RoBERTa**: Strong performance
- **SciBERT**: Specialized for science
""")

# Add footer
st.markdown("---")
st.markdown("Made with ❀️ using Streamlit and Transformers")