File size: 3,308 Bytes
122d428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import itertools
import torch 
import numpy as np
from tqdm.auto import tqdm

def get_char_probs(texts, predictions, tokenizer):
    """
    Maps prediction from encoded offset mapping to the text

    Prediction = 466 sequence length * batch
    text = 768 * batch
    Using offset mapping [(0, 4), ] -- 466

    creates results that is size of texts

    for each text result[i]
    result[0, 4] = pred[0] like wise for all
    
    """
    results = [np.zeros(len(t)) for t in texts]
    for i, (text, prediction) in enumerate(zip(texts, predictions)):
        encoded = tokenizer(text, 
                            add_special_tokens=True,
                            return_offsets_mapping=True)
        for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)):
            start = offset_mapping[0]
            end = offset_mapping[1]
            results[i][start:end] = pred
    return results


def get_results(char_probs, th=0.5):
    """
      Get the list of probabilites with size of text
      And then get the index of the characters which are more than th
      example:
          char_prob = [0.1, 0.1, 0.9, 0.9, 0.9, 0.9, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7] ## length == 766
          where > 0.5 index ## [ 2,  3,  4,  5,  9, 10, 11]

          Groupby same one -- [[2, 3, 4, 5], [9, 10, 11]]
          And get the max and min and output the results

    """
    results = []
    for char_prob in char_probs:
        result = np.where(char_prob >= th)[0] + 1
        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        result = [f"{min(r)} {max(r)}" for r in result]
        result = ";".join(result)
        results.append(result)
    return results


def get_predictions(results):
    """
      Will get the location, as a string, just like location in the df
      results = ['2 5', '9 11']

      loop through, split it and save it as start and end and store it in array
    """
    predictions = []
    for result in results:
        prediction = []
        if result != "":
            for loc in [s.split() for s in result.split(';')]:
                start, end = int(loc[0]), int(loc[1])
                prediction.append([start, end])
        predictions.append(prediction)
    return predictions

def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.sigmoid().numpy())
    predictions = np.concatenate(preds)
    return predictions

def get_text(context, indexes):
  if (indexes):
      if ';' in indexes:
        list_indexes = indexes.split(';')
        
        answer = ''
        for idx in list_indexes:
          start_index = int(idx.split(' ')[0])
          end_index = int(idx.split(' ')[1])
          answer += ' ' 
          answer += context[start_index:end_index]
        return answer
      else:
        start_index = int(indexes.split(' ')[0])
        end_index = int(indexes.split(' ')[1])

        return context[start_index:end_index]
  else:
    return 'Not found in this Context'