Spaces:
Running
Running
Commit
·
fc4b12d
1
Parent(s):
d799589
deploy without gpu
Browse files
app.py
CHANGED
@@ -1,84 +1,170 @@
|
|
|
|
1 |
import re
|
2 |
import json
|
|
|
3 |
import gradio as gr
|
4 |
|
5 |
-
#
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
}
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
{'source': 'Home Visits Survey', 'relation': 'data geography', 'target': 'Jordan', 'score': 0.6180844902992249},
|
20 |
-
{'source': 'Home Visits Survey', 'relation': 'version', 'target': 'Round II', 'score': 0.9688164591789246},
|
21 |
-
{'source': 'Home Visits Survey', 'relation': 'acronym', 'target': 'HV', 'score': 0.9140607714653015},
|
22 |
-
{'source': 'Home Visits Survey', 'relation': 'publisher', 'target': 'UNHCR', 'score': 0.7762154340744019},
|
23 |
-
{'source': 'Home Visits Survey', 'relation': 'publisher', 'target': 'World Food Programme', 'score': 0.6582539677619934},
|
24 |
-
{'source': 'Home Visits Survey', 'relation': 'reference year', 'target': '2013', 'score': 0.524115264415741},
|
25 |
-
{'source': 'Home Visits Survey', 'relation': 'reference year', 'target': '2014', 'score': 0.6853994131088257},
|
26 |
-
{'source': 'Home Visits Survey', 'relation': 'data description', 'target': 'detailed socio-economic, health, and protection data', 'score': 0.6544178128242493},
|
27 |
]
|
28 |
-
}
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
#
|
37 |
-
|
|
|
|
|
38 |
|
39 |
-
|
40 |
entities = []
|
41 |
-
|
42 |
-
for ent in ner:
|
43 |
entities.append({
|
44 |
-
"entity":
|
45 |
-
"start":
|
46 |
-
"end":
|
47 |
})
|
48 |
-
|
49 |
-
for rel_list in relations.values():
|
50 |
-
for r in rel_list:
|
51 |
-
for m in re.finditer(re.escape(r["target"]), text):
|
52 |
-
entities.append({
|
53 |
-
"entity": r["relation"],
|
54 |
-
"start": m.start(),
|
55 |
-
"end": m.end(),
|
56 |
-
})
|
57 |
return {"text": text, "entities": entities}
|
58 |
|
59 |
-
|
60 |
-
return json.dumps({"ner": ner, "relations": relations}, indent=2)
|
61 |
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
gr.Markdown("## Data Use Detector\n"
|
64 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
txt_in = gr.Textbox(label="Input Text", lines=4, value=SAMPLE_TEXT)
|
67 |
highlight_btn = gr.Button("Submit")
|
68 |
-
txt_out
|
69 |
|
70 |
get_pred_btn = gr.Button("Get Model Predictions")
|
71 |
-
ner_rel_box
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
)
|
77 |
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
|
83 |
-
|
84 |
-
demo.launch()
|
|
|
1 |
+
import os
|
2 |
import re
|
3 |
import json
|
4 |
+
from collections import defaultdict
|
5 |
import gradio as gr
|
6 |
|
7 |
+
# Load environment variable for cache dir (useful on Spaces)
|
8 |
+
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
9 |
+
|
10 |
+
# Import GLiNER model and relation extractor
|
11 |
+
from gliner import GLiNER
|
12 |
+
from gliner.multitask import GLiNERRelationExtractor
|
13 |
+
|
14 |
+
# Import inference pipeline and configuration
|
15 |
+
from my_project.pipeline import inference_pipeline
|
16 |
+
from my_project.config import TYPE2RELS, labels
|
17 |
+
|
18 |
+
# Cache and initialize model + relation extractor
|
19 |
+
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3"
|
20 |
+
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
|
21 |
+
relation_extractor = GLiNERRelationExtractor(model=model)
|
22 |
+
|
23 |
+
# Sample text
|
24 |
+
SAMPLE_TEXT = (
|
25 |
+
"In early 2012, the World Bank published the full report of the 2011 Demographic and Health Survey (DHS) "
|
26 |
+
"for the Republic of Mali. Conducted between June and December 2011 under the technical oversight of Mali’s "
|
27 |
+
"National Institute of Statistics and paired with on-the-ground data-collection teams, this nationally representative survey "
|
28 |
+
"gathered detailed information on household composition, education levels, employment and income, fertility and family planning, "
|
29 |
+
"maternal and child health, nutrition, mortality, and access to basic services. By combining traditional census modules with "
|
30 |
+
"specialized questionnaires on women’s and children’s health, the DHS offers policymakers, development partners, and researchers "
|
31 |
+
"a rich dataset of socioeconomic characteristics—ranging from literacy and school attendance to water and sanitation infrastructure—"
|
32 |
+
"that can be used to monitor progress on poverty reduction, inform targeted social programs, and guide longer-term economic planning."
|
33 |
+
)
|
34 |
+
|
35 |
+
# Post-processing: prune acronyms and self-relations
|
36 |
+
|
37 |
+
def prune_acronym_and_self_relations(ner_preds, rel_preds):
|
38 |
+
# 1) Find acronym targets strictly shorter than their source
|
39 |
+
acronym_targets = {
|
40 |
+
r["target"]
|
41 |
+
for src, rels in rel_preds.items()
|
42 |
+
for r in rels
|
43 |
+
if r["relation"] == "acronym" and len(r["target"]) < len(src)
|
44 |
}
|
45 |
+
|
46 |
+
# 2) Filter NER: drop any named-dataset whose text is in that set
|
47 |
+
filtered_ner = [
|
48 |
+
ent for ent in ner_preds
|
49 |
+
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
]
|
|
|
51 |
|
52 |
+
# 3) Filter RE: drop blocks for acronym sources, and self-relations
|
53 |
+
filtered_re = {}
|
54 |
+
for src, rels in rel_preds.items():
|
55 |
+
if src in acronym_targets:
|
56 |
+
continue
|
57 |
+
kept = [r for r in rels if r["target"] != src]
|
58 |
+
if kept:
|
59 |
+
filtered_re[src] = kept
|
60 |
+
|
61 |
+
return filtered_ner, filtered_re
|
62 |
+
|
63 |
+
# Highlighting function
|
64 |
+
|
65 |
+
def highlight_text(text, ner_threshold, re_threshold):
|
66 |
+
# Run inference
|
67 |
+
ner_preds, rel_preds = inference_pipeline(
|
68 |
+
text,
|
69 |
+
model=model,
|
70 |
+
labels=labels,
|
71 |
+
relation_extractor=relation_extractor,
|
72 |
+
TYPE2RELS=TYPE2RELS,
|
73 |
+
ner_threshold=ner_threshold,
|
74 |
+
re_threshold=re_threshold,
|
75 |
+
re_multi_label=False
|
76 |
+
)
|
77 |
+
|
78 |
+
# Post-process
|
79 |
+
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
|
80 |
+
|
81 |
+
# Gather all spans
|
82 |
+
spans = []
|
83 |
+
for ent in ner_preds:
|
84 |
+
spans.append((ent["start"], ent["end"], ent["label"]))
|
85 |
+
for src, rels in rel_preds.items():
|
86 |
+
for r in rels:
|
87 |
+
for m in re.finditer(re.escape(r["target"]), text):
|
88 |
+
spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
|
89 |
|
90 |
+
# Merge labels by span
|
91 |
+
merged = defaultdict(list)
|
92 |
+
for start, end, lbl in spans:
|
93 |
+
merged[(start, end)].append(lbl)
|
94 |
|
95 |
+
# Build Gradio entities
|
96 |
entities = []
|
97 |
+
for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
|
|
|
98 |
entities.append({
|
99 |
+
"entity": ", ".join(lbls),
|
100 |
+
"start": start,
|
101 |
+
"end": end
|
102 |
})
|
103 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
return {"text": text, "entities": entities}
|
105 |
|
106 |
+
# JSON output function
|
|
|
107 |
|
108 |
+
def get_model_predictions(text, ner_threshold, re_threshold):
|
109 |
+
ner_preds, rel_preds = inference_pipeline(
|
110 |
+
text,
|
111 |
+
model=model,
|
112 |
+
labels=labels,
|
113 |
+
relation_extractor=relation_extractor,
|
114 |
+
TYPE2RELS=TYPE2RELS,
|
115 |
+
ner_threshold=ner_threshold,
|
116 |
+
re_threshold=re_threshold,
|
117 |
+
re_multi_label=False
|
118 |
+
)
|
119 |
+
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
|
120 |
+
return json.dumps({"ner": ner_preds, "relations": rel_preds}, indent=2)
|
121 |
+
|
122 |
+
# Build Gradio UI
|
123 |
+
demo = gr.Blocks()
|
124 |
+
with demo:
|
125 |
gr.Markdown("## Data Use Detector\n"
|
126 |
+
"Adjust the sliders below to set thresholds, then:\n"
|
127 |
+
"- **Submit** to highlight entities.\n"
|
128 |
+
"- **Get Model Predictions** to see the raw JSON output.")
|
129 |
+
|
130 |
+
txt_in = gr.Textbox(
|
131 |
+
label="Input Text",
|
132 |
+
lines=4,
|
133 |
+
value=SAMPLE_TEXT
|
134 |
+
)
|
135 |
+
|
136 |
+
ner_slider = gr.Slider(
|
137 |
+
0, 1, value=0.7, step=0.01,
|
138 |
+
label="NER Threshold",
|
139 |
+
info="Minimum confidence for named-entity spans."
|
140 |
+
)
|
141 |
+
re_slider = gr.Slider(
|
142 |
+
0, 1, value=0.5, step=0.01,
|
143 |
+
label="RE Threshold",
|
144 |
+
info="Minimum confidence for relation extractions."
|
145 |
+
)
|
146 |
|
|
|
147 |
highlight_btn = gr.Button("Submit")
|
148 |
+
txt_out = gr.HighlightedText(label="Annotated Entities")
|
149 |
|
150 |
get_pred_btn = gr.Button("Get Model Predictions")
|
151 |
+
ner_rel_box = gr.Textbox(label="Model Predictions (JSON)", lines=15)
|
152 |
+
|
153 |
+
# Wire up interactions
|
154 |
+
highlight_btn.click(
|
155 |
+
fn=highlight_text,
|
156 |
+
inputs=[txt_in, ner_slider, re_slider],
|
157 |
+
outputs=txt_out
|
158 |
+
)
|
159 |
+
get_pred_btn.click(
|
160 |
+
fn=get_model_predictions,
|
161 |
+
inputs=[txt_in, ner_slider, re_slider],
|
162 |
+
outputs=ner_rel_box
|
163 |
)
|
164 |
|
165 |
+
# Enable queue for concurrency
|
166 |
+
demo.queue(default_concurrency_limit=5)
|
167 |
+
|
168 |
+
# Launch the app
|
169 |
|
170 |
+
demo.launch(debug=True)
|
|