rafmacalaba's picture
change model
cce0575
import os
import re
import json
from collections import defaultdict
import gradio as gr
from typing import List, Dict, Any, Tuple
# Load environment variable for cache dir (useful on Spaces)
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
# Import GLiNER model and relation extractor
from gliner import GLiNER
#from relation_extraction import CustomGLiNERRelationExtractor
# Cache and initialize model + relation extractor
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v7-pos"
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
from relation_extraction import CustomGLiNERRelationExtractor
relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True)
# Sample text
SAMPLE_TEXT = (
"Encuesta Nacional de Hogares (ENAHO) is the Peruvian version of the Living Standards Measurement Survey, e.g. a nationally representative household survey collected monthly on a continuous basis. For our analysis, we use data from January 2007 to December 2020. The survey covers a wide variety of topics, including basic demographics, educational background, labor market conditions, crime victimization, and a module on respondent’s perceptions about the main problems in the country and trust in different local and national‐level institutions."
)
# Post-processing: prune acronyms and self-relations
labels = ['named dataset', 'unnamed dataset', 'vague dataset']
rels = ['acronym', 'author', 'data description', 'data geography', 'data source', 'data type', 'publication year', 'publisher', 'reference population', 'reference year', 'version']
TYPE2RELS = {
"named dataset": rels,
"unnamed dataset": rels,
"vague dataset": rels,
}
def inference_pipeline(
text: str,
model,
labels: List[str],
relation_extractor: CustomGLiNERRelationExtractor,
TYPE2RELS: Dict[str, List[str]],
ner_threshold: float = 0.7,
rel_threshold: float = 0.5,
re_multi_label: bool = False,
return_index: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
ner_preds = model.predict_entities(
text,
labels,
flat_ner=True,
threshold=ner_threshold
)
re_results: Dict[str, List[Dict[str, Any]]] = {}
for ner in ner_preds:
span = ner['text']
rel_types = TYPE2RELS.get(ner['label'], [])
if not rel_types:
continue
slot_labels = [f"{span} <> {r}" for r in rel_types]
preds = relation_extractor(
text,
relations=None,
entities=None,
relation_labels=slot_labels,
threshold=rel_threshold,
multi_label=re_multi_label,
return_index=return_index,
)[0]
re_results[span] = preds
return ner_preds, re_results
def prune_acronym_and_self_relations(ner_preds, rel_preds):
# 1) Find acronym targets strictly shorter than their source
acronym_targets = {
r["target"]
for src, rels in rel_preds.items()
for r in rels
if r["relation"] == "acronym" and len(r["target"]) < len(src)
}
# 2) Filter NER: drop any named-dataset whose text is in that set
filtered_ner = [
ent for ent in ner_preds
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets)
]
# 3) Filter RE: drop blocks for acronym sources, and self-relations
filtered_re = {}
for src, rels in rel_preds.items():
if src in acronym_targets:
continue
kept = [r for r in rels if r["target"] != src]
if kept:
filtered_re[src] = kept
return filtered_ner, filtered_re
# Highlighting function
def highlight_text(text, ner_threshold, rel_threshold):
# 1) Inference
ner_preds, rel_preds = inference_pipeline(
text,
model=model,
labels=labels,
relation_extractor=relation_extractor,
TYPE2RELS=TYPE2RELS,
ner_threshold=ner_threshold,
rel_threshold=rel_threshold,
re_multi_label=False,
return_index=True,
)
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
# 2) Compute how long the RE prompt prefix is
# This must match exactly what your extractor prepends:
prefix = f"{relation_extractor.prompt} \n "
prefix_len = len(prefix)
# 3) Gather spans
spans = []
for ent in ner_preds:
spans.append((ent["start"], ent["end"], ent["label"]))
# Use the extractor‐returned start/end, minus prefix_len
for src, rels in rel_preds.items():
for r in rels:
# adjust the indices back onto the raw text
s = r["start"] - prefix_len
e = r["end"] - prefix_len
# skip anything that fell outside
if s < 0 or e < 0:
continue
label = f"{r['source']} <> {r['relation']}"
spans.append((s, e, label))
# 4) Merge & build entities (same as before)
merged = defaultdict(list)
for s, e, lbl in spans:
merged[(s, e)].append(lbl)
entities = []
for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]):
entities.append({
"entity": ", ".join(lbls),
"start": s,
"end": e
})
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
# JSON output function
def _cached_predictions(state):
if not state:
return "📋 No predictions yet. Click **Submit** first."
return json.dumps(state, indent=2)
with gr.Blocks() as demo:
gr.Markdown(f"""# Data Use Detector
This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more.
**How it works**
1. **NER**: Recognizes dataset names in your text.
2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym).
3. **Visualization**: Highlights entities and relation spans inline.
**Instructions**
1. Paste or edit your text in the box below.
2. Tweak the **NER** & **RE** confidence sliders.
3. Click **Submit** to see highlights.
4. Click **Get Model Predictions** to view the raw JSON output.
**Resources**
- **Model:** [{DATA_MODEL_ID}](https://huggingface.co/{DATA_MODEL_ID})
- **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263)
- [GLiNER GitHub Repo](https://github.com/urchade/GLiNER)
- [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html)
""")
txt_in = gr.Textbox(
label="Input Text",
lines=4,
value=SAMPLE_TEXT
)
ner_slider = gr.Slider(
0, 1, value=0.7, step=0.01,
label="NER Threshold",
info="Minimum confidence for named-entity spans."
)
re_slider = gr.Slider(
0, 1, value=0.5, step=0.01,
label="RE Threshold",
info="Minimum confidence for relation extractions."
)
highlight_btn = gr.Button("Submit")
txt_out = gr.HighlightedText(label="Annotated Entities")
get_pred_btn = gr.Button("Get Model Predictions")
json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15)
state = gr.State()
# Wire up interactions
highlight_btn.click(
fn=highlight_text,
inputs=[txt_in, ner_slider, re_slider],
outputs=[txt_out, state]
)
get_pred_btn.click(
fn=_cached_predictions,
inputs=[state],
outputs=[json_out]
)
# Enable queue for concurrency
demo.queue(default_concurrency_limit=5)
# Launch the app
demo.launch(debug=True, inline=True)