dusense / app.py
saily's picture
add model for zhaiyao
e40e3e7
import gradio as gr
import spaces
import numpy as np
import os
import time
import torch
from config import Config
from ebart import PegasusSummarizer
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
@spaces.GPU
def greet(inputStr):
set_seed(1)
config = Config("./data_12345")
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
bert_config = BertConfig.from_pretrained("bert-base-chinese", num_labels=config.num_labels)
model = BertForSequenceClassification.from_pretrained("bert-base-chinese",
config=bert_config
)
model.to(config.device)
summarizerModel = PegasusSummarizer()
inputStr = summarizerModel.generate_summary(inputStr, 200, 64)
model.load_state_dict(torch.load(config.saved_model))
model.eval()
inputs = tokenizer(
inputStr,
max_length=config.max_seq_len,
truncation="longest_first",
return_tensors="pt")
inputs = inputs.to(config.device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs[0]
label = torch.max(logits.data, 1)[1].tolist()
print("Classification result:" + config.label_list[label[0]])
return config.label_list[label[0]]
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
#demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
demo.launch()