archit11 commited on
Commit
fb8f6e3
1 Parent(s): 22ff5cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -65,22 +65,24 @@ def log_comparison(model1_name: str, model2_name: str, question: str, answer1: s
65
  except requests.RequestException as e:
66
  print(f"Error sending log to server: {e}")
67
 
68
- # Function to prepare input
69
  def prepare_input(model_id: str, message: str, chat_history: List[Tuple[str, str]]):
70
  tokenizer = tokenizers[model_id]
71
  try:
 
72
  inputs = tokenizer(
73
  [x[1] for x in chat_history] + [message],
74
  return_tensors="pt",
75
  truncation=True,
76
  padding=True,
77
  max_length=MAX_INPUT_TOKEN_LENGTH,
 
78
  )
79
  except Exception as e:
80
  print(f"Error preparing input for model {model_id}: {e}")
81
- inputs = tokenizer([message], return_tensors="pt", padding=True, max_length=MAX_INPUT_TOKEN_LENGTH)
82
  return inputs
83
 
 
84
  # Function to generate responses from models
85
  @spaces.GPU(duration=120)
86
  def generate(
@@ -96,16 +98,26 @@ def generate(
96
 
97
  inputs = prepare_input(model_id, message, chat_history)
98
  input_ids = inputs.input_ids
 
99
 
100
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
101
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
102
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
 
 
 
 
103
  input_ids = input_ids.to(model.device)
 
104
 
105
  try:
106
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
107
  generate_kwargs = dict(
108
  input_ids=input_ids,
 
109
  streamer=streamer,
110
  max_new_tokens=max_new_tokens,
111
  do_sample=True,
@@ -125,6 +137,7 @@ def generate(
125
  print(f"Error generating response from model {model_id}: {e}")
126
  yield "Error generating response."
127
 
 
128
  # Function to compare two models
129
  def compare_models(
130
  model1_name: str,
 
65
  except requests.RequestException as e:
66
  print(f"Error sending log to server: {e}")
67
 
 
68
  def prepare_input(model_id: str, message: str, chat_history: List[Tuple[str, str]]):
69
  tokenizer = tokenizers[model_id]
70
  try:
71
+ # Prepare inputs for the model
72
  inputs = tokenizer(
73
  [x[1] for x in chat_history] + [message],
74
  return_tensors="pt",
75
  truncation=True,
76
  padding=True,
77
  max_length=MAX_INPUT_TOKEN_LENGTH,
78
+ return_attention_mask=True # Include the attention_mask
79
  )
80
  except Exception as e:
81
  print(f"Error preparing input for model {model_id}: {e}")
82
+ inputs = tokenizer([message], return_tensors="pt", padding=True, max_length=MAX_INPUT_TOKEN_LENGTH, return_attention_mask=True)
83
  return inputs
84
 
85
+
86
  # Function to generate responses from models
87
  @spaces.GPU(duration=120)
88
  def generate(
 
98
 
99
  inputs = prepare_input(model_id, message, chat_history)
100
  input_ids = inputs.input_ids
101
+ attention_mask = inputs.attention_mask # Get attention_mask
102
 
103
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
104
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
105
+ attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
106
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
107
+
108
+ # Ensure batch size is 1
109
+ if input_ids.shape[0] != 1:
110
+ input_ids = input_ids[:1]
111
+ attention_mask = attention_mask[:1]
112
+
113
  input_ids = input_ids.to(model.device)
114
+ attention_mask = attention_mask.to(model.device) # Move to the same device as input_ids
115
 
116
  try:
117
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
118
  generate_kwargs = dict(
119
  input_ids=input_ids,
120
+ attention_mask=attention_mask, # Pass the attention_mask
121
  streamer=streamer,
122
  max_new_tokens=max_new_tokens,
123
  do_sample=True,
 
137
  print(f"Error generating response from model {model_id}: {e}")
138
  yield "Error generating response."
139
 
140
+
141
  # Function to compare two models
142
  def compare_models(
143
  model1_name: str,