Spaces:
Running
Running
File size: 1,859 Bytes
f36a10a 2f55336 f36a10a 2f55336 f36a10a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the model and tokenizer
# @st.cache_resource
# def load_model():
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
# model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
# model.eval()
# return tokenizer, model
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
model.eval()
return tokenizer, model
def predict_news(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_label = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0][predicted_label].item()
return "FAKE" if predicted_label == 1 else "REAL", confidence
def main():
st.title("News Classifier")
# Load model
tokenizer, model = load_model()
# Text input
news_text = st.text_area("Enter news text to analyze:", height=200)
if st.button("Classify"):
if news_text:
with st.spinner('Analyzing...'):
prediction, confidence = predict_news(news_text, tokenizer, model)
# Display results
if prediction == "FAKE":
st.error(f"⚠️ {prediction} NEWS")
else:
st.success(f"✅ {prediction} NEWS")
st.info(f"Confidence: {confidence*100:.2f}%")
if __name__ == "__main__":
main()
|