File size: 1,859 Bytes
f36a10a
 
 
 
 
2f55336
 
 
 
 
 
f36a10a
 
2f55336
f36a10a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load the model and tokenizer
# @st.cache_resource
# def load_model():
#     tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
#     model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
#     model.eval()
#     return tokenizer, model
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
    model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
    model.eval()
    return tokenizer, model

def predict_news(text, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_label = torch.argmax(probabilities, dim=-1).item()
    confidence = probabilities[0][predicted_label].item()
    return "FAKE" if predicted_label == 1 else "REAL", confidence

def main():
    st.title("News Classifier")
    
    # Load model
    tokenizer, model = load_model()
    
    # Text input
    news_text = st.text_area("Enter news text to analyze:", height=200)
    
    if st.button("Classify"):
        if news_text:
            with st.spinner('Analyzing...'):
                prediction, confidence = predict_news(news_text, tokenizer, model)
                
                # Display results
                if prediction == "FAKE":
                    st.error(f"⚠️ {prediction} NEWS")
                else:
                    st.success(f"✅ {prediction} NEWS")
                
                st.info(f"Confidence: {confidence*100:.2f}%")

if __name__ == "__main__":
    main()