import os import re import json from collections import defaultdict import gradio as gr # Load environment variable for cache dir (useful on Spaces) _CACHE_DIR = os.environ.get("CACHE_DIR", None) # Import GLiNER model and relation extractor from gliner import GLiNER from gliner.multitask import GLiNERRelationExtractor # Cache and initialize model + relation extractor DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3" model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR) relation_extractor = GLiNERRelationExtractor(model=model) # Sample text SAMPLE_TEXT = ( "In early 2012, the World Bank published the full report of the 2011 Demographic and Health Survey (DHS) " "for the Republic of Mali. Conducted between June and December 2011 under the technical oversight of Mali’s " "National Institute of Statistics and paired with on-the-ground data-collection teams, this nationally representative survey " "gathered detailed information on household composition, education levels, employment and income, fertility and family planning, " "maternal and child health, nutrition, mortality, and access to basic services. By combining traditional census modules with " "specialized questionnaires on women’s and children’s health, the DHS offers policymakers, development partners, and researchers " "a rich dataset of socioeconomic characteristics—ranging from literacy and school attendance to water and sanitation infrastructure—" "that can be used to monitor progress on poverty reduction, inform targeted social programs, and guide longer-term economic planning." ) # Post-processing: prune acronyms and self-relations labels = ['named dataset', 'unnamed dataset', 'vague dataset'] rels = ['acronym', 'author', 'data description',\ 'data geography', 'data source', 'data type',\ 'publication year', 'publisher', 'reference year', 'version'] TYPE2RELS = { "named dataset": rels, "unnamed dataset": rels, "vague dataset": rels, } def prune_acronym_and_self_relations(ner_preds, rel_preds): # 1) Find acronym targets strictly shorter than their source acronym_targets = { r["target"] for src, rels in rel_preds.items() for r in rels if r["relation"] == "acronym" and len(r["target"]) < len(src) } # 2) Filter NER: drop any named-dataset whose text is in that set filtered_ner = [ ent for ent in ner_preds if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets) ] # 3) Filter RE: drop blocks for acronym sources, and self-relations filtered_re = {} for src, rels in rel_preds.items(): if src in acronym_targets: continue kept = [r for r in rels if r["target"] != src] if kept: filtered_re[src] = kept return filtered_ner, filtered_re # Highlighting function def highlight_text(text, ner_threshold, re_threshold): # Run inference ner_preds, rel_preds = inference_pipeline( text, model=model, labels=labels, relation_extractor=relation_extractor, TYPE2RELS=TYPE2RELS, ner_threshold=ner_threshold, re_threshold=re_threshold, re_multi_label=False ) # Post-process ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) # Gather all spans spans = [] for ent in ner_preds: spans.append((ent["start"], ent["end"], ent["label"])) for src, rels in rel_preds.items(): for r in rels: for m in re.finditer(re.escape(r["target"]), text): spans.append((m.start(), m.end(), f"{src} <> {r['relation']}")) # Merge labels by span merged = defaultdict(list) for start, end, lbl in spans: merged[(start, end)].append(lbl) # Build Gradio entities entities = [] for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]): entities.append({ "entity": ", ".join(lbls), "start": start, "end": end }) return {"text": text, "entities": entities} # JSON output function def get_model_predictions(text, ner_threshold, re_threshold): ner_preds, rel_preds = inference_pipeline( text, model=model, labels=labels, relation_extractor=relation_extractor, TYPE2RELS=TYPE2RELS, ner_threshold=ner_threshold, re_threshold=re_threshold, re_multi_label=False ) ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds) return json.dumps({"ner": ner_preds, "relations": rel_preds}, indent=2) # Build Gradio UI demo = gr.Blocks() with demo: gr.Markdown("## Data Use Detector\n" "Adjust the sliders below to set thresholds, then:\n" "- **Submit** to highlight entities.\n" "- **Get Model Predictions** to see the raw JSON output.") txt_in = gr.Textbox( label="Input Text", lines=4, value=SAMPLE_TEXT ) ner_slider = gr.Slider( 0, 1, value=0.7, step=0.01, label="NER Threshold", info="Minimum confidence for named-entity spans." ) re_slider = gr.Slider( 0, 1, value=0.5, step=0.01, label="RE Threshold", info="Minimum confidence for relation extractions." ) highlight_btn = gr.Button("Submit") txt_out = gr.HighlightedText(label="Annotated Entities") get_pred_btn = gr.Button("Get Model Predictions") ner_rel_box = gr.Textbox(label="Model Predictions (JSON)", lines=15) # Wire up interactions highlight_btn.click( fn=highlight_text, inputs=[txt_in, ner_slider, re_slider], outputs=txt_out ) get_pred_btn.click( fn=get_model_predictions, inputs=[txt_in, ner_slider, re_slider], outputs=ner_rel_box ) # Enable queue for concurrency demo.queue(default_concurrency_limit=5) # Launch the app demo.launch(debug=True)