LinkLinkWu commited on
Commit
d25b499
·
verified ·
1 Parent(s): bdaffbb

Update func.py

Browse files
Files changed (1) hide show
  1. func.py +107 -58
func.py CHANGED
@@ -1,99 +1,148 @@
1
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
 
 
 
 
 
 
 
2
  from bs4 import BeautifulSoup
3
  import requests
4
 
5
- # ----------- Eager Initialization of Pipelines -----------
6
- # Sentiment pipeline
7
- model_id = "ahmedrachid/FinancialBERT-Sentiment-Analysis"
8
- sentiment_tokenizer = AutoTokenizer.from_pretrained(model_id)
9
- sentiment_model = AutoModelForSequenceClassification.from_pretrained(model_id)
 
 
 
 
 
 
 
10
  sentiment_pipeline = pipeline(
11
  "sentiment-analysis",
12
  model=sentiment_model,
13
- tokenizer=sentiment_tokenizer
14
  )
15
 
16
- # NER pipeline
17
- ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
18
- ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
19
  ner_pipeline = pipeline(
20
  "ner",
21
  model=ner_model,
22
  tokenizer=ner_tokenizer,
23
- grouped_entities=True
24
  )
25
 
26
- # ----------- Core Functions -----------
27
- def fetch_news(ticker):
 
 
 
 
 
 
 
 
28
  try:
29
  url = f"https://finviz.com/quote.ashx?t={ticker}"
30
  headers = {
31
- 'User-Agent': 'Mozilla/5.0',
32
- 'Accept': 'text/html',
33
- 'Accept-Language': 'en-US,en;q=0.5',
34
- 'Referer': 'https://finviz.com/',
35
- 'Connection': 'keep-alive',
36
  }
37
- response = requests.get(url, headers=headers)
38
  if response.status_code != 200:
39
  return []
40
 
41
- soup = BeautifulSoup(response.text, 'html.parser')
42
- title = soup.title.text if soup.title else ""
43
- if ticker not in title:
 
44
  return []
45
 
46
- news_table = soup.find(id='news-table')
47
  if news_table is None:
48
  return []
49
 
50
- news = []
51
- for row in news_table.findAll('tr')[:30]:
52
- a_tag = row.find('a')
53
- if a_tag:
54
- title_text = a_tag.get_text()
55
- link = a_tag['href']
56
- news.append({'title': title_text, 'link': link})
57
- return news
 
58
  except Exception:
 
59
  return []
60
 
61
- def analyze_sentiment(text, pipe=None):
62
- """
63
- 兼容两种调用:
64
- - analyze_sentiment(text) -> 使用全局 sentiment_pipeline
65
- - analyze_sentiment(text, some_pipeline) -> 使用传入的 some_pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  """
67
  try:
68
  sentiment_pipe = pipe or sentiment_pipeline
69
- result = sentiment_pipe(text)[0]
70
- return "Positive" if result['label'] == 'POSITIVE' else "Negative"
 
 
 
71
  except Exception:
72
- return "Unknown"
73
 
74
- def extract_org_entities(text, pipe=None):
75
- """
76
- - extract_org_entities(text)
77
- - extract_org_entities(text, some_pipeline)
 
 
 
 
 
 
 
 
 
78
  """
79
- try:
80
- ner_pipe = pipe or ner_pipeline
81
- entities = ner_pipe(text)
82
- orgs = []
83
- for ent in entities:
84
- if ent["entity_group"] == "ORG":
85
- w = ent["word"].replace("##", "").strip().upper()
86
- if w not in orgs:
87
- orgs.append(w)
88
- if len(orgs) >= 5:
89
- break
90
- return orgs
91
- except Exception:
92
- return []
93
 
94
- # ----------- Helper Functions for Imports -----------
95
  def get_sentiment_pipeline():
 
96
  return sentiment_pipeline
97
 
 
98
  def get_ner_pipeline():
 
99
  return ner_pipeline
 
1
+ from typing import List, Tuple
2
+
3
+ from transformers import (
4
+ pipeline,
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification,
7
+ AutoModelForTokenClassification,
8
+ )
9
  from bs4 import BeautifulSoup
10
  import requests
11
 
12
+ # ---------------------------------------------------------------------------
13
+ # Model identifiers
14
+ # ---------------------------------------------------------------------------
15
+ SENTIMENT_MODEL_ID = "ahmedrachid/FinancialBERT-Sentiment-Analysis" # returns: positive / neutral / negative
16
+ NER_MODEL_ID = "dslim/bert-base-NER"
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Eager initialisation of Hugging Face pipelines (shared across requests)
20
+ # ---------------------------------------------------------------------------
21
+ # Sentiment pipeline (binary decision will be made later)
22
+ sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_ID)
23
+ sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_ID)
24
  sentiment_pipeline = pipeline(
25
  "sentiment-analysis",
26
  model=sentiment_model,
27
+ tokenizer=sentiment_tokenizer,
28
  )
29
 
