rafmacalaba commited on
Commit
fc4b12d
·
1 Parent(s): d799589

deploy without gpu

Browse files
Files changed (1) hide show
  1. app.py +147 -61
app.py CHANGED
@@ -1,84 +1,170 @@
 
1
  import re
2
  import json
 
3
  import gradio as gr
4
 
5
- # Your model’s raw NER output (we trust these start/end indices)
6
- ner = [
7
- {
8
- 'start': 11,
9
- 'end': 29,
10
- 'text': 'Home Visits Survey',
11
- 'label': 'named dataset',
12
- 'score': 0.9947463870048523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  }
14
- ]
15
-
16
- # Your model’s raw RE output
17
- relations = {
18
- 'Home Visits Survey': [
19
- {'source': 'Home Visits Survey', 'relation': 'data geography', 'target': 'Jordan', 'score': 0.6180844902992249},
20
- {'source': 'Home Visits Survey', 'relation': 'version', 'target': 'Round II', 'score': 0.9688164591789246},
21
- {'source': 'Home Visits Survey', 'relation': 'acronym', 'target': 'HV', 'score': 0.9140607714653015},
22
- {'source': 'Home Visits Survey', 'relation': 'publisher', 'target': 'UNHCR', 'score': 0.7762154340744019},
23
- {'source': 'Home Visits Survey', 'relation': 'publisher', 'target': 'World Food Programme', 'score': 0.6582539677619934},
24
- {'source': 'Home Visits Survey', 'relation': 'reference year', 'target': '2013', 'score': 0.524115264415741},
25
- {'source': 'Home Visits Survey', 'relation': 'reference year', 'target': '2014', 'score': 0.6853994131088257},
26
- {'source': 'Home Visits Survey', 'relation': 'data description', 'target': 'detailed socio-economic, health, and protection data', 'score': 0.6544178128242493},
27
  ]
28
- }
29
 
