Rowan Martnishn commited on
Commit
662cb6b
·
verified ·
1 Parent(s): c1ea7de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -65
app.py CHANGED
@@ -1,3 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import schedule
3
  import time
@@ -14,6 +246,8 @@ import matplotlib.pyplot as plt
14
  from scipy.interpolate import make_interp_spline
15
  from transformers import AutoTokenizer
16
  import matplotlib.font_manager as fm
 
 
17
 
18
  # Load models and data (your existing code)
19
  autovectorizer = joblib.load('AutoVectorizer.pkl')
@@ -22,7 +256,6 @@ MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
23
 
24
  class ScorePredictor(nn.Module):
25
- # ... (Your ScorePredictor class)
26
  def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
27
  super(ScorePredictor, self).__init__()
28
  self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
@@ -40,27 +273,22 @@ class ScorePredictor(nn.Module):
40
  score_model = ScorePredictor(tokenizer.vocab_size)
41
  score_model.load_state_dict(torch.load("score_predictor.pth"))
42
  score_model.eval()
43
-
44
  sentiment_model = joblib.load('sentiment_forecast_model.pkl')
45
 
46
  reddit = praw.Reddit(
47
  client_id="PH99oWZjM43GimMtYigFvA",
48
  client_secret="3tJsXQKEtFFYInxzLEDqRZ0s_w5z0g",
49
  user_agent='MyAPI/0.0.1',
50
- check_for_async=False)
 
51
 
52
  subreddits = [
53
  "florida",
54
  "ohio",
55
- # "libertarian",
56
- # "southpark",
57
- # "walkaway",
58
- # "truechristian",
59
- # "conservatives"
60
  ]
61
 
62
  # Global variables for data
63
- global prediction_plot_base64
64
 
65
  def process_data():
66
  """Fetches data, performs analysis, and generates the plot."""
@@ -69,14 +297,12 @@ def process_data():
69
  start_date = end_date - datetime.timedelta(days=14)
70
 
71
  def fetch_all_recent_posts(subreddit_name, start_time, limit=500):
72
- # ... (Your fetch_all_recent_posts function)
73
  subreddit = reddit.subreddit(subreddit_name)
74
  posts = []
75
-
76
  try:
77
- for post in subreddit.top(limit=limit): # Fetch recent posts
78
  post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
79
- if post_time >= start_time: # Filter only within last 14 days
80
  posts.append({
81
  "subreddit": subreddit_name,
82
  "timestamp": post.created_utc,
@@ -85,11 +311,9 @@ def process_data():
85
  })
86
  except Exception as e:
87
  print(f"Error fetching posts from r/{subreddit_name}: {e}")
88
-
89
  return posts
90
 
91
  def preprocess_text(text):
92
- # ... (Your preprocess_text function)
93
  text = text.lower()
94
  text = re.sub(r'http\S+', '', text)
95
  text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
@@ -97,11 +321,9 @@ def process_data():
97
  return text
98
 
99
  def predict_score(text):
100
- # ... (Your predict_score function)
101
  if not text:
102
  return 0.0
103
  max_length = 512
104
-
105
  encoded_input = tokenizer(
106
  text.split(),
107
  return_tensors='pt',
@@ -109,7 +331,6 @@ def process_data():
109
  truncation=True,
110
  max_length=max_length
111
  )
112
-
113
  input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
114
  with torch.no_grad():
115
  score = score_model(input_ids, attention_mask)[0].item()
@@ -175,55 +396,9 @@ def process_data():
175
  ax.set_ylabel("Negative Sentiment (0-1)", fontsize=16, fontproperties=custom_font)
176
  ax.set_xticks(np.arange(7))
177
  ax.set_xticklabels(days_str, fontsize=14, fontproperties=custom_font)
178
-
179
- # Continue from previous app.py code
180
-
181
  ax.set_yticklabels([f"{tick:.2f}" for tick in ax.get_yticks()], fontsize=14, fontproperties=custom_font)
182
 
183
  ax.spines['top'].set_visible(False)
184
  ax.spines['right'].set_visible(False)
185
  ax.spines['left'].set_visible(False)
