Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import torch | |
import torch.nn as nn | |
import re | |
import datetime | |
from transformers import AutoTokenizer | |
import numpy as np | |
from transformers import AutoModelForSequenceClassification | |
from transformers import TFAutoModelForSequenceClassification | |
from transformers import AutoConfig | |
from scipy.special import softmax | |
# Load tokenizer and sentiment model | |
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
config = AutoConfig.from_pretrained(MODEL) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL) | |
model.save_pretrained(MODEL) | |
class ScorePredictor(nn.Module): | |
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1): | |
super(ScorePredictor, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) | |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) | |
self.fc = nn.Linear(hidden_dim, output_dim) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, input_ids, attention_mask): | |
embedded = self.embedding(input_ids) | |
lstm_out, _ = self.lstm(embedded) | |
final_hidden_state = lstm_out[:, -1, :] | |
output = self.fc(final_hidden_state) | |
return self.sigmoid(output) | |
# Load trained score predictor model | |
score_model = ScorePredictor(tokenizer.vocab_size) | |
score_model.load_state_dict(torch.load("score_predictor.pth")) | |
score_model.eval() | |
# preprocesses text | |
def preprocess_text(text): | |
text = text.lower() | |
text = re.sub(r'http\S+', '', text) | |
text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text) | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
# predicts sentiment | |
def predict_sentiment(text): | |
if not text: | |
return 0.0 | |
text = preprocess_text(text) | |
encoded_input = tokenizer(text, return_tensors='pt') | |
output = model(**encoded_input) | |
scores = output[0][0].detach().numpy() | |
scores = softmax(scores) | |
ranking = np.argsort(scores) | |
ranking = ranking[::-1] | |
negative_index = None | |
for i in range(scores.shape[0]): | |
if config.id2label[ranking[i]] == 'positive': | |
negative_index = ranking[i] | |
break | |
negative_score = scores[negative_index] | |
return float(negative_score)*100 | |
# uses Polygon API to fetch article | |
def fetch_articles(ticker): | |
POLYGON_API_KEY = "cMCv7jipVvV4qLBikgzllNmW_isiODRR" | |
url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}" | |
try: | |
response = requests.get(url) | |
data = response.json() | |
if "results" in data and len(data["results"]) > 0: | |
article = data["results"][0] | |
title = article.get("title", "") | |
description = article.get("description", "") | |
return [title] | |
else: | |
return [f"No news articles found for {ticker}."] | |
except Exception as e: | |
return [f"Error fetching articles for {ticker}: {str(e)}"] | |
# allowed tickers | |
ALLOWED_TICKERS = {"AAPL", "GOOG", "AMZN", "NVDA", "META",'TSLA','QQQ'} | |
# initialize cache | |
sentiment_cache = {ticker: {"article": None, "sentiment": None, "timestamp": None} for ticker in ALLOWED_TICKERS} | |
# checks if cache is valid | |
def is_cache_valid(cached_time, max_age_minutes=30): | |
if cached_time is None: | |
return False | |
now = datetime.datetime.utcnow() | |
age = now - cached_time | |
return age.total_seconds() < max_age_minutes * 60 | |
# analyzes the tikcers | |
def analyze_ticker(ticker): | |
ticker = ticker.upper() | |
if ticker not in ALLOWED_TICKERS: | |
return [{ | |
"article": f"Sorry, '{ticker}' is not supported. Please choose one of: {', '.join(sorted(ALLOWED_TICKERS))}.", | |
"sentiment": 0.0 | |
}] | |
cache_entry = sentiment_cache[ticker] | |
# if cache is valid and article exists | |
if is_cache_valid(cache_entry["timestamp"]) and cache_entry["article"] is not None: | |
return [{ | |
"article": cache_entry["article"], | |
"sentiment": cache_entry["sentiment"] | |
}] | |
# fetch new article and update cache if cache is invalid | |
articles = fetch_articles(ticker) | |
if not articles: | |
return [{"article": "No articles found.", "sentiment": 0.0}] | |
article = articles[0] | |
clean_text = preprocess_text(article) | |
sentiment = predict_sentiment(clean_text) | |
# update cache with current time | |
sentiment_cache[ticker] = { | |
"article": article, | |
"sentiment": sentiment, | |
"timestamp": datetime.datetime.utcnow() | |
} | |
return [{ | |
"article": article, | |
"sentiment": sentiment | |
}] | |
# display's sentiment | |
def display_sentiment(ticker): | |
results = analyze_ticker(ticker) | |
html_output = "<h2>Sentiment Analysis</h2><ul>" | |
for r in results: | |
html_output += f"<li><b>{r['article']}</b><br>Score: {r['sentiment']:.2f}</li>" | |
html_output += "</ul>" | |
return html_output | |
# search feature | |
with gr.Blocks() as demo: | |
gr.Markdown("# Ticker Sentiment Analysis") | |
ticker_input = gr.Textbox(label="Enter Ticker Symbol (e.g., AAPL)") | |
output_html = gr.HTML() | |
analyze_btn = gr.Button("Analyze") | |
analyze_btn.click(fn=display_sentiment, inputs=[ticker_input], outputs=[output_html]) | |
demo.launch() |