30
- # 1) Non-destructive: build a new filtered dict
31
- relations = {
32
- src: [r for r in rels if r['source'] != r['target']]
33
- for src, rels in relations.items()
34
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Exact sample text
37
- SAMPLE_TEXT = """The Jordan Home Visits Survey, Round II (HV), was carried out by UNHCR and the World Food Programme between November 2013 and September 2014. Through in-home visits to Syrian refugee households in Jordan, it gathered detailed socio-economic, health, and protection data—each household tagged with a unique ID to allow longitudinal tracking."""
 
 
38
 
39
- def highlight_text(text):
40
  entities = []
41
- # 1) NER spans
42
- for ent in ner:
43
  entities.append({
44
- "entity": ent["label"],
45
- "start": ent["start"],
46
- "end": ent["end"],
47
  })
48
- # 2) RE spans
49
- for rel_list in relations.values():
50
- for r in rel_list:
51
- for m in re.finditer(re.escape(r["target"]), text):
52
- entities.append({
53
- "entity": r["relation"],
54
- "start": m.start(),
55
- "end": m.end(),
56
- })
57
  return {"text": text, "entities": entities}
58
 
59
- def get_model_predictions():
60
- return json.dumps({"ner": ner, "relations": relations}, indent=2)
61
 
62
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  gr.Markdown("## Data Use Detector\n"
64
- "Edit the sample text or input your desired text, then click **Submit** to annotate entities, or **Get Model Predictions** to see the raw JSON.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- txt_in = gr.Textbox(label="Input Text", lines=4, value=SAMPLE_TEXT)
67
  highlight_btn = gr.Button("Submit")
68
- txt_out = gr.HighlightedText(label="Annotated Entities")
69
 
70
  get_pred_btn = gr.Button("Get Model Predictions")
71
- ner_rel_box = gr.Textbox(
72
- label="Model Predictions (JSON)",
73
- lines=15,
74
- value="",
75
- interactive=False
 
 
 
 
 
 
 
76
  )
77
 
78
- # Only trigger highlighting on click
79
- highlight_btn.click(fn=highlight_text, inputs=txt_in, outputs=txt_out)
80
- # Only show preds on click
81
- get_pred_btn.click(fn=get_model_predictions, inputs=None, outputs=ner_rel_box)
82
 
83
- if __name__ == "__main__":
84
- demo.launch()
 
1
+ import os
2
  import re
3
  import json
4
+ from collections import defaultdict
5
  import gradio as gr
6
 
7
+ # Load environment variable for cache dir (useful on Spaces)
8
+ _CACHE_DIR = os.environ.get("CACHE_DIR", None)
9
+
10
+ # Import GLiNER model and relation extractor
11
+ from gliner import GLiNER
12
+ from gliner.multitask import GLiNERRelationExtractor
13
+
14
+ # Import inference pipeline and configuration
15
+ from my_project.pipeline import inference_pipeline
16
+ from my_project.config import TYPE2RELS, labels
17
+
18
+ # Cache and initialize model + relation extractor
19
+ DATA_MODEL_ID = "rafmacalaba/gliner_re_finetuned-v3"
20
+ model = GLiNER.from_pretrained(DATA_MODEL_ID, cache_dir=_CACHE_DIR)
21
+ relation_extractor = GLiNERRelationExtractor(model=model)
22
+
23
+ # Sample text
24
+ SAMPLE_TEXT = (
25
+ "In early 2012, the World Bank published the full report of the 2011 Demographic and Health Survey (DHS) "
26
+ "for the Republic of Mali. Conducted between June and December 2011 under the technical oversight of Mali’s "
27
+ "National Institute of Statistics and paired with on-the-ground data-collection teams, this nationally representative survey "
28
+ "gathered detailed information on household composition, education levels, employment and income, fertility and family planning, "
29
+ "maternal and child health, nutrition, mortality, and access to basic services. By combining traditional census modules with "
30
+ "specialized questionnaires on women’s and children’s health, the DHS offers policymakers, development partners, and researchers "
31
+ "a rich dataset of socioeconomic characteristics—ranging from literacy and school attendance to water and sanitation infrastructure—"
32
+ "that can be used to monitor progress on poverty reduction, inform targeted social programs, and guide longer-term economic planning."
33
+ )
34
+
35
+ # Post-processing: prune acronyms and self-relations
36
+
37
+ def prune_acronym_and_self_relations(ner_preds, rel_preds):
38
+ # 1) Find acronym targets strictly shorter than their source
39
+ acronym_targets = {
40
+ r["target"]
41
+ for src, rels in rel_preds.items()
42
+ for r in rels
43
+ if r["relation"] == "acronym" and len(r["target"]) < len(src)
44
  }
45
+
46
+ # 2) Filter NER: drop any named-dataset whose text is in that set
47
+ filtered_ner = [
48
+ ent for ent in ner_preds
49
+ if not (ent["label"] == "named dataset" and ent["text"] in acronym_targets)
 
 
 
 
 
 
 
 
50
  ]
 
51
 
52
+ # 3) Filter RE: drop blocks for acronym sources, and self-relations
53
+ filtered_re = {}
54
+ for src, rels in rel_preds.items():
55
+ if src in acronym_targets:
56
+ continue
57
+ kept = [r for r in rels if r["target"] != src]
58
+ if kept:
59
+ filtered_re[src] = kept
60
+
61
+ return filtered_ner, filtered_re
62
+
63
+ # Highlighting function
64
+
65
+ def highlight_text(text, ner_threshold, re_threshold):
66
+ # Run inference
67
+ ner_preds, rel_preds = inference_pipeline(
68
+ text,
69
+ model=model,
70
+ labels=labels,
71
+ relation_extractor=relation_extractor,
72
+ TYPE2RELS=TYPE2RELS,
73
+ ner_threshold=ner_threshold,
74
+ re_threshold=re_threshold,
75
+ re_multi_label=False
76
+ )
77
+
78
+ # Post-process
79
+ ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
80
+
81
+ # Gather all spans
82
+ spans = []
83
+ for ent in ner_preds:
84
+ spans.append((ent["start"], ent["end"], ent["label"]))
85
+ for src, rels in rel_preds.items():
86
+ for r in rels:
87
+ for m in re.finditer(re.escape(r["target"]), text):
88
+ spans.append((m.start(), m.end(), f"{src} <> {r['relation']}"))
89
 
90
+ # Merge labels by span
91
+ merged = defaultdict(list)
92
+ for start, end, lbl in spans:
93
+ merged[(start, end)].append(lbl)
94
 
95
+ # Build Gradio entities
96
  entities = []
97
+ for (start, end), lbls in sorted(merged.items(), key=lambda x: x[0]):
 
98
  entities.append({
99
+ "entity": ", ".join(lbls),
100
+ "start": start,
101
+ "end": end
102
  })
103
+
 
 
 
 
 
 
 
 
104
  return {"text": text, "entities": entities}
105
 
106
+ # JSON output function
 
107
 
108
+ def get_model_predictions(text, ner_threshold, re_threshold):
109
+ ner_preds, rel_preds = inference_pipeline(
110
+ text,
111
+ model=model,
112
+ labels=labels,
113
+ relation_extractor=relation_extractor,
114
+ TYPE2RELS=TYPE2RELS,
115
+ ner_threshold=ner_threshold,
116
+ re_threshold=re_threshold,
117
+ re_multi_label=False
118
+ )
119
+ ner_preds, rel_preds = prune_acronym_and_self_relations(ner_preds, rel_preds)
120
+ return json.dumps({"ner": ner_preds, "relations": rel_preds}, indent=2)
121
+
122
+ # Build Gradio UI
123
+ demo = gr.Blocks()
124
+ with demo:
125
  gr.Markdown("## Data Use Detector\n"
126
+ "Adjust the sliders below to set thresholds, then:\n"
127
+ "- **Submit** to highlight entities.\n"
128
+ "- **Get Model Predictions** to see the raw JSON output.")
129
+
130
+ txt_in = gr.Textbox(
131
+ label="Input Text",
132
+ lines=4,
133
+ value=SAMPLE_TEXT
134
+ )
135
+
136
+ ner_slider = gr.Slider(
137
+ 0, 1, value=0.7, step=0.01,
138
+ label="NER Threshold",
139
+ info="Minimum confidence for named-entity spans."
140
+ )
141
+ re_slider = gr.Slider(
142
+ 0, 1, value=0.5, step=0.01,
143
+ label="RE Threshold",
144
+ info="Minimum confidence for relation extractions."
145
+ )
146
 
 
147
  highlight_btn = gr.Button("Submit")
148
+ txt_out = gr.HighlightedText(label="Annotated Entities")
149
 
150
  get_pred_btn = gr.Button("Get Model Predictions")
151
+ ner_rel_box = gr.Textbox(label="Model Predictions (JSON)", lines=15)
152
+
153
+ # Wire up interactions
154
+ highlight_btn.click(
155
+ fn=highlight_text,
156
+ inputs=[txt_in, ner_slider, re_slider],
157
+ outputs=txt_out
158
+ )
159
+ get_pred_btn.click(
160
+ fn=get_model_predictions,
161
+ inputs=[txt_in, ner_slider, re_slider],
162
+ outputs=ner_rel_box
163
  )
164
 
165
+ # Enable queue for concurrency
166
+ demo.queue(default_concurrency_limit=5)
167
+
168
+ # Launch the app
169
 
170
+ demo.launch(debug=True)