schuler commited on
Commit
c0252bb
·
verified ·
1 Parent(s): 8f7074e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -13,7 +13,7 @@ def load_model(repo_name):
13
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
14
  generator_conf = GenerationConfig.from_pretrained(repo_name)
15
  model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="eager")
16
- # model.to('cuda')
17
  return tokenizer, generator_conf, model
18
 
19
  tokenizer, generator_conf, model = load_model(REPO_NAME)
@@ -61,7 +61,8 @@ def respond(
61
  max_new_tokens=max_tokens,
62
  do_sample=True,
63
  top_p=top_p,
64
- repetition_penalty=1.2
 
65
  )
66
 
67
  generated_text = response_output[0]['generated_text']
@@ -101,7 +102,7 @@ demo = gr.ChatInterface(
101
  additional_inputs=[
102
  gr.Textbox(value="" + global_error, label="System message"),
103
  gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
104
- # gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
105
  gr.Slider(
106
  minimum=0.1,
107
  maximum=1.0,
 
13
  tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
14
  generator_conf = GenerationConfig.from_pretrained(repo_name)
15
  model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="eager")
16
+ model.to('cuda')
17
  return tokenizer, generator_conf, model
18
 
19
  tokenizer, generator_conf, model = load_model(REPO_NAME)
 
61
  max_new_tokens=max_tokens,
62
  do_sample=True,
63
  top_p=top_p,
64
+ repetition_penalty=1.2,
65
+ temperature=temperature
66
  )
67
 
68
  generated_text = response_output[0]['generated_text']
 
102
  additional_inputs=[
103
  gr.Textbox(value="" + global_error, label="System message"),
104
  gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
105
+ gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
106
  gr.Slider(
107
  minimum=0.1,
108
  maximum=1.0,