Spaces:
Running
Running
import os | |
import json | |
import gradio as gr | |
import torch | |
import spaces | |
from gliner import GLiNER | |
from gliner.multitask import GLiNERRelationExtractor | |
from typing import List, Dict, Any, Tuple | |
from tqdm.auto import tqdm | |
# Configuration | |
data_model_id = "rafmacalaba/gliner_re_finetuned-v3" | |
CACHE_DIR = os.environ.get("CACHE_DIR", None) | |
# Relation types | |
trels = [ | |
'acronym', 'author', 'data description', | |
'data geography', 'data source', 'data type', | |
'publication year', 'publisher', 'reference year', 'version' | |
] | |
# Map NER labels to relation types | |
TYPE2RELS = { | |
"named dataset": trels, | |
"unnamed dataset": trels, | |
"vague dataset": trels, | |
} | |
# Load models | |
print("Loading NER+RE model...") | |
model = GLiNER.from_pretrained(data_model_id, cache_dir=CACHE_DIR) | |
relation_extractor = GLiNERRelationExtractor(model=model) | |
if torch.cuda.is_available(): | |
model.to("cuda") | |
relation_extractor.model.to("cuda") | |
print("Models loaded.") | |
# Inference pipeline | |
def inference_pipeline( | |
text: str, | |
model, | |
labels: List[str], | |
relation_extractor: GLiNERRelationExtractor, | |
TYPE2RELS: Dict[str, List[str]], | |
ner_threshold: float = 0.5, | |
re_threshold: float = 0.4, | |
re_multi_label: bool = False, | |
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]: | |
# NER predictions | |
ner_preds = model.predict_entities( | |
text, | |
labels, | |
flat_ner=True, | |
threshold=ner_threshold | |
) | |
# Relation extraction per entity span | |
re_results: Dict[str, List[Dict[str, Any]]] = {} | |
for ner in ner_preds: | |
span = ner['text'] | |
rel_types = TYPE2RELS.get(ner['label'], []) | |
if not rel_types: | |
continue | |
slot_labels = [f"{span} <> {r}" for r in rel_types] | |
preds = relation_extractor( | |
text, | |
relations=None, | |
entities=None, | |
relation_labels=slot_labels, | |
threshold=re_threshold, | |
multi_label=re_multi_label, | |
distance_threshold=100, | |
)[0] | |
re_results[span] = preds | |
return ner_preds, re_results | |
# Gradio UI - Step 2: Model Inference | |
def model_inference(query: str) -> str: | |
labels = ["named dataset", "unnamed dataset", "vague dataset"] | |
ner_preds, re_results = inference_pipeline( | |
query, | |
model, | |
labels, | |
relation_extractor, | |
TYPE2RELS | |
) | |
output = { | |
"entities": ner_preds, | |
"relations": re_results, | |
} | |
return json.dumps(output, indent=2) | |
with gr.Blocks(title="Step 2: NER + Relation Inference") as demo: | |
gr.Markdown( | |
""" | |
## Step 2: Integrate Model Inference | |
Enter text and click submit to run your GLiNER-based NER + RE pipeline. | |
""" | |
) | |
query_input = gr.Textbox( | |
lines=4, | |
placeholder="Type your text here...", | |
label="Input Text", | |
) | |
submit_btn = gr.Button("Submit") | |
output_box = gr.Textbox( | |
lines=15, | |
label="Model Output (JSON)", | |
) | |
submit_btn.click( | |
fn=model_inference, | |
inputs=[query_input], | |
outputs=[output_box], | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |