Spaces:
Running
Running
"""NER annotation module using GLiNER models.""" | |
from typing import List, Dict, Union, Optional | |
import torch | |
import random | |
from gliner import GLiNER | |
from ..utils.text_processing import tokenize_text | |
class AutoAnnotator: | |
"""A class for automatic NER annotation using GLiNER models.""" | |
def __init__( | |
self, | |
model: str = "BookingCare/gliner-multi-healthcare", | |
device: Optional[torch.device] = None | |
) -> None: | |
"""Initialize the annotator with a GLiNER model. | |
Args: | |
model: Name or path of the GLiNER model to use | |
device: Device to run the model on (CPU/GPU) | |
""" | |
if device is None: | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
# Set PyTorch memory management settings | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.set_per_process_memory_fraction(0.8) # Use 80% of available GPU memory | |
self.model = GLiNER.from_pretrained(model).to(device) | |
self.annotated_data = [] | |
self.stat = { | |
"total": None, | |
"current": -1 | |
} | |
def auto_annotate( | |
self, | |
data: List[str], | |
labels: List[str], | |
prompt: Optional[Union[str, List[str]]] = None, | |
threshold: float = 0.5, | |
nested_ner: bool = False | |
) -> List[Dict]: | |
"""Annotate a list of texts with NER labels. | |
Args: | |
data: List of texts to annotate | |
labels: List of entity labels to detect | |
prompt: Optional prompt or list of prompts to use | |
threshold: Confidence threshold for entity detection | |
nested_ner: Whether to allow nested entities | |
Returns: | |
List of annotated examples | |
""" | |
self.stat["total"] = len(data) | |
self.stat["current"] = -1 | |
# Process texts in batches | |
processed_data = [] | |
batch_size = 8 # Reduced batch size to prevent OOM errors | |
for i in range(0, len(data), batch_size): | |
batch_texts = data[i:i + batch_size] | |
batch_with_prompts = [] | |
# Add prompts to batch texts | |
for text in batch_texts: | |
if isinstance(prompt, list): | |
prompt_text = random.choice(prompt) | |
else: | |
prompt_text = prompt | |
text_with_prompt = f"{prompt_text}\n{text}" if prompt_text else text | |
batch_with_prompts.append(text_with_prompt) | |
# Process batch | |
batch_results = self._batch_annotate_text( | |
batch_with_prompts, | |
labels, | |
threshold, | |
nested_ner | |
) | |
processed_data.extend(batch_results) | |
# Clear CUDA cache after each batch | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Update progress | |
self.stat["current"] = min(i + batch_size, len(data)) | |
self.annotated_data = processed_data | |
return self.annotated_data | |
def _batch_annotate_text( | |
self, | |
texts: List[str], | |
labels: List[str], | |
threshold: float, | |
nested_ner: bool | |
) -> List[Dict]: | |
"""Annotate multiple texts in batch. | |
Args: | |
texts: List of texts to annotate | |
labels: List of entity labels | |
threshold: Confidence threshold | |
nested_ner: Whether to allow nested entities | |
Returns: | |
List of annotated examples | |
""" | |
batch_entities = self.model.batch_predict_entities( | |
texts, | |
labels, | |
flat_ner=not nested_ner, | |
threshold=threshold | |
) | |
results = [] | |
for text, entities in zip(texts, batch_entities): | |
r = { | |
"text": text, | |
"entities": [ | |
{ | |
"entity": entity["label"], | |
"word": entity["text"], | |
"start": entity["start"], | |
"end": entity["end"], | |
"score": 0, | |
} | |
for entity in entities | |
], | |
} | |
r["entities"] = self._merge_entities(r["entities"]) | |
results.append(self._transform_data(r)) | |
return results | |
def _merge_entities(self, entities: List[Dict]) -> List[Dict]: | |
"""Merge adjacent entities of the same type. | |
Args: | |
entities: List of entity dictionaries | |
Returns: | |
List of merged entities | |
""" | |
if not entities: | |
return [] | |
merged = [] | |
current = entities[0] | |
for next_entity in entities[1:]: | |
if (next_entity['entity'] == current['entity'] and | |
(next_entity['start'] == current['end'] + 1 or | |
next_entity['start'] == current['end'])): | |
current['word'] += ' ' + next_entity['word'] | |
current['end'] = next_entity['end'] | |
else: | |
merged.append(current) | |
current = next_entity | |
merged.append(current) | |
return merged | |
def _transform_data(self, data: Dict) -> Dict: | |
"""Transform raw annotation data into tokenized format. | |
Args: | |
data: Raw annotation data | |
Returns: | |
Transformed data with tokenized text and NER spans | |
""" | |
tokens = tokenize_text(data['text']) | |
spans = [] | |
for entity in data['entities']: | |
entity_tokens = tokenize_text(entity['word']) | |
entity_length = len(entity_tokens) | |
# Find the start and end indices of each entity in the tokenized text | |
for i in range(len(tokens) - entity_length + 1): | |
if tokens[i:i + entity_length] == entity_tokens: | |
spans.append([i, i + entity_length - 1, entity['entity']]) | |
break | |
return { | |
"tokenized_text": tokens, | |
"ner": spans, | |
"validated": False | |
} |