Nexus_NLP_model / app.py
Krish Patel
try2
842adb5
raw
history blame
6.34 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the model and tokenizer
# @st.cache_resource
# def load_model():
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
# model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
# model.eval()
# return tokenizer, model
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
model.eval()
return tokenizer, model
def predict_news(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_label = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0][predicted_label].item()
return "FAKE" if predicted_label == 1 else "REAL", confidence
def main():
st.title("News Classifier")
# Load model
tokenizer, model = load_model()
# Text input
news_text = st.text_area("Enter news text to analyze:", height=200)
if st.button("Classify"):
if news_text:
with st.spinner('Analyzing...'):
prediction, confidence = predict_news(news_text, tokenizer, model)
# Display results
if prediction == "FAKE":
st.error(f"⚠️ {prediction} NEWS")
else:
st.success(f"✅ {prediction} NEWS")
st.info(f"Confidence: {confidence*100:.2f}%")
if __name__ == "__main__":
main()
# import streamlit as st
# import torch
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# from fastapi import FastAPI, Request
# from pydantic import BaseModel
# from threading import Thread
# from streamlit.web import cli
# # FastAPI app
# api_app = FastAPI()
# # Load the model and tokenizer
# @st.cache_resource
# def load_model():
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
# model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
# model.eval()
# return tokenizer, model
# # Prediction function
# def predict_news(text, tokenizer, model):
# inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# with torch.no_grad():
# outputs = model(**inputs)
# probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# predicted_label = torch.argmax(probabilities, dim=-1).item()
# confidence = probabilities[0][predicted_label].item()
# return "FAKE" if predicted_label == 1 else "REAL", confidence
# # FastAPI request model
# class NewsInput(BaseModel):
# text: str
# # FastAPI route for POST requests
# @api_app.post("/classify")
# async def classify_news(data: NewsInput):
# tokenizer, model = load_model()
# prediction, confidence = predict_news(data.text, tokenizer, model)
# return {
# "prediction": prediction,
# "confidence": f"{confidence*100:.2f}%"
# }
# # Streamlit app
# def run_streamlit():
# def main():
# st.title("News Classifier")
# # Load model
# tokenizer, model = load_model()
# # Text input
# news_text = st.text_area("Enter news text to analyze:", height=200)
# if st.button("Classify"):
# if news_text:
# with st.spinner('Analyzing...'):
# prediction, confidence = predict_news(news_text, tokenizer, model)
# # Display results
# if prediction == "FAKE":
# st.error(f"⚠️ {prediction} NEWS")
# else:
# st.success(f"✅ {prediction} NEWS")
# st.info(f"Confidence: {confidence*100:.2f}%")
# main()
# # Threaded execution for FastAPI and Streamlit
# def start_fastapi():
# import uvicorn
# uvicorn.run(api_app, host="0.0.0.0", port=8502)
# if __name__ == "__main__":
# fastapi_thread = Thread(target=start_fastapi, daemon=True)
# fastapi_thread.start()
# # Start Streamlit
# cli.main()
# # from fastapi import FastAPI, HTTPException
# # from pydantic import BaseModel
# # from transformers import AutoTokenizer, AutoModelForSequenceClassification
# # import torch
# # from fastapi.middleware.cors import CORSMiddleware
# # # Define the FastAPI app
# # app = FastAPI()
# # app.add_middleware(
# # CORSMiddleware,
# # allow_origins=["*"], # Update with your frontend's URL for security
# # allow_credentials=True,
# # allow_methods=["*"],
# # allow_headers=["*"],
# # )
# # # Define the input data schema
# # class InputText(BaseModel):
# # text: str
# # # Load the model and tokenizer (ensure these paths are correct in your Space)
# # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
# # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
# # model.eval()
# # # Prediction function
# # def predict_news(text: str):
# # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# # with torch.no_grad():
# # outputs = model(**inputs)
# # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# # predicted_label = torch.argmax(probabilities, dim=-1).item()
# # confidence = probabilities[0][predicted_label].item()
# # return {
# # "prediction": "FAKE" if predicted_label == 1 else "REAL",
# # "confidence": round(confidence * 100, 2) # Return confidence as a percentage
# # }
# # # Define the POST endpoint
# # @app.post("/predict")
# # async def classify_news(input_text: InputText):
# # try:
# # result = predict_news(input_text.text)
# # return result
# # except Exception as e:
# # raise HTTPException(status_code=500, detail=str(e))