saily commited on
Commit
14e808d
·
1 Parent(s): 8c14d36

change app dtemo

Browse files
Files changed (1) hide show
  1. app.py +41 -11
app.py CHANGED
@@ -1,22 +1,52 @@
1
  import gradio as gr
2
  import spaces
 
 
3
  import torch
 
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
7
 
8
- testgpu= torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- print(testgpu)
11
 
12
- @spaces.GPU
13
- def greet(n):
14
- print(zero.device) # <-- 'cuda:0' 🤗
15
- return f"Hello {zero + n} Tensor"
16
-
17
- #demo = gr.Interface(fn=greet, inputs="text", outputs="text")
18
 
19
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  demo.launch()
21
 
22
 
 
1
  import gradio as gr
2
  import spaces
3
+ import os
4
+ import time
5
  import torch
6
+ from config import Config
7
+ from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
8
 
 
 
9
 
 
10
 
 
11
 
12
+ def set_seed(seed):
13
+ np.random.seed(seed)
14
+ torch.manual_seed(seed)
15
+ torch.cuda.manual_seed_all(seed)
16
+ torch.backends.cudnn.deterministic = True
 
17
 
18
+ @spaces.GPU
19
+ def greet(inputStr):
20
+ set_seed(1)
21
+ config = Config("./data_12345")
22
+
23
+ tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
24
+ bert_config = BertConfig.from_pretrained("bert-base-chinese", num_labels=config.num_labels)
25
+ model = BertForSequenceClassification.from_pretrained("bert-base-chinese",
26
+ config=bert_config
27
+ )
28
+ model.to(config.device)
29
+
30
+
31
+ model.load_state_dict(torch.load(config.saved_model))
32
+ model.eval()
33
+ inputs = tokenizer(
34
+ inputStr,
35
+ max_length=config.max_seq_len,
36
+ truncation="longest_first",
37
+ return_tensors="pt")
38
+ inputs = inputs.to(config.device)
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+ logits = outputs[0]
42
+ label = torch.max(logits.data, 1)[1].tolist()
43
+ print("Classification result:" + config.label_list[label[0]])
44
+ return config.label_list[label[0]]
45
+
46
+
47
+
48
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
49
+ #demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
50
  demo.launch()
51
 
52