sentivity commited on
Commit
5a3607d
·
verified ·
1 Parent(s): 519defd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -57
app.py CHANGED
@@ -3,12 +3,14 @@ import requests
3
  import torch
4
  import torch.nn as nn
5
  import re
6
- import os
7
  from transformers import AutoTokenizer
8
 
 
9
  MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
11
 
 
12
  class ScorePredictor(nn.Module):
13
  def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
14
  super(ScorePredictor, self).__init__()
@@ -16,7 +18,7 @@ class ScorePredictor(nn.Module):
16
  self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
17
  self.fc = nn.Linear(hidden_dim, output_dim)
18
  self.sigmoid = nn.Sigmoid()
19
-
20
  def forward(self, input_ids, attention_mask):
21
  embedded = self.embedding(input_ids)
22
  lstm_out, _ = self.lstm(embedded)
@@ -24,10 +26,12 @@ class ScorePredictor(nn.Module):
24
  output = self.fc(final_hidden_state)
25
  return self.sigmoid(output)
26
 
 
27
  score_model = ScorePredictor(tokenizer.vocab_size)
28
  score_model.load_state_dict(torch.load("score_predictor.pth"))
29
  score_model.eval()
30
 
 
31
  def preprocess_text(text):
32
  text = text.lower()
33
  text = re.sub(r'http\S+', '', text)
@@ -35,6 +39,7 @@ def preprocess_text(text):
35
  text = re.sub(r'\s+', ' ', text).strip()
36
  return text
37
 
 
38
  def predict_sentiment(text):
39
  if not text:
40
  return 0.0
@@ -48,15 +53,11 @@ def predict_sentiment(text):
48
  input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
49
  with torch.no_grad():
50
  score = score_model(input_ids, attention_mask)[0].item()
51
- min_val, max_val = 0.3, 0.9
52
- scaled_score = (score - min_val) / (max_val - min_val)
53
- # Clip to ensure it stays within [0, 1] in case original score was outside [0.3, 0.9]
54
- scaled_score = max(0.0, min(1.0, scaled_score))
55
-
56
- return scaled_score
57
 
 
58
  def fetch_articles(ticker):
59
- POLYGON_API_KEY = os.getenv('poly_api')
60
  url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}"
61
  try:
62
  response = requests.get(url)
@@ -71,54 +72,76 @@ def fetch_articles(ticker):
71
  except Exception as e:
72
  return [f"Error fetching articles for {ticker}: {str(e)}"]
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def analyze_ticker(ticker):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  articles = fetch_articles(ticker)
76
- sentiments = []
77
- for article in articles:
78
- clean_text = preprocess_text(article)
79
- sentiment = predict_sentiment(clean_text)
80
-
81
- # Determine sentiment label
82
- if sentiment > 0.6:
83
- sentiment_label = "Negative"
84
- emoji = "😊"
85
- elif sentiment < 0.4:
86
- sentiment_label = "Positive"
87
- emoji = "😞"
88
- else:
89
- sentiment_label = "Neutral"
90
- emoji = "😐"
91
-
92
- sentiments.append({
93
- "article": article,
94
- "sentiment": sentiment,
95
- "sentiment_label": sentiment_label,
96
- "emoji": emoji
97
- })
98
- return sentiments
99
-
100
- def gradio_interface(ticker):
101
  results = analyze_ticker(ticker)
102
- output = f"""
103
- <h2>Sentiment Analysis for {ticker}</h2>
104
- <div style='border: 1px solid #ccc; padding: 15px; border-radius: 5px; margin-bottom: 20px;'>
105
- <h3>Article:</h3>
106
- <p>{results[0]['article']}</p>
107
- <h3>Sentiment:</h3>
108
- <p>Score: {results[0]['sentiment']:.4f}</p>
109
- <p>Label: {results[0]['sentiment_label']} {results[0]['emoji']}</p>
110
- </div>
111
- """
112
- return output
113
-
114
- # Create Gradio interface
115
- iface = gr.Interface(
116
- fn=gradio_interface,
117
- inputs=gr.Textbox(label="Enter Stock Ticker", placeholder="AAPL, MSFT, GOOGL..."),
118
- outputs=gr.HTML(label="Sentiment Analysis Results"),
119
- title="Stock News Sentiment Analyzer",
120
- description="Enter a stock ticker to analyze the sentiment of recent news articles about that company.",
121
- examples=[["AAPL"], ["MSFT"], ["TSLA"]]
122
- )
123
-
124
- iface.launch()
 
3
  import torch
4
  import torch.nn as nn
5
  import re
6
+ import datetime
7
  from transformers import AutoTokenizer
8
 
9
+ # Load tokenizer and sentiment model
10
  MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
12
 
