sentivity's picture
Update app.py
2cbfc6f verified
raw
history blame
6.11 kB
import gradio as gr
import requests
import torch
import torch.nn as nn
import re
import datetime
from transformers import AutoTokenizer
import numpy as np
from transformers import AutoModelForSequenceClassification
from transformers import TFAutoModelForSequenceClassification
from transformers import AutoConfig
from scipy.special import softmax
# Load tokenizer and sentiment model
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
config = AutoConfig.from_pretrained(MODEL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
model.save_pretrained(MODEL)
class ScorePredictor(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
super(ScorePredictor, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, attention_mask):
embedded = self.embedding(input_ids)
lstm_out, _ = self.lstm(embedded)
final_hidden_state = lstm_out[:, -1, :]
output = self.fc(final_hidden_state)
return self.sigmoid(output)
# Load trained score predictor model
score_model = ScorePredictor(tokenizer.vocab_size)
score_model.load_state_dict(torch.load("score_predictor.pth"))
score_model.eval()
# preprocesses text
def preprocess_text(text):
text = text.lower()
text = re.sub(r'http\S+', '', text)
text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# predicts sentiment
def predict_sentiment(text):
if not text:
return 0.0
# encoded_input = tokenizer(
# text.split(),
# return_tensors='pt',
# padding=True,
# truncation=True,
# max_length=512
# )
# input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
# with torch.no_grad():
# score = score_model(input_ids, attention_mask)[0].item()
# k = 20
# midpoint = 0.7
# scaled_score = 1 / (1 + np.exp(-k * (score - midpoint)))
# final_output = scaled_score * 100
# return 1-final_output
text = preprocess_text(text)
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
scores = output[0][0].detach().numpy()
scores = softmax(scores)
ranking = np.argsort(scores)
ranking = ranking[::-1]
negative_id = -1
for idx, label in config.id2label.items():
if label.lower() == 'negative':
negative_id = idx
negative_score = scores[negative_id]
return (1-(float(negative_score)))*100
# uses Polygon API to fetch article
def fetch_articles(ticker):
POLYGON_API_KEY = "cMCv7jipVvV4qLBikgzllNmW_isiODRR"
url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}"
try:
response = requests.get(url)
data = response.json()
if "results" in data and len(data["results"]) > 0:
article = data["results"][0]
title = article.get("title", "")
description = article.get("description", "")
return [title+description]
else:
return [f"No news articles found for {ticker}."]
except Exception as e:
return [f"Error fetching articles for {ticker}: {str(e)}"]
# allowed tickers
ALLOWED_TICKERS = {"AAPL", "GOOG", "AMZN", "NVDA", "META",'TSLA','QQQ'}
# initialize cache
sentiment_cache = {ticker: {"article": None, "sentiment": None, "timestamp": None} for ticker in ALLOWED_TICKERS}
# checks if cache is valid
def is_cache_valid(cached_time, max_age_minutes=30):
if cached_time is None:
return False
now = datetime.datetime.utcnow()
age = now - cached_time
return age.total_seconds() < max_age_minutes * 60
# analyzes the tikcers
def analyze_ticker(ticker):
ticker = ticker.upper()
if ticker not in ALLOWED_TICKERS:
return [{
"article": f"Sorry, '{ticker}' is not supported. Please choose one of: {', '.join(sorted(ALLOWED_TICKERS))}.",
"sentiment": 0.0
}]
cache_entry = sentiment_cache[ticker]
# if cache is valid and article exists
if is_cache_valid(cache_entry["timestamp"]) and cache_entry["article"] is not None:
return [{
"article": cache_entry["article"],
"sentiment": cache_entry["sentiment"]
}]
# fetch new article and update cache if cache is invalid
articles = fetch_articles(ticker)
if not articles:
return [{"article": "No articles found.", "sentiment": 0.0}]
article = articles[0]
clean_text = preprocess_text(article)
sentiment = predict_sentiment(clean_text)
# update cache with current time
sentiment_cache[ticker] = {
"article": article,
"sentiment": sentiment,
"timestamp": datetime.datetime.utcnow()
}
return [{
"article": article,
"sentiment": sentiment
}]
def display_sentiment(ticker):
results = analyze_ticker(ticker)
html_output = "<h2>Sentiment Analysis</h2><ul>"
if results:
for r in results:
html_output += f"<li><b>{r['article']}</b><br>Score: {r['sentiment']:.2f}</li>"
else:
html_output += "<li>No sentiment data available for this ticker.</li>"
html_output += "</ul>"
return html_output
with gr.Blocks() as demo:
gr.Markdown("# Ticker Sentiment Analysis")
ticker_input = gr.Textbox(label="Enter Ticker Symbol (e.g., AAPL)")
output_html = gr.HTML()
analyze_btn = gr.Button("Analyze")
analyze_btn.click(
fn=lambda t: "<h2>Fetching and scoring sentiment for... " + t + "</h2><p>Please wait...</p>",
inputs=[ticker_input],
outputs=[output_html],
queue=False
).then(
fn=display_sentiment,
inputs=[ticker_input],
outputs=[output_html]
)
demo.launch()