Spaces:
Running
Running
File size: 7,954 Bytes
fc4b12d fd0fe48 d8c3809 fc4b12d 3b9fb2c 2463f9e fc4b12d 2ae65ac fc4b12d cce0575 fc4b12d 428a649 2ae65ac fc4b12d 4c028c5 fc4b12d ab71a6e ce82a96 ab71a6e 2463f9e 2ae65ac 2463f9e 2ae65ac 2463f9e 2ae65ac 2463f9e 2ae65ac 2463f9e 2ae65ac 2463f9e fc4b12d cd683ff fc4b12d cd683ff c35975c fc4b12d 2ae65ac 28e7655 eb6e673 2ae65ac eb6e673 fc4b12d 28e7655 fc4b12d 28e7655 fc4b12d 079aa2a 28e7655 eb6e673 28e7655 eb6e673 28e7655 eb6e673 28e7655 eb6e673 28e7655 72c5156 3d53082 fc4b12d 72c5156 b99c40c 39d4f87 2222dbb 593f17e 2222dbb 593f17e 2222dbb 593f17e 2222dbb 593f17e 4c028c5 593f17e fc4b12d 13e7831 d799589 fc4b12d 215cbc3 72c5156 fc4b12d 72c5156 fc4b12d 72c5156 d8c3809 13e7831 fc4b12d 13e7831 2ae65ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import os
import re
import json
from collections import defaultdict
import gradio as gr
from typing import List, Dict, Any, Tuple
# Load environment variable for cache dir (useful on Spaces)
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
# Import GLiNER model and relation extractor
from gliner import GLiNER
#from relation_extraction import CustomGLiNERRelationExtractor
# Cache and initialize model + relation extractor
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v7-pos"
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
from relation_extraction import CustomGLiNERRelationExtractor
relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True)
# Sample text
SAMPLE_TEXT = (
"Encuesta Nacional de Hogares (ENAHO) is the Peruvian version of the Living Standards Measurement Survey, e.g. a nationally representative household survey collected monthly on a continuous basis. For our analysis, we use data from January 2007 to December 2020. The survey covers a wide variety of topics, including basic demographics, educational background, labor market conditions, crime victimization, and a module on respondent’s perceptions about the main problems in the country and trust in different local and national‐level institutions."
)
# Post-processing: prune acronyms and self-relations
labels = ['named dataset', 'unnamed dataset', 'vague dataset']
rels = ['acronym', 'author', 'data description', 'data geography', 'data source', 'data type', 'publication year', 'publisher', 'reference population', 'reference year', 'version']
TYPE2RELS = {
"named dataset": rels,
"unnamed dataset": rels,
"vague dataset": rels,
}
def inference_pipeline(
text: str,
model,
labels: List[str],
relation_extractor: CustomGLiNERRelationExtractor,
TYPE2RELS: Dict[str, List[str]],
ner_threshold: float = 0.7,
rel_threshold: float = 0.5,
re_multi_label: bool = False,
return_index: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
ner_preds = model.predict_entities(
text,
labels,
flat_ner=True,
threshold=ner_threshold
)
re_results: Dict[str, List[Dict[str, Any]]] = {}
for ner in ner_preds:
span = ner['text']
rel_types = TYPE2RELS.get(ner['label'], [])
if not rel_types:
continue
slot_labels = [f"{span} <> {r}" for r in rel_types]
preds = relation_extractor(
text,
relations=None,
entities=None,
relation_labels=slot_labels,
threshold=rel_threshold,
multi_label=re_multi_label,
return_index=return_index,
)[0]
re_results[span] = preds
return ner_preds, re_results
def prune_acronym_and_self_relations(ner_preds, rel_preds):
# 1) Find acronym targets strictly shorter than their source
acronym_targets = {
r["target"]
for src, rels in rel_preds.items()
for r in rels
if r["relation"] == "acronym" and len(r["target"]) < len(src)
}
# 2) Filter NER: drop any named-dataset whose text is in that set
filtered_ner = [
ent for ent in ner_preds
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets)
]
# 3) Filter RE: drop blocks for acronym sources, and self-relations
filtered_re = {}
for src, rels in rel_preds.items():
if src in acronym_targets:
continue
kept = [r for r in rels if r["target"] != src]
if kept:
filtered_re[src] = kept
return filtered_ner, filtered_re
# Highlighting function
def highlight_text(text, ner_threshold, rel_threshold):
# 1) Inference
ner_preds, rel_preds = inference_pipeline(
text,
model=model,
labels=labels,
relation_extractor=relation_extractor,
TYPE2RELS=TYPE2RELS,
ner_threshold=ner_threshold,
rel_threshold=rel_threshold,
re_multi_label=False,
return_index=True,
)
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
# 2) Compute how long the RE prompt prefix is
# This must match exactly what your extractor prepends:
prefix = f"{relation_extractor.prompt} \n "
prefix_len = len(prefix)
# 3) Gather spans
spans = []
for ent in ner_preds:
spans.append((ent["start"], ent["end"], ent["label"]))
# Use the extractor‐returned start/end, minus prefix_len
for src, rels in rel_preds.items():
for r in rels:
# adjust the indices back onto the raw text
s = r["start"] - prefix_len
e = r["end"] - prefix_len
# skip anything that fell outside
if s < 0 or e < 0:
continue
label = f"{r['source']} <> {r['relation']}"
spans.append((s, e, label))
# 4) Merge & build entities (same as before)
merged = defaultdict(list)
for s, e, lbl in spans:
merged[(s, e)].append(lbl)
entities = []
for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]):
entities.append({
"entity": ", ".join(lbls),
"start": s,
"end": e
})
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
# JSON output function
def _cached_predictions(state):
if not state:
return "📋 No predictions yet. Click **Submit** first."
return json.dumps(state, indent=2)
with gr.Blocks() as demo:
gr.Markdown(f"""# Data Use Detector
This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more.
**How it works**
1. **NER**: Recognizes dataset names in your text.
2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym).
3. **Visualization**: Highlights entities and relation spans inline.
**Instructions**
1. Paste or edit your text in the box below.
2. Tweak the **NER** & **RE** confidence sliders.
3. Click **Submit** to see highlights.
4. Click **Get Model Predictions** to view the raw JSON output.
**Resources**
- **Model:** [{DATA_MODEL_ID}](https://huggingface.co/{DATA_MODEL_ID})
- **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263)
- [GLiNER GitHub Repo](https://github.com/urchade/GLiNER)
- [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html)
""")
txt_in = gr.Textbox(
label="Input Text",
lines=4,
value=SAMPLE_TEXT
)
ner_slider = gr.Slider(
0, 1, value=0.7, step=0.01,
label="NER Threshold",
info="Minimum confidence for named-entity spans."
)
re_slider = gr.Slider(
0, 1, value=0.5, step=0.01,
label="RE Threshold",
info="Minimum confidence for relation extractions."
)
highlight_btn = gr.Button("Submit")
txt_out = gr.HighlightedText(label="Annotated Entities")
get_pred_btn = gr.Button("Get Model Predictions")
json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15)
state = gr.State()
# Wire up interactions
highlight_btn.click(
fn=highlight_text,
inputs=[txt_in, ner_slider, re_slider],
outputs=[txt_out, state]
)
get_pred_btn.click(
fn=_cached_predictions,
inputs=[state],
outputs=[json_out]
)
# Enable queue for concurrency
demo.queue(default_concurrency_limit=5)
# Launch the app
demo.launch(debug=True, inline=True) |