30
+ # Named‑entity‑recognition pipeline (ORG extraction)
31
+ ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_ID)
32
+ ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_ID)
33
  ner_pipeline = pipeline(
34
  "ner",
35
  model=ner_model,
36
  tokenizer=ner_tokenizer,
37
+ grouped_entities=True,
38
  )
39
 
40
+ # ---------------------------------------------------------------------------
41
+ # Core functionality
42
+ # ---------------------------------------------------------------------------
43
+
44
+ def fetch_news(ticker: str) -> List[dict]:
45
+ """Scrape *up to* 30 recent headlines from Finviz for a given *ticker*.
46
+
47
+ Returns a list of dictionaries with ``{"title": str, "link": str}`` or an
48
+ empty list on any error/edge‑case (e.g. anti‑scraping redirect).
49
+ """
50
  try:
51
  url = f"https://finviz.com/quote.ashx?t={ticker}"
52
  headers = {
53
+ "User-Agent": "Mozilla/5.0",
54
+ "Accept": "text/html",
55
+ "Accept-Language": "en-US,en;q=0.5",
56
+ "Referer": "https://finviz.com/",
57
+ "Connection": "keep-alive",
58
  }
59
+ response = requests.get(url, headers=headers, timeout=10)
60
  if response.status_code != 200:
61
  return []
62
 
63
+ soup = BeautifulSoup(response.text, "html.parser")
64
+ page_title = soup.title.text if soup.title else ""
65
+ if ticker.upper() not in page_title.upper():
66
+ # Finviz sometimes redirects to a placeholder page if the ticker is unknown.
67
  return []
68
 
69
+ news_table = soup.find(id="news-table")
70
  if news_table is None:
71
  return []
72
 
73
+ latest_news: List[dict] = []
74
+ for row in news_table.find_all("tr")[:30]: # keep only the 30 most recent rows
75
+ link_tag = row.find("a")
76
+ if link_tag:
77
+ latest_news.append({
78
+ "title": link_tag.get_text(strip=True),
79
+ "link": link_tag["href"],
80
+ })
81
+ return latest_news
82
  except Exception:
83
+ # swallow all exceptions and degrade gracefully
84
  return []
85
 
86
+ # ---------------------------------------------------------------------------
87
+ # Sentiment analysis helpers
88
+ # ---------------------------------------------------------------------------
89
+ # Raw labels coming from the FinancialBERT model
90
+ _POSITIVE = "positive"
91
+ _NEGATIVE = "negative"
92
+
93
+ _DEFAULT_THRESHOLD = 0.55 # default probability threshold; callers may override
94
+
95
+ def analyze_sentiment(
96
+ text: str,
97
+ pipe=None,
98
+ threshold: float = _DEFAULT_THRESHOLD,
99
+ ) -> Tuple[str, float]:
100
+ """Classify *text* as **Positive/Negative** and return its positive probability.
101
+
102
+ The underlying model is three‑class (positive/neutral/negative). We keep the
103
+ **positive** score only and compare it against *threshold* to obtain a binary
104
+ label. The function is **side‑effect free** and will never raise; on any
105
+ internal error it falls back to ``("Unknown", 0.0)``.
106
  """
107
  try:
108
  sentiment_pipe = pipe or sentiment_pipeline
109
+ raw_scores = sentiment_pipe(text, return_all_scores=True, truncation=True)[0]
110
+ score_lookup = {item["label"].lower(): item["score"] for item in raw_scores}
111
+ pos_score = score_lookup.get(_POSITIVE, 0.0)
112
+ label = "Positive" if pos_score >= threshold else "Negative"
113
+ return label, pos_score
114
  except Exception:
115
+ return "Unknown", 0.0
116
 
117
+ # ---------------------------------------------------------------------------
118
+ # Aggregation logic – turning many headlines into one overall label
119
+ # ---------------------------------------------------------------------------
120
+
121
+ def aggregate_sentiments(
122
+ results: List[Tuple[str, float]],
123
+ avg_threshold: float = _DEFAULT_THRESHOLD,
124
+ ) -> str:
125
+ """Combine individual headline results into a single overall label.
126
+
127
+ The rule is simple: compute the *mean* positive probability across all
128
+ headlines and compare it with *avg_threshold*. If the list is empty, the
129
+ function returns ``"Unknown"``.
130
  """
131
+ if not results:
132
+ return "Unknown"
133
+
134
+ avg_pos = sum(score for _, score in results) / len(results)
135
+ return "Positive" if avg_pos >= avg_threshold else "Negative"
136
+
137
+ # ---------------------------------------------------------------------------
138
+ # Public helpers (kept for backward compatibility with app.py)
139
+ # ---------------------------------------------------------------------------
 
 
 
 
 
140
 
 
141
  def get_sentiment_pipeline():
142
+ """Expose the initialised sentiment pipeline (singleton)."""
143
  return sentiment_pipeline
144
 
145
+
146
  def get_ner_pipeline():
147
+ """Expose the initialised NER pipeline (singleton)."""
148
  return ner_pipeline