data-silence commited on
Commit
d79d747
·
verified ·
1 Parent(s): 9ba7b64

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +25 -0
README.md CHANGED
@@ -61,11 +61,36 @@ classification of news categories politics, society and conflicts.
61
  Example of how to use the model:
62
 
63
  ```python
 
 
64
  import torch
65
  from transformers import AutoTokenizer
66
  from huggingface_hub import hf_hub_download
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
70
  'politics', 'science', 'society', 'sports', 'travel']
71
 
 
61
  Example of how to use the model:
62
 
63
  ```python
64
+ import torch.nn as nn
65
+ from transformers import BertModel
66
  import torch
67
  from transformers import AutoTokenizer
68
  from huggingface_hub import hf_hub_download
69
 
70
 
71
+ class BiLSTMClassifier(nn.Module):
72
+ def __init__(self, hidden_dim, output_dim, n_layers, dropout):
73
+ super(BiLSTMClassifier, self).__init__()
74
+ self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
75
+ self.lstm = nn.LSTM(self.bert.config.hidden_size, hidden_dim, num_layers=n_layers,
76
+ bidirectional=True, dropout=dropout, batch_first=True)
77
+ self.fc = nn.Linear(hidden_dim * 2, output_dim)
78
+ self.dropout = nn.Dropout(dropout)
79
+
80
+ def forward(self, input_ids, attention_mask, labels=None):
81
+ with torch.no_grad():
82
+ embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
83
+ lstm_out, _ = self.lstm(embedded)
84
+ pooled = torch.mean(lstm_out, dim=1)
85
+ logits = self.fc(self.dropout(pooled))
86
+
87
+ if labels is not None:
88
+ loss_fn = nn.CrossEntropyLoss()
89
+ loss = loss_fn(logits, labels)
90
+ return {"loss": loss, "logits": logits} # Возвращаем словарь
91
+ return logits # Возвращаем логиты, если метки не переданы
92
+
93
+
94
  categories = ['climate', 'conflicts', 'culture', 'economy', 'gloss', 'health',
95
  'politics', 'science', 'society', 'sports', 'travel']
96