File size: 5,285 Bytes
7d68ade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from re import compile as re_compile,findall as re_findall
from pandas import DataFrame
from torch import device as Device , load as torch_load,no_grad as torch_no_grad,max as torch_max
from torch.utils.data import Dataset,DataLoader
from transformers import AutoTokenizer
from fasttext import load_model as fasttext_load_model

class IndicBERT_Data(Dataset):
    def __init__(self, indices, X):
        self.size = len(X)
        self.x = X
        self.i = indices
    def __len__(self):
        return self.size
    def __getitem__(self, idx):
        return (self.i[idx], self.x[idx])

class IndicLID():
    def __init__(self,bert_Path:str,ftr_path:str,ftn_path:str,langs:list[str],device:Device,input_threshold = 0.5, roman_lid_threshold = 0.6):
        self.device = device
        self.IndicLID_FTN = fasttext_load_model(ftn_path)
        self.IndicLID_FTR = fasttext_load_model(ftr_path)
        self.IndicLID_BERT = torch_load(bert_Path, map_location = self.device,weights_only=False)
        self.IndicLID_BERT.eval()
        self.IndicLID_BERT_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
        self.input_threshold = input_threshold
        self.model_threshold = roman_lid_threshold
        self.classes = 47
        self.IndicLID_lang_code_dict = {cont:ind for ind,cont in enumerate(langs)}
        self.IndicLID_lang_code_dict_reverse = {ind:cont for ind,cont in enumerate(langs)}

    def char_percent_check(self, input):
        total_chars = len(list(input)) - (len(re_compile('[@_!#$%^&*()<>?/\|}{~:]').findall(input)) + len(re_findall('\s', input)) + len(re_findall('\n', input)))
        if total_chars == 0:
            return 0
        return len(re_compile('[a-zA-Z0-9]').findall(input))/total_chars

    def native_inference(self, input_list, output_dict):
        if not input_list:
            return output_dict
        input_texts = [line[1] for line in input_list]
        IndicLID_FTN_predictions = self.IndicLID_FTN.predict(input_texts)
        for input, pred_label, pred_score in zip(input_list, IndicLID_FTN_predictions[0], IndicLID_FTN_predictions[1]):
            output_dict[input[0]] = (input[1], pred_label[0][9:], pred_score[0], 'IndicLID-FTN')
        return output_dict

    def roman_inference(self, input_list, output_dict, batch_size):
        if not input_list:
            return output_dict
        input_texts = [line[1] for line in input_list]
        IndicLID_FTR_predictions = self.IndicLID_FTR.predict(input_texts)
        IndicLID_BERT_inputs = []
        for input, pred_label, pred_score in zip(input_list, IndicLID_FTR_predictions[0], IndicLID_FTR_predictions[1]):
            if pred_score[0] > self.model_threshold:
                output_dict[input[0]] = (input[1], pred_label[0][9:], pred_score[0], 'IndicLID-FTR')
            else:
                IndicLID_BERT_inputs.append(input)
        return self.IndicBERT_roman_inference(IndicLID_BERT_inputs, output_dict, batch_size)

    def IndicBERT_roman_inference(self, IndicLID_BERT_inputs, output_dict, batch_size):
        if not IndicLID_BERT_inputs:
            return output_dict
        df = DataFrame(IndicLID_BERT_inputs)
        dataloader = self.get_dataloaders(df.iloc[:,0], df.iloc[:,1], batch_size)
        with torch_no_grad():
            for data in dataloader:
                batch_indices = data[0]
                batch_inputs = data[1]
                word_embeddings = self.IndicLID_BERT_tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True, max_length=512)
                word_embeddings = word_embeddings.to(self.device)
                batch_outputs = self.IndicLID_BERT(word_embeddings['input_ids'], token_type_ids=word_embeddings['token_type_ids'], attention_mask=word_embeddings['attention_mask'])
                _, batch_predicted = torch_max(batch_outputs.logits, 1)
                for index, input, pred_label, logit in zip(batch_indices, batch_inputs, batch_predicted, batch_outputs.logits):
                    output_dict[index] = (input,self.IndicLID_lang_code_dict_reverse[pred_label.item()],logit[pred_label.item()].item(), 'IndicLID-BERT')
        return output_dict
    def post_process(self, output_dict:dict):
        return [output_dict[index] for index in sorted(list(output_dict.keys()))]
    def get_dataloaders(self, indices, input_texts, batch_size):
        return DataLoader(IndicBERT_Data(indices, input_texts),batch_size=batch_size,shuffle=False)
    def predict(self, input):
        return self.batch_predict([input], 1)[0]
    def batch_predict(self, input_list, batch_size):
        output_dict = {}
        roman_inputs = []
        native_inputs = []
        for index, input in enumerate(input_list):
            if self.char_percent_check(input) > self.input_threshold:
                roman_inputs.append((index, input))
            else:
                native_inputs.append((index, input))
        return self.post_process(self.roman_inference(roman_inputs, self.native_inference(native_inputs, output_dict), batch_size))
    def lang_detection(self,input_text):
        output = self.predict(input_text)
        return output[1],round(float(output[2]),2)*100