Spaces:
Running
Running
import os | |
import re | |
import json | |
from collections import defaultdict | |
import gradio as gr | |
from typing import List, Dict, Any, Tuple | |
# Load environment variable for cache dir (useful on Spaces) | |
_CACHE_DIR = os.environ.get("CACHE_DIR", None) | |
# Import GLiNER model and relation extractor | |
from gliner import GLiNER | |
#from relation_extraction import CustomGLiNERRelationExtractor | |
# Cache and initialize model + relation extractor | |
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v6-pos" | |
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR) | |
from relation_extraction import CustomGLiNERRelationExtractor | |
relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True) | |
# Sample text | |
SAMPLE_TEXT = ( | |
"In 2010, Smith published the third round of the Demographic and Health Survey (DHS), a nationally representative cross-sectional survey funded and published by the World Bank with fieldwork reference year 2019 and a reference population of women aged 15–49. Conducted in 2020, it serves as the principal data source for collecting household composition, fertility and mortality rates, maternal and child health indicators, and access to water and sanitation across Nigeria, Kenya, and Ghana." | |
) | |
# Post-processing: prune acronyms and self-relations | |
labels = ['named dataset', 'unnamed dataset', 'vague dataset'] | |
rels = ['acronym', 'author', 'data description', 'data geography', 'data source', 'data type', 'publication year', 'publisher', 'reference population', 'reference year', 'version'] | |
TYPE2RELS = { | |
"named dataset": rels, | |
"unnamed dataset": rels, | |
"vague dataset": rels, | |
} | |
def inference_pipeline( | |
text: str, | |
model, | |
labels: List[str], | |
relation_extractor: CustomGLiNERRelationExtractor, | |
TYPE2RELS: Dict[str, List[str]], | |
ner_threshold: float = 0.7, | |
rel_threshold: float = 0.5, | |
re_multi_label: bool = False, | |
return_index: bool = False, | |
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]: | |
ner_preds = model.predict_entities( | |
text, | |
labels, | |
flat_ner=True, | |
threshold=ner_threshold | |
) | |
re_results: Dict[str, List[Dict[str, Any]]] = {} | |
for ner in ner_preds: | |
span = ner['text'] | |
rel_types = TYPE2RELS.get(ner['label'], []) | |
if not rel_types: | |
continue | |
slot_labels = [f"{span} <> {r}" for r in rel_types] | |
preds = relation_extractor( | |
text, | |
relations=None, | |
entities=None, | |
relation_labels=slot_labels, | |
threshold=rel_threshold, | |
multi_label=re_multi_label, | |
return_index=return_index, | |
)[0] | |
re_results[span] = preds | |
return ner_preds, re_results | |
def prune_acronym_and_self_relations(ner_preds, rel_preds): | |
# 1) Find acronym targets strictly shorter than their source | |
acronym_targets = { | |
r["target"] | |
for src, rels in rel_preds.items() | |
for r in rels | |
if r["relation"] == "acronym" and len(r["target"]) < len(src) | |
} | |
# 2) Filter NER: drop any named-dataset whose text is in that set | |
filtered_ner = [ | |
ent for ent in ner_preds | |
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets) | |
] | |
# 3) Filter RE: drop blocks for acronym sources, and self-relations | |
filtered_re = {} | |
for src, rels in rel_preds.items(): | |
if src in acronym_targets: | |
continue | |
kept = [r for r in rels if r["target"] != src] | |
if kept: | |
filtered_re[src] = kept | |
return filtered_ner, filtered_re | |
# Highlighting function | |
def highlight_text(text, ner_threshold, rel_threshold): | |
# 1) Inference | |
ner_preds, rel_preds = inference_pipeline( | |
text, | |
model=model, | |
labels=labels, | |
relation_extractor=relation_extractor, | |
TYPE2RELS=TYPE2RELS, | |
ner_threshold=ner_threshold, | |
rel_threshold=rel_threshold, | |
re_multi_label=False, | |
return_index=True, | |
) | |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) | |
# 2) Compute how long the RE prompt prefix is | |
# This must match exactly what your extractor prepends: | |
prefix = f"{relation_extractor.prompt} \n " | |
prefix_len = len(prefix) | |
# 3) Gather spans | |
spans = [] | |
for ent in ner_preds: | |
spans.append((ent["start"], ent["end"], ent["label"])) | |
# Use the extractor‐returned start/end, minus prefix_len | |
for src, rels in rel_preds.items(): | |
for r in rels: | |
# adjust the indices back onto the raw text | |
s = r["start"] - prefix_len | |
e = r["end"] - prefix_len | |
# skip anything that fell outside | |
if s < 0 or e < 0: | |
continue | |
label = f"{r['source']} <> {r['relation']}" | |
spans.append((s, e, label)) | |
# 4) Merge & build entities (same as before) | |
merged = defaultdict(list) | |
for s, e, lbl in spans: | |
merged[(s, e)].append(lbl) | |
entities = [] | |
for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]): | |
entities.append({ | |
"entity": ", ".join(lbls), | |
"start": s, | |
"end": e | |
}) | |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds} | |
# JSON output function | |
def _cached_predictions(state): | |
if not state: | |
return "📋 No predictions yet. Click **Submit** first." | |
return json.dumps(state, indent=2) | |
with gr.Blocks() as demo: | |
gr.Markdown("""# Data Use Detector | |
This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more. | |
**How it works** | |
1. **NER**: Recognizes dataset names in your text. | |
2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym). | |
3. **Visualization**: Highlights entities and relation spans inline. | |
**Instructions** | |
1. Paste or edit your text in the box below. | |
2. Tweak the **NER** & **RE** confidence sliders. | |
3. Click **Submit** to see highlights. | |
4. Click **Get Model Predictions** to view the raw JSON output. | |
**Resources** | |
- **Model:** [rafmacalaba/gliner_re_finetuned-v3](https://huggingface.co/rafmacalaba/gliner_re_finetuned-v3) | |
- **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263) | |
- [GLiNER GitHub Repo](https://github.com/urchade/GLiNER) | |
- [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html) | |
""") | |
txt_in = gr.Textbox( | |
label="Input Text", | |
lines=4, | |
value=SAMPLE_TEXT | |
) | |
ner_slider = gr.Slider( | |
0, 1, value=0.7, step=0.01, | |
label="NER Threshold", | |
info="Minimum confidence for named-entity spans." | |
) | |
re_slider = gr.Slider( | |
0, 1, value=0.5, step=0.01, | |
label="RE Threshold", | |
info="Minimum confidence for relation extractions." | |
) | |
highlight_btn = gr.Button("Submit") | |
txt_out = gr.HighlightedText(label="Annotated Entities") | |
get_pred_btn = gr.Button("Get Model Predictions") | |
json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15) | |
state = gr.State() | |
# Wire up interactions | |
highlight_btn.click( | |
fn=highlight_text, | |
inputs=[txt_in, ner_slider, re_slider], | |
outputs=[txt_out, state] | |
) | |
get_pred_btn.click( | |
fn=_cached_predictions, | |
inputs=[state], | |
outputs=[json_out] | |
) | |
# Enable queue for concurrency | |
demo.queue(default_concurrency_limit=5) | |
# Launch the app | |
demo.launch(debug=True, inline=True) |