Krish Patel commited on
Commit
842adb5
·
1 Parent(s): 13e414c
Files changed (1) hide show
  1. app.py +142 -142
app.py CHANGED
@@ -1,68 +1,14 @@
1
- # # import streamlit as st
2
- # # import torch
3
- # # from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
-
5
- # # # Load the model and tokenizer
6
- # # # @st.cache_resource
7
- # # # def load_model():
8
- # # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
9
- # # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
10
- # # # model.eval()
11
- # # # return tokenizer, model
12
- # # @st.cache_resource
13
- # # def load_model():
14
- # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
15
- # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
16
- # # model.eval()
17
- # # return tokenizer, model
18
-
19
- # # def predict_news(text, tokenizer, model):
20
- # # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
21
- # # with torch.no_grad():
22
- # # outputs = model(**inputs)
23
- # # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
24
- # # predicted_label = torch.argmax(probabilities, dim=-1).item()
25
- # # confidence = probabilities[0][predicted_label].item()
26
- # # return "FAKE" if predicted_label == 1 else "REAL", confidence
27
-
28
- # # def main():
29
- # # st.title("News Classifier")
30
-
31
- # # # Load model
32
- # # tokenizer, model = load_model()
33
-
34
- # # # Text input
35
- # # news_text = st.text_area("Enter news text to analyze:", height=200)
36
-
37
- # # if st.button("Classify"):
38
- # # if news_text:
39
- # # with st.spinner('Analyzing...'):
40
- # # prediction, confidence = predict_news(news_text, tokenizer, model)
41
-
42
- # # # Display results
43
- # # if prediction == "FAKE":
44
- # # st.error(f"⚠️ {prediction} NEWS")
45
- # # else:
46
- # # st.success(f"✅ {prediction} NEWS")
47
-
48
- # # st.info(f"Confidence: {confidence*100:.2f}%")
49
-
50
- # # if __name__ == "__main__":
51
- # # main()
52
-
53
-
54
  import streamlit as st
55
  import torch
56
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
57
- from fastapi import FastAPI, Request
58
- from pydantic import BaseModel
59
- from threading import Thread
60
- from streamlit.web import cli
61
-
62
- # FastAPI app
63
- api_app = FastAPI()
64
 
65
  # Load the model and tokenizer
 
 
 
 
 
 
66
  @st.cache_resource
67
  def load_model():
68
  tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
@@ -70,7 +16,6 @@ def load_model():
70
  model.eval()
71
  return tokenizer, model
72
 
73
- # Prediction function
74
  def predict_news(text, tokenizer, model):
75
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
76
  with torch.no_grad():
@@ -80,104 +25,159 @@ def predict_news(text, tokenizer, model):
80
  confidence = probabilities[0][predicted_label].item()
81
  return "FAKE" if predicted_label == 1 else "REAL", confidence
82
 
83
- # FastAPI request model
84
- class NewsInput(BaseModel):
85
- text: str
86
-
87
- # FastAPI route for POST requests
88
- @api_app.post("/classify")
89
- async def classify_news(data: NewsInput):
90
  tokenizer, model = load_model()
91
- prediction, confidence = predict_news(data.text, tokenizer, model)
92
- return {
93
- "prediction": prediction,
94
- "confidence": f"{confidence*100:.2f}%"
95
- }
96
-
97
- # Streamlit app
98
- def run_streamlit():
99
- def main():
100
- st.title("News Classifier")
101
-
102
- # Load model
103
- tokenizer, model = load_model()
104
-
105
- # Text input
106
- news_text = st.text_area("Enter news text to analyze:", height=200)
107
-
108
- if st.button("Classify"):
109
- if news_text:
110
- with st.spinner('Analyzing...'):
111
- prediction, confidence = predict_news(news_text, tokenizer, model)
112
-
113
- # Display results
114
- if prediction == "FAKE":
115
- st.error(f"⚠️ {prediction} NEWS")
116
- else:
117
- st.success(f"✅ {prediction} NEWS")
118
-
119
- st.info(f"Confidence: {confidence*100:.2f}%")
120
-
121
- main()
122
-
123
- # Threaded execution for FastAPI and Streamlit
124
- def start_fastapi():
125
- import uvicorn
126
- uvicorn.run(api_app, host="0.0.0.0", port=8502)
127
 
