karths commited on
Commit
dd28b0b
·
verified ·
1 Parent(s): 34063b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -10,6 +10,11 @@ import logging
10
  import spaces
11
  from threading import Thread
12
  from collections.abc import Iterator
 
 
 
 
 
13
 
14
  # Setup logging
15
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
@@ -101,15 +106,16 @@ def llama_generate(
101
  repetition_penalty: float = 1.2,
102
  ) -> Iterator[str]:
103
 
104
- input_ids = llama_tokenizer.encode(message, return_tensors="pt").to(llama_model.device)
 
105
 
106
- if input_ids.shape[1] > LLAMA_MAX_INPUT_TOKEN_LENGTH:
107
- input_ids = input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
108
  gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
109
 
110
  streamer = TextIteratorStreamer(llama_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
111
  generate_kwargs = dict(
112
- {"input_ids": input_ids},
113
  streamer=streamer,
114
  max_new_tokens=max_new_tokens,
115
  do_sample=True,
 
10
  import spaces
11
  from threading import Thread
12
  from collections.abc import Iterator
13
+ import csv
14
+
15
+ # Increase CSV field size limit
16
+ csv.field_size_limit(1000000) # Or an even larger value if needed
17
+
18
 
19
  # Setup logging
20
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
 
106
  repetition_penalty: float = 1.2,
107
  ) -> Iterator[str]:
108
 
109
+ inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
110
+ #The line above was changed to add attention mask
111
 
112
+ if inputs.input_ids.shape[1] > LLAMA_MAX_INPUT_TOKEN_LENGTH:
113
+ inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
114
  gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
115
 
116
  streamer = TextIteratorStreamer(llama_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
117
  generate_kwargs = dict(
118
+ inputs, # Pass the entire inputs dictionary
119
  streamer=streamer,
120
  max_new_tokens=max_new_tokens,
121
  do_sample=True,