|
--- |
|
language: ru |
|
license: apache-2.0 |
|
datasets: |
|
- mlenjoyneer/RuTextSegNews |
|
- mlenjoyneer/RuTextSegWiki |
|
--- |
|
|
|
# RuTextSegModel |
|
|
|
Model for Russian text segmentation, trained on wiki and news corpora |
|
|
|
## Model description |
|
|
|
This model is a top-level part of HierBERT model and solves the problem of text segmentation as a token classification at the sentence level. The ai-forever/sbert_large_nlu_ru with max pooling is used as a low-level model (sentence embedding generator). It's recommended to use this model only with specified low-level model with defined pooling for embeddings. |
|
|
|
## Intended uses & limitations |
|
|
|
### How to use |
|
|
|
Here is how to use this model in PyTorch: |
|
|
|
```python |
|
import torch |
|
import torch.nn as nn |
|
from transformers import BertForTokenClassification, AutoModel, AutoTokenizer |
|
from razdel import sentenize |
|
|
|
class BertForTextSegmentationEmbeddings(nn.Module): |
|
def __init__(self, config, embeddings_dim=768): |
|
super(BertForTextSegmentationEmbeddings, self).__init__() |
|
|
|
self.config = config |
|
self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
|
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward(self, inputs_embeds, position_ids=None, input_ids=None, token_type_ids=None, past_key_values_length=None): |
|
input_shape = inputs_embeds.size()[:-1] |
|
seq_length = input_shape[1] |
|
device = inputs_embeds.device |
|
|
|
assert seq_length <= self.config.max_position_embeddings, \ |
|
f"Too long sequence is passed {seq_length}. Maximum allowed sequence length is {self.config.max_position_embeddings}" |
|
|
|
if position_ids is None: |
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device) |
|
position_ids = position_ids.unsqueeze(0).expand(input_shape) |
|
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
|
embeddings = inputs_embeds + position_embeddings |
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
class BertForTextSegmentation(BertForTokenClassification): |
|
def __init__(self, config): |
|
super(BertForTextSegmentation, self).__init__(config) |
|
|
|
self.bert.base_model.embeddings = BertForTextSegmentationEmbeddings(config) |
|
|
|
self.init_weights() |
|
|
|
def max_pooling(model_output, attention_mask): |
|
token_embeddings = model_output[0] #First element of model_output contains all token embeddings |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value |
|
return torch.max(token_embeddings, 1)[0] |
|
|
|
def create_embeddings(sentences, tokenizer, model): |
|
# Tokenize sentences |
|
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') |
|
# Compute token embeddings |
|
with torch.no_grad(): |
|
model_output = model(**encoded_input.to(device)) |
|
# Perform pooling. In this case, max pooling. |
|
sentence_embeddings = max_pooling(model_output, encoded_input['attention_mask']) |
|
|
|
return sentence_embeddings |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
emb_tokenizer = AutoTokenizer.from_pretrained("ai-forever/sbert_large_nlu_ru") |
|
emb_model = AutoModel.from_pretrained("ai-forever/sbert_large_nlu_ru") |
|
model = BertForTextSegmentation.from_pretrained("mlenjoyneer/RuTextSegModel") |
|
|
|
emb_model.to(device) |
|
model.to(device) |
|
|
|
text = """В Норильске за годы работы телефона доверия консультанты приняли в общей сложности порядка 75 тысяч обращений, сообщает «Заполярная Правда». Служба психологической помощи появилась в 2000 году. Руководитель службы профилактики наркомании Елена Слатвицкая рассказала журналистам, что в Заполярье настал период, когда ухудшается психо– эмоциональное состояние населения. Это происходит на входе в полярную ночь и на выходе из нее. Осень является кризисным моментом. Сейчас на телефоне доверия работают 15 специалистов. Каждый — под своим псевдонимом. Тему беседы определяет звонящий. Это могут быть наркомания и алкоголизм, ВИЧ–инфекция и прочие заболевания и зависимости, кризисы семейных отношений и многое другое. Сотрудники службы отметили, что больше стало звонков по поводу суицидальных намерений. Наибольшее количество обращений по суицидам пришлось на октябрь — ноябрь. Много звонков как от мужчин, так и от женщин с вопросами об одиночестве. Лидерами по количеству обращений пока остаются женщины. В сентябре в Норильске обнаружили тело девятиклассницы. По версии следствия, девочка сбросилась с крыши. В январе подросток нанес себе порезы стеклом от разбитой бутылки, пытаясь покончить с собой. Мальчик поссорился с матерью и в ходе ссоры нанес себе несколько порезов. Проводится расследование.""" |
|
|
|
input_embeds = create_embeddings([s.text for s in sentenize(text)], emb_tokenizer, emb_model).unsqueeze(0) |
|
outputs = model(inputs_embeds=input_embeds) |
|
|
|
logits = outputs.logits.cpu() |
|
preds = logits.argmax(axis=2).tolist()[0] # [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0] |
|
|
|
# true_labels = [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0] |
|
``` |
|
|
|
## Training data |
|
|
|
Model trained on mlenjoyneer/RuTextSegNews and mlenjoyneer/RuTextSegWiki datasets. |
|
|
|
## Evaluation results |
|
|
|
| Train Dataset | Test Dataset | F1_total | F1_1 | Pk | Pk_5 | WinDiff | WinDiff_5 | |
|
|:-------------:|:------------:|:--------:|:-----:|:----:|:----:|:-------:|:---------:| |
|
| News+Wiki | News | 0.88 | 0.80 | 0.16 | 0.11 | 0.20 | 0.35 | |
|
| News+Wiki | Wiki | 0.89 | 0.80 | 0.18 | 0.16 | 0.09 | 0.19 | |
|
|
|
|
|
### Citation info |
|
|
|
```bibtex |
|
In progress |
|
``` |
|
|