sentivity commited on
Commit
1831c48
·
verified ·
1 Parent(s): 527812e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -248
app.py CHANGED
@@ -1,41 +1,21 @@
1
  import gradio as gr
2
- import schedule
3
- import time
4
- import datetime
5
- import praw
6
- import joblib
7
  import torch
8
- import scipy.sparse as sp
9
  import torch.nn as nn
10
- import pandas as pd
11
  import re
12
- import numpy as np
13
- import matplotlib.pyplot as plt
14
- from scipy.interpolate import make_interp_spline
15
  from transformers import AutoTokenizer
16
- import matplotlib.font_manager as fm
17
- import pytz
18
 
19
- # Load models and data (your existing code)
20
- autovectorizer = joblib.load('AutoVectorizer.pkl')
21
- autoclassifier = joblib.load('AutoClassifier.pkl')
22
  MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
24
 
25
-
26
-
27
-
28
-
29
-
30
  class ScorePredictor(nn.Module):
31
- # ... (Your ScorePredictor class)
32
  def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
33
  super(ScorePredictor, self).__init__()
34
  self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
35
  self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
36
  self.fc = nn.Linear(hidden_dim, output_dim)
37
  self.sigmoid = nn.Sigmoid()
38
-
39
  def forward(self, input_ids, attention_mask):
40
  embedded = self.embedding(input_ids)
41
  lstm_out, _ = self.lstm(embedded)
@@ -47,230 +27,92 @@ score_model = ScorePredictor(tokenizer.vocab_size)
47
  score_model.load_state_dict(torch.load("score_predictor.pth"))
48
  score_model.eval()
49
 
