nam pham
feat: improve ui/ux
a33a001
"""NER annotation module using GLiNER models."""
from typing import List, Dict, Union, Optional
import torch
import random
from gliner import GLiNER
from ..utils.text_processing import tokenize_text
class AutoAnnotator:
"""A class for automatic NER annotation using GLiNER models."""
def __init__(
self,
model: str = "BookingCare/gliner-multi-healthcare",
device: Optional[torch.device] = None
) -> None:
"""Initialize the annotator with a GLiNER model.
Args:
model: Name or path of the GLiNER model to use
device: Device to run the model on (CPU/GPU)
"""
if device is None:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Set PyTorch memory management settings
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.8) # Use 80% of available GPU memory
self.model = GLiNER.from_pretrained(model).to(device)
self.annotated_data = []
self.stat = {
"total": None,
"current": -1
}
def auto_annotate(
self,
data: List[str],
labels: List[str],
prompt: Optional[Union[str, List[str]]] = None,
threshold: float = 0.5,
nested_ner: bool = False
) -> List[Dict]:
"""Annotate a list of texts with NER labels.
Args:
data: List of texts to annotate
labels: List of entity labels to detect
prompt: Optional prompt or list of prompts to use
threshold: Confidence threshold for entity detection
nested_ner: Whether to allow nested entities
Returns:
List of annotated examples
"""
self.stat["total"] = len(data)
self.stat["current"] = -1
# Process texts in batches
processed_data = []
batch_size = 8 # Reduced batch size to prevent OOM errors
for i in range(0, len(data), batch_size):
batch_texts = data[i:i + batch_size]
batch_with_prompts = []
# Add prompts to batch texts
for text in batch_texts:
if isinstance(prompt, list):
prompt_text = random.choice(prompt)
else:
prompt_text = prompt
text_with_prompt = f"{prompt_text}\n{text}" if prompt_text else text
batch_with_prompts.append(text_with_prompt)
# Process batch
batch_results = self._batch_annotate_text(
batch_with_prompts,
labels,
threshold,
nested_ner
)
processed_data.extend(batch_results)
# Clear CUDA cache after each batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Update progress
self.stat["current"] = min(i + batch_size, len(data))
self.annotated_data = processed_data
return self.annotated_data
def _batch_annotate_text(
self,
texts: List[str],
labels: List[str],
threshold: float,
nested_ner: bool
) -> List[Dict]:
"""Annotate multiple texts in batch.
Args:
texts: List of texts to annotate
labels: List of entity labels
threshold: Confidence threshold
nested_ner: Whether to allow nested entities
Returns:
List of annotated examples
"""
batch_entities = self.model.batch_predict_entities(
texts,
labels,
flat_ner=not nested_ner,
threshold=threshold
)
results = []
for text, entities in zip(texts, batch_entities):
r = {
"text": text,
"entities": [
{
"entity": entity["label"],
"word": entity["text"],
"start": entity["start"],
"end": entity["end"],
"score": 0,
}
for entity in entities
],
}
r["entities"] = self._merge_entities(r["entities"])
results.append(self._transform_data(r))
return results
def _merge_entities(self, entities: List[Dict]) -> List[Dict]:
"""Merge adjacent entities of the same type.
Args:
entities: List of entity dictionaries
Returns:
List of merged entities
"""
if not entities:
return []
merged = []
current = entities[0]
for next_entity in entities[1:]:
if (next_entity['entity'] == current['entity'] and
(next_entity['start'] == current['end'] + 1 or
next_entity['start'] == current['end'])):
current['word'] += ' ' + next_entity['word']
current['end'] = next_entity['end']
else:
merged.append(current)
current = next_entity
merged.append(current)
return merged
def _transform_data(self, data: Dict) -> Dict:
"""Transform raw annotation data into tokenized format.
Args:
data: Raw annotation data
Returns:
Transformed data with tokenized text and NER spans
"""
tokens = tokenize_text(data['text'])
spans = []
for entity in data['entities']:
entity_tokens = tokenize_text(entity['word'])
entity_length = len(entity_tokens)
# Find the start and end indices of each entity in the tokenized text
for i in range(len(tokens) - entity_length + 1):
if tokens[i:i + entity_length] == entity_tokens:
spans.append([i, i + entity_length - 1, entity['entity']])
break
return {
"tokenized_text": tokens,
"ner": spans,
"validated": False
}