Krish Patel commited on
Commit
6b30efd
·
1 Parent(s): ecb4f5b

Added api endpoint 1

Browse files
Files changed (1) hide show
  1. app.py +132 -92
app.py CHANGED
@@ -1,14 +1,68 @@
1
- # import streamlit as st
2
- # import torch
3
- # from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
-
5
- # # Load the model and tokenizer
 
 
 
 
 
 
6
  # # @st.cache_resource
7
  # # def load_model():
8
- # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
9
  # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
10
  # # model.eval()
11
  # # return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # @st.cache_resource
13
  # def load_model():
14
  # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
@@ -16,6 +70,7 @@
16
  # model.eval()
17
  # return tokenizer, model
18
 
 
19
  # def predict_news(text, tokenizer, model):
20
  # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
21
  # with torch.no_grad():
@@ -25,109 +80,94 @@
25
  # confidence = probabilities[0][predicted_label].item()
26
  # return "FAKE" if predicted_label == 1 else "REAL", confidence
27
 
28
- # def main():
29
- # st.title("News Classifier")
30
-
31
- # # Load model
 
 
 
32
  # tokenizer, model = load_model()
33
-
34
- # # Text input
35
- # news_text = st.text_area("Enter news text to analyze:", height=200)
36
-
37
- # if st.button("Classify"):
38
- # if news_text:
39
- # with st.spinner('Analyzing...'):
40
- # prediction, confidence = predict_news(news_text, tokenizer, model)
41
-
42
- # # Display results
43
- # if prediction == "FAKE":
44
- # st.error(f"⚠️ {prediction} NEWS")
45
- # else:
46
- # st.success(f"✅ {prediction} NEWS")
47
-
48
- # st.info(f"Confidence: {confidence*100:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # if __name__ == "__main__":
51
  # main()
52
 
 
 
 
 
53
 
54
- import streamlit as st
55
- import torch
56
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
57
- from fastapi import FastAPI, Request
 
 
 
 
58
  from pydantic import BaseModel
59
- from threading import Thread
60
- from streamlit.web import cli
61
 
62
- # FastAPI app
63
- api_app = FastAPI()
64
 
65
- # Load the model and tokenizer
66
- @st.cache_resource
67
- def load_model():
68
- tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
69
- model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
70
- model.eval()
71
- return tokenizer, model
 
72
 
73
  # Prediction function
74
- def predict_news(text, tokenizer, model):
75
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
76
  with torch.no_grad():
77
  outputs = model(**inputs)
78
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
79
  predicted_label = torch.argmax(probabilities, dim=-1).item()
80
  confidence = probabilities[0][predicted_label].item()
81
- return "FAKE" if predicted_label == 1 else "REAL", confidence
82
-
83
- # FastAPI request model
84
- class NewsInput(BaseModel):
85
- text: str
86
-
87
- # FastAPI route for POST requests
88
- @api_app.post("/classify")
89
- async def classify_news(data: NewsInput):
90
- tokenizer, model = load_model()
91
- prediction, confidence = predict_news(data.text, tokenizer, model)
92
  return {
93
- "prediction": prediction,
94
- "confidence": f"{confidence*100:.2f}%"
95
  }
96
 
97
- # Streamlit app
98
- def run_streamlit():
99
- def main():
100
- st.title("News Classifier")
101
-
102
- # Load model
103
- tokenizer, model = load_model()
104
-
105
- # Text input
106
- news_text = st.text_area("Enter news text to analyze:", height=200)
107
-
108
- if st.button("Classify"):
109
- if news_text:
110
- with st.spinner('Analyzing...'):
111
- prediction, confidence = predict_news(news_text, tokenizer, model)
112
-
113
- # Display results
114
- if prediction == "FAKE":
115
- st.error(f"⚠️ {prediction} NEWS")
116
- else:
117
- st.success(f"✅ {prediction} NEWS")
118
-
119
- st.info(f"Confidence: {confidence*100:.2f}%")
120
-
121
- main()
122
-
123
- # Threaded execution for FastAPI and Streamlit
124
- def start_fastapi():
125
- import uvicorn
126
- uvicorn.run(api_app, host="0.0.0.0", port=8502)
127
-
128
- if __name__ == "__main__":
129
- fastapi_thread = Thread(target=start_fastapi, daemon=True)
130
- fastapi_thread.start()
131
 
132
- # Start Streamlit
133
- cli.main()
 
1
+ # # import streamlit as st
2
+ # # import torch
3
+ # # from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+
5
+ # # # Load the model and tokenizer
6
+ # # # @st.cache_resource
7
+ # # # def load_model():
8
+ # # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
9
+ # # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
10
+ # # # model.eval()
11
+ # # # return tokenizer, model
12
  # # @st.cache_resource
13
  # # def load_model():
14
+ # # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
15
  # # model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
16
  # # model.eval()
17
  # # return tokenizer, model
