File size: 6,342 Bytes
8314677 6eecf76 ae44182 d25b499 7832e21 64ffc8f d25b499 6eecf76 d25b499 0f1a02a d25b499 6eecf76 d25b499 6eecf76 dd3df57 6eecf76 dd3df57 64ffc8f 6eecf76 dd3df57 6eecf76 d25b499 dd3df57 64ffc8f 6eecf76 e12e190 6eecf76 e12e190 6eecf76 d25b499 e12e190 d25b499 8602bc9 6eecf76 64ffc8f d25b499 64ffc8f c7f60fc 64ffc8f c7f60fc 6eecf76 64ffc8f c7f60fc 64ffc8f 64d5a00 6eecf76 d25b499 6eecf76 64d5a00 7832e21 64ffc8f ae44182 e12e190 7c727fa 64d5a00 6eecf76 e12e190 64d5a00 e12e190 6eecf76 7c727fa 6eecf76 d25b499 64d5a00 d25b499 dd3df57 e12e190 dd3df57 d25b499 dd3df57 e12e190 dd3df57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
"""
* **Single** `analyze_sentiment` implementation – no more duplicates.
* Returns **label string by default**, optional probability via `return_prob`.
* Threshold lowered to **0.50** and Neutral treated as Positive.
* Helper pipelines cached at module level.
"""
from __future__ import annotations
from typing import List, Tuple
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
)
from bs4 import BeautifulSoup
import requests
# ---------------------------------------------------------------------------
# Model identifiers (Hugging Face)
# ---------------------------------------------------------------------------
SENTIMENT_MODEL_ID = "LinkLinkWu/Boss_Stock_News_Analysis" # LABEL_0 = Negative, LABEL_1 = Positive
NER_MODEL_ID = "dslim/bert-base-NER"
# ---------------------------------------------------------------------------
# Pipeline singletons – loaded once on first import
# ---------------------------------------------------------------------------
# Sentiment
_sent_tok = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_ID)
_sent_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_ID)
sentiment_pipeline = pipeline(
"text-classification",
model=_sent_model,
tokenizer=_sent_tok,
return_all_scores=True,
)
# NER
_ner_tok = AutoTokenizer.from_pretrained(NER_MODEL_ID)
_ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_ID)
ner_pipeline = pipeline(
"ner",
model=_ner_model,
tokenizer=_ner_tok,
grouped_entities=True,
)
# ---------------------------------------------------------------------------
# Sentiment helpers
# ---------------------------------------------------------------------------
_POSITIVE_RAW = "LABEL_1" # positive class id in model output
_NEUTRAL_RAW = "NEUTRAL" # some models add a neutral class
_SINGLE_THRESHOLD = 0.50 # ≥50% positive prob → Positive
_LABEL_NEG = "Negative"
_LABEL_POS = "Positive"
_LABEL_UNK = "Unknown"
def analyze_sentiment(
text: str,
*,
pipe=None,
threshold: float = _SINGLE_THRESHOLD,
return_prob: bool = False,
):
"""Classify *text* as Positive / Negative.
Parameters
----------
text : str
Input sentence (e.g. news headline).
pipe : transformers.Pipeline, optional
Custom sentiment pipeline; defaults to module-level singleton.
threshold : float, default 0.50
Positive-probability cut-off.
return_prob : bool, default False
If *True*, returns ``(label, positive_probability)`` tuple;
otherwise returns just the label string.
Notes
-----
* When the underlying model emits *NEUTRAL*, we treat it the same
as *Positive* – finance headlines often sound cautious.
* Function never raises; on failure returns ``"Unknown"`` (or
``("Unknown", 0.0)`` when *return_prob* is *True*).
"""
try:
s_pipe = pipe or sentiment_pipeline
scores = s_pipe(text, truncation=True)[0] # list[dict]
score_map = {item["label"].upper(): item["score"] for item in scores}
pos_prob = score_map.get(_POSITIVE_RAW, 0.0)
if _NEUTRAL_RAW in score_map: # treat Neutral as Positive
pos_prob = max(pos_prob, score_map[_NEUTRAL_RAW])
label = _LABEL_POS if pos_prob >= threshold else _LABEL_NEG
return (label, pos_prob) if return_prob else label
except Exception:
return (_LABEL_UNK, 0.0) if return_prob else _LABEL_UNK
# ---------------------------------------------------------------------------
# Web-scraping helper (Finviz)
# ---------------------------------------------------------------------------
def fetch_news(ticker: str, max_items: int = 30) -> List[dict]:
"""Return up to *max_items* latest Finviz headlines for *ticker*.
Result format:
``[{'title': str, 'link': str}, ...]``
"""
try:
url = f"https://finviz.com/quote.ashx?t={ticker}"
headers = {
"User-Agent": "Mozilla/5.0",
"Accept": "text/html",
"Accept-Language": "en-US,en;q=0.5",
"Referer": "https://finviz.com/",
"Connection": "keep-alive",
}
r = requests.get(url, headers=headers, timeout=10)
if r.status_code != 200:
return []
soup = BeautifulSoup(r.text, "html.parser")
if ticker.upper() not in (soup.title.text if soup.title else "").upper():
return [] # redirected / placeholder page
table = soup.find(id="news-table")
if table is None:
return []
headlines: List[dict] = []
for row in table.find_all("tr")[:max_items]:
link_tag = row.find("a")
if link_tag:
headlines.append(
{"title": link_tag.text.strip(), "link": link_tag["href"]}
)
return headlines
except Exception:
return []
# ---------------------------------------------------------------------------
# Named-entity extraction helper
# ---------------------------------------------------------------------------
def extract_org_entities(text: str, pipe=None, max_entities: int = 5) -> List[str]:
"""Extract *ORG* tokens (upper-cased) from *text*.
Returns at most *max_entities* unique ticker-like strings suitable
for Finviz / Yahoo queries.
"""
try:
ner_pipe = pipe or ner_pipeline
entities = ner_pipe(text)
orgs: List[str] = []
for ent in entities:
if ent.get("entity_group") == "ORG":
token = ent["word"].replace("##", "").strip().upper()
if token and token not in orgs:
orgs.append(token)
if len(orgs) >= max_entities:
break
return orgs
except Exception:
return []
# ---------------------------------------------------------------------------
# Public accessors (legacy compatibility)
# ---------------------------------------------------------------------------
def get_sentiment_pipeline():
"""Return the module-level sentiment pipeline singleton."""
return sentiment_pipeline
def get_ner_pipeline():
"""Return the module-level NER pipeline singleton."""
return ner_pipeline
|