Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Load the model and tokenizer | |
# @st.cache_resource | |
# def load_model(): | |
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small') | |
# model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
# model.eval() | |
# return tokenizer, model | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False) | |
model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
model.eval() | |
return tokenizer, model | |
def predict_news(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_label = torch.argmax(probabilities, dim=-1).item() | |
confidence = probabilities[0][predicted_label].item() | |
return "FAKE" if predicted_label == 1 else "REAL", confidence | |
def main(): | |
st.title("News Classifier") | |
# Load model | |
tokenizer, model = load_model() | |
# Text input | |
news_text = st.text_area("Enter news text to analyze:", height=200) | |
if st.button("Classify"): | |
if news_text: | |
with st.spinner('Analyzing...'): | |
prediction, confidence = predict_news(news_text, tokenizer, model) | |
# Display results | |
if prediction == "FAKE": | |
st.error(f"⚠️ {prediction} NEWS") | |
else: | |
st.success(f"✅ {prediction} NEWS") | |
st.info(f"Confidence: {confidence*100:.2f}%") | |
if __name__ == "__main__": | |
main() | |
# import streamlit as st | |
# import torch | |
# from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# from fastapi import FastAPI, Request | |
# from pydantic import BaseModel | |
# from threading import Thread | |
# from streamlit.web import cli | |
# # FastAPI app | |
# api_app = FastAPI() | |
# # Load the model and tokenizer | |
# @st.cache_resource | |
# def load_model(): | |
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False) | |
# model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
# model.eval() | |
# return tokenizer, model | |
# # Prediction function | |
# def predict_news(text, tokenizer, model): | |
# inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
# with torch.no_grad(): | |
# outputs = model(**inputs) | |
# probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# predicted_label = torch.argmax(probabilities, dim=-1).item() | |
# confidence = probabilities[0][predicted_label].item() | |
# return "FAKE" if predicted_label == 1 else "REAL", confidence | |
# # FastAPI request model | |
# class NewsInput(BaseModel): | |
# text: str | |
# # FastAPI route for POST requests | |
# @api_app.post("/classify") | |
# async def classify_news(data: NewsInput): | |
# tokenizer, model = load_model() | |
# prediction, confidence = predict_news(data.text, tokenizer, model) | |
# return { | |
# "prediction": prediction, | |
# "confidence": f"{confidence*100:.2f}%" | |
# } | |
# # Streamlit app | |
# def run_streamlit(): | |
# def main(): | |
# st.title("News Classifier") | |
# # Load model | |
# tokenizer, model = load_model() | |
# # Text input | |
# news_text = st.text_area("Enter news text to analyze:", height=200) | |
# if st.button("Classify"): | |
# if news_text: | |
# with st.spinner('Analyzing...'): | |
# prediction, confidence = predict_news(news_text, tokenizer, model) | |
# # Display results | |
# if prediction == "FAKE": | |
# st.error(f"⚠️ {prediction} NEWS") | |
# else: | |
# st.success(f"✅ {prediction} NEWS") | |
# st.info(f"Confidence: {confidence*100:.2f}%") | |
# main() | |
# # Threaded execution for FastAPI and Streamlit | |
# def start_fastapi(): | |
# import uvicorn | |
# uvicorn.run(api_app, host="0.0.0.0", port=8502) | |
# if __name__ == "__main__": | |
# fastapi_thread = Thread(target=start_fastapi, daemon=True) | |
# fastapi_thread.start() | |
# # Start Streamlit | |
# cli.main() | |
# # from fastapi import FastAPI, HTTPException | |
# # from pydantic import BaseModel | |
# # from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# # import torch | |
# # from fastapi.middleware.cors import CORSMiddleware | |
# # # Define the FastAPI app | |
# # app = FastAPI() | |
# # app.add_middleware( | |
# # CORSMiddleware, | |
# # allow_origins=["*"], # Update with your frontend's URL for security | |
# # allow_credentials=True, | |
# # allow_methods=["*"], | |
# # allow_headers=["*"], | |
# # ) | |
# # # Define the input data schema | |
# # class InputText(BaseModel): | |
# # text: str | |
# # # Load the model and tokenizer (ensure these paths are correct in your Space) | |
# # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False) | |
# # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753") | |
# # model.eval() | |
# # # Prediction function | |
# # def predict_news(text: str): | |
# # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
# # with torch.no_grad(): | |
# # outputs = model(**inputs) | |
# # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# # predicted_label = torch.argmax(probabilities, dim=-1).item() | |
# # confidence = probabilities[0][predicted_label].item() | |
# # return { | |
# # "prediction": "FAKE" if predicted_label == 1 else "REAL", | |
# # "confidence": round(confidence * 100, 2) # Return confidence as a percentage | |
# # } | |
# # # Define the POST endpoint | |
# # @app.post("/predict") | |
# # async def classify_news(input_text: InputText): | |
# # try: | |
# # result = predict_news(input_text.text) | |
# # return result | |
# # except Exception as e: | |
# # raise HTTPException(status_code=500, detail=str(e)) | |