Rowan Martnishn commited on
Commit
af88ed0
·
verified ·
1 Parent(s): 662cb6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -240
app.py CHANGED
@@ -1,235 +1,3 @@
1
- # import gradio as gr
2
- # import schedule
3
- # import time
4
- # import datetime
5
- # import praw
6
- # import joblib
7
- # import torch
8
- # import scipy.sparse as sp
9
- # import torch.nn as nn
10
- # import pandas as pd
11
- # import re
12
- # import numpy as np
13
- # import matplotlib.pyplot as plt
14
- # from scipy.interpolate import make_interp_spline
15
- # from transformers import AutoTokenizer
16
- # import matplotlib.font_manager as fm
17
-
18
- # # Load models and data (your existing code)
19
- # autovectorizer = joblib.load('AutoVectorizer.pkl')
20
- # autoclassifier = joblib.load('AutoClassifier.pkl')
21
- # MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
22
- # tokenizer = AutoTokenizer.from_pretrained(MODEL)
23
-
24
- # class ScorePredictor(nn.Module):
25
- # # ... (Your ScorePredictor class)
26
- # def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
27
- # super(ScorePredictor, self).__init__()
28
- # self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
29
- # self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
30
- # self.fc = nn.Linear(hidden_dim, output_dim)
31
- # self.sigmoid = nn.Sigmoid()
32
-
33
- # def forward(self, input_ids, attention_mask):
34
- # embedded = self.embedding(input_ids)
35
- # lstm_out, _ = self.lstm(embedded)
36
- # final_hidden_state = lstm_out[:, -1, :]
37
- # output = self.fc(final_hidden_state)
38
- # return self.sigmoid(output)
39
-
40
- # score_model = ScorePredictor(tokenizer.vocab_size)
41
- # score_model.load_state_dict(torch.load("score_predictor.pth"))
42
- # score_model.eval()
43
-
44
- # sentiment_model = joblib.load('sentiment_forecast_model.pkl')
45
-
46
- # reddit = praw.Reddit(
47
- # client_id="PH99oWZjM43GimMtYigFvA",
48
- # client_secret="3tJsXQKEtFFYInxzLEDqRZ0s_w5z0g",
49
- # user_agent='MyAPI/0.0.1',
50
- # check_for_async=False)
51
-
52
- # subreddits = [
53
- # "florida",
54
- # "ohio",
55
- # # "libertarian",
56
- # # "southpark",
57
- # # "walkaway",
58
- # # "truechristian",
59
- # # "conservatives"
60
- # ]
61
-
62
- # # Global variables for data
63
- # global prediction_plot_base64
64
-
65
- # def process_data():
66
- # """Fetches data, performs analysis, and generates the plot."""
67
- # global prediction_plot_base64
68
- # end_date = datetime.datetime.utcnow()
69
- # start_date = end_date - datetime.timedelta(days=14)
70
-
71
- # def fetch_all_recent_posts(subreddit_name, start_time, limit=500):
72
- # # ... (Your fetch_all_recent_posts function)
73
- # subreddit = reddit.subreddit(subreddit_name)
74
- # posts = []
75
-
76
- # try:
77
- # for post in subreddit.top(limit=limit): # Fetch recent posts
78
- # post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
79
- # if post_time >= start_time: # Filter only within last 14 days
80
- # posts.append({
81
- # "subreddit": subreddit_name,
82
- # "timestamp": post.created_utc,
83
- # "date": post_time.strftime('%Y-%m-%d %H:%M:%S'),
84
- # "post_text": post.title
85
- # })
86
- # except Exception as e:
87
- # print(f"Error fetching posts from r/{subreddit_name}: {e}")
88
-
89
- # return posts
90
-
91
- # def preprocess_text(text):
92
- # # ... (Your preprocess_text function)
93
- # text = text.lower()
94
- # text = re.sub(r'http\S+', '', text)
95
- # text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
96
- # text = re.sub(r'\s+', ' ', text).strip()
97
- # return text
98
-
99
- # def predict_score(text):
100
- # # ... (Your predict_score function)
101
- # if not text:
102
- # return 0.0
103
- # max_length = 512
104
-
105
- # encoded_input = tokenizer(
106
- # text.split(),
107
- # return_tensors='pt',
108
- # padding=True,
109
- # truncation=True,
110
- # max_length=max_length
111
- # )
112
-
113
- # input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
114
- # with torch.no_grad():
115
- # score = score_model(input_ids, attention_mask)[0].item()
116
- # return score
117
-
118
- # start_time = datetime.datetime.utcnow() - datetime.timedelta(days=14)
119
- # all_posts = []
120
- # for sub in subreddits:
121
- # print(f"Fetching posts from r/{sub}")
122
- # posts = fetch_all_recent_posts(sub, start_time)
123
- # all_posts.extend(posts)
124
- # print(f"Fetched {len(posts)} posts from r/{sub}")
125
-
126
- # filtered_posts = []
127
- # for post in all_posts:
128
- # vector = autovectorizer.transform([post['post_text']])
129
- # prediction = autoclassifier.predict(vector)
130
- # if prediction[0] == 1:
131
- # filtered_posts.append(post)
132
- # all_posts = filtered_posts
133
-
134
- # df = pd.DataFrame(all_posts)
135
- # df['date'] = pd.to_datetime(df['date'])
136
- # df['date_only'] = df['date'].dt.date
137
- # df = df.sort_values(by=['date_only'])
138
- # df['sentiment_score'] = df['post_text'].apply(predict_score)
139
-
140
- # last_14_dates = df['date_only'].unique()
141
- # num_dates = min(len(last_14_dates), 14)
142
- # last_14_dates = sorted(last_14_dates, reverse=True)[:num_dates]
143
-
144
- # filtered_df = df[df['date_only'].isin(last_14_dates)]
145
- # daily_sentiment = filtered_df.groupby('date_only')['sentiment_score'].median()
146
-
147
- # if len(daily_sentiment) < 14:
148
- # mean_sentiment = daily_sentiment.mean()
149
- # padding = [mean_sentiment] * (14 - len(daily_sentiment))
150
- # daily_sentiment = np.concatenate([daily_sentiment.values, padding])
151
- # daily_sentiment = pd.Series(daily_sentiment)
152
-
153
- # sentiment_scores_np = daily_sentiment.values.reshape(1, -1)
154
- # prediction = sentiment_model.predict(sentiment_scores_np)
155
- # pred = (prediction[0])
156
-
157
- # font_path = "AfacadFlux-VariableFont_slnt,wght[1].ttf"
158
- # custom_font = fm.FontProperties(fname=font_path)
159
-
160
- # today = datetime.date.today()
161
- # days = [today + datetime.timedelta(days=i) for i in range(7)]
162
- # days_str = [day.strftime('%a %m/%d') for day in days]
163
-
164
- # xnew = np.linspace(0, 6, 300)
165
- # spline = make_interp_spline(np.arange(7), pred, k=3)
166
- # pred_smooth = spline(xnew)
167
-
168
- # fig, ax = plt.subplots(figsize=(12, 7))
169
- # ax.fill_between(xnew, pred_smooth, color='#244B48', alpha=0.4)
170
- # ax.plot(xnew, pred_smooth, color='#244B48', lw=3, label='Forecast')
171
- # ax.scatter(np.arange(7), pred, color='#244B48', s=100, zorder=5)
172
-
173
- # ax.set_title("7-Day Political Sentiment Forecast", fontsize=22, fontweight='bold', pad=20, fontproperties=custom_font)
174
- # ax.set_xlabel("Day", fontsize=16, fontproperties=custom_font)
175
- # ax.set_ylabel("Negative Sentiment (0-1)", fontsize=16, fontproperties=custom_font)
176
- # ax.set_xticks(np.arange(7))
177
- # ax.set_xticklabels(days_str, fontsize=14, fontproperties=custom_font)
178
-
179
- # # Continue from previous app.py code
180
-
181
- # ax.set_yticklabels([f"{tick:.2f}" for tick in ax.get_yticks()], fontsize=14, fontproperties=custom_font)
182
-
183
- # ax.spines['top'].set_visible(False)
184
- # ax.spines['right'].set_visible(False)
185
- # ax.spines['left'].set_visible(False)
186
- # ax.spines['bottom'].set_visible(False)
187
-
188
- # ax.legend(fontsize=14, loc='upper right', prop=custom_font)
189
- # plt.tight_layout()
190
-
191
- # import io
192
- # import base64
193
- # buffer = io.BytesIO()
194
- # plt.savefig(buffer, format='png')
195
- # buffer.seek(0)
196
- # prediction_plot_base64 = base64.b64encode(buffer.getvalue()).decode()
197
- # plt.close(fig)
198
-
199
- # def display_plot():
200
- # """Displays the plot in the Gradio interface."""
201
- # global prediction_plot_base64
202
- # if prediction_plot_base64:
203
- # return f'<img src="data:image/png;base64,{prediction_plot_base64}" alt="Prediction Plot">'
204
- # else:
205
- # return "Processing data..."
206
-
207
- # # Initial data processing
208
- # process_data()
209
-
210
- # # Schedule daily refresh
211
- # def run_daily():
212
- # process_data()
213
- # print("Data refreshed at:", datetime.datetime.now())
214
-
215
- # schedule.every().day.at("00:00").do(run_daily)
216
-
217
- # def run_schedule():
218
- # while True:
219
- # schedule.run_pending()
220
- # time.sleep(60)
221
-
222
- # import threading
223
- # thread = threading.Thread(target=run_schedule)
224
- # thread.daemon = True
225
- # thread.start()
226
-
227
- # # Gradio Interface
228
- # iface = gr.Interface(fn=display_plot, inputs=None, outputs="html")
229
- # iface.launch()
230
-
231
-
232
-
233
  import gradio as gr