18
+
19
+ # # def predict_news(text, tokenizer, model):
20
+ # # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
21
+ # # with torch.no_grad():
22
+ # # outputs = model(**inputs)
23
+ # # probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
24
+ # # predicted_label = torch.argmax(probabilities, dim=-1).item()
25
+ # # confidence = probabilities[0][predicted_label].item()
26
+ # # return "FAKE" if predicted_label == 1 else "REAL", confidence
27
+
28
+ # # def main():
29
+ # # st.title("News Classifier")
30
+
31
+ # # # Load model
32
+ # # tokenizer, model = load_model()
33
+
34
+ # # # Text input
35
+ # # news_text = st.text_area("Enter news text to analyze:", height=200)
36
+
37
+ # # if st.button("Classify"):
38
+ # # if news_text:
39
+ # # with st.spinner('Analyzing...'):
40
+ # # prediction, confidence = predict_news(news_text, tokenizer, model)
41
+
42
+ # # # Display results
43
+ # # if prediction == "FAKE":
44
+ # # st.error(f"⚠️ {prediction} NEWS")
45
+ # # else:
46
+ # # st.success(f"✅ {prediction} NEWS")
47
+
48
+ # # st.info(f"Confidence: {confidence*100:.2f}%")
49
+
50
+ # # if __name__ == "__main__":
51
+ # # main()
52
+
53
+
54
+ # import streamlit as st
55
+ # import torch
56
+ # from transformers import AutoTokenizer, AutoModelForSequenceClassification
57
+ # from fastapi import FastAPI, Request
58
+ # from pydantic import BaseModel
59
+ # from threading import Thread
60
+ # from streamlit.web import cli
61
+
62
+ # # FastAPI app
63
+ # api_app = FastAPI()
64
+
65
+ # # Load the model and tokenizer
66
  # @st.cache_resource
67
  # def load_model():
68
  # tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
 
70
  # model.eval()
71
  # return tokenizer, model
72
 
73
+ # # Prediction function
74
  # def predict_news(text, tokenizer, model):
75
  # inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
76
  # with torch.no_grad():
 
80
  # confidence = probabilities[0][predicted_label].item()
81
  # return "FAKE" if predicted_label == 1 else "REAL", confidence
82
 
83
+ # # FastAPI request model
84
+ # class NewsInput(BaseModel):
85
+ # text: str
86
+
87
+ # # FastAPI route for POST requests
88
+ # @api_app.post("/classify")
89
+ # async def classify_news(data: NewsInput):
90
  # tokenizer, model = load_model()
91
+ # prediction, confidence = predict_news(data.text, tokenizer, model)
92
+ # return {
93
+ # "prediction": prediction,
94
+ # "confidence": f"{confidence*100:.2f}%"
95
+ # }
96
+
97
+ # # Streamlit app
98
+ # def run_streamlit():
99
+ # def main():
100
+ # st.title("News Classifier")
101
+
102
+ # # Load model
103
+ # tokenizer, model = load_model()
104
+
105
+ # # Text input
106
+ # news_text = st.text_area("Enter news text to analyze:", height=200)
107
+
108
+ # if st.button("Classify"):
109
+ # if news_text:
110
+ # with st.spinner('Analyzing...'):
111
+ # prediction, confidence = predict_news(news_text, tokenizer, model)
112
+
113
+ # # Display results
114
+ # if prediction == "FAKE":
115
+ # st.error(f"⚠️ {prediction} NEWS")
116
+ # else:
117
+ # st.success(f"✅ {prediction} NEWS")
118
+
119
+ # st.info(f"Confidence: {confidence*100:.2f}%")
120
 
 
121
  # main()
122
 
123
+ # # Threaded execution for FastAPI and Streamlit
124
+ # def start_fastapi():
125
+ # import uvicorn
126
+ # uvicorn.run(api_app, host="0.0.0.0", port=8502)
127
 
128
+ # if __name__ == "__main__":
129
+ # fastapi_thread = Thread(target=start_fastapi, daemon=True)
130
+ # fastapi_thread.start()
131
+
132
+ # # Start Streamlit
133
+ # cli.main()
134
+
135
+ from fastapi import FastAPI, HTTPException
136
  from pydantic import BaseModel
137
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
138
+ import torch
139
 
140
+ # Define the FastAPI app
141
+ app = FastAPI()
142
 
143
+ # Define the input data schema
144
+ class InputText(BaseModel):
145
+ text: str
146
+
147
+ # Load the model and tokenizer (ensure these paths are correct in your Space)
148
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
149
+ model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
150
+ model.eval()
151
 
152
  # Prediction function
153
+ def predict_news(text: str):
154
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
155
  with torch.no_grad():
156
  outputs = model(**inputs)
157
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
158
  predicted_label = torch.argmax(probabilities, dim=-1).item()
159
  confidence = probabilities[0][predicted_label].item()
 
 
 
 
 
 
 
 
 
 
 
160
  return {
161
+ "prediction": "FAKE" if predicted_label == 1 else "REAL",
162
+ "confidence": round(confidence * 100, 2) # Return confidence as a percentage
163
  }
164
 
165
+ # Define the POST endpoint
166
+ @app.post("/predict")
167
+ async def classify_news(input_text: InputText):
168
+ try:
169
+ result = predict_news(input_text.text)
170
+ return result
171
+ except Exception as e:
172
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173