sentivity's picture
Update app.py
5c05c21 verified
raw
history blame
4.02 kB
import gradio as gr
import requests
import torch
import torch.nn as nn
import re
import os
from transformers import AutoTokenizer
MODEL = "cardiffnlp/xlm-twitter-politics-sentiment"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
class ScorePredictor(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, output_dim=1):
super(ScorePredictor, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, attention_mask):
embedded = self.embedding(input_ids)
lstm_out, _ = self.lstm(embedded)
final_hidden_state = lstm_out[:, -1, :]
output = self.fc(final_hidden_state)
return self.sigmoid(output)
score_model = ScorePredictor(tokenizer.vocab_size)
score_model.load_state_dict(torch.load("score_predictor.pth"))
score_model.eval()
def preprocess_text(text):
text = text.lower()
text = re.sub(r'http\S+', '', text)
text = re.sub(r'[^a-zA-Z0-9\s.,!?]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def predict_sentiment(text):
if not text:
return 0.0
encoded_input = tokenizer(
text.split(),
return_tensors='pt',
padding=True,
truncation=True,
max_length=512
)
input_ids, attention_mask = encoded_input["input_ids"], encoded_input["attention_mask"]
with torch.no_grad():
score = score_model(input_ids, attention_mask)[0].item()
return score
def fetch_articles(ticker):
POLYGON_API_KEY = os.get_env('poly_api')
url = f"https://api.polygon.io/v2/reference/news?ticker={ticker}&limit=1&apiKey={POLYGON_API_KEY}"
try:
response = requests.get(url)
data = response.json()
if "results" in data and len(data["results"]) > 0:
article = data["results"][0]
title = article.get("title", "")
description = article.get("description", "")
return [title + " " + description]
else:
return [f"No news articles found for {ticker}."]
except Exception as e:
return [f"Error fetching articles for {ticker}: {str(e)}"]
def analyze_ticker(ticker):
articles = fetch_articles(ticker)
sentiments = []
for article in articles:
clean_text = preprocess_text(article)
sentiment = predict_sentiment(clean_text)
# Determine sentiment label
if sentiment > 0.6:
sentiment_label = "Negative"
emoji = "😊"
elif sentiment < 0.4:
sentiment_label = "Positive"
emoji = "😞"
else:
sentiment_label = "Neutral"
emoji = "😐"
sentiments.append({
"article": article,
"sentiment": sentiment,
"sentiment_label": sentiment_label,
"emoji": emoji
})
return sentiments
def gradio_interface(ticker):
results = analyze_ticker(ticker)
output = f"""
<h2>Sentiment Analysis for {ticker}</h2>
<div style='border: 1px solid #ccc; padding: 15px; border-radius: 5px; margin-bottom: 20px;'>
<h3>Article:</h3>
<p>{results[0]['article']}</p>
<h3>Sentiment:</h3>
<p>Score: {results[0]['sentiment']:.4f}</p>
<p>Label: {results[0]['sentiment_label']} {results[0]['emoji']}</p>
</div>
"""
return output
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(label="Enter Stock Ticker", placeholder="AAPL, MSFT, GOOGL..."),
outputs=gr.HTML(label="Sentiment Analysis Results"),
title="Stock News Sentiment Analyzer",
description="Enter a stock ticker to analyze the sentiment of recent news articles about that company.",
examples=[["AAPL"], ["MSFT"], ["TSLA"]]
)
iface.launch()