50
- sentiment_model = joblib.load('sentiment_forecast_model.pkl')
51
-
52
- reddit = praw.Reddit(
53
- client_id="PH99oWZjM43GimMtYigFvA",
54
- client_secret="3tJsXQKEtFFYInxzLEDqRZ0s_w5z0g",
55
- user_agent='MyAPI/0.0.1',
56
- check_for_async=False)
57
-
58
- subreddits = [
59
- "centrist",
60
- "libertarian",
61
- "southpark",
62
- "truechristian",
63
- "conservatives"
64
- ]
65
-
66
- # Global variables for data
67
- global prediction_plot_base64
68
- prediction_plot_base64 = None
69
-
70
- def process_data():
71
- """Fetches data, performs analysis, and generates the plot."""
72
- global prediction_plot_base64
73
- end_date = datetime.datetime.utcnow()
74
- start_date = end_date - datetime.timedelta(days=14)
75
-
76
- def fetch_all_recent_posts(subreddit_name, start_time, limit=500):
77
- # ... (Your fetch_all_recent_posts function)
78
- subreddit = reddit.subreddit(subreddit_name)
79
- posts = []
80
-
81
- try:
82
- for post in subreddit.top(limit=limit): # Fetch recent posts
83
- post_time = datetime.datetime.utcfromtimestamp(post.created_utc)
84
- if post_time >= start_time: # Filter only within last 14 days
85
- posts.append({
86
- "subreddit": subreddit_name,
87
- "timestamp": post.created_utc,
88
- "date": post_time.strftime('%Y-%m-%d %H:%M:%S'),
89
- "post_text": post.title
90
- })
91
- except Exception as e:
92
- print(f"Error fetching posts from r/{subreddit_name}: {e}")
93
-
94
- return posts
95
-
96
- def preprocess_text(text):
97
- # ... (Your preprocess_text function)
98
- text = text.lower()
99
- text = re.sub(r'http\S+', '', text)
100
- text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
101
- text = re.sub(r'\s+', ' ', text).strip()
102
- return text
103
-
104
- def predict_score(text):
105
- # ... (Your predict_score function)
106
- if not text:
107
- return 0.0
108
- max_length = 512
109
-
110
- encoded_input = tokenizer(
111
- text.split(),
112
- return_tensors='pt',
113
- padding=True,
114
- truncation=True,
115
- max_length=max_length
116
- )
117
-
118
- input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
119
- with torch.no_grad():
120
- score = score_model(input_ids, attention_mask)[0].item()
121
- return score
122
-
123
- start_time = datetime.datetime.utcnow() - datetime.timedelta(days=14)
124
- all_posts = []
125
- for sub in subreddits:
126
- print(f"Fetching posts from r/{sub}")
127
- posts = fetch_all_recent_posts(sub, start_time)
128
- all_posts.extend(posts)
129
- print(f"Fetched {len(posts)} posts from r/{sub}")
130
-
131
- filtered_posts = []
132
- for post in all_posts:
133
- vector = autovectorizer.transform([post['post_text']])
134
- prediction = autoclassifier.predict(vector)
135
- if prediction[0] == 1:
136
- filtered_posts.append(post)
137
- all_posts = filtered_posts
138
-
139
- df = pd.DataFrame(all_posts)
140
- df['date'] = pd.to_datetime(df['date'])
141
- df['date_only'] = df['date'].dt.date
142
- df = df.sort_values(by=['date_only'])
143
- df['sentiment_score'] = df['post_text'].apply(predict_score)
144
-
145
- last_14_dates = df['date_only'].unique()
146
- num_dates = min(len(last_14_dates), 14)
147
- last_14_dates = sorted(last_14_dates, reverse=True)[:num_dates]
148
-
149
- filtered_df = df[df['date_only'].isin(last_14_dates)]
150
- daily_sentiment = filtered_df.groupby('date_only')['sentiment_score'].median()
151
-
152
- if len(daily_sentiment) < 14:
153
- mean_sentiment = daily_sentiment.mean()
154
- padding = [mean_sentiment] * (14 - len(daily_sentiment))
155
- daily_sentiment = np.concatenate([daily_sentiment.values, padding])
156
- daily_sentiment = pd.Series(daily_sentiment)
157
-
158
- sentiment_scores_np = daily_sentiment.values.reshape(1, -1)
159
- prediction = sentiment_model.predict(sentiment_scores_np)
160
- pred = (prediction[0])
161
-
162
- font_path = "AfacadFlux-VariableFont_slnt,wght[1].ttf"
163
- custom_font = fm.FontProperties(fname=font_path)
164
-
165
- today = datetime.date.today()
166
- days = [today + datetime.timedelta(days=i) for i in range(7)]
167
- days_str = [day.strftime('%a %m/%d') for day in days]
168
-
169
- xnew = np.linspace(0, 6, 300)
170
- spline = make_interp_spline(np.arange(7), pred, k=3)
171
- pred_smooth = spline(xnew)
172
-
173
- fig, ax = plt.subplots(figsize=(12, 7))
174
- ax.fill_between(xnew, pred_smooth, color='#244B48', alpha=0.4)
175
- ax.plot(xnew, pred_smooth, color='#244B48', lw=3, label='Forecast')
176
- ax.scatter(np.arange(7), pred, color='#244B48', s=100, zorder=5)
177
-
178
- est_timezone = pytz.timezone('America/New_York')
179
- est_time = datetime.datetime.now(est_timezone)
180
- ax.set_title(f"7-Day Political Sentiment Forecast - {est_time.strftime('%Y-%m-%d %H:%M:%S EST')}",
181
- fontsize=70, fontweight='bold', pad=20, fontproperties=custom_font)
182
- # ax.set_title(f"7-Day Political Sentiment Forecast - {datetime.datetime.now()}", fontsize=22, fontweight='bold', pad=20, fontproperties=custom_font)
183
- ax.set_xlabel("Day", fontsize=16, fontproperties=custom_font)
184
- ax.set_ylabel("Negative Sentiment (0-1)", fontsize=16, fontproperties=custom_font)
185
- ax.set_xticks(np.arange(7))
186
- ax.set_xticklabels(days_str, fontsize=14, fontproperties=custom_font)
187
-
188
- # Continue from previous app.py code
189
-
190
- ax.set_yticklabels([f"{tick:.2f}" for tick in ax.get_yticks()], fontsize=14, fontproperties=custom_font)
191
-
192
- ax.spines['top'].set_visible(False)
193
- ax.spines['right'].set_visible(False)
194
- ax.spines['left'].set_visible(False)
195
- ax.spines['bottom'].set_visible(False)
196
-
197
- ax.legend(fontsize=14, loc='upper right', prop=custom_font)
198
- plt.tight_layout()
199
-
200
- import io
201
- import base64
202
- buffer = io.BytesIO()
203
- plt.savefig(buffer, format='png')
204
- buffer.seek(0)
205
- prediction_plot_base64 = base64.b64encode(buffer.getvalue()).decode()
206
- plt.close(fig)
207
-
208
- def display_plot():
209
- """Displays the plot in the Gradio interface."""
210
- global prediction_plot_base64
211
- if prediction_plot_base64:
212
- return f'<img src="data:image/png;base64,{prediction_plot_base64}" alt="Prediction Plot">'
213
- else:
214
- return "Processing data..."
215
-
216
-
217
-
218
- process_data()
219
-
220
- # Schedule daily refresh
221
- def run_daily():
222
- process_data()
223
- print("Data refreshed at:", datetime.datetime.now())
224
-
225
- #schedule.every().day.at("00:00").do(run_daily)
226
- schedule.every(10).seconds.do(run_daily)
227
-
228
- def run_schedule():
229
- while True:
230
- schedule.run_pending()
231
- #time.sleep(60)
232
-
233
- import threading
234
- thread = threading.Thread(target=run_schedule)
235
- thread.daemon = True
236
- thread.start()
237
-
238
-
239
-
240
-
241
- custom_css = """
242
- body, .gradio-container {
243
- margin: 0;
244
- padding: 0;
245
- }
246
- """
247
-
248
-
249
- with gr.Blocks(css=custom_css) as demo:
250
- # Initialize the HTML output with a default message
251
- html_output = gr.HTML("Processing data...")
252
-
253
- # Define the refresh function
254
- def refresh_html():
255
- if prediction_plot_base64:
256
- return (
257
- f'<img src="data:image/png;base64,{prediction_plot_base64}" '
258
- 'alt="Prediction Plot" '
259
- 'style="width: 100vw; height: 100vh; object-fit: contain;">'
260
- )
261
  else:
