Shangding-Gu commited on
Commit
38e6bfc
·
1 Parent(s): a0b0102

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -20
app.py CHANGED
@@ -5,7 +5,6 @@
5
 
6
  import sys
7
  import os
8
- import torch
9
  import transformers
10
  import json
11
 
@@ -14,18 +13,6 @@ assert (
14
  ), "Please reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
15
  from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
16
 
17
- if torch.cuda.is_available():
18
- device = "cuda"
19
- else:
20
- device = "cpu"
21
-
22
- try:
23
- if torch.backends.mps.is_available():
24
- device = "mps"
25
- except:
26
- pass
27
-
28
-
29
  base_model = "https://huggingface.co/Shangding-Gu/Lunyu-LLM/"
30
 
31
  tokenizer = LlamaTokenizer.from_pretrained(base_model)
@@ -48,11 +35,6 @@ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
48
  model.config.bos_token_id = 1
49
  model.config.eos_token_id = 2
50
 
51
- if not load_8bit:
52
- model.half() # seems to fix bugs for some users.
53
- if torch.__version__ >= "2" and sys.platform != "win32":
54
- model = torch.compile(model)
55
-
56
  class Call_model():
57
  model.eval()
58
  def evaluate(self, instruction):
@@ -84,8 +66,7 @@ class Call_model():
84
  num_beams=num_beams,
85
  **kwargs,
86
  )
87
- with torch.no_grad():
88
- generation_output = model.generate(
89
  input_ids=input_ids,
90
  generation_config=generation_config,
91
  return_dict_in_generate=True,
 
5
 
6
  import sys
7
  import os
 
8
  import transformers
9
  import json
10
 
 
13
  ), "Please reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
14
  from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  base_model = "https://huggingface.co/Shangding-Gu/Lunyu-LLM/"
17
 
18
  tokenizer = LlamaTokenizer.from_pretrained(base_model)
 
35
  model.config.bos_token_id = 1
36
  model.config.eos_token_id = 2
37
 
 
 
 
 
 
38
  class Call_model():
39
  model.eval()
40
  def evaluate(self, instruction):
 
66
  num_beams=num_beams,
67
  **kwargs,
68
  )
69
+ generation_output = model.generate(
 
70
  input_ids=input_ids,
71
  generation_config=generation_config,
72
  return_dict_in_generate=True,