Update README.md
Browse files
README.md
CHANGED
@@ -25,6 +25,8 @@ Model for extractive summarization based on [rubert-base-cased](DeepPavlov/ruber
|
|
25 |
|
26 |
#### How to use
|
27 |
|
|
|
|
|
28 |
```python
|
29 |
import razdel
|
30 |
from transformers import AutoTokenizer, BertForTokenClassification
|
@@ -48,7 +50,7 @@ inputs = tokenizer(
|
|
48 |
truncation=True,
|
49 |
return_tensors="pt",
|
50 |
)
|
51 |
-
sep_mask = inputs["input_ids"] == sep_token_id
|
52 |
|
53 |
# Fix token_type_ids
|
54 |
current_token_type_id = 0
|
@@ -60,7 +62,7 @@ for pos, input_id in enumerate(inputs["input_ids"][0]):
|
|
60 |
# Infer model
|
61 |
with torch.no_grad():
|
62 |
outputs = model(**inputs)
|
63 |
-
logits = outputs.logits[
|
64 |
|
65 |
# Choose sentences
|
66 |
logits = logits[sep_mask]
|
@@ -68,7 +70,7 @@ logits, indices = logits.sort(descending=True)
|
|
68 |
logits, indices = logits.cpu().tolist(), indices.cpu().tolist()
|
69 |
pairs = list(zip(logits, indices))
|
70 |
pairs = pairs[:3]
|
71 |
-
indices = [idx for _, idx in pairs]
|
72 |
summary = " ".join([sentences[idx] for idx in indices])
|
73 |
print(summary)
|
74 |
```
|
|
|
25 |
|
26 |
#### How to use
|
27 |
|
28 |
+
Colab: [link](https://colab.research.google.com/drive/1Q8_v3H-kxdJhZIiyLYat7Kj02qDq7M1L)
|
29 |
+
|
30 |
```python
|
31 |
import razdel
|
32 |
from transformers import AutoTokenizer, BertForTokenClassification
|
|
|
50 |
truncation=True,
|
51 |
return_tensors="pt",
|
52 |
)
|
53 |
+
sep_mask = inputs["input_ids"][0] == sep_token_id
|
54 |
|
55 |
# Fix token_type_ids
|
56 |
current_token_type_id = 0
|
|
|
62 |
# Infer model
|
63 |
with torch.no_grad():
|
64 |
outputs = model(**inputs)
|
65 |
+
logits = outputs.logits[0, :, 1]
|
66 |
|
67 |
# Choose sentences
|
68 |
logits = logits[sep_mask]
|
|
|
70 |
logits, indices = logits.cpu().tolist(), indices.cpu().tolist()
|
71 |
pairs = list(zip(logits, indices))
|
72 |
pairs = pairs[:3]
|
73 |
+
indices = list(sorted([idx for _, idx in pairs]))
|
74 |
summary = " ".join([sentences[idx] for idx in indices])
|
75 |
print(summary)
|
76 |
```
|