186
- ax.spines['bottom'].set_visible(False)
187
-
188
- ax.legend(fontsize=14, loc='upper right', prop=custom_font)
189
- plt.tight_layout()
190
-
191
- import io
192
- import base64
193
- buffer = io.BytesIO()
194
- plt.savefig(buffer, format='png')
195
- buffer.seek(0)
196
- prediction_plot_base64 = base64.b64encode(buffer.getvalue()).decode()
197
- plt.close(fig)
198
-
199
- def display_plot():
200
- """Displays the plot in the Gradio interface."""
201
- global prediction_plot_base64
202
- if prediction_plot_base64:
203
- return f'<img src="data:image/png;base64,{prediction_plot_base64}" alt="Prediction Plot">'
204
- else:
205
- return "Processing data..."
206
-
207
- # Initial data processing
208
- process_data()
209
-
210
- # Schedule daily refresh
211
- def run_daily():
212
- process_data()
213
- print("Data refreshed at:", datetime.datetime.now())
214
-
215
- schedule.every().day.at("00:00").do(run_daily)
216
-
217
- def run_schedule():
218
- while True:
219
- schedule.run_pending()
220
- time.sleep(60)
221
-
222
- import threading
223
- thread = threading.Thread(target=run_schedule)
224
- thread.daemon = True
225
- thread.start()
226
-
227
- # Gradio Interface
228
- iface = gr.Interface(fn=display_plot, inputs=None, outputs="html")
229
- iface.launch()
 
