Spaces:
Runtime error
Runtime error
project-baize
commited on
Commit
•
9d3530f
1
Parent(s):
7e2a3f2
Update app.py
Browse files
app.py
CHANGED
@@ -17,8 +17,8 @@ base_model = "decapoda-research/llama-7b-hf"
|
|
17 |
adapter_model = "project-baize/baize-lora-7B"
|
18 |
tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
|
19 |
|
20 |
-
global
|
21 |
-
|
22 |
def predict(text,
|
23 |
chatbot,
|
24 |
history,
|
@@ -44,8 +44,8 @@ def predict(text,
|
|
44 |
begin_length = len(prompt)
|
45 |
torch.cuda.empty_cache()
|
46 |
input_ids = inputs["input_ids"].to(device)
|
47 |
-
|
48 |
-
print(
|
49 |
with torch.no_grad():
|
50 |
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
|
51 |
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
|
|
|
17 |
adapter_model = "project-baize/baize-lora-7B"
|
18 |
tokenizer,model,device = load_tokenizer_and_model(base_model,adapter_model)
|
19 |
|
20 |
+
global total_count
|
21 |
+
total_count = 0
|
22 |
def predict(text,
|
23 |
chatbot,
|
24 |
history,
|
|
|
44 |
begin_length = len(prompt)
|
45 |
torch.cuda.empty_cache()
|
46 |
input_ids = inputs["input_ids"].to(device)
|
47 |
+
total_count += 1
|
48 |
+
print(total_count)
|
49 |
with torch.no_grad():
|
50 |
for x in greedy_search(input_ids,model,tokenizer,stop_words=["[|Human|]", "[|AI|]"],max_length=max_length_tokens,temperature=temperature,top_p=top_p):
|
51 |
if is_stop_word_or_prefix(x,["[|Human|]", "[|AI|]"]) is False:
|