rafmacalaba's picture
add model
3d53082
raw
history blame
3.23 kB
import os
import json
import gradio as gr
import torch
import spaces
from gliner import GLiNER
from gliner.multitask import GLiNERRelationExtractor
from typing import List, Dict, Any, Tuple
from tqdm.auto import tqdm
# Configuration
data_model_id = "rafmacalaba/gliner_re_finetuned-v3"
CACHE_DIR = os.environ.get("CACHE_DIR", None)
# Relation types
trels = [
'acronym', 'author', 'data description',
'data geography', 'data source', 'data type',
'publication year', 'publisher', 'reference year', 'version'
]
# Map NER labels to relation types
TYPE2RELS = {
"named dataset": trels,
"unnamed dataset": trels,
"vague dataset": trels,
}
# Load models
print("Loading NER+RE model...")
model = GLiNER.from_pretrained(data_model_id, cache_dir=CACHE_DIR)
relation_extractor = GLiNERRelationExtractor(model=model)
if torch.cuda.is_available():
model.to("cuda")
relation_extractor.model.to("cuda")
print("Models loaded.")
# Inference pipeline
def inference_pipeline(
text: str,
model,
labels: List[str],
relation_extractor: GLiNERRelationExtractor,
TYPE2RELS: Dict[str, List[str]],
ner_threshold: float = 0.5,
re_threshold: float = 0.4,
re_multi_label: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
# NER predictions
ner_preds = model.predict_entities(
text,
labels,
flat_ner=True,
threshold=ner_threshold
)
# Relation extraction per entity span
re_results: Dict[str, List[Dict[str, Any]]] = {}
for ner in ner_preds:
span = ner['text']
rel_types = TYPE2RELS.get(ner['label'], [])
if not rel_types:
continue
slot_labels = [f"{span} <> {r}" for r in rel_types]
preds = relation_extractor(
text,
relations=None,
entities=None,
relation_labels=slot_labels,
threshold=re_threshold,
multi_label=re_multi_label,
distance_threshold=100,
)[0]
re_results[span] = preds
return ner_preds, re_results
# Gradio UI - Step 2: Model Inference
@spaces.GPU(enable_queue=True, duration=120)
def model_inference(query: str) -> str:
labels = ["named dataset", "unnamed dataset", "vague dataset"]
ner_preds, re_results = inference_pipeline(
query,
model,
labels,
relation_extractor,
TYPE2RELS
)
output = {
"entities": ner_preds,
"relations": re_results,
}
return json.dumps(output, indent=2)
with gr.Blocks(title="Step 2: NER + Relation Inference") as demo:
gr.Markdown(
"""
## Step 2: Integrate Model Inference
Enter text and click submit to run your GLiNER-based NER + RE pipeline.
"""
)
query_input = gr.Textbox(
lines=4,
placeholder="Type your text here...",
label="Input Text",
)
submit_btn = gr.Button("Submit")
output_box = gr.Textbox(
lines=15,
label="Model Output (JSON)",
)
submit_btn.click(
fn=model_inference,
inputs=[query_input],
outputs=[output_box],
)
if __name__ == "__main__":
demo.launch(debug=True)