File size: 2,388 Bytes
5292e0d
7ef0959
a2dbd3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5292e0d
93e15af
 
a2dbd3d
 
 
 
 
 
5292e0d
0daff2d
 
 
93e15af
cae2767
5292e0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
from transformers import AutoModel, AutotTokenizer, PretrainedConfig, PreTrainedModel, MT5EncoderModel

class MTRankerConfig(PretrainedConfig):
    
	def __init__(self, backbone='google/mt5-base', **kwargs):
            self.backbone = backbone
            super().__init__(**kwargs)
            
	

class MTRanker(PreTrainedModel):
    config_class = MTRankerConfig

    def __init__(self, config):
        super().__init__(config)
        self.encoder = MT5EncoderModel.from_pretrained(config.backbone)
        self.num_classes = 2
        self.classifier = torch.nn.Linear(self.encoder.config.hidden_size, self.num_classes)
    
    def forward(self, input_ids, attention_mask):
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        seq_lengths = torch.sum(attention_mask, keepdim=True, dim=1)
        pooled_hidden_state = torch.sum(encoder_output * attention_mask.unsqueeze(-1).expand(-1, -1, self.encoder.config.hidden_size), dim=1)
        pooled_hidden_state /= seq_lengths
        prediction_logit = self.classifier(pooled_hidden_state)
        return prediction_logit


config = MTRankerConfig(backbone='google/mt5-base')
tokenizer = AutoTokenizer.from_pretrained(config.backbone)
model = MTRanker.from_pretrained('ibraheemmoosa/mt-ranker-base')

def predict(source, translation1, translation2):
    model_input = "Source: {} Translation 0: {} Translation 1: {}".format(source, translation1, translation2)
    inputs = tokenizer([model_input], max_length=512, padding='max_length', truncation=True, return_tensors='pt')
    with autocast(dtype=torch.bfloat16):
        logits = model(inputs.input_ids, inputs.attention_mask)
        output_scores = torch.softmax(logits, dim=1)
        output_scores = output_scores[0]
    return {'Translation 1': output_scores[0], 'Translation 2': output_scores[1]}

source_textbox = gr.Textbox(label="Source", info="Source Sentence", value="Le chat est sur la tapis.")
translation1_textbox = gr.Textbox(label="Translation 1", info="Translation 1", value="The cat is on the bed.")
translation2_textbox = gr.Textbox(label="Translation 2", info="Translation 2", value="The cat is on the carpet.")
output = gr.Label(label="Result")
iface = gr.Interface(fn=predict, inputs=[source_textbox, translation1_textbox, translation2_textbox], outputs=output)
iface.launch()