JERNGOC commited on
Commit
f4972b5
·
verified ·
1 Parent(s): 9af66f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -38
app.py CHANGED
@@ -44,14 +44,20 @@ if not torch.cuda.is_available():
44
  model = None
45
  tokenizer = None
46
 
47
- if torch.cuda.is_available():
48
- model_id = "apple/OpenELM-3B-Instruct"
49
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
50
- tokenizer_id = "meta-llama/Llama-2-7b-hf"
51
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
52
- if tokenizer.pad_token == None:
53
- tokenizer.pad_token = tokenizer.eos_token
54
- tokenizer.pad_token_id = tokenizer.eos_token_id
 
 
 
 
 
 
55
 
56
  @spaces.GPU
57
  def generate(
@@ -63,36 +69,43 @@ def generate(
63
  top_k: int = 50,
64
  repetition_penalty: float = 1.4,
65
  ) -> Iterator[str]:
66
- global model, tokenizer # Access global variables
67
-
68
- input_ids = tokenizer([message], return_tensors="pt").input_ids
69
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
70
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
71
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
72
- input_ids = input_ids.to(model.device)
73
-
74
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
75
- generate_kwargs = dict(
76
- {"input_ids": input_ids},
77
- streamer=streamer,
78
- max_new_tokens=max_new_tokens,
79
- do_sample=True,
80
- top_p=top_p,
81
- top_k=top_k,
82
- temperature=temperature,
83
- num_beams=1,
84
- pad_token_id = tokenizer.eos_token_id,
85
- repetition_penalty=repetition_penalty,
86
- no_repeat_ngram_size=5,
87
- early_stopping=True,
88
- )
89
- t = Thread(target=model.generate, kwargs=generate_kwargs)
90
- t.start()
91
-
92
- outputs = []
93
- for text in streamer:
94
- outputs.append(text)
95
- yield "".join(outputs)
 
 
 
 
 
 
 
96
 
97
  chat_interface = gr.ChatInterface(
98
  fn=generate,
 
44
  model = None
45
  tokenizer = None
46
 
47
+ def initialize_model_and_tokenizer():
48
+ global model, tokenizer
49
+ if torch.cuda.is_available():
50
+ model_id = "apple/OpenELM-3B-Instruct"
51
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
52
+ tokenizer_id = "meta-llama/Llama-2-7b-hf"
53
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
54
+ if tokenizer.pad_token is None:
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+ tokenizer.pad_token_id = tokenizer.eos_token_id
57
+ else:
58
+ print("CUDA is not available. Model and tokenizer will not be initialized.")
59
+
60
+ initialize_model_and_tokenizer()
61
 
62
  @spaces.GPU
63
  def generate(
 
69
  top_k: int = 50,
70
  repetition_penalty: float = 1.4,
71
  ) -> Iterator[str]:
72
+ global model, tokenizer
73
+
74
+ if tokenizer is None or model is None:
75
+ yield "Error: Model or tokenizer not initialized. Make sure you have GPU support and the necessary model access."
76
+ return
77
+
78
+ try:
79
+ input_ids = tokenizer([message], return_tensors="pt").input_ids
80
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
81
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
82
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
83
+ input_ids = input_ids.to(model.device)
84
+
85
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
86
+ generate_kwargs = dict(
87
+ input_ids=input_ids,
88
+ streamer=streamer,
89
+ max_new_tokens=max_new_tokens,
90
+ do_sample=True,
91
+ top_p=top_p,
92
+ top_k=top_k,
93
+ temperature=temperature,
94
+ num_beams=1,
95
+ pad_token_id=tokenizer.eos_token_id,
96
+ repetition_penalty=repetition_penalty,
97
+ no_repeat_ngram_size=5,
98
+ early_stopping=True,
99
+ )
100
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
101
+ t.start()
102
+
103
+ outputs = []
104
+ for text in streamer:
105
+ outputs.append(text)
106
+ yield "".join(outputs)
107
+ except Exception as e:
108
+ yield f"An error occurred: {str(e)}"
109
 
110
  chat_interface = gr.ChatInterface(
111
  fn=generate,