Aakash Vardhan commited on
Commit
12289b8
·
1 Parent(s): 21aa301
Files changed (1) hide show
  1. app.py +32 -24
app.py CHANGED
@@ -5,38 +5,46 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
  from config import load_config
6
 
7
  config = load_config("config.yaml")
8
-
9
  model_config = config["model_config"]
10
-
11
  model_name = model_config.pop("model_name")
12
-
13
- # Convert torch_dtype from string to torch.dtype
14
- if "torch_dtype" in model_config:
15
- if model_config["torch_dtype"] == "float32":
16
- model_config["torch_dtype"] = torch.float32
17
- elif model_config["torch_dtype"] == "float16":
18
- model_config["torch_dtype"] = torch.float16
19
- elif model_config["torch_dtype"] == "bfloat16":
20
- model_config["torch_dtype"] = torch.bfloat16
21
-
22
- # Load the model without quantization config
23
- model = AutoModelForCausalLM.from_pretrained(
24
- model_name,
25
- low_cpu_mem_usage=True,
26
- **model_config
27
- )
28
-
29
  checkpoint_model = "checkpoint_dir/checkpoint-650"
30
 
31
- model.load_adapter(checkpoint_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- tokenizer = AutoTokenizer.from_pretrained(checkpoint_model, trust_remote_code=True)
34
- tokenizer.pad_token = tokenizer.eos_token
35
- tokenizer.padding_side = "right"
36
 
37
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
38
 
39
  def respond(message, history):
 
40
  system_message = """You are General Knowledge Assistant.
41
  Answer the questions based on the provided information.
42
  Be succinct and use first-principles thinking to answer the questions."""
 
5
  from config import load_config
6
 
7
  config = load_config("config.yaml")
 
8
  model_config = config["model_config"]
 
9
  model_name = model_config.pop("model_name")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  checkpoint_model = "checkpoint_dir/checkpoint-650"
11
 
12
+ # Global variables for model and tokenizer
13
+ model = None
14
+ tokenizer = None
15
+ pipe = None
16
+
17
+ def load_model_and_tokenizer():
18
+ global model, tokenizer, pipe
19
+ if model is None:
20
+ print("Loading model and tokenizer...")
21
+ # Convert torch_dtype from string to torch.dtype
22
+ if "torch_dtype" in model_config:
23
+ if model_config["torch_dtype"] == "float32":
24
+ model_config["torch_dtype"] = torch.float32
25
+ elif model_config["torch_dtype"] == "float16":
26
+ model_config["torch_dtype"] = torch.float16
27
+ elif model_config["torch_dtype"] == "bfloat16":
28
+ model_config["torch_dtype"] = torch.bfloat16
29
+
30
+ # Load the model without quantization config
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ model_name,
33
+ low_cpu_mem_usage=True,
34
+ **model_config
35
+ )
36
+
37
+ model.load_adapter(checkpoint_model)
38
 
39
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_model, trust_remote_code=True)
40
+ tokenizer.pad_token = tokenizer.eos_token
41
+ tokenizer.padding_side = "right"
42
 
43
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
44
+ print("Model and tokenizer loaded successfully.")
45
 
46
  def respond(message, history):
47
+ load_model_and_tokenizer()
48
  system_message = """You are General Knowledge Assistant.
49
  Answer the questions based on the provided information.
50
  Be succinct and use first-principles thinking to answer the questions."""