kz209 commited on
Commit
9dfac6e
1 Parent(s): 34ffea3
Files changed (2) hide show
  1. pages/arena.py +1 -1
  2. utils/model.py +2 -2
pages/arena.py CHANGED
@@ -22,7 +22,7 @@ def create_arena():
22
  submit_button = gr.Button("✨ Submit ✨")
23
 
24
  with gr.Row():
25
- columns = [gr.Textbox(label=f"Column {i+1}", lines=10) for i in range(len(prompts))]
26
 
27
  content_list = [prompt + '\n{' + datapoint + '}\n\nsummary:' for prompt in prompts]
28
  model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
 
22
  submit_button = gr.Button("✨ Submit ✨")
23
 
24
  with gr.Row():
25
+ columns = [gr.Textbox(label=f"Prompt {i+1}", lines=10) for i in range(len(prompts))]
26
 
27
  content_list = [prompt + '\n{' + datapoint + '}\n\nsummary:' for prompt in prompts]
28
  model = get_model_batch_generation("Qwen/Qwen2-1.5B-Instruct")
utils/model.py CHANGED
@@ -55,7 +55,7 @@ class Model(torch.nn.Module):
55
  def return_model(self):
56
  return self.pipeline
57
 
58
- def gen(self, content_list, temp=0.1, max_length=500, streaming=False):
59
  # Convert list of texts to input IDs
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
@@ -74,7 +74,7 @@ class Model(torch.nn.Module):
74
  return_dict_in_generate=True,
75
  output_scores=True,
76
  streamer=streamer):
77
- pass # TextStreamer automatically handles the streaming, no need to manually handle the output
78
  else:
79
  outputs = self.model.generate(
80
  input_ids,
 
55
  def return_model(self):
56
  return self.pipeline
57
 
58
+ def gen(self, content_list, temp=0.001, max_length=500, streaming=False):
59
  # Convert list of texts to input IDs
60
  input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
61
 
 
74
  return_dict_in_generate=True,
75
  output_scores=True,
76
  streamer=streamer):
77
+ yield output # TextStreamer automatically handles the streaming, no need to manually handle the output
78
  else:
79
  outputs = self.model.generate(
80
  input_ids,