alfredplpl commited on
Commit
44c960c
·
verified ·
1 Parent(s): 377528c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -46,7 +46,7 @@ h1 {
46
  # Load the tokenizer and model
47
  tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0")
48
  model = AutoModelForCausalLM.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0", device_map="auto", torch_dtype=torch.bfloat16)
49
- #model=model.eval()
50
 
51
  @spaces.GPU()
52
  def chat_llm_jp_v2(message: str,
@@ -71,15 +71,14 @@ def chat_llm_jp_v2(message: str,
71
  conversation.append({"role": "user", "content": message})
72
 
73
  tokenized_input = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=True, return_tensors="pt").to(model.device)
74
- with torch.no_grad():
75
- output = model.generate(
76
- tokenized_input,
77
- max_new_tokens=max_new_tokens,
78
- do_sample=True,
79
- top_p=0.95,
80
- temperature=temperature,
81
- repetition_penalty=1.05,
82
- )[0]
83
  return tokenizer.decode(output)
84
 
85
 
 
46
  # Load the tokenizer and model
47
  tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0")
48
  model = AutoModelForCausalLM.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0", device_map="auto", torch_dtype=torch.bfloat16)
49
+ model=model.eval()
50
 
51
  @spaces.GPU()
52
  def chat_llm_jp_v2(message: str,
 
71
  conversation.append({"role": "user", "content": message})
72
 
73
  tokenized_input = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=True, return_tensors="pt").to(model.device)
74
+ output = model.generate(
75
+ tokenized_input,
76
+ max_new_tokens=max_new_tokens,
77
+ do_sample=True,
78
+ top_p=0.95,
79
+ temperature=temperature,
80
+ repetition_penalty=1.05,
81
+ )[0]
 
82
  return tokenizer.decode(output)
83
 
84