Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,11 @@ import logging
|
|
10 |
import spaces
|
11 |
from threading import Thread
|
12 |
from collections.abc import Iterator
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Setup logging
|
15 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
|
@@ -101,15 +106,16 @@ def llama_generate(
|
|
101 |
repetition_penalty: float = 1.2,
|
102 |
) -> Iterator[str]:
|
103 |
|
104 |
-
|
|
|
105 |
|
106 |
-
if input_ids.shape[1] > LLAMA_MAX_INPUT_TOKEN_LENGTH:
|
107 |
-
input_ids = input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
|
108 |
gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
|
109 |
|
110 |
streamer = TextIteratorStreamer(llama_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
111 |
generate_kwargs = dict(
|
112 |
-
|
113 |
streamer=streamer,
|
114 |
max_new_tokens=max_new_tokens,
|
115 |
do_sample=True,
|
|
|
10 |
import spaces
|
11 |
from threading import Thread
|
12 |
from collections.abc import Iterator
|
13 |
+
import csv
|
14 |
+
|
15 |
+
# Increase CSV field size limit
|
16 |
+
csv.field_size_limit(1000000) # Or an even larger value if needed
|
17 |
+
|
18 |
|
19 |
# Setup logging
|
20 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
|
|
|
106 |
repetition_penalty: float = 1.2,
|
107 |
) -> Iterator[str]:
|
108 |
|
109 |
+
inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
|
110 |
+
#The line above was changed to add attention mask
|
111 |
|
112 |
+
if inputs.input_ids.shape[1] > LLAMA_MAX_INPUT_TOKEN_LENGTH:
|
113 |
+
inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
|
114 |
gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
|
115 |
|
116 |
streamer = TextIteratorStreamer(llama_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
117 |
generate_kwargs = dict(
|
118 |
+
inputs, # Pass the entire inputs dictionary
|
119 |
streamer=streamer,
|
120 |
max_new_tokens=max_new_tokens,
|
121 |
do_sample=True,
|