Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import pandas as pd | |
import numpy as np | |
# Load the FinBERT model and tokenizer | |
# This might download the model files the first time it's run | |
tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert") | |
model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert") | |
def analyze_sentiment(text): | |
""" | |
Analyzes the sentiment of a given text using the FinBERT model. | |
Args: | |
text (str): The input text (e.g., news headline or description). | |
Returns: | |
tuple: A tuple containing: | |
- sentiment_label (str): 'positive', 'negative', or 'neutral'. | |
- sentiment_score (float): The probability score of the predicted sentiment. | |
- scores (dict): Dictionary containing probabilities for all labels ('positive', 'negative', 'neutral'). | |
Returns (None, None, None) if the input is invalid or an error occurs. | |
""" | |
if not isinstance(text, str) or not text.strip(): | |
return None, None, None # Return None for empty or invalid input | |
try: | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) | |
with torch.no_grad(): # Disable gradient calculation for inference | |
outputs = model(**inputs) | |
# Get probabilities using softmax | |
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
scores = probabilities[0].numpy() # Get scores for the first (and only) input | |
# Get the predicted sentiment label index | |
predicted_class_id = np.argmax(scores) | |
# Map index to label based on model config | |
sentiment_label = model.config.id2label[predicted_class_id] | |
sentiment_score = scores[predicted_class_id] | |
all_scores = {model.config.id2label[i]: scores[i] for i in range(len(scores))} | |
return sentiment_label, float(sentiment_score), {k: float(v) for k, v in all_scores.items()} | |
except Exception as e: | |
print(f"Error during sentiment analysis for text: '{text[:50]}...': {e}") | |
return None, None, None | |
# Example usage (for testing the module directly) | |
if __name__ == '__main__': | |
test_texts = [ | |
"Stocks rallied on positive economic news.", | |
"The company reported a significant drop in profits.", | |
"Market remains flat amid uncertainty.", | |
"", # Empty string test | |
None # None test | |
] | |
print("--- Testing Sentiment Analysis ---") | |
for t in test_texts: | |
label, score, all_scores_dict = analyze_sentiment(t) | |
if label: | |
print(f"Text: '{t}'") | |
print(f" Sentiment: {label} (Score: {score:.4f})") | |
print(f" All Scores: {all_scores_dict}") | |
else: | |
print(f"Text: '{t}' -> Invalid input or error during analysis.") | |