|
from re import compile as re_compile,findall as re_findall
|
|
from pandas import DataFrame
|
|
from torch import device as Device , load as torch_load,no_grad as torch_no_grad,max as torch_max
|
|
from torch.utils.data import Dataset,DataLoader
|
|
from transformers import AutoTokenizer
|
|
from fasttext import load_model as fasttext_load_model
|
|
|
|
class IndicBERT_Data(Dataset):
|
|
def __init__(self, indices, X):
|
|
self.size = len(X)
|
|
self.x = X
|
|
self.i = indices
|
|
def __len__(self):
|
|
return self.size
|
|
def __getitem__(self, idx):
|
|
return (self.i[idx], self.x[idx])
|
|
|
|
class IndicLID():
|
|
def __init__(self,bert_Path:str,ftr_path:str,ftn_path:str,langs:list[str],device:Device,input_threshold = 0.5, roman_lid_threshold = 0.6):
|
|
self.device = device
|
|
self.IndicLID_FTN = fasttext_load_model(ftn_path)
|
|
self.IndicLID_FTR = fasttext_load_model(ftr_path)
|
|
self.IndicLID_BERT = torch_load(bert_Path, map_location = self.device,weights_only=False)
|
|
self.IndicLID_BERT.eval()
|
|
self.IndicLID_BERT_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBERTv2-MLM-only")
|
|
self.input_threshold = input_threshold
|
|
self.model_threshold = roman_lid_threshold
|
|
self.classes = 47
|
|
self.IndicLID_lang_code_dict = {cont:ind for ind,cont in enumerate(langs)}
|
|
self.IndicLID_lang_code_dict_reverse = {ind:cont for ind,cont in enumerate(langs)}
|
|
|
|
def char_percent_check(self, input):
|
|
total_chars = len(list(input)) - (len(re_compile('[@_!#$%^&*()<>?/\|}{~:]').findall(input)) + len(re_findall('\s', input)) + len(re_findall('\n', input)))
|
|
if total_chars == 0:
|
|
return 0
|
|
return len(re_compile('[a-zA-Z0-9]').findall(input))/total_chars
|
|
|
|
def native_inference(self, input_list, output_dict):
|
|
if not input_list:
|
|
return output_dict
|
|
input_texts = [line[1] for line in input_list]
|
|
IndicLID_FTN_predictions = self.IndicLID_FTN.predict(input_texts)
|
|
for input, pred_label, pred_score in zip(input_list, IndicLID_FTN_predictions[0], IndicLID_FTN_predictions[1]):
|
|
output_dict[input[0]] = (input[1], pred_label[0][9:], pred_score[0], 'IndicLID-FTN')
|
|
return output_dict
|
|
|
|
def roman_inference(self, input_list, output_dict, batch_size):
|
|
if not input_list:
|
|
return output_dict
|
|
input_texts = [line[1] for line in input_list]
|
|
IndicLID_FTR_predictions = self.IndicLID_FTR.predict(input_texts)
|
|
IndicLID_BERT_inputs = []
|
|
for input, pred_label, pred_score in zip(input_list, IndicLID_FTR_predictions[0], IndicLID_FTR_predictions[1]):
|
|
if pred_score[0] > self.model_threshold:
|
|
output_dict[input[0]] = (input[1], pred_label[0][9:], pred_score[0], 'IndicLID-FTR')
|
|
else:
|
|
IndicLID_BERT_inputs.append(input)
|
|
return self.IndicBERT_roman_inference(IndicLID_BERT_inputs, output_dict, batch_size)
|
|
|
|
def IndicBERT_roman_inference(self, IndicLID_BERT_inputs, output_dict, batch_size):
|
|
if not IndicLID_BERT_inputs:
|
|
return output_dict
|
|
df = DataFrame(IndicLID_BERT_inputs)
|
|
dataloader = self.get_dataloaders(df.iloc[:,0], df.iloc[:,1], batch_size)
|
|
with torch_no_grad():
|
|
for data in dataloader:
|
|
batch_indices = data[0]
|
|
batch_inputs = data[1]
|
|
word_embeddings = self.IndicLID_BERT_tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
|
word_embeddings = word_embeddings.to(self.device)
|
|
batch_outputs = self.IndicLID_BERT(word_embeddings['input_ids'], token_type_ids=word_embeddings['token_type_ids'], attention_mask=word_embeddings['attention_mask'])
|
|
_, batch_predicted = torch_max(batch_outputs.logits, 1)
|
|
for index, input, pred_label, logit in zip(batch_indices, batch_inputs, batch_predicted, batch_outputs.logits):
|
|
output_dict[index] = (input,self.IndicLID_lang_code_dict_reverse[pred_label.item()],logit[pred_label.item()].item(), 'IndicLID-BERT')
|
|
return output_dict
|
|
def post_process(self, output_dict:dict):
|
|
return [output_dict[index] for index in sorted(list(output_dict.keys()))]
|
|
def get_dataloaders(self, indices, input_texts, batch_size):
|
|
return DataLoader(IndicBERT_Data(indices, input_texts),batch_size=batch_size,shuffle=False)
|
|
def predict(self, input):
|
|
return self.batch_predict([input], 1)[0]
|
|
def batch_predict(self, input_list, batch_size):
|
|
output_dict = {}
|
|
roman_inputs = []
|
|
native_inputs = []
|
|
for index, input in enumerate(input_list):
|
|
if self.char_percent_check(input) > self.input_threshold:
|
|
roman_inputs.append((index, input))
|
|
else:
|
|
native_inputs.append((index, input))
|
|
return self.post_process(self.roman_inference(roman_inputs, self.native_inference(native_inputs, output_dict), batch_size))
|
|
def lang_detection(self,input_text):
|
|
output = self.predict(input_text)
|
|
return output[1],round(float(output[2]),2)*100
|
|
|