schuler commited on
Commit
ff32de9
·
verified ·
1 Parent(s): 8aee95a

Update app.py

Browse files

local_generate

Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -30,6 +30,29 @@ except Exception as e:
30
  global_error = f"Failed to load model: {str(e)}"
31
 
32
  @spaces.GPU()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def respond(
34
  message,
35
  history: list[tuple[str, str]],
@@ -82,18 +105,15 @@ def respond(
82
  full_result = ''
83
  while ( (tokens_cnt < max_tokens) and (last_token_len > 0) ):
84
  # Generate the response
85
- response_output = generator(
86
  prompt,
87
  generation_config=generator_conf,
88
- max_new_tokens=tokens_inc,
89
  do_sample=True,
90
  top_p=top_p,
91
  repetition_penalty=1.2,
92
  temperature=temperature
93
- )
94
- generated_text = response_output[0]['generated_text']
95
- # Extract the assistant's response
96
- result = generated_text[len(prompt):]
97
  full_result = full_result + result
98
  prompt = prompt + result
99
  tokens_cnt = tokens_cnt + tokens_inc
 
30
  global_error = f"Failed to load model: {str(e)}"
31
 
32
  @spaces.GPU()
33
+ def local_generate(
34
+ prompt,
35
+ generation_config,
36
+ max_new_tokens,
37
+ do_sample,
38
+ top_p,
39
+ repetition_penalty,
40
+ temperature=temperature
41
+ ):
42
+ response_output = generator(
43
+ prompt,
44
+ generation_config=generation_config,
45
+ max_new_tokens=max_new_tokens,
46
+ do_sample=do_sample,
47
+ top_p=top_p,
48
+ repetition_penalty=repetition_penalty,
49
+ temperature=temperature
50
+ )
51
+ generated_text = response_output[0]['generated_text']
52
+ # Extract the assistant's response
53
+ result = generated_text[len(prompt):]
54
+ return result
55
+
56
  def respond(
57
  message,
58
  history: list[tuple[str, str]],
 
105
  full_result = ''
106
  while ( (tokens_cnt < max_tokens) and (last_token_len > 0) ):
107
  # Generate the response
108
+ result = local_generate(
109
  prompt,
110
  generation_config=generator_conf,
111
+ max_new_tokens=max_tokens,
112
  do_sample=True,
113
  top_p=top_p,
114
  repetition_penalty=1.2,
115
  temperature=temperature
116
+ )
 
 
 
117
  full_result = full_result + result
118
  prompt = prompt + result
119
  tokens_cnt = tokens_cnt + tokens_inc