|
import gradio as gr |
|
import spaces |
|
import numpy as np |
|
import os |
|
import time |
|
import torch |
|
from config import Config |
|
from ebart import PegasusSummarizer |
|
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification |
|
|
|
|
|
|
|
|
|
def set_seed(seed): |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
@spaces.GPU |
|
def greet(inputStr): |
|
set_seed(1) |
|
config = Config("./data_12345") |
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") |
|
bert_config = BertConfig.from_pretrained("bert-base-chinese", num_labels=config.num_labels) |
|
model = BertForSequenceClassification.from_pretrained("bert-base-chinese", |
|
config=bert_config |
|
) |
|
model.to(config.device) |
|
|
|
summarizerModel = PegasusSummarizer() |
|
inputStr = summarizerModel.generate_summary(inputStr, 200, 64) |
|
|
|
model.load_state_dict(torch.load(config.saved_model)) |
|
model.eval() |
|
inputs = tokenizer( |
|
inputStr, |
|
max_length=config.max_seq_len, |
|
truncation="longest_first", |
|
return_tensors="pt") |
|
inputs = inputs.to(config.device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs[0] |
|
label = torch.max(logits.data, 1)[1].tolist() |
|
print("Classification result:" + config.label_list[label[0]]) |
|
return config.label_list[label[0]] |
|
|
|
|
|
|
|
demo = gr.Interface(fn=greet, inputs="text", outputs="text") |
|
|
|
demo.launch() |
|
|
|
|
|
|