rafmacalaba's picture
new example with new rels
ce82a96
raw
history blame
7.93 kB
import os
import re
import json
from collections import defaultdict
import gradio as gr
from typing import List, Dict, Any, Tuple
# Load environment variable for cache dir (useful on Spaces)
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
# Import GLiNER model and relation extractor
from gliner import GLiNER
#from relation_extraction import CustomGLiNERRelationExtractor
# Cache and initialize model + relation extractor
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v6-pos"
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
from relation_extraction import CustomGLiNERRelationExtractor
relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True)
# Sample text
SAMPLE_TEXT = (
"In 2010, Smith published the third round of the Demographic and Health Survey (DHS), a nationally representative cross-sectional survey funded and published by the World Bank with fieldwork reference year 2019 and a reference population of women aged 15–49. Conducted in 2020, it serves as the principal data source for collecting household composition, fertility and mortality rates, maternal and child health indicators, and access to water and sanitation across Nigeria, Kenya, and Ghana."
)
# Post-processing: prune acronyms and self-relations
labels = ['named dataset', 'unnamed dataset', 'vague dataset']
rels = ['acronym', 'author', 'data description', 'data geography', 'data source', 'data type', 'publication year', 'publisher', 'reference population', 'reference year', 'version']
TYPE2RELS = {
"named dataset": rels,
"unnamed dataset": rels,
"vague dataset": rels,
}
def inference_pipeline(
text: str,
model,
labels: List[str],
relation_extractor: CustomGLiNERRelationExtractor,
TYPE2RELS: Dict[str, List[str]],
ner_threshold: float = 0.7,
rel_threshold: float = 0.5,
re_multi_label: bool = False,
return_index: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
ner_preds = model.predict_entities(
text,
labels,
flat_ner=True,
threshold=ner_threshold
)
re_results: Dict[str, List[Dict[str, Any]]] = {}
for ner in ner_preds:
span = ner['text']
rel_types = TYPE2RELS.get(ner['label'], [])
if not rel_types:
continue
slot_labels = [f"{span} <> {r}" for r in rel_types]
preds = relation_extractor(
text,
relations=None,
entities=None,
relation_labels=slot_labels,
threshold=rel_threshold,
multi_label=re_multi_label,
return_index=return_index,
)[0]
re_results[span] = preds
return ner_preds, re_results
def prune_acronym_and_self_relations(ner_preds, rel_preds):
# 1) Find acronym targets strictly shorter than their source
acronym_targets = {
r["target"]
for src, rels in rel_preds.items()
for r in rels
if r["relation"] == "acronym" and len(r["target"]) < len(src)
}
# 2) Filter NER: drop any named-dataset whose text is in that set
filtered_ner = [
ent for ent in ner_preds
if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets)
]
# 3) Filter RE: drop blocks for acronym sources, and self-relations
filtered_re = {}
for src, rels in rel_preds.items():
if src in acronym_targets:
continue
kept = [r for r in rels if r["target"] != src]
if kept:
filtered_re[src] = kept
return filtered_ner, filtered_re
# Highlighting function
def highlight_text(text, ner_threshold, rel_threshold):
# 1) Inference
ner_preds, rel_preds = inference_pipeline(
text,
model=model,
labels=labels,
relation_extractor=relation_extractor,
TYPE2RELS=TYPE2RELS,
ner_threshold=ner_threshold,
rel_threshold=rel_threshold,
re_multi_label=False,
return_index=True,
)
ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
# 2) Compute how long the RE prompt prefix is
# This must match exactly what your extractor prepends:
prefix = f"{relation_extractor.prompt} \n "
prefix_len = len(prefix)
# 3) Gather spans
spans = []
for ent in ner_preds:
spans.append((ent["start"], ent["end"], ent["label"]))
# Use the extractor‐returned start/end, minus prefix_len
for src, rels in rel_preds.items():
for r in rels:
# adjust the indices back onto the raw text
s = r["start"] - prefix_len
e = r["end"] - prefix_len
# skip anything that fell outside
if s < 0 or e < 0:
continue
label = f"{r['source']} <> {r['relation']}"
spans.append((s, e, label))
# 4) Merge & build entities (same as before)
merged = defaultdict(list)
for s, e, lbl in spans:
merged[(s, e)].append(lbl)
entities = []
for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]):
entities.append({
"entity": ", ".join(lbls),
"start": s,
"end": e
})
return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}
# JSON output function
def _cached_predictions(state):
if not state:
return "📋 No predictions yet. Click **Submit** first."
return json.dumps(state, indent=2)
with gr.Blocks() as demo:
gr.Markdown("""# Data Use Detector
This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more.
**How it works**
1. **NER**: Recognizes dataset names in your text.
2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym).
3. **Visualization**: Highlights entities and relation spans inline.
**Instructions**
1. Paste or edit your text in the box below.
2. Tweak the **NER** & **RE** confidence sliders.
3. Click **Submit** to see highlights.
4. Click **Get Model Predictions** to view the raw JSON output.
**Resources**
- **Model:** [rafmacalaba/gliner_re_finetuned-v3](https://huggingface.co/rafmacalaba/gliner_re_finetuned-v3)
- **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263)
- [GLiNER GitHub Repo](https://github.com/urchade/GLiNER)
- [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html)
""")
txt_in = gr.Textbox(
label="Input Text",
lines=4,
value=SAMPLE_TEXT
)
ner_slider = gr.Slider(
0, 1, value=0.7, step=0.01,
label="NER Threshold",
info="Minimum confidence for named-entity spans."
)
re_slider = gr.Slider(
0, 1, value=0.5, step=0.01,
label="RE Threshold",
info="Minimum confidence for relation extractions."
)
highlight_btn = gr.Button("Submit")
txt_out = gr.HighlightedText(label="Annotated Entities")
get_pred_btn = gr.Button("Get Model Predictions")
json_out = gr.Textbox(label="Model Predictions (JSON)", lines=15)
state = gr.State()
# Wire up interactions
highlight_btn.click(
fn=highlight_text,
inputs=[txt_in, ner_slider, re_slider],
outputs=[txt_out, state]
)
get_pred_btn.click(
fn=_cached_predictions,
inputs=[state],
outputs=[json_out]
)
# Enable queue for concurrency
demo.queue(default_concurrency_limit=5)
# Launch the app
demo.launch(debug=True, inline=True)