File size: 6,334 Bytes
a33a001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""NER annotation module using GLiNER models."""

from typing import List, Dict, Union, Optional
import torch
import random
from gliner import GLiNER
from ..utils.text_processing import tokenize_text

class AutoAnnotator:
    """A class for automatic NER annotation using GLiNER models."""
    
    def __init__(
        self,
        model: str = "BookingCare/gliner-multi-healthcare",
        device: Optional[torch.device] = None
    ) -> None:
        """Initialize the annotator with a GLiNER model.
        
        Args:
            model: Name or path of the GLiNER model to use
            device: Device to run the model on (CPU/GPU)
        """
        if device is None:
            device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        # Set PyTorch memory management settings
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.set_per_process_memory_fraction(0.8)  # Use 80% of available GPU memory

        self.model = GLiNER.from_pretrained(model).to(device)
        self.annotated_data = []
        self.stat = {
            "total": None,
            "current": -1
        }

    def auto_annotate(
        self,
        data: List[str],
        labels: List[str],
        prompt: Optional[Union[str, List[str]]] = None,
        threshold: float = 0.5,
        nested_ner: bool = False
    ) -> List[Dict]:
        """Annotate a list of texts with NER labels.
        
        Args:
            data: List of texts to annotate
            labels: List of entity labels to detect
            prompt: Optional prompt or list of prompts to use
            threshold: Confidence threshold for entity detection
            nested_ner: Whether to allow nested entities
            
        Returns:
            List of annotated examples
        """
        self.stat["total"] = len(data)
        self.stat["current"] = -1
        
        # Process texts in batches
        processed_data = []
        batch_size = 8  # Reduced batch size to prevent OOM errors
        
        for i in range(0, len(data), batch_size):
            batch_texts = data[i:i + batch_size]
            batch_with_prompts = []
            
            # Add prompts to batch texts
            for text in batch_texts:
                if isinstance(prompt, list):
                    prompt_text = random.choice(prompt)
                else:
                    prompt_text = prompt
                text_with_prompt = f"{prompt_text}\n{text}" if prompt_text else text
                batch_with_prompts.append(text_with_prompt)
            
            # Process batch
            batch_results = self._batch_annotate_text(
                batch_with_prompts,
                labels,
                threshold,
                nested_ner
            )
            processed_data.extend(batch_results)
            
            # Clear CUDA cache after each batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Update progress
            self.stat["current"] = min(i + batch_size, len(data))
        
        self.annotated_data = processed_data
        return self.annotated_data

    def _batch_annotate_text(
        self,
        texts: List[str],
        labels: List[str],
        threshold: float,
        nested_ner: bool
    ) -> List[Dict]:
        """Annotate multiple texts in batch.
        
        Args:
            texts: List of texts to annotate
            labels: List of entity labels
            threshold: Confidence threshold
            nested_ner: Whether to allow nested entities
            
        Returns:
            List of annotated examples
        """
        batch_entities = self.model.batch_predict_entities(
            texts,
            labels,
            flat_ner=not nested_ner,
            threshold=threshold
        )
        
        results = []
        for text, entities in zip(texts, batch_entities):
            r = {
                "text": text,
                "entities": [
                    {
                        "entity": entity["label"],
                        "word": entity["text"],
                        "start": entity["start"],
                        "end": entity["end"],
                        "score": 0,
                    }
                    for entity in entities
                ],
            }
            r["entities"] = self._merge_entities(r["entities"])
            results.append(self._transform_data(r))
        return results

    def _merge_entities(self, entities: List[Dict]) -> List[Dict]:
        """Merge adjacent entities of the same type.
        
        Args:
            entities: List of entity dictionaries
            
        Returns:
            List of merged entities
        """
        if not entities:
            return []
        merged = []
        current = entities[0]
        for next_entity in entities[1:]:
            if (next_entity['entity'] == current['entity'] and
                (next_entity['start'] == current['end'] + 1 or
                 next_entity['start'] == current['end'])):
                current['word'] += ' ' + next_entity['word']
                current['end'] = next_entity['end']
            else:
                merged.append(current)
                current = next_entity
        merged.append(current)
        return merged

    def _transform_data(self, data: Dict) -> Dict:
        """Transform raw annotation data into tokenized format.
        
        Args:
            data: Raw annotation data
            
        Returns:
            Transformed data with tokenized text and NER spans
        """
        tokens = tokenize_text(data['text'])
        spans = []

        for entity in data['entities']:
            entity_tokens = tokenize_text(entity['word'])
            entity_length = len(entity_tokens)

            # Find the start and end indices of each entity in the tokenized text
            for i in range(len(tokens) - entity_length + 1):
                if tokens[i:i + entity_length] == entity_tokens:
                    spans.append([i, i + entity_length - 1, entity['entity']])
                    break

        return {
            "tokenized_text": tokens,
            "ner": spans,
            "validated": False
        }