128
  if __name__ == "__main__":
129
- fastapi_thread = Thread(target=start_fastapi, daemon=True)
130
- fastapi_thread.start()
131
 
132
- # Start Streamlit
133
- cli.main()
134
 
135
- # from fastapi import FastAPI, HTTPException
136
- # from pydantic import BaseModel
137
- # from transformers import AutoTokenizer, AutoModelForSequenceClassification
138
  # import torch
 
 
 
 
 
139
 
140
- # from fastapi.middleware.cors import CORSMiddleware
141
-
142
-
143
- # # Define the FastAPI app
144
- # app = FastAPI()
145
-
146
- # app.add_middleware(
147
- # CORSMiddleware,
148
- # allow_origins=["*"], # Update with your frontend's URL for security
149
- # allow_credentials=True,
150
- # allow_methods=["*"],
151
- # allow_headers=["*"],
152
- # )
153
- # # Define the input data schema
154
- # class InputText(BaseModel):
155
- # text: str
156
 
157
- # # Load the model and tokenizer (ensure these paths are correct in your Space)
158
- # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
159
- # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
160
- # model.eval()
 
 
 
161
 
162
  # # Prediction function
163
- # def predict_news(text: str):
164
  # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
165
  # with torch.no_grad():
166
  # outputs = model(**inputs)
167
  # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
168
  # predicted_label = torch.argmax(probabilities, dim=-1).item()
169
  # confidence = probabilities[0][predicted_label].item()
 
 
 
 
 
 
 
 
 
 
 
170
  # return {
171
- # "prediction": "FAKE" if predicted_label == 1 else "REAL",
172
- # "confidence": round(confidence * 100, 2) # Return confidence as a percentage
173
  # }
174
 
175
- # # Define the POST endpoint
176
- # @app.post("/predict")
177
- # async def classify_news(input_text: InputText):
178
- # try:
179
- # result = predict_news(input_text.text)
180
- # return result
181
- # except Exception as e:
182
- # raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
 
 
 
 
 
4
 
5
  # Load the model and tokenizer
6
+ # @st.cache_resource
7
+ # def load_model():
8
+ # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
9
+ # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
10
+ # model.eval()
11
+ # return tokenizer, model
12
  @st.cache_resource
13
  def load_model():
14
  tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
 
16
  model.eval()
17
  return tokenizer, model
18
 
 
19
  def predict_news(text, tokenizer, model):
20
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
21
  with torch.no_grad():
 
25
  confidence = probabilities[0][predicted_label].item()
26
  return "FAKE" if predicted_label == 1 else "REAL", confidence
27
 
28
+ def main():
29
+ st.title("News Classifier")
30
+
31
+ # Load model
 
 
 
32
  tokenizer, model = load_model()
