Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -336,8 +336,8 @@ def answer_query(query_text, index, documents, llm_model, llm_tokenizer, embeddi
|
|
336 |
retrieved_info = get_retrieved_info(documents, I, D)
|
337 |
formatted_info = format_retrieved_info(retrieved_info)
|
338 |
prompt = generate_prompt(query_text, formatted_info)
|
339 |
-
answer = answer_using_gemma(prompt, llm_model, llm_tokenizer)
|
340 |
-
return
|
341 |
|
342 |
|
343 |
|
@@ -393,7 +393,30 @@ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", token=HF_TOKE
|
|
393 |
|
394 |
|
395 |
def make_inference(query, hist):
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
demo = gr.ChatInterface(fn = make_inference,
|
399 |
examples = ["What is diabetes?", "Is ginseng good for diabetes?", "What are the symptoms of diabetes?", "What is Celiac disease?"],
|
|
|
336 |
retrieved_info = get_retrieved_info(documents, I, D)
|
337 |
formatted_info = format_retrieved_info(retrieved_info)
|
338 |
prompt = generate_prompt(query_text, formatted_info)
|
339 |
+
# answer = answer_using_gemma(prompt, llm_model, llm_tokenizer)
|
340 |
+
return prompt
|
341 |
|
342 |
|
343 |
|
|
|
393 |
|
394 |
|
395 |
def make_inference(query, hist):
|
396 |
+
prompt = answer_query(query, index, documents, model, tokenizer, CFG.embedding_model, CFG.n_samples, CFG.device)
|
397 |
+
# answer = answer_using_gemma(prompt, llm_model, llm_tokenizer)
|
398 |
+
model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
|
399 |
+
count_tokens = lambda text: len(tokenizer.tokenize(text))
|
400 |
+
|
401 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=540., skip_prompt=True, skip_special_tokens=True)
|
402 |
+
|
403 |
+
generate_kwargs = dict(
|
404 |
+
model_inputs,
|
405 |
+
streamer=streamer,
|
406 |
+
max_new_tokens=6000 - count_tokens(prompt),
|
407 |
+
top_p=0.2,
|
408 |
+
top_k=20,
|
409 |
+
temperature=0.1,
|
410 |
+
repetition_penalty=2.0,
|
411 |
+
length_penalty=-0.5,
|
412 |
+
num_beams=1
|
413 |
+
)
|
414 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
415 |
+
t.start() # Starting the generation in a separate thread.
|
416 |
+
partial_message = ""
|
417 |
+
for new_token in streamer:
|
418 |
+
partial_message += new_token
|
419 |
+
yield partial_message
|
420 |
|
421 |
demo = gr.ChatInterface(fn = make_inference,
|
422 |
examples = ["What is diabetes?", "Is ginseng good for diabetes?", "What are the symptoms of diabetes?", "What is Celiac disease?"],
|