File size: 7,954 Bytes
fc4b12d
fd0fe48
d8c3809
fc4b12d
3b9fb2c
2463f9e
fc4b12d
 
 
 
 
2ae65ac
fc4b12d
 
cce0575
fc4b12d
428a649
2ae65ac
fc4b12d
 
 
4c028c5
fc4b12d
 
 
 
ab71a6e
ce82a96
ab71a6e
 
 
 
 
 
 
2463f9e
 
 
 
2ae65ac
2463f9e
2ae65ac
 
2463f9e
2ae65ac
2463f9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ae65ac
2463f9e
2ae65ac
2463f9e
 
 
 
 
 
fc4b12d
 
 
 
 
 
 
cd683ff
fc4b12d
 
 
 
 
cd683ff
c35975c
fc4b12d
 
 
 
 
 
 
 
 
 
 
 
 
2ae65ac
28e7655
eb6e673
 
 
 
 
 
 
2ae65ac
 
 
eb6e673
fc4b12d
 
28e7655
 
 
 
 
 
fc4b12d
 
 
28e7655
 
fc4b12d
079aa2a
28e7655
 
 
 
 
 
 
 
 
 
eb6e673
28e7655
 
eb6e673
 
28e7655
eb6e673
 
28e7655
 
eb6e673
28e7655
72c5156
3d53082
fc4b12d
72c5156
 
 
 
 
b99c40c
39d4f87
2222dbb
593f17e
2222dbb
593f17e
 
 
 
2222dbb
593f17e
 
 
 
 
2222dbb
593f17e
4c028c5
593f17e
 
 
 
 
fc4b12d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e7831
d799589
fc4b12d
215cbc3
 
72c5156
 
fc4b12d
 
 
 
72c5156
fc4b12d
 
72c5156
 
 
d8c3809
13e7831
fc4b12d
 
 
 
13e7831
2ae65ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import re
import json
from collections import defaultdict
import gradio as gr
from typing import List, Dict, Any, Tuple
# Load environment variable for cache dir (useful on Spaces)
_CACHE_DIR = os.environ.get("CACHE_DIR", None)

# Import GLiNER model and relation extractor
from gliner import GLiNER
#from relation_extraction import CustomGLiNERRelationExtractor

# Cache and initialize model + relation extractor
DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v7-pos"
model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
from relation_extraction import CustomGLiNERRelationExtractor
relation_extractor = CustomGLiNERRelationExtractor(model=model, return_index=True)

# Sample text
SAMPLE_TEXT = (
"Encuesta Nacional de Hogares (ENAHO) is the Peruvian version of the Living Standards Measurement Survey, e.g. a nationally representative household survey collected monthly on a continuous basis. For our analysis, we use data from January 2007 to December 2020. The survey covers a wide variety of topics, including basic demographics, educational background, labor market conditions, crime victimization, and a module on respondent’s perceptions about the main problems in the country and trust in different local and national‐level institutions."
)

# Post-processing: prune acronyms and self-relations

labels = ['named dataset', 'unnamed dataset', 'vague dataset']
rels = ['acronym', 'author', 'data description', 'data geography', 'data source', 'data type', 'publication year', 'publisher', 'reference population', 'reference year', 'version']

TYPE2RELS = {
    "named dataset":   rels,
    "unnamed dataset": rels,
    "vague dataset":   rels,
}

def inference_pipeline(
    text: str,
    model,
    labels: List[str],
    relation_extractor: CustomGLiNERRelationExtractor,
    TYPE2RELS: Dict[str, List[str]],
    ner_threshold: float = 0.7,
    rel_threshold: float = 0.5,
    re_multi_label: bool = False,
    return_index: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
    ner_preds = model.predict_entities(
        text,
        labels,
        flat_ner=True,
        threshold=ner_threshold
    )

    re_results: Dict[str, List[Dict[str, Any]]] = {}
    for ner in ner_preds:
        span       = ner['text']
        rel_types  = TYPE2RELS.get(ner['label'], [])
        if not rel_types:
            continue

        slot_labels = [f"{span} <> {r}" for r in rel_types]

        preds = relation_extractor(
            text,
            relations=None,
            entities=None,
            relation_labels=slot_labels,
            threshold=rel_threshold,
            multi_label=re_multi_label,
            return_index=return_index,
        )[0]

        re_results[span] = preds

    return ner_preds, re_results

def prune_acronym_and_self_relations(ner_preds, rel_preds):
    # 1) Find acronym targets strictly shorter than their source
    acronym_targets = {
        r["target"]
        for src, rels in rel_preds.items()
        for r in rels
        if r["relation"] == "acronym" and len(r["target"]) < len(src)
    }

    # 2) Filter NER: drop any named-dataset whose text is in that set
    filtered_ner = [
        ent for ent in ner_preds
        if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets)
    ]

    # 3) Filter RE: drop blocks for acronym sources, and self-relations
    filtered_re = {}
    for src, rels in rel_preds.items():
        if src in acronym_targets:
            continue
        kept = [r for r in rels if r["target"] != src]
        if kept:
            filtered_re[src] = kept

    return filtered_ner, filtered_re