234
  import schedule
235
  import time
@@ -246,8 +14,6 @@ import matplotlib.pyplot as plt
246
  from scipy.interpolate import make_interp_spline
247
  from transformers import AutoTokenizer
248
  import matplotlib.font_manager as fm
249
- import io
250
- import base64
251
 
252
  # Load models and data (your existing code)
253
  autovectorizer = joblib.load('AutoVectorizer.pkl')
@@ -256,6 +22,7 @@ MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
256
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
257
 
258
  class ScorePredictor(nn.Module):
 
259
  def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
260
  super(ScorePredictor, self).__init__()
261
  self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
@@ -273,22 +40,27 @@ class ScorePredictor(nn.Module):
273
  score_model = ScorePredictor(tokenizer.vocab_size)
274
  score_model.load_state_dict(torch.load("score_predictor.pth"))
275
  score_model.eval()
 
276
  sentiment_model = joblib.load('sentiment_forecast_model.pkl')
277
 
278
  reddit = praw.Reddit(
279
  client_id="PH99oWZjM43GimMtYigFvA",
280
  client_secret="3tJsXQKEtFFYInxzLEDqRZ0s_w5z0g",
281
  user_agent='MyAPI/0.0.1',
282
- check_for_async=False
283
- )
284
 
285
  subreddits = [
286
  "florida",
287
  "ohio",
 
 
 
 
 
288
  ]