13
+
14
  class ScorePredictor(nn.Module):
15
  def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
16
  super(ScorePredictor, self).__init__()
 
18
  self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
19
  self.fc = nn.Linear(hidden_dim, output_dim)
20
  self.sigmoid = nn.Sigmoid()
21
+
22
  def forward(self, input_ids, attention_mask):
23
  embedded = self.embedding(input_ids)
24
  lstm_out, _ = self.lstm(embedded)
 
26
  output = self.fc(final_hidden_state)
27
  return self.sigmoid(output)
28
 
29
+ # Load trained score predictor model
30
  score_model = ScorePredictor(tokenizer.vocab_size)
31
  score_model.load_state_dict(torch.load("score_predictor.pth"))
32
  score_model.eval()
33
 
34
+ # preprocesses text
35
  def preprocess_text(text):
36
  text = text.lower()
37
  text = re.sub(r'http\S+', '', text)
 
39
  text = re.sub(r'\s+', ' ', text).strip()
40
  return text
41
 
42
+ # predicts sentiment
43
  def predict_sentiment(text):
44
  if not text:
45
  return 0.0
 
53
  input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
54
  with torch.no_grad():
55
  score = score_model(input_ids, attention_mask)[0].item()
56
+ return score
 
 
 
 
 
57
 
58
+ # uses Polygon API to fetch article
59
  def fetch_articles(ticker):
60
+ POLYGON_API_KEY = "cMCv7jipVvV4qLBikgzllNmW_isiODRR"
61
  url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}"
62
  try:
63
  response = requests.get(url)
 
72
  except Exception as e:
73
  return [f"Error fetching articles for {ticker}: {str(e)}"]
74
 
75
+ # allowed tickers
76
+ ALLOWED_TICKERS = {"AAPL", "GOOG", "AMZN", "NVDA", "META"}
77
+
78
+ # initialize cache
79
+ sentiment_cache = {ticker: {"article": None, "sentiment": None, "timestamp": None} for ticker in ALLOWED_TICKERS}
80
+
81
+ # checks if cache is valid
82
+ def is_cache_valid(cached_time, max_age_minutes=30):
83
+ if cached_time is None:
84
+ return False
85
+ now = datetime.datetime.utcnow()
86
+ age = now - cached_time
87
+ return age.total_seconds() < max_age_minutes * 60
88
+
89
+ # analyzes the tikcers
90
  def analyze_ticker(ticker):
91
+ ticker = ticker.upper()
92
+ if ticker not in ALLOWED_TICKERS:
93
+ return [{
94
+ "article": f"Sorry, '{ticker}' is not supported. Please choose one of: {', '.join(sorted(ALLOWED_TICKERS))}.",
95
+ "sentiment": 0.0
96
+ }]
97
+
98
+ cache_entry = sentiment_cache[ticker]
99
+
100
+ # if cache is valid and article exists
101
+ if is_cache_valid(cache_entry["timestamp"]) and cache_entry["article"] is not None:
102
+
103
+ return [{
104
+ "article": cache_entry["article"],
105
+ "sentiment": cache_entry["sentiment"]
106
+ }]
107
+
108
+ # fetch new article and update cache if cache is invalid
109
  articles = fetch_articles(ticker)
110
+ if not articles:
111
+ return [{"article": "No articles found.", "sentiment": 0.0}]
112
+
113
+ article = articles[0]
114
+
115
+ clean_text = preprocess_text(article)
116
+ sentiment = predict_sentiment(clean_text)
117
+
118
+ # update cache with current time
119
+ sentiment_cache[ticker] = {
120
+ "article": article,
121
+ "sentiment": sentiment,
122
+ "timestamp": datetime.datetime.utcnow()
123
+ }
124
+
125
+ return [{
126
+ "article": article,
127
+ "sentiment": sentiment
128
+ }]
129
+
130
+ # display's sentiment
131
+ def display_sentiment(ticker):
 
 
 
132
  results = analyze_ticker(ticker)
133
+ html_output = "<h2>Sentiment Analysis</h2><ul>"
134
+ for r in results:
135
+ html_output += f"<li><b>{r['article']}</b><br>Score: {r['sentiment']:.2f}</li>"
136
+ html_output += "</ul>"
137
+ return html_output
138
+
139
+ # search feature
140
+ with gr.Blocks() as demo:
141
+ gr.Markdown("# Ticker Sentiment Analysis")
142
+ ticker_input = gr.Textbox(label="Enter Ticker Symbol (e.g., AAPL)")
143
+ output_html = gr.HTML()
144
+ analyze_btn = gr.Button("Analyze")
145
+ analyze_btn.click(fn=display_sentiment, inputs=[ticker_input], outputs=[output_html])
146
+
147
+ demo.launch()