Spaces:
Running
Running
import os | |
import re | |
import json | |
from collections import defaultdict | |
import gradio as gr | |
from typing import List, Dict, Any, Tuple | |
# Load environment variable for cache dir (useful on Spaces) | |
_CACHE_DIR = os.environ.get("CACHE_DIR", None) | |
# Import GLiNER model and relation extractor | |
from gliner import GLiNER | |
#from relation_extraction import CustomGLiNERRelationExtractor | |
# Cache and initialize model + relation extractor | |
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v7-pos" | |
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR) | |
from relation_extraction import CustomGLiNERRelationExtractor | |
relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True) | |
# Sample text | |
SAMPLE_TEXT = ( | |
"Encuesta Nacional de Hogares (ENAHO) is the Peruvian version of the Living Standards Measurement Survey, e.g. a nationally representative household survey collected monthly on a continuous basis. For our analysis, we use data from January 2007 to December 2020. The survey covers a wide variety of topics, including basic demographics, educational background, labor market conditions, crime victimization, and a module on respondent’s perceptions about the main problems in the country and trust in different local and national‐level institutions." | |
) | |
# Post-processing: prune acronyms and self-relations | |
labels = ['named dataset', 'unnamed dataset', 'vague dataset'] | |
rels = ['acronym', 'author', 'data description', 'data geography', 'data source', 'data type', 'publication year', 'publisher', 'reference population', 'reference year', 'version'] | |
TYPE2RELS = { | |
"named dataset": rels, | |
"unnamed dataset": rels, | |
"vague dataset": rels, | |
} | |
def inference_pipeline( | |
text: str, | |
model, | |
labels: List[str], | |
relation_extractor: CustomGLiNERRelationExtractor, | |
TYPE2RELS: Dict[str, List[str]], | |
ner_threshold: float = 0.7, | |
rel_threshold: float = 0.5, | |
re_multi_label: bool = False, | |
return_index: bool = False, | |
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]: | |
ner_preds = model.predict_entities( | |
text, | |
labels, | |
flat_ner=True, | |
threshold=ner_threshold | |
) | |
re_results: Dict[str, List[Dict[str, Any]]] = {} | |
for ner in ner_preds: | |
span = ner['text'] | |
rel_types = TYPE2RELS.get(ner['label'], []) | |
if not rel_types: | |
continue | |
slot_labels = [f"{span} <> {r}" for r in rel_types] | |
preds = relation_extractor( | |
text, | |
relations=None, | |
entities=None, | |
relation_labels=slot_labels, | |
threshold=rel_threshold, | |
multi_label=re_multi_label, | |
return_index=return_index, | |
)[0] | |
re_results[span] = preds | |
return ner_preds, re_results | |
def prune_acronym_and_self_relations(ner_preds, rel_preds): | |
# 1) Find acronym targets strictly shorter than their source | |
acronym_targets = { | |
r["target"] | |
for src, rels in rel_preds.items() | |
for r in rels | |
if r["relation"] == "acronym" and len(r["target"]) < len(src) | |
} | |
# 2) Filter NER: drop any named-dataset whose text is in that set | |
filtered_ner = [ | |
ent for ent in ner_preds | |
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets) | |
] | |
# 3) Filter RE: drop blocks for acronym sources, and self-relations | |
filtered_re = {} | |
for src, rels in rel_preds.items(): | |
if src in acronym_targets: | |
continue | |
kept = [r for r in rels if r["target"] != src] | |
if kept: | |
filtered_re[src] = kept | |
return filtered_ner, filtered_re | |
# Highlighting function | |
def highlight_text(text, ner_threshold, rel_threshold): | |
# 1) Inference | |
ner_preds, rel_preds = inference_pipeline( | |
text, | |
model=model, | |
labels=labels, | |
relation_extractor=relation_extractor, | |
TYPE2RELS=TYPE2RELS, | |
ner_threshold=ner_threshold, | |
rel_threshold=rel_threshold, | |
re_multi_label=False, | |
return_index=True, | |
) | |
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) | |
# 2) Compute how long the RE prompt prefix is | |
# This must match exactly what your extractor prepends: | |
prefix = f"{relation_extractor.prompt} \n " | |
prefix_len = len(prefix) | |
# 3) Gather spans | |
spans = [] | |
for ent in ner_preds: | |
spans.append((ent["start"], ent["end"], ent["label"])) | |
# Use the extractor‐returned start/end, minus prefix_len | |
for src, rels in rel_preds.items(): | |
for r in rels: | |
# adjust the indices back onto the raw text | |
s = r["start"] - prefix_len | |
e = r["end"] - prefix_len | |
# skip anything that fell outside | |
if s < 0 or e < 0: | |
continue | |
label = f"{r['source']} <> {r['relation']}" | |
spans.append((s, e, label)) | |
# 4) Merge & build entities (same as before) | |
merged = defaultdict(list) | |
for s, e, lbl in spans: | |
merged[(s, e)].append(lbl) | |
entities = [] | |
for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]): | |
entities.append({ | |
"entity": ", ".join(lbls), | |
"start": s, | |
"end": e | |
}) | |
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds} | |
# JSON output function | |
def _cached_predictions(state): | |
if not state: | |
return "📋 No predictions yet. Click **Submit** first." | |
return json.dumps(state, indent=2) | |
with gr.Blocks() as demo: | |
gr.Markdown(f"""# Data Use Detector | |
This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more. | |
**How it works** | |
1. **NER**: Recognizes dataset names in your text. | |
2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym). | |
3. **Visualization**: Highlights entities and relation spans inline. | |
**Instructions** | |
1. Paste or edit your text in the box below. | |
2. Tweak the **NER** & **RE** confidence sliders. | |
3. Click **Submit** to see highlights. | |
4. Click **Get Model Predictions** to view the raw JSON output. | |
**Resources** | |
- **Model:** [{DATA_MODEL_ID}](https://huggingface.co/{DATA_MODEL_ID}) | |
- **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263) | |
- [GLiNER GitHub Repo](https://github.com/urchade/GLiNER) | |
- [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html) | |
""") | |
txt_in = gr.Textbox( | |
label="Input Text", | |
lines=4, | |
value=SAMPLE_TEXT | |
) | |
ner_slider = gr.Slider( | |
0, 1, value=0.7, step=0.01, | |
label="NER Threshold", | |
info="Minimum confidence for named-entity spans." | |
) | |
re_slider = gr.Slider( | |
0, 1, value=0.5, step=0.01, | |
label="RE Threshold", | |
info="Minimum confidence for relation extractions." | |
) | |
highlight_btn = gr.Button("Submit") | |
txt_out = gr.HighlightedText(label="Annotated Entities") | |
get_pred_btn = gr.Button("Get Model Predictions") | |
json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15) | |
state = gr.State() | |
# Wire up interactions | |
highlight_btn.click( | |
fn=highlight_text, | |
inputs=[txt_in, ner_slider, re_slider], | |
outputs=[txt_out, state] | |
) | |
get_pred_btn.click( | |
fn=_cached_predictions, | |
inputs=[state], | |
outputs=[json_out] | |
) | |
# Enable queue for concurrency | |
demo.queue(default_concurrency_limit=5) | |
# Launch the app | |
demo.launch(debug=True, inline=True) |