289
 
290
  # Global variables for data
291
- prediction_plot_base64 = None # Initialize to None
292
 
293
  def process_data():
294
  """Fetches data, performs analysis, and generates the plot."""
@@ -297,12 +69,14 @@ def process_data():
297
  start_date = end_date - datetime.timedelta(days=14)
298
 
299
  def fetch_all_recent_posts(subreddit_name, start_time, limit=500):
 
300
  subreddit = reddit.subreddit(subreddit_name)
301
  posts = []
 
302
  try:
303
- for post in subreddit.top(limit=limit):
304
  post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
305
- if post_time >= start_time:
306
  posts.append({
307
  "subreddit": subreddit_name,
308
  "timestamp": post.created_utc,
@@ -311,9 +85,11 @@ def process_data():
311
  })
312
  except Exception as e:
313
  print(f"Error fetching posts from r/{subreddit_name}: {e}")
 
314
  return posts
315
 
316
  def preprocess_text(text):
 
317
  text = text.lower()
318
  text = re.sub(r'http\S+', '', text)
319
  text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
@@ -321,9 +97,11 @@ def process_data():
321
  return text
322
 
323
  def predict_score(text):
 
324
  if not text:
325
  return 0.0
326
  max_length = 512
 
327
  encoded_input = tokenizer(
328
  text.split(),
329
  return_tensors='pt',
@@ -331,6 +109,7 @@ def process_data():
331
  truncation=True,
332
  max_length=max_length
333
  )
 
334
  input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
335
  with torch.no_grad():
336
  score = score_model(input_ids, attention_mask)[0].item()
@@ -396,9 +175,56 @@ def process_data():
396
  ax.set_ylabel("Negative Sentiment (0-1)", fontsize=16, fontproperties=custom_font)
397
  ax.set_xticks(np.arange(7))
398
  ax.set_xticklabels(days_str, fontsize=14, fontproperties=custom_font)
 
 
 
399
  ax.set_yticklabels([f"{tick:.2f}" for tick in ax.get_yticks()], fontsize=14, fontproperties=custom_font)
400
 
401
  ax.spines['top'].set_visible(False)
402
  ax.spines['right'].set_visible(False)
403
  ax.spines['left'].set_visible(False)
404
- ax.spines['bottom'].set_visible
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import schedule
3
  import time
 
14
  from scipy.interpolate import make_interp_spline
15
  from transformers import AutoTokenizer
16
  import matplotlib.font_manager as fm
 
 
17
 
18
  # Load models and data (your existing code)
19
  autovectorizer = joblib.load('AutoVectorizer.pkl')
 
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
23
 
24
  class ScorePredictor(nn.Module):
25
+ # ... (Your ScorePredictor class)
26
  def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