1
+ # import gradio as gr
2
+ # import schedule
3
+ # import time
4
+ # import datetime
5
+ # import praw
6
+ # import joblib
7
+ # import torch
8
+ # import scipy.sparse as sp
9
+ # import torch.nn as nn
10
+ # import pandas as pd
11
+ # import re
12
+ # import numpy as np
13
+ # import matplotlib.pyplot as plt
14
+ # from scipy.interpolate import make_interp_spline
15
+ # from transformers import AutoTokenizer
16
+ # import matplotlib.font_manager as fm
17
+
18
+ # # Load models and data (your existing code)
19
+ # autovectorizer = joblib.load('AutoVectorizer.pkl')
20
+ # autoclassifier = joblib.load('AutoClassifier.pkl')
21
+ # MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
22
+ # tokenizer = AutoTokenizer.from_pretrained(MODEL)
23
+
24
+ # class ScorePredictor(nn.Module):
25
+ # # ... (Your ScorePredictor class)
26
+ # def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
27
+ # super(ScorePredictor, self).__init__()
28
+ # self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
29
+ # self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
30
+ # self.fc = nn.Linear(hidden_dim, output_dim)
31
+ # self.sigmoid = nn.Sigmoid()
32
+
33
+ # def forward(self, input_ids, attention_mask):
34
+ # embedded = self.embedding(input_ids)
35
+ # lstm_out, _ = self.lstm(embedded)
36
+ # final_hidden_state = lstm_out[:, -1, :]
37
+ # output = self.fc(final_hidden_state)
38
+ # return self.sigmoid(output)
39
+
40
+ # score_model = ScorePredictor(tokenizer.vocab_size)
41
+ # score_model.load_state_dict(torch.load("score_predictor.pth"))
42
+ # score_model.eval()
43
+
44
+ # sentiment_model = joblib.load('sentiment_forecast_model.pkl')
45
+
46
+ # reddit = praw.Reddit(
47
+ # client_id="PH99oWZjM43GimMtYigFvA",
48
+ # client_secret="3tJsXQKEtFFYInxzLEDqRZ0s_w5z0g",
49
+ # user_agent='MyAPI/0.0.1',
50
+ # check_for_async=False)
51
+
52
+ # subreddits = [
53
+ # "florida",
54
+ # "ohio",
55
+ # # "libertarian",
56
+ # # "southpark",
57
+ # # "walkaway",
58
+ # # "truechristian",
59
+ # # "conservatives"
60
+ # ]
61
+
62
+ # # Global variables for data
63
+ # global prediction_plot_base64
64
+
65
+ # def process_data():
66
+ # """Fetches data, performs analysis, and generates the plot."""
67
+ # global prediction_plot_base64
68
+ # end_date = datetime.datetime.utcnow()
69
+ # start_date = end_date - datetime.timedelta(days=14)
70
+
71
+ # def fetch_all_recent_posts(subreddit_name, start_time, limit=500):
72
+ # # ... (Your fetch_all_recent_posts function)
73
+ # subreddit = reddit.subreddit(subreddit_name)
74
+ # posts = []
75
+
76
+ # try:
77
+ # for post in subreddit.top(limit=limit): # Fetch recent posts
78
+ # post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
79
+ # if post_time >= start_time: # Filter only within last 14 days
80
+ # posts.append({
81
+ # "subreddit": subreddit_name,
82
+ # "timestamp": post.created_utc,
83
+ # "date": post_time.strftime('%Y-%m-%d %H:%M:%S'),
84
+ # "post_text": post.title
85
+ # })
86
+ # except Exception as e:
87
+ # print(f"Error fetching posts from r/{subreddit_name}: {e}")
88
+
89
+ # return posts
90
+
91
+ # def preprocess_text(text):
92
+ # # ... (Your preprocess_text function)
93
+ # text = text.lower()
94
+ # text = re.sub(r'http\S+', '', text)
95
+ # text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
96
+ # text = re.sub(r'\s+', ' ', text).strip()
97
+ # return text
98
+
99
+ # def predict_score(text):
100
+ # # ... (Your predict_score function)
101
+ # if not text:
102
+ # return 0.0
103
+ # max_length = 512
104
+
105
+ # encoded_input = tokenizer(
106
+ # text.split(),
107
+ # return_tensors='pt',
108
+ # padding=True,
109
+ # truncation=True,
110
+ # max_length=max_length
111
+ # )
112
+
113
+ # input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
114
+ # with torch.no_grad():
115
+ # score = score_model(input_ids, attention_mask)[0].item()
116
+ # return score
117
+
118
+ # start_time = datetime.datetime.utcnow() - datetime.timedelta(days=14)
119
+ # all_posts = []
120
+ # for sub in subreddits:
121
+ # print(f"Fetching posts from r/{sub}")
122
+ # posts = fetch_all_recent_posts(sub, start_time)
123
+ # all_posts.extend(posts)
124
+ # print(f"Fetched {len(posts)} posts from r/{sub}")
125
+
126
+ # filtered_posts = []
127
+ # for post in all_posts:
128
+ # vector = autovectorizer.transform([post['post_text']])
129
+ # prediction = autoclassifier.predict(vector)
130
+ # if prediction[0] == 1:
131
+ # filtered_posts.append(post)
132
+ # all_posts = filtered_posts
133
+
134
+ # df = pd.DataFrame(all_posts)
135
+ # df['date'] = pd.to_datetime(df['date'])
136
+ # df['date_only'] = df['date'].dt.date
137
+ # df = df.sort_values(by=['date_only'])
138
+ # df['sentiment_score'] = df['post_text'].apply(predict_score)
139
+
140
+ # last_14_dates = df['date_only'].unique()
141
+ # num_dates = min(len(last_14_dates), 14)
142
+ # last_14_dates = sorted(last_14_dates, reverse=True)[:num_dates]
143
+
144
+ # filtered_df = df[df['date_only'].isin(last_14_dates)]
145
+ # daily_sentiment = filtered_df.groupby('date_only')['sentiment_score'].median()
146
+
147
+ # if len(daily_sentiment) < 14:
148
+ # mean_sentiment = daily_sentiment.mean()
149
+ # padding = [mean_sentiment] * (14 - len(daily_sentiment))
150
+ # daily_sentiment = np.concatenate([daily_sentiment.values, padding])
151
+ # daily_sentiment = pd.Series(daily_sentiment)
152
+
153
+ # sentiment_scores_np = daily_sentiment.values.reshape(1, -1)
154
+ # prediction = sentiment_model.predict(sentiment_scores_np)
155
+ # pred = (prediction[0])
156
+
157
+ # font_path = "AfacadFlux-VariableFont_slnt,wght[1].ttf"
158
+ # custom_font = fm.FontProperties(fname=font_path)
159
+
160
+ # today = datetime.date.today()
161
+ # days = [today + datetime.timedelta(days=i) for i in range(7)]
162
+ # days_str = [day.strftime('%a %m/%d') for day in days]
163
+
164
+ # xnew = np.linspace(0, 6, 300)
165
+ # spline = make_interp_spline(np.arange(7), pred, k=3)
166
+ # pred_smooth = spline(xnew)
167
+
168
+ # fig, ax = plt.subplots(figsize=(12, 7))
169
+ # ax.fill_between(xnew, pred_smooth, color='#244B48', alpha=0.4)
170
+ # ax.plot(xnew, pred_smooth, color='#244B48', lw=3, label='Forecast')
171
+ # ax.scatter(np.arange(7), pred, color='#244B48', s=100, zorder=5)
172
+
173
+ # ax.set_title("7-Day Political Sentiment Forecast", fontsize=22, fontweight='bold', pad=20, fontproperties=custom_font)
174
+ # ax.set_xlabel("Day", fontsize=16, fontproperties=custom_font)
175
+ # ax.set_ylabel("Negative Sentiment (0-1)", fontsize=16, fontproperties=custom_font)
176
+ # ax.set_xticks(np.arange(7))
177
+ # ax.set_xticklabels(days_str, fontsize=14, fontproperties=custom_font)
178
+
179
+ # # Continue from previous app.py code
180
+
181
+ # ax.set_yticklabels([f"{tick:.2f}" for tick in ax.get_yticks()], fontsize=14, fontproperties=custom_font)
182
+
183
+ # ax.spines['top'].set_visible(False)
184
+ # ax.spines['right'].set_visible(False)
185
+ # ax.spines['left'].set_visible(False)
186
+ # ax.spines['bottom'].set_visible(False)
187
+
188
+ # ax.legend(fontsize=14, loc='upper right', prop=custom_font)
189
+ # plt.tight_layout()
190
+
191
+ # import io
192
+ # import base64
193
+ # buffer = io.BytesIO()
194
+ # plt.savefig(buffer, format='png')
195
+ # buffer.seek(0)
196
+ # prediction_plot_base64 = base64.b64encode(buffer.getvalue()).decode()
197
+ # plt.close(fig)
198
+
199
+ # def display_plot():
200
+ # """Displays the plot in the Gradio interface."""
201
+ # global prediction_plot_base64
202
+ # if prediction_plot_base64:
203
+ # return f'<img src="data:image/png;base64,{prediction_plot_base64}" alt="Prediction Plot">'
204
+ # else:
205
+ # return "Processing data..."
206
+
207
+ # # Initial data processing
208
+ # process_data()
209
+
210
+ # # Schedule daily refresh
211
+ # def run_daily():
212
+ # process_data()
213
+ # print("Data refreshed at:", datetime.datetime.now())
214
+
215
+ # schedule.every().day.at("00:00").do(run_daily)
216
+
217
+ # def run_schedule():
218
+ # while True:
219
+ # schedule.run_pending()
220
+ # time.sleep(60)
221
+
222
+ # import threading
223
+ # thread = threading.Thread(target=run_schedule)
224
+ # thread.daemon = True
225
+ # thread.start()
226
+
227
+ # # Gradio Interface
228
+ # iface = gr.Interface(fn=display_plot, inputs=None, outputs="html")
229
+ # iface.launch()
230
+
231
+
232
+
233
  import gradio as gr
