LinkLinkWu commited on
Commit
dd3df57
·
verified ·
1 Parent(s): 15e8ca2

Update func.py

Browse files
Files changed (1) hide show
  1. func.py +26 -4
func.py CHANGED
@@ -7,12 +7,21 @@ import requests
7
  model_id = "LinkLinkWu/ISOM5240HKUSTBASE"
8
  sentiment_tokenizer = AutoTokenizer.from_pretrained(model_id)
9
  sentiment_model = AutoModelForSequenceClassification.from_pretrained(model_id)
10
- sentiment_pipeline = pipeline("sentiment-analysis", model=sentiment_model, tokenizer=sentiment_tokenizer)
 
 
 
 
11
 
12
  # NER pipeline
13
  ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
14
  ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
15
- ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, grouped_entities=True)
 
 
 
 
 
16
 
17
  # ----------- Core Functions -----------
18
  def fetch_news(ticker):
@@ -42,9 +51,9 @@ def fetch_news(ticker):
42
  for row in news_table.findAll('tr')[:30]:
43
  a_tag = row.find('a')
44
  if a_tag:
45
- title = a_tag.get_text()
46
  link = a_tag['href']
47
- news.append({'title': title, 'link': link})
48
  return news
49
  except Exception:
50
  return []
@@ -70,3 +79,16 @@ def extract_org_entities(text):
70
  return org_entities
71
  except Exception:
72
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  model_id = "LinkLinkWu/ISOM5240HKUSTBASE"
8
  sentiment_tokenizer = AutoTokenizer.from_pretrained(model_id)
9
  sentiment_model = AutoModelForSequenceClassification.from_pretrained(model_id)
10
+ sentiment_pipeline = pipeline(
11
+ "sentiment-analysis",
12
+ model=sentiment_model,
13
+ tokenizer=sentiment_tokenizer
14
+ )
15
 
16
  # NER pipeline
17
  ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
18
  ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
19
+ ner_pipeline = pipeline(
20
+ "ner",
21
+ model=ner_model,
22
+ tokenizer=ner_tokenizer,
23
+ grouped_entities=True
24
+ )
25
 
26
  # ----------- Core Functions -----------
27
  def fetch_news(ticker):
 
51
  for row in news_table.findAll('tr')[:30]:
52
  a_tag = row.find('a')
53
  if a_tag:
54
+ title_text = a_tag.get_text()
55
  link = a_tag['href']
56
+ news.append({'title': title_text, 'link': link})
57
  return news
58
  except Exception:
59
  return []
 
79
  return org_entities
80
  except Exception:
81
  return []
82
+
83
+ # ----------- Helper Functions for Imports -----------
84
+ def get_sentiment_pipeline():
85
+ """
86
+ Return the pre-initialized sentiment-analysis pipeline.
87
+ """
88
+ return sentiment_pipeline
89
+
90
+ def get_ner_pipeline():
91
+ """
92
+ Return the pre-initialized NER pipeline.
93
+ """
94
+ return ner_pipeline