change app dtemo
Browse files
app.py
CHANGED
@@ -1,22 +1,52 @@
|
|
1 |
import gradio as gr
|
2 |
import spaces
|
|
|
|
|
3 |
import torch
|
|
|
|
|
4 |
|
5 |
-
zero = torch.Tensor([0]).cuda()
|
6 |
-
print(zero.device) # <-- 'cpu' 🤔
|
7 |
|
8 |
-
testgpu= torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
|
10 |
-
print(testgpu)
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
#demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
demo.launch()
|
21 |
|
22 |
|
|
|
1 |
import gradio as gr
|
2 |
import spaces
|
3 |
+
import os
|
4 |
+
import time
|
5 |
import torch
|
6 |
+
from config import Config
|
7 |
+
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
|
8 |
|
|
|
|
|
9 |
|
|
|
10 |
|
|
|
11 |
|
12 |
+
def set_seed(seed):
|
13 |
+
np.random.seed(seed)
|
14 |
+
torch.manual_seed(seed)
|
15 |
+
torch.cuda.manual_seed_all(seed)
|
16 |
+
torch.backends.cudnn.deterministic = True
|
|
|
17 |
|
18 |
+
@spaces.GPU
|
19 |
+
def greet(inputStr):
|
20 |
+
set_seed(1)
|
21 |
+
config = Config("./data_12345")
|
22 |
+
|
23 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
|
24 |
+
bert_config = BertConfig.from_pretrained("bert-base-chinese", num_labels=config.num_labels)
|
25 |
+
model = BertForSequenceClassification.from_pretrained("bert-base-chinese",
|
26 |
+
config=bert_config
|
27 |
+
)
|
28 |
+
model.to(config.device)
|
29 |
+
|
30 |
+
|
31 |
+
model.load_state_dict(torch.load(config.saved_model))
|
32 |
+
model.eval()
|
33 |
+
inputs = tokenizer(
|
34 |
+
inputStr,
|
35 |
+
max_length=config.max_seq_len,
|
36 |
+
truncation="longest_first",
|
37 |
+
return_tensors="pt")
|
38 |
+
inputs = inputs.to(config.device)
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = model(**inputs)
|
41 |
+
logits = outputs[0]
|
42 |
+
label = torch.max(logits.data, 1)[1].tolist()
|
43 |
+
print("Classification result:" + config.label_list[label[0]])
|
44 |
+
return config.label_list[label[0]]
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
49 |
+
#demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
|
50 |
demo.launch()
|
51 |
|
52 |
|