# Highlighting function

def highlight_text(text, ner_threshold, rel_threshold):
    # 1) Inference
    ner_preds, rel_preds = inference_pipeline(
        text,
        model=model,
        labels=labels,
        relation_extractor=relation_extractor,
        TYPE2RELS=TYPE2RELS,
        ner_threshold=ner_threshold,
        rel_threshold=rel_threshold,
        re_multi_label=False,
        return_index=True,
    )
    ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)

    # 2) Compute how long the RE prompt prefix is
    #    This must match exactly what your extractor prepends:
    prefix = f"{relation_extractor.prompt} \n "
    prefix_len = len(prefix)

    # 3) Gather spans
    spans = []
    for ent in ner_preds:
        spans.append((ent["start"], ent["end"], ent["label"]))

    #    Use the extractor‐returned start/end, minus prefix_len
    for src, rels in rel_preds.items():
        for r in rels:
            # adjust the indices back onto the raw text
            s = r["start"] - prefix_len
            e = r["end"]   - prefix_len
            # skip anything that fell outside
            if s < 0 or e < 0:
                continue
            label = f"{r['source']} <> {r['relation']}"
            spans.append((s, e, label))

    # 4) Merge & build entities (same as before)
    merged = defaultdict(list)
    for s, e, lbl in spans:
        merged[(s, e)].append(lbl)

    entities = []
    for (s, e), lbls in sorted(merged.items(), key=lambda x: x[0]):
        entities.append({
            "entity": ", ".join(lbls),
            "start":  s,
            "end":    e
        })

    return {"text": text, "entities": entities}, {"ner": ner_preds, "relations": rel_preds}

# JSON output function
def _cached_predictions(state):
    if not state:
        return "📋 No predictions yet. Click **Submit** first."
    return json.dumps(state, indent=2)

with gr.Blocks() as demo:
    gr.Markdown(f"""# Data Use Detector
    
    This Space demonstrates our fine-tuned GLiNER model’s ability to spot **dataset mentions** and **relations** in any input text. It identifies dataset names via NER, then extracts relations such as **publisher**, **acronym**, **publication year**, **data geography**, and more.
    
    **How it works**  
    1. **NER**: Recognizes dataset names in your text.  
    2. **RE**: Links each dataset to its attributes (e.g., publisher, year, acronym).  
    3. **Visualization**: Highlights entities and relation spans inline.
    
    **Instructions**  
    1. Paste or edit your text in the box below.  
    2. Tweak the **NER** & **RE** confidence sliders.  
    3. Click **Submit** to see highlights.  
    4. Click **Get Model Predictions** to view the raw JSON output.
    
    **Resources**  
    - **Model:** [{DATA_MODEL_ID}](https://huggingface.co/{DATA_MODEL_ID}) 
    - **Paper:** _Large Language Models and Synthetic Data for Monitoring Dataset Mentions in Research Papers_ – ArXiv: [2502.10263](https://arxiv.org/pdf/2502.10263)  
    - [GLiNER GitHub Repo](https://github.com/urchade/GLiNER)  
    - [Project Docs](https://worldbank.github.io/ai4data-use/docs/introduction.html)  
    """)


    txt_in = gr.Textbox(
        label="Input Text",
        lines=4,
        value=SAMPLE_TEXT
    )

    ner_slider = gr.Slider(
        0, 1, value=0.7, step=0.01,
        label="NER Threshold",
        info="Minimum confidence for named-entity spans."
    )
    re_slider = gr.Slider(
        0, 1, value=0.5, step=0.01,
        label="RE Threshold",
        info="Minimum confidence for relation extractions."
    )

    highlight_btn = gr.Button("Submit")
    txt_out = gr.HighlightedText(label="Annotated Entities")

    get_pred_btn = gr.Button("Get Model Predictions")
    json_out      = gr.Textbox(label="Model Predictions (JSON)", lines=15)
    state         = gr.State()
    # Wire up interactions
    highlight_btn.click(
        fn=highlight_text,
        inputs=[txt_in, ner_slider, re_slider],
        outputs=[txt_out, state]
    )
    get_pred_btn.click(
        fn=_cached_predictions,
        inputs=[state],
        outputs=[json_out]
    )

    # Enable queue for concurrency
    demo.queue(default_concurrency_limit=5)

# Launch the app

demo.launch(debug=True, inline=True)