33
+
34
+ # Text input
35
+ news_text = st.text_area("Enter news text to analyze:", height=200)
36
+
37
+ if st.button("Classify"):
38
+ if news_text:
39
+ with st.spinner('Analyzing...'):
40
+ prediction, confidence = predict_news(news_text, tokenizer, model)
41
+
42
+ # Display results
43
+ if prediction == "FAKE":
44
+ st.error(f"⚠️ {prediction} NEWS")
45
+ else:
46
+ st.success(f"✅ {prediction} NEWS")
47
+
48
+ st.info(f"Confidence: {confidence*100:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  if __name__ == "__main__":
51
+ main()
 
52
 
 
 
53
 
54
+ # import streamlit as st
 
 
55
  # import torch
56
+ # from transformers import AutoTokenizer, AutoModelForSequenceClassification
57
+ # from fastapi import FastAPI, Request
58
+ # from pydantic import BaseModel
59
+ # from threading import Thread
60
+ # from streamlit.web import cli
61
 
62
+ # # FastAPI app
63
+ # api_app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # # Load the model and tokenizer
66
+ # @st.cache_resource
67
+ # def load_model():
68
+ # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
69
+ # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
70
+ # model.eval()
71
+ # return tokenizer, model
72
 
73
  # # Prediction function
74
+ # def predict_news(text, tokenizer, model):
75
  # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
76
  # with torch.no_grad():
77
  # outputs = model(**inputs)
78
  # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
79
  # predicted_label = torch.argmax(probabilities, dim=-1).item()
80
  # confidence = probabilities[0][predicted_label].item()
81
+ # return "FAKE" if predicted_label == 1 else "REAL", confidence
82
+
83
+ # # FastAPI request model
84
+ # class NewsInput(BaseModel):
85
+ # text: str
86
+
87
+ # # FastAPI route for POST requests
88
+ # @api_app.post("/classify")
89
+ # async def classify_news(data: NewsInput):
90
+ # tokenizer, model = load_model()
91
+ # prediction, confidence = predict_news(data.text, tokenizer, model)
92
  # return {
93
+ # "prediction": prediction,
94
+ # "confidence": f"{confidence*100:.2f}%"
95
  # }
96
 
97
+ # # Streamlit app
98
+ # def run_streamlit():
99
+ # def main():
100
+ # st.title("News Classifier")
101
+
102
+ # # Load model
103
+ # tokenizer, model = load_model()
104
+
105
+ # # Text input
106
+ # news_text = st.text_area("Enter news text to analyze:", height=200)
107
+
108
+ # if st.button("Classify"):
109
+ # if news_text:
110
+ # with st.spinner('Analyzing...'):
111
+ # prediction, confidence = predict_news(news_text, tokenizer, model)
112
+
113
+ # # Display results
114
+ # if prediction == "FAKE":
115
+ # st.error(f"⚠️ {prediction} NEWS")
116
+ # else:
117
+ # st.success(f"✅ {prediction} NEWS")
118
+
119
+ # st.info(f"Confidence: {confidence*100:.2f}%")
120
+
121
+ # main()
122
+
123
+ # # Threaded execution for FastAPI and Streamlit
124
+ # def start_fastapi():
125
+ # import uvicorn
126
+ # uvicorn.run(api_app, host="0.0.0.0", port=8502)
127
+
128
+ # if __name__ == "__main__":
129
+ # fastapi_thread = Thread(target=start_fastapi, daemon=True)
130
+ # fastapi_thread.start()
131
+
132
+ # # Start Streamlit
133
+ # cli.main()
134
+
135
+ # # from fastapi import FastAPI, HTTPException
136
+ # # from pydantic import BaseModel
137
+ # # from transformers import AutoTokenizer, AutoModelForSequenceClassification
138
+ # # import torch
139
+
140
+ # # from fastapi.middleware.cors import CORSMiddleware
141
+
142
+
143
+ # # # Define the FastAPI app
144
+ # # app = FastAPI()
145
+
146
+ # # app.add_middleware(
147
+ # # CORSMiddleware,
148
+ # # allow_origins=["*"], # Update with your frontend's URL for security
149
+ # # allow_credentials=True,
150
+ # # allow_methods=["*"],
151
+ # # allow_headers=["*"],
152
+ # # )
153
+ # # # Define the input data schema
154
+ # # class InputText(BaseModel):
155
+ # # text: str
156
+
157
+ # # # Load the model and tokenizer (ensure these paths are correct in your Space)
158
+ # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
159
+ # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
160
+ # # model.eval()
161
+
162
+ # # # Prediction function
163
+ # # def predict_news(text: str):
164
+ # # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
165
+ # # with torch.no_grad():
166
+ # # outputs = model(**inputs)
167
+ # # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
168
+ # # predicted_label = torch.argmax(probabilities, dim=-1).item()
169
+ # # confidence = probabilities[0][predicted_label].item()
170
+ # # return {
171
+ # # "prediction": "FAKE" if predicted_label == 1 else "REAL",
172
+ # # "confidence": round(confidence * 100, 2) # Return confidence as a percentage
173
+ # # }
174
+
175
+ # # # Define the POST endpoint
176
+ # # @app.post("/predict")
177
+ # # async def classify_news(input_text: InputText):
178
+ # # try:
179
+ # # result = predict_news(input_text.text)
180
+ # # return result
181
+ # # except Exception as e:
182
+ # # raise HTTPException(status_code=500, detail=str(e))
183