Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,12 +3,14 @@ import requests
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import re
|
6 |
-
import
|
7 |
from transformers import AutoTokenizer
|
8 |
|
|
|
9 |
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
|
10 |
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
11 |
|
|
|
12 |
class ScorePredictor(nn.Module):
|
13 |
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
|
14 |
super(ScorePredictor, self).__init__()
|
@@ -16,7 +18,7 @@ class ScorePredictor(nn.Module):
|
|
16 |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
|
17 |
self.fc = nn.Linear(hidden_dim, output_dim)
|
18 |
self.sigmoid = nn.Sigmoid()
|
19 |
-
|
20 |
def forward(self, input_ids, attention_mask):
|
21 |
embedded = self.embedding(input_ids)
|
22 |
lstm_out, _ = self.lstm(embedded)
|
@@ -24,10 +26,12 @@ class ScorePredictor(nn.Module):
|
|
24 |
output = self.fc(final_hidden_state)
|
25 |
return self.sigmoid(output)
|
26 |
|
|
|
27 |
score_model = ScorePredictor(tokenizer.vocab_size)
|
28 |
score_model.load_state_dict(torch.load("score_predictor.pth"))
|
29 |
score_model.eval()
|
30 |
|
|
|
31 |
def preprocess_text(text):
|
32 |
text = text.lower()
|
33 |
text = re.sub(r'http\S+', '', text)
|
@@ -35,6 +39,7 @@ def preprocess_text(text):
|
|
35 |
text = re.sub(r'\s+', ' ', text).strip()
|
36 |
return text
|
37 |
|
|
|
38 |
def predict_sentiment(text):
|
39 |
if not text:
|
40 |
return 0.0
|
@@ -48,15 +53,11 @@ def predict_sentiment(text):
|
|
48 |
input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
|
49 |
with torch.no_grad():
|
50 |
score = score_model(input_ids, attention_mask)[0].item()
|
51 |
-
|
52 |
-
scaled_score = (score - min_val) / (max_val - min_val)
|
53 |
-
# Clip to ensure it stays within [0, 1] in case original score was outside [0.3, 0.9]
|
54 |
-
scaled_score = max(0.0, min(1.0, scaled_score))
|
55 |
-
|
56 |
-
return scaled_score
|
57 |
|
|
|
58 |
def fetch_articles(ticker):
|
59 |
-
POLYGON_API_KEY =
|
60 |
url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}"
|
61 |
try:
|
62 |
response = requests.get(url)
|
@@ -71,54 +72,76 @@ def fetch_articles(ticker):
|
|
71 |
except Exception as e:
|
72 |
return [f"Error fetching articles for {ticker}: {str(e)}"]
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
def analyze_ticker(ticker):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
articles = fetch_articles(ticker)
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
return sentiments
|
99 |
-
|
100 |
-
def gradio_interface(ticker):
|
101 |
results = analyze_ticker(ticker)
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
""
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
inputs=gr.Textbox(label="Enter Stock Ticker", placeholder="AAPL, MSFT, GOOGL..."),
|
118 |
-
outputs=gr.HTML(label="Sentiment Analysis Results"),
|
119 |
-
title="Stock News Sentiment Analyzer",
|
120 |
-
description="Enter a stock ticker to analyze the sentiment of recent news articles about that company.",
|
121 |
-
examples=[["AAPL"], ["MSFT"], ["TSLA"]]
|
122 |
-
)
|
123 |
-
|
124 |
-
iface.launch()
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import re
|
6 |
+
import datetime
|
7 |
from transformers import AutoTokenizer
|
8 |
|
9 |
+
# Load tokenizer and sentiment model
|
10 |
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
|
11 |
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
12 |
|
13 |
+
|
14 |
class ScorePredictor(nn.Module):
|
15 |
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
|
16 |
super(ScorePredictor, self).__init__()
|
|
|
18 |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
|
19 |
self.fc = nn.Linear(hidden_dim, output_dim)
|
20 |
self.sigmoid = nn.Sigmoid()
|
21 |
+
|
22 |
def forward(self, input_ids, attention_mask):
|
23 |
embedded = self.embedding(input_ids)
|
24 |
lstm_out, _ = self.lstm(embedded)
|
|
|
26 |
output = self.fc(final_hidden_state)
|
27 |
return self.sigmoid(output)
|
28 |
|
29 |
+
# Load trained score predictor model
|
30 |
score_model = ScorePredictor(tokenizer.vocab_size)
|
31 |
score_model.load_state_dict(torch.load("score_predictor.pth"))
|
32 |
score_model.eval()
|
33 |
|
34 |
+
# preprocesses text
|
35 |
def preprocess_text(text):
|
36 |
text = text.lower()
|
37 |
text = re.sub(r'http\S+', '', text)
|
|
|
39 |
text = re.sub(r'\s+', ' ', text).strip()
|
40 |
return text
|
41 |
|
42 |
+
# predicts sentiment
|
43 |
def predict_sentiment(text):
|
44 |
if not text:
|
45 |
return 0.0
|
|
|
53 |
input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
|
54 |
with torch.no_grad():
|
55 |
score = score_model(input_ids, attention_mask)[0].item()
|
56 |
+
return score
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# uses Polygon API to fetch article
|
59 |
def fetch_articles(ticker):
|
60 |
+
POLYGON_API_KEY = "cMCv7jipVvV4qLBikgzllNmW_isiODRR"
|
61 |
url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}"
|
62 |
try:
|
63 |
response = requests.get(url)
|
|
|
72 |
except Exception as e:
|
73 |
return [f"Error fetching articles for {ticker}: {str(e)}"]
|
74 |
|
75 |
+
# allowed tickers
|
76 |
+
ALLOWED_TICKERS = {"AAPL", "GOOG", "AMZN", "NVDA", "META"}
|
77 |
+
|
78 |
+
# initialize cache
|
79 |
+
sentiment_cache = {ticker: {"article": None, "sentiment": None, "timestamp": None} for ticker in ALLOWED_TICKERS}
|
80 |
+
|
81 |
+
# checks if cache is valid
|
82 |
+
def is_cache_valid(cached_time, max_age_minutes=30):
|
83 |
+
if cached_time is None:
|
84 |
+
return False
|
85 |
+
now = datetime.datetime.utcnow()
|
86 |
+
age = now - cached_time
|
87 |
+
return age.total_seconds() < max_age_minutes * 60
|
88 |
+
|
89 |
+
# analyzes the tikcers
|
90 |
def analyze_ticker(ticker):
|
91 |
+
ticker = ticker.upper()
|
92 |
+
if ticker not in ALLOWED_TICKERS:
|
93 |
+
return [{
|
94 |
+
"article": f"Sorry, '{ticker}' is not supported. Please choose one of: {', '.join(sorted(ALLOWED_TICKERS))}.",
|
95 |
+
"sentiment": 0.0
|
96 |
+
}]
|
97 |
+
|
98 |
+
cache_entry = sentiment_cache[ticker]
|
99 |
+
|
100 |
+
# if cache is valid and article exists
|
101 |
+
if is_cache_valid(cache_entry["timestamp"]) and cache_entry["article"] is not None:
|
102 |
+
|
103 |
+
return [{
|
104 |
+
"article": cache_entry["article"],
|
105 |
+
"sentiment": cache_entry["sentiment"]
|
106 |
+
}]
|
107 |
+
|
108 |
+
# fetch new article and update cache if cache is invalid
|
109 |
articles = fetch_articles(ticker)
|
110 |
+
if not articles:
|
111 |
+
return [{"article": "No articles found.", "sentiment": 0.0}]
|
112 |
+
|
113 |
+
article = articles[0]
|
114 |
+
|
115 |
+
clean_text = preprocess_text(article)
|
116 |
+
sentiment = predict_sentiment(clean_text)
|
117 |
+
|
118 |
+
# update cache with current time
|
119 |
+
sentiment_cache[ticker] = {
|
120 |
+
"article": article,
|
121 |
+
"sentiment": sentiment,
|
122 |
+
"timestamp": datetime.datetime.utcnow()
|
123 |
+
}
|
124 |
+
|
125 |
+
return [{
|
126 |
+
"article": article,
|
127 |
+
"sentiment": sentiment
|
128 |
+
}]
|
129 |
+
|
130 |
+
# display's sentiment
|
131 |
+
def display_sentiment(ticker):
|
|
|
|
|
|
|
132 |
results = analyze_ticker(ticker)
|
133 |
+
html_output = "<h2>Sentiment Analysis</h2><ul>"
|
134 |
+
for r in results:
|
135 |
+
html_output += f"<li><b>{r['article']}</b><br>Score: {r['sentiment']:.2f}</li>"
|
136 |
+
html_output += "</ul>"
|
137 |
+
return html_output
|
138 |
+
|
139 |
+
# search feature
|
140 |
+
with gr.Blocks() as demo:
|
141 |
+
gr.Markdown("# Ticker Sentiment Analysis")
|
142 |
+
ticker_input = gr.Textbox(label="Enter Ticker Symbol (e.g., AAPL)")
|
143 |
+
output_html = gr.HTML()
|
144 |
+
analyze_btn = gr.Button("Analyze")
|
145 |
+
analyze_btn.click(fn=display_sentiment, inputs=[ticker_input], outputs=[output_html])
|
146 |
+
|
147 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|