Spaces:
Running
Running
Commit
·
2463f9e
1
Parent(s):
ab71a6e
add labels and rels
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import re
|
|
3 |
import json
|
4 |
from collections import defaultdict
|
5 |
import gradio as gr
|
6 |
-
|
7 |
# Load environment variable for cache dir (useful on Spaces)
|
8 |
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
9 |
|
@@ -41,6 +41,46 @@ TYPE2RELS = {
|
|
41 |
"vague dataset": rels,
|
42 |
}
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def prune_acronym_and_self_relations(ner_preds, rel_preds):
|
45 |
# 1) Find acronym targets strictly shorter than their source
|
46 |
acronym_targets = {
|
|
|
3 |
import json
|
4 |
from collections import defaultdict
|
5 |
import gradio as gr
|
6 |
+
from typing import List, Dict, Any, Tuple
|
7 |
# Load environment variable for cache dir (useful on Spaces)
|
8 |
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
9 |
|
|
|
41 |
"vague dataset": rels,
|
42 |
}
|
43 |
|
44 |
+
def inference_pipeline(
|
45 |
+
text: str,
|
46 |
+
model,
|
47 |
+
labels: List[str],
|
48 |
+
relation_extractor: GLiNERRelationExtractor,
|
49 |
+
TYPE2RELS: Dict[str, List[str]],
|
50 |
+
ner_threshold: float = 0.5,
|
51 |
+
re_threshold: float = 0.4,
|
52 |
+
re_multi_label: bool = False,
|
53 |
+
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
|
54 |
+
ner_preds = model.predict_entities(
|
55 |
+
text,
|
56 |
+
labels,
|
57 |
+
flat_ner=True,
|
58 |
+
threshold=ner_threshold
|
59 |
+
)
|
60 |
+
|
61 |
+
re_results: Dict[str, List[Dict[str, Any]]] = {}
|
62 |
+
for ner in ner_preds:
|
63 |
+
span = ner['text']
|
64 |
+
rel_types = TYPE2RELS.get(ner['label'], [])
|
65 |
+
if not rel_types:
|
66 |
+
continue
|
67 |
+
|
68 |
+
slot_labels = [f"{span} <> {r}" for r in rel_types]
|
69 |
+
|
70 |
+
preds = relation_extractor(
|
71 |
+
text,
|
72 |
+
relations=None,
|
73 |
+
entities=None,
|
74 |
+
relation_labels=slot_labels,
|
75 |
+
threshold=re_threshold,
|
76 |
+
multi_label=re_multi_label,
|
77 |
+
distance_threshold=100,
|
78 |
+
)[0]
|
79 |
+
|
80 |
+
re_results[span] = preds
|
81 |
+
|
82 |
+
return ner_preds, re_results
|
83 |
+
|
84 |
def prune_acronym_and_self_relations(ner_preds, rel_preds):
|
85 |
# 1) Find acronym targets strictly shorter than their source
|
86 |
acronym_targets = {
|