LinkLinkWu commited on
Commit
64d5a00
·
verified ·
1 Parent(s): c7f60fc

Update func.py

Browse files
Files changed (1) hide show
  1. func.py +41 -50
func.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List, Tuple
2
 
3
  from transformers import (
4
  pipeline,
@@ -10,14 +10,16 @@ from bs4 import BeautifulSoup
10
  import requests
11
 
12
  # ---------------------------------------------------------------------------
13
- # Model identifiers
14
  # ---------------------------------------------------------------------------
15
- SENTIMENT_MODEL_ID = "LinkLinkWu/Stock_Analysis_Test_Ahamed"
16
  NER_MODEL_ID = "dslim/bert-base-NER"
17
 
18
  # ---------------------------------------------------------------------------
19
- # Eager initialisation of Hugging Face pipelines (shared singletons)
20
  # ---------------------------------------------------------------------------
 
 
21
  sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_ID)
22
  sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_ID)
23
  sentiment_pipeline = pipeline(
@@ -26,6 +28,7 @@ sentiment_pipeline = pipeline(
26
  tokenizer=sentiment_tokenizer,
27
  )
28
 
 
29
  ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_ID)
30
  ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_ID)
31
  ner_pipeline = pipeline(
@@ -36,15 +39,11 @@ ner_pipeline = pipeline(
36
  )
37
 
38
  # ---------------------------------------------------------------------------
39
- # Web‑scraping helper
40
  # ---------------------------------------------------------------------------
41
 
42
  def fetch_news(ticker: str) -> List[dict]:
43
- """Return up to 30 latest Finviz headlines for *ticker* (title & link).
44
-
45
- Empty list on network / parsing errors or if Finviz redirects to a generic
46
- page (e.g. wrong ticker).
47
- """
48
  try:
49
  url = f"https://finviz.com/quote.ashx?t={ticker}"
50
  headers = {
@@ -60,79 +59,71 @@ def fetch_news(ticker: str) -> List[dict]:
60
 
61
  soup = BeautifulSoup(r.text, "html.parser")
62
  if ticker.upper() not in (soup.title.text if soup.title else "").upper():
63
- return [] # Finviz placeholder page
64
 
65
  table = soup.find(id="news-table")
66
  if table is None:
67
  return []
68
 
69
- news: List[dict] = []
70
  for row in table.find_all("tr")[:30]:
71
  link_tag = row.find("a")
72
  if link_tag:
73
- news.append({"title": link_tag.get_text(strip=True), "link": link_tag["href"]})
74
- return news
75
  except Exception:
76
  return []
77
 
78
  # ---------------------------------------------------------------------------
79
- # Sentiment helpers
80
  # ---------------------------------------------------------------------------
81
- _POSITIVE = "positive"
82
- _DEFAULT_THRESHOLD = 0.55 # per‑headline probability cut‑off
83
 
84
 
85
- def analyze_sentiment(
86
- text: str,
87
- pipe=None,
88
- threshold: float = _DEFAULT_THRESHOLD,
89
- ) -> Tuple[str, float]:
90
- """Classify *text* and return ``(label, positive_probability)``.
91
 
92
- * Binary label (*Positive* / *Negative*) is determined by comparing the
93
- *positive* probability with *threshold*.
94
- * Neutral headlines are mapped to *Negative* by design.
95
- * On any internal error → ("Unknown", 0.0).
96
  """
97
  try:
98
  sentiment_pipe = pipe or sentiment_pipeline
99
- scores = sentiment_pipe(text, return_all_scores=True, truncation=True)[0]
100
- pos_prob = 0.0
101
- for item in scores:
102
- if item["label"].lower() == _POSITIVE:
103
- pos_prob = item["score"]
104
- break
105
- label = "Positive" if pos_prob >= threshold else "Negative"
106
- return label, pos_prob
107
  except Exception:
108
- return "Unknown", 0.0
109
 
110
  # ---------------------------------------------------------------------------
111
- # Aggregation – average positive probability → binary overall label
112
  # ---------------------------------------------------------------------------
113
 
114
- def aggregate_sentiments(
115
- results: List[Tuple[str, float]],
116
- avg_threshold: float = _DEFAULT_THRESHOLD,
117
- ) -> str:
118
- """Compute overall **Positive/Negative** based on *mean* positive probability.
119
 
120
- * *results* list returned by ``analyze_sentiment`` for each headline.
121
- * If the average positive probability ≥ *avg_threshold* → *Positive*.
122
  * Empty list → *Unknown*.
123
  """
124
- if not results:
125
  return "Unknown"
126
 
127
- avg_pos = sum(prob for _, prob in results) / len(results)
128
- return "Positive" if avg_pos >= avg_threshold else "Negative"
 
 
129
 
130
  # ---------------------------------------------------------------------------
131
- # ORG‑entity extraction (for ticker discovery)
132
  # ---------------------------------------------------------------------------
133
 
134
  def extract_org_entities(text: str, pipe=None, max_entities: int = 5) -> List[str]:
135
- """Return up to *max_entities* unique ORG tokens (upper‑case, de‑hashed)."""
136
  try:
137
  ner_pipe = pipe or ner_pipeline
138
  entities = ner_pipe(text)
@@ -149,7 +140,7 @@ def extract_org_entities(text: str, pipe=None, max_entities: int = 5) -> List[st
149
  return []
150
 
151
  # ---------------------------------------------------------------------------
152
- # Public accessors (backward compatibility with app.py)
153
  # ---------------------------------------------------------------------------
154
 
155
  def get_sentiment_pipeline():
 
1
+ from typing import List
2
 
3
  from transformers import (
4
  pipeline,
 
10
  import requests
11
 
12
  # ---------------------------------------------------------------------------
13
+ # Model identifiers – use your custom sentiment model hosted on Hugging Face
14
  # ---------------------------------------------------------------------------
15
+ SENTIMENT_MODEL_ID = "LinkLinkWu/Stock_Analysis_Test_Ahamed" # binary sentiment
16
  NER_MODEL_ID = "dslim/bert-base-NER"
17
 
18
  # ---------------------------------------------------------------------------
19
+ # Eager initialisation (singletons shared by the whole Streamlit session)
20
  # ---------------------------------------------------------------------------
21
+ # Sentiment pipeline – returns one label with its score. We will *ignore* the
22
+ # numeric score down‑stream to satisfy the "no numbers" requirement.
23
  sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_ID)
24
  sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_ID)
25
  sentiment_pipeline = pipeline(
 
28
  tokenizer=sentiment_tokenizer,
29
  )
30
 
31
+ # Named‑entity‑recognition pipeline (ORG extraction)
32
  ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_ID)
33
  ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_ID)
