File size: 4,008 Bytes
aa4a79e
68fe523
 
 
 
 
 
 
 
 
 
aa4a79e
 
68fe523
aa4a79e
68fe523
 
 
 
aa4a79e
68fe523
aa4a79e
68fe523
 
aa4a79e
68fe523
aa4a79e
68fe523
aa4a79e
68fe523
aa4a79e
68fe523
aa4a79e
68fe523
 
 
 
 
 
 
 
 
 
 
7105e32
aa4a79e
68fe523
aa4a79e
68fe523
aa4a79e
68fe523
 
aa4a79e
68fe523
 
aa4a79e
68fe523
aa4a79e
 
 
68fe523
aa4a79e
68fe523
d79d747
 
68fe523
 
 
aa4a79e
 
d79d747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68fe523
 
aa4a79e
68fe523
 
 
aa4a79e
68fe523
aa4a79e
68fe523
 
 
 
 
 
 
 
aa4a79e
 
68fe523
 
 
b4de6ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
---
language:
- ru
library_name: lstm
pipeline_tag: text-classification
tags:
- news
- media
- russian
datasets:
- data-silence/rus_news_classifier
---

# LSTM Text Classifier

This is a LSTM model for text classification, trained on
my [news dataset](https://huggingface.co/datasets/data-silence/rus_news_classifier), consisting of news from the last 5
years, hosted on Hugging Face Hub.
The learning news dataset is a well-balanced sample of recent news from the last five years.

## Model Description

This model uses LSTM to classify text into 11 categories. It has been trained on ~70_000 examples and achieves an
accuracy of 0.8691 on a test dataset.

## Task

The model is designed to classify russian languages news articles into 11 categories.

## Categories

The news category is assigned by the classifier to one of 11 categories:

- climate (климат)
- conflicts (конфликты)
- culture (культура)
- economy (экономика)
- gloss (глянец)
- health (здоровье)
- politics (политика)
- science (наука)
- society (общество)
- sports (спорт)
- travel (путешествия)


## Intended uses & limitations

This model has been trained and downloaded for training purposes only.

You should not use this model to solve practical problems: LSTM is not the best and fastest solution for text classification. 
Moreover, the model architecture is not compatible enough to work with the HF library (pipline, endpoints, etc. are not supported).

The "gloss" category is used to select yellow press, trashy and dubious news. The model can get confused in the
classification of news categories politics, society and conflicts.

## Usage



Example of how to use the model:

```python
import torch.nn as nn
from transformers import BertModel
import torch
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download


class BiLSTMClassifier(nn.Module):
    def __init__(self, hidden_dim, output_dim, n_layers, dropout):
        super(BiLSTMClassifier, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
        self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, num_layers=n_layers, 
                            bidirectional=True, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input_ids, attention_mask, labels=None):
            with torch.no_grad():
                embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
            lstm_out, _ = self.lstm(embedded)
            pooled = torch.mean(lstm_out, dim=1)
            logits = self.fc(self.dropout(pooled))
            
            if labels is not None:
                loss_fn = nn.CrossEntropyLoss()
                loss = loss_fn(logits, labels)
                return {"loss": loss, "logits": logits}  # Возвращаем словарь
            return logits  # Возвращаем логиты, если метки не переданы


categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
              'politics', 'science', 'society', 'sports', 'travel']

repo_id = "data-silence/lstm-news-classifier"
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model_path = hf_hub_download(repo_id=repo_id, filename="model.pth")

model = torch.load(model_path)

def get_predictions(news: str, model) -> str:
    with torch.no_grad():
        inputs = tokenizer(news, return_tensors="pt")
        del inputs['token_type_ids']
        output = model.forward(**inputs)
    id_best_label = torch.argmax(output[0, :], dim=-1).detach().cpu().numpy()
    prediction = categories[id_best_label]
    return prediction


# Использование классификатора
get_predictions('В Париже завершилась церемония завершения Олимпийский игр', model=model)
# 'sports'
```