Krish Patel commited on
Commit
ecb4f5b
·
1 Parent(s): 5b5a804

Added supportive post request code

Browse files
Files changed (1) hide show
  1. app.py +108 -26
app.py CHANGED
@@ -1,14 +1,68 @@
1
- import streamlit as st
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
- # Load the model and tokenizer
 
 
 
 
 
 
6
  # @st.cache_resource
7
  # def load_model():
8
- # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
9
  # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
10
  # model.eval()
11
  # return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @st.cache_resource
13
  def load_model():
14
  tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
@@ -16,6 +70,7 @@ def load_model():
16
  model.eval()
17
  return tokenizer, model
18
 
 
19
  def predict_news(text, tokenizer, model):
20
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
21
  with torch.no_grad():
@@ -25,27 +80,54 @@ def predict_news(text, tokenizer, model):
25
  confidence = probabilities[0][predicted_label].item()
26
  return "FAKE" if predicted_label == 1 else "REAL", confidence
27
 
28
- def main():
29
- st.title("News Classifier")
30
-
31
- # Load model
 
 
 
32
  tokenizer, model = load_model()
33
-
34
- # Text input
35
- news_text = st.text_area("Enter news text to analyze:", height=200)
36
-
37
- if st.button("Classify"):
38
- if news_text:
39
- with st.spinner('Analyzing...'):
40
- prediction, confidence = predict_news(news_text, tokenizer, model)
41
-
42
- # Display results
43
- if prediction == "FAKE":
44
- st.error(f"⚠️ {prediction} NEWS")
45
- else:
46
- st.success(f"✅ {prediction} NEWS")
47
-
48
- st.info(f"Confidence: {confidence*100:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- if __name__ == "__main__":
51
  main()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # import torch
3
+ # from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # # Load the model and tokenizer
6
+ # # @st.cache_resource
7
+ # # def load_model():
8
+ # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
9
+ # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
10
+ # # model.eval()
11
+ # # return tokenizer, model
12
  # @st.cache_resource
13
  # def load_model():
14
+ # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
15
  # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
16
  # model.eval()
17
  # return tokenizer, model
18
+
19
+ # def predict_news(text, tokenizer, model):
20
+ # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
21
+ # with torch.no_grad():
22
+ # outputs = model(**inputs)
23
+ # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
24
+ # predicted_label = torch.argmax(probabilities, dim=-1).item()
25
+ # confidence = probabilities[0][predicted_label].item()
26
+ # return "FAKE" if predicted_label == 1 else "REAL", confidence
27
+
28
+ # def main():
29
+ # st.title("News Classifier")
30
+
31
+ # # Load model
32
+ # tokenizer, model = load_model()
33
+
34
+ # # Text input
35
+ # news_text = st.text_area("Enter news text to analyze:", height=200)
36
+
37
+ # if st.button("Classify"):
38
+ # if news_text:
39
+ # with st.spinner('Analyzing...'):
40
+ # prediction, confidence = predict_news(news_text, tokenizer, model)
41
+
42
+ # # Display results
43
+ # if prediction == "FAKE":
44
+ # st.error(f"⚠️ {prediction} NEWS")
45
+ # else:
46
+ # st.success(f"✅ {prediction} NEWS")
47
+
48
+ # st.info(f"Confidence: {confidence*100:.2f}%")
49
+
50
+ # if __name__ == "__main__":
51
+ # main()
52
+
53
+
54
+ import streamlit as st
55
+ import torch
56
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
57
+ from fastapi import FastAPI, Request
58
+ from pydantic import BaseModel
59
+ from threading import Thread
60
+ from streamlit.web import cli
61
+
62
+ # FastAPI app
63
+ api_app = FastAPI()
64
+
65
+ # Load the model and tokenizer
66
  @st.cache_resource
67
  def load_model():
68
  tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
 
70
  model.eval()
71
  return tokenizer, model
72
 
73
+ # Prediction function
74
  def predict_news(text, tokenizer, model):
75
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
76
  with torch.no_grad():
 
80
  confidence = probabilities[0][predicted_label].item()
81
  return "FAKE" if predicted_label == 1 else "REAL", confidence
82
 
83
+ # FastAPI request model
84
+ class NewsInput(BaseModel):
85
+ text: str
86
+
87
+ # FastAPI route for POST requests
88
+ @api_app.post("/classify")
89
+ async def classify_news(data: NewsInput):
90
  tokenizer, model = load_model()
91
+ prediction, confidence = predict_news(data.text, tokenizer, model)
92
+ return {
93
+ "prediction": prediction,
94
+ "confidence": f"{confidence*100:.2f}%"
95
+ }
96
+
97
+ # Streamlit app
98
+ def run_streamlit():
99
+ def main():
100
+ st.title("News Classifier")
101
+
102
+ # Load model
103
+ tokenizer, model = load_model()
104
+
105
+ # Text input
106
+ news_text = st.text_area("Enter news text to analyze:", height=200)
107
+
108
+ if st.button("Classify"):
109
+ if news_text:
110
+ with st.spinner('Analyzing...'):
111
+ prediction, confidence = predict_news(news_text, tokenizer, model)
112
+
113
+ # Display results
114
+ if prediction == "FAKE":
115
+ st.error(f"⚠️ {prediction} NEWS")
116
+ else:
117
+ st.success(f"✅ {prediction} NEWS")
118
+
119
+ st.info(f"Confidence: {confidence*100:.2f}%")
120
 
 
121
  main()
122
+
123
+ # Threaded execution for FastAPI and Streamlit
124
+ def start_fastapi():
125
+ import uvicorn
126
+ uvicorn.run(api_app, host="0.0.0.0", port=8502)
127
+
128
+ if __name__ == "__main__":
129
+ fastapi_thread = Thread(target=start_fastapi, daemon=True)
130
+ fastapi_thread.start()
131
+
132
+ # Start Streamlit
133
+ cli.main()