27
  super(ScorePredictor, self).__init__()
28
  self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
 
40
  score_model = ScorePredictor(tokenizer.vocab_size)
41
  score_model.load_state_dict(torch.load("score_predictor.pth"))
42
  score_model.eval()
43
+
44
  sentiment_model = joblib.load('sentiment_forecast_model.pkl')
45
 
46
  reddit = praw.Reddit(
47
  client_id="PH99oWZjM43GimMtYigFvA",
48
  client_secret="3tJsXQKEtFFYInxzLEDqRZ0s_w5z0g",
49
  user_agent='MyAPI/0.0.1',
50
+ check_for_async=False)
 
51
 
52
  subreddits = [
53
  "florida",
54
  "ohio",
55
+ "libertarian",
56
+ "southpark",
57
+ "walkaway",
58
+ "truechristian",
59
+ "conservatives"
60
  ]
61
 
62
  # Global variables for data
63
+ global prediction_plot_base64
64
 
65
  def process_data():
66
  """Fetches data, performs analysis, and generates the plot."""
 
69
  start_date = end_date - datetime.timedelta(days=14)
70
 
71
  def fetch_all_recent_posts(subreddit_name, start_time, limit=500):
72
+ # ... (Your fetch_all_recent_posts function)
73
  subreddit = reddit.subreddit(subreddit_name)
74
  posts = []
75
+
76
  try:
77
+ for post in subreddit.top(limit=limit): # Fetch recent posts
78
  post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
79
+ if post_time >= start_time: # Filter only within last 14 days
80
  posts.append({
81
  "subreddit": subreddit_name,
82
  "timestamp": post.created_utc,
 
85
  })
86
  except Exception as e:
87
  print(f"Error fetching posts from r/{subreddit_name}: {e}")
88
+
89
  return posts
90
 
91
  def preprocess_text(text):
92
+ # ... (Your preprocess_text function)
93
  text = text.lower()
94
  text = re.sub(r'http\S+', '', text)
95
  text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
 
97
  return text
98
 
99
  def predict_score(text):
100
+ # ... (Your predict_score function)
101
  if not text:
102
  return 0.0
103
  max_length = 512
104
+
105
  encoded_input = tokenizer(
106
  text.split(),
107
  return_tensors='pt',
 
109
  truncation=True,
110
  max_length=max_length
111
  )
112
+
113
  input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
114
  with torch.no_grad():
115
  score = score_model(input_ids, attention_mask)[0].item()
 
175
  ax.set_ylabel("Negative Sentiment (0-1)", fontsize=16, fontproperties=custom_font)
176
  ax.set_xticks(np.arange(7))
177
  ax.set_xticklabels(days_str, fontsize=14, fontproperties=custom_font)
178
+
179
+ # Continue from previous app.py code
180
+
181
  ax.set_yticklabels([f"{tick:.2f}" for tick in ax.get_yticks()], fontsize=14, fontproperties=custom_font)
182
 
183
  ax.spines['top'].set_visible(False)
184
  ax.spines['right'].set_visible(False)
185
  ax.spines['left'].set_visible(False)
186
+ ax.spines['bottom'].set_visible(False)
187
+
188
+ ax.legend(fontsize=14, loc='upper right', prop=custom_font)
189
+ plt.tight_layout()
190
+
191
+ import io
192
+ import base64
193
+ buffer = io.BytesIO()
194
+ plt.savefig(buffer, format='png')
195
+ buffer.seek(0)
196
+ prediction_plot_base64 = base64.b64encode(buffer.getvalue()).decode()
197
+ plt.close(fig)
198
+
199
+ def display_plot():
200
+ """Displays the plot in the Gradio interface."""
201
+ global prediction_plot_base64
202
+ if prediction_plot_base64:
203
+ return f'<img src="data:image/png;base64,{prediction_plot_base64}" alt="Prediction Plot">'
204
+ else:
205
+ return "Processing data..."
206
+
207
+ # Initial data processing
208
+ process_data()
209
+
210
+ # Schedule daily refresh
211
+ def run_daily():
212
+ process_data()
213
+ print("Data refreshed at:", datetime.datetime.now())
214
+
215
+ schedule.every().day.at("00:00").do(run_daily)
216
+
217
+ def run_schedule():
218
+ while True:
219
+ schedule.run_pending()
220
+ time.sleep(60)
221
+
222
+ import threading
223
+ thread = threading.Thread(target=run_schedule)
224
+ thread.daemon = True
225
+ thread.start()
226
+
227
+ # Gradio Interface
228
+ iface = gr.Interface(fn=display_plot, inputs=None, outputs="html")
229
+ iface.launch()
230
+