34
  ner_pipeline = pipeline(
 
39
  )
40
 
41
  # ---------------------------------------------------------------------------
42
+ # Web‑scraping helper (Finviz)
43
  # ---------------------------------------------------------------------------
44
 
45
  def fetch_news(ticker: str) -> List[dict]:
46
+ """Return at most 30 latest Finviz headlines for *ticker* ("title" & "link")."""
 
 
 
 
47
  try:
48
  url = f"https://finviz.com/quote.ashx?t={ticker}"
49
  headers = {
 
59
 
60
  soup = BeautifulSoup(r.text, "html.parser")
61
  if ticker.upper() not in (soup.title.text if soup.title else "").upper():
62
+ return [] # possibly a redirect page
63
 
64
  table = soup.find(id="news-table")
65
  if table is None:
66
  return []
67
 
68
+ headlines: List[dict] = []
69
  for row in table.find_all("tr")[:30]:
70
  link_tag = row.find("a")
71
  if link_tag:
72
+ headlines.append({"title": link_tag.get_text(strip=True), "link": link_tag["href"]})
73
+ return headlines
74
  except Exception:
75
  return []
76
 
77
  # ---------------------------------------------------------------------------
78
+ # Sentiment helpers – binary classification, *no* numeric score exposed
79
  # ---------------------------------------------------------------------------
80
+ _LABEL_MAP = {"LABEL_0": "Negative", "LABEL_1": "Positive"} # adjust if model config differs
 
81
 
82
 
83
+ def analyze_sentiment(text: str, pipe=None) -> str:
84
+ """Return **"Positive"** or **"Negative"** for a single headline.
 
 
 
 
85
 
86
+ *Neutral* outputs (if ever returned by the model) are coerced to *Negative*.
87
+ Numeric confidence scores are deliberately discarded to honour the
88
+ "no numbers" requirement.
 
89
  """
90
  try:
91
  sentiment_pipe = pipe or sentiment_pipeline
92
+ result = sentiment_pipe(text, truncation=True, return_all_scores=False)[0]
93
+ raw_label = result.get("label", "").upper()
94
+ label = _LABEL_MAP.get(raw_label, "Negative") # default to Negative
95
+ return label
 
 
 
 
96
  except Exception:
97
+ return "Unknown"
98
 
99
  # ---------------------------------------------------------------------------
100
+ # Aggregation – majority vote (Positive‑ratio) → binary label
101
  # ---------------------------------------------------------------------------
102
 
103
+ _POS_RATIO_THRESHOLD = 0.6 # ≥60 % positives → overall Positive
104
+
105
+
106
+ def aggregate_sentiments(labels: List[str], pos_ratio_threshold: float = _POS_RATIO_THRESHOLD) -> str:
107
+ """Combine individual headline labels into an overall binary sentiment.
108
 
109
+ * If *Positive* proportion *pos_ratio_threshold* *Positive*.
110
+ * Otherwise → *Negative*.
111
  * Empty list → *Unknown*.
112
  """
113
+ if not labels:
114
  return "Unknown"
115
 
116
+ total = len(labels)
117
+ positives = sum(1 for l in labels if l == "Positive")
118
+ ratio = positives / total
119
+ return "Positive" if ratio >= pos_ratio_threshold else "Negative"
120
 
121
  # ---------------------------------------------------------------------------
122
+ # ORG‑entity extraction (ticker discovery)
123
  # ---------------------------------------------------------------------------
124
 
125
  def extract_org_entities(text: str, pipe=None, max_entities: int = 5) -> List[str]:
126
+ """Extract up to *max_entities* unique ORG tokens (upper‑case, de‑hashed)."""
127
  try:
128
  ner_pipe = pipe or ner_pipeline
129
  entities = ner_pipe(text)
 
140
  return []
141
 
142
  # ---------------------------------------------------------------------------
143
+ # Public accessors (legacy compatibility)
144
  # ---------------------------------------------------------------------------
145
 
146
  def get_sentiment_pipeline():