IlyaGusev commited on
Commit
3fa04d6
1 Parent(s): 97807dc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -3
README.md CHANGED
@@ -25,6 +25,8 @@ Model for extractive summarization based on [rubert-base-cased](DeepPavlov/ruber
25
 
26
  #### How to use
27
 
 
 
28
  ```python
29
  import razdel
30
  from transformers import AutoTokenizer, BertForTokenClassification
@@ -48,7 +50,7 @@ inputs = tokenizer(
48
  truncation=True,
49
  return_tensors="pt",
50
  )
51
- sep_mask = inputs["input_ids"] == sep_token_id
52
 
53
  # Fix token_type_ids
54
  current_token_type_id = 0
@@ -60,7 +62,7 @@ for pos, input_id in enumerate(inputs["input_ids"][0]):
60
  # Infer model
61
  with torch.no_grad():
62
  outputs = model(**inputs)
63
- logits = outputs.logits[:, :, 1][0]
64
 
65
  # Choose sentences
66
  logits = logits[sep_mask]
@@ -68,7 +70,7 @@ logits, indices = logits.sort(descending=True)
68
  logits, indices = logits.cpu().tolist(), indices.cpu().tolist()
69
  pairs = list(zip(logits, indices))
70
  pairs = pairs[:3]
71
- indices = [idx for _, idx in pairs]
72
  summary = " ".join([sentences[idx] for idx in indices])
73
  print(summary)
74
  ```
 
25
 
26
  #### How to use
27
 
28
+ Colab: [link](https://colab.research.google.com/drive/1Q8_v3H-kxdJhZIiyLYat7Kj02qDq7M1L)
29
+
30
  ```python
31
  import razdel
32
  from transformers import AutoTokenizer, BertForTokenClassification
 
50
  truncation=True,
51
  return_tensors="pt",
52
  )
53
+ sep_mask = inputs["input_ids"][0] == sep_token_id
54
 
55
  # Fix token_type_ids
56
  current_token_type_id = 0
 
62
  # Infer model
63
  with torch.no_grad():
64
  outputs = model(**inputs)
65
+ logits = outputs.logits[0, :, 1]
66
 
67
  # Choose sentences
68
  logits = logits[sep_mask]
 
70
  logits, indices = logits.cpu().tolist(), indices.cpu().tolist()
71
  pairs = list(zip(logits, indices))
72
  pairs = pairs[:3]
73
+ indices = list(sorted([idx for _, idx in pairs]))
74
  summary = " ".join([sentences[idx] for idx in indices])
75
  print(summary)
76
  ```