262
- return "Processing data..."
263
-
264
- # Use the Timer component according to the documentation
265
- timer = gr.Timer(3600, refresh_html, [], html_output)
266
-
267
-
268
-
269
-
270
-
271
-
272
- # Initial call to set the HTML content when the page loads
273
- demo.load(refresh_html, [], html_output)
274
-
275
- # Launch the demo
276
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import requests
 
 
 
 
3
  import torch
 
4
  import torch.nn as nn
 
5
  import re
 
 
 
6
  from transformers import AutoTokenizer
 
 
7
 
 
 
 
8
  MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
10
 
 
 
 
 
 
11
  class ScorePredictor(nn.Module):
 
12
  def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
13
  super(ScorePredictor, self).__init__()
14
  self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
15
  self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
16
  self.fc = nn.Linear(hidden_dim, output_dim)
17
  self.sigmoid = nn.Sigmoid()
18
+
19
  def forward(self, input_ids, attention_mask):
20
  embedded = self.embedding(input_ids)
21
  lstm_out, _ = self.lstm(embedded)
 
27
  score_model.load_state_dict(torch.load("score_predictor.pth"))
28
  score_model.eval()
29
 
30
+ def preprocess_text(text):
31
+ text = text.lower()
32
+ text = re.sub(r'http\S+', '', text)
33
+ text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
34
+ text = re.sub(r'\s+', ' ', text).strip()
35
+ return text
36
+
37
+ def predict_sentiment(text):
38
+ if not text:
39
+ return 0.0
40
+ encoded_input = tokenizer(
41
+ text.split(),
42
+ return_tensors='pt',
43
+ padding=True,
44
+ truncation=True,
45
+ max_length=512
46
+ )
47
+ input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
48
+ with torch.no_grad():
49
+ score = score_model(input_ids, attention_mask)[0].item()
50
+ return score
51
+
52
+ def fetch_articles(ticker):
53
+ POLYGON_API_KEY = "cMCv7jipVvV4qLBikgzllNmW_isiODRR"
54
+ url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}"
55
+ try:
56
+ response = requests.get(url)
57
+ data = response.json()
58
+ if "results" in data and len(data["results"]) > 0:
59
+ article = data["results"][0]
60
+ title = article.get("title", "")
61
+ description = article.get("description", "")
62
+ return [title + " " + description]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  else:
64
+ return [f"No news articles found for {ticker}."]
65
+ except Exception as e:
66
+ return [f"Error fetching articles for {ticker}: {str(e)}"]
67
+
68
+ def analyze_ticker(ticker):
69
+ articles = fetch_articles(ticker)
70
+ sentiments = []
71
+ for article in articles:
72
+ clean_text = preprocess_text(article)
73
+ sentiment = predict_sentiment(clean_text)
74
+
75
+ # Determine sentiment label
76
+ if sentiment > 0.6:
77
+ sentiment_label = "Negative"
78
+ emoji = "😊"
79
+ elif sentiment < 0.4:
80
+ sentiment_label = "Positive"
81
+ emoji = "😞"
82
+ else:
83
+ sentiment_label = "Neutral"
84
+ emoji = "😐"
85
+
86
+ sentiments.append({
87
+ "article": article,
88
+ "sentiment": sentiment,
89
+ "sentiment_label": sentiment_label,
90
+ "emoji": emoji
91
+ })
92
+ return sentiments
93
+
94
+ def gradio_interface(ticker):
95
+ results = analyze_ticker(ticker)
96
+ output = f"""
97
+ <h2>Sentiment Analysis for {ticker}</h2>
98
+ <div style='border: 1px solid #ccc; padding: 15px; border-radius: 5px; margin-bottom: 20px;'>
99
+ <h3>Article:</h3>
100
+ <p>{results[0]['article']}</p>
101
+ <h3>Sentiment:</h3>
102
+ <p>Score: {results[0]['sentiment']:.4f}</p>
103
+ <p>Label: {results[0]['sentiment_label']} {results[0]['emoji']}</p>
104
+ </div>
105
+ """
106
+ return output
107
+
108
+ # Create Gradio interface
109
+ iface = gr.Interface(
110
+ fn=gradio_interface,
111
+ inputs=gr.Textbox(label="Enter Stock Ticker", placeholder="AAPL, MSFT, GOOGL..."),
112
+ outputs=gr.HTML(label="Sentiment Analysis Results"),
113
+ title="Stock News Sentiment Analyzer",
114
+ description="Enter a stock ticker to analyze the sentiment of recent news articles about that company.",
115
+ examples=[["AAPL"], ["MSFT"], ["TSLA"]]
116
+ )
117
+
118
+ iface.launch()