234
  import schedule
235
  import time
 
246
  from scipy.interpolate import make_interp_spline
247
  from transformers import AutoTokenizer
248
  import matplotlib.font_manager as fm
249
+ import io
250
+ import base64
251
 
252
  # Load models and data (your existing code)
253
  autovectorizer = joblib.load('AutoVectorizer.pkl')
 
256
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
257
 
258
  class ScorePredictor(nn.Module):
 
259
  def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
260
  super(ScorePredictor, self).__init__()
261
  self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
 
273
  score_model = ScorePredictor(tokenizer.vocab_size)
274
  score_model.load_state_dict(torch.load("score_predictor.pth"))
275
  score_model.eval()
 
276
  sentiment_model = joblib.load('sentiment_forecast_model.pkl')
277
 
278
  reddit = praw.Reddit(
279
  client_id="PH99oWZjM43GimMtYigFvA",
280
  client_secret="3tJsXQKEtFFYInxzLEDqRZ0s_w5z0g",
281
  user_agent='MyAPI/0.0.1',
282
+ check_for_async=False
283
+ )
284
 
285
  subreddits = [
286
  "florida",
287
  "ohio",
 
 
 
 
 
288
  ]
289
 
290
  # Global variables for data
291
+ prediction_plot_base64 = None # Initialize to None
292
 
293
  def process_data():
294
  """Fetches data, performs analysis, and generates the plot."""
 
297
  start_date = end_date - datetime.timedelta(days=14)
298
 
299
  def fetch_all_recent_posts(subreddit_name, start_time, limit=500):
 
300
  subreddit = reddit.subreddit(subreddit_name)
301
  posts = []
 
302
  try:
303
+ for post in subreddit.top(limit=limit):
304
  post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
305
+ if post_time >= start_time:
306
  posts.append({
307
  "subreddit": subreddit_name,
308
  "timestamp": post.created_utc,
 
311
  })
312
  except Exception as e:
313
  print(f"Error fetching posts from r/{subreddit_name}: {e}")
 
314
  return posts
315
 
316
  def preprocess_text(text):
 
317
  text = text.lower()
318
  text = re.sub(r'http\S+', '', text)
319
  text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
 
321
  return text
322
 
323
  def predict_score(text):
 
324
  if not text:
325
  return 0.0
326
  max_length = 512
 
327
  encoded_input = tokenizer(
328
  text.split(),
329
  return_tensors='pt',
 
331
  truncation=True,
332
  max_length=max_length
333
  )
 
334
  input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
335
  with torch.no_grad():
336
  score = score_model(input_ids, attention_mask)[0].item()
 
396
  ax.set_ylabel("Negative Sentiment (0-1)", fontsize=16, fontproperties=custom_font)
397
  ax.set_xticks(np.arange(7))
398
  ax.set_xticklabels(days_str, fontsize=14, fontproperties=custom_font)
 
 
 
399
  ax.set_yticklabels([f"{tick:.2f}" for tick in ax.get_yticks()], fontsize=14, fontproperties=custom_font)
400
 
401
  ax.spines['top'].set_visible(False)
402
  ax.spines['right'].set_visible(False)
403
  ax.spines['left'].set_visible(False)
404
+ ax.spines['bottom'].set_visible