Spaces:
Running
Running
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Load the model and tokenizer | |
# @st.cache_resource | |
# def load_model(): | |
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small') | |
# model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
# model.eval() | |
# return tokenizer, model | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False) | |
model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
model.eval() | |
return tokenizer, model | |
def predict_news(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_label = torch.argmax(probabilities, dim=-1).item() | |
confidence = probabilities[0][predicted_label].item() | |
return "FAKE" if predicted_label == 1 else "REAL", confidence | |
def main(): | |
st.title("News Classifier") | |
# Load model | |
tokenizer, model = load_model() | |
# Text input | |
news_text = st.text_area("Enter news text to analyze:", height=200) | |
if st.button("Classify"): | |
if news_text: | |
with st.spinner('Analyzing...'): | |
prediction, confidence = predict_news(news_text, tokenizer, model) | |
# Display results | |
if prediction == "FAKE": | |
st.error(f"⚠️ {prediction} NEWS") | |
else: | |
st.success(f"✅ {prediction} NEWS") | |
st.info(f"Confidence: {confidence*100:.2f}%") | |
if __name__ == "__main__": | |
main() | |