lab2 / recipe_lora.py
jedeland's picture
layout
ab6e7b1
import gradio as gr
from transformers import TextIteratorStreamer
from threading import Thread
from unsloth import FastLanguageModel
load_in_4bit = True
recipe_model_id = "ID2223JR/recipe_model"
lora_model_id = "ID2223JR/lora_model"
recipe_model, recipe_tokenizer = FastLanguageModel.from_pretrained(
model_name=recipe_model_id,
load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(recipe_model)
lora_model, lora_tokenizer = FastLanguageModel.from_pretrained(
model_name=lora_model_id,
load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(lora_model)
ingredients_list = []
def add_ingredient(ingredient, quantity):
if ingredient and int(quantity) > 0:
ingredients_list.append(f"{ingredient}, {quantity} grams")
return (
"\n".join(ingredients_list),
gr.update(value="", interactive=True),
gr.update(value=None, interactive=True),
)
def validate_inputs(ingredient, quantity):
if ingredient and int(quantity) > 0:
return gr.update(interactive=True)
return gr.update(interactive=False)
def submit_to_model(selected_model):
if not ingredients_list:
return "Ingredients list is empty! Please add ingredients first."
if selected_model == "Recipe Model":
model = recipe_model
tokenizer = recipe_tokenizer
elif selected_model == "LoRA Model":
model = lora_model
tokenizer = lora_tokenizer
else:
return "Invalid model selected!"
prompt = f"Using the following ingredients, suggest a recipe:\n\n" + "\n".join(
ingredients_list
)
ingredients_list.clear()
messages = [
{"role": "user", "content": prompt},
]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True, # Must add for generation
return_tensors="pt",
).to("cuda")
text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(inputs=inputs, streamer=text_streamer, use_cache=True, temperature=0.3, min_p=0.1)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
content = ""
for text in text_streamer:
print(text)
content += text
if content.endswith("<|eot_id|>"):
content = content.replace("<|eot_id|>", "")
yield content
def app():
with gr.Blocks() as demo:
with gr.Row():
ingredient_input = gr.Textbox(
label="Ingredient", placeholder="Enter ingredient name"
)
quantity_input = gr.Number(label="Quantity (grams)", value=None)
add_button = gr.Button("Add Ingredient", interactive=False)
output = gr.Textbox(label="Ingredients List", lines=10, interactive=False)
with gr.Row():
submit_button = gr.Button("Give me a meal!")
model_dropdown = gr.Dropdown(
label="Choose Model",
choices=["Recipe Model", "LoRA Model"],
value="Recipe Model"
)
with gr.Row():
model_output = gr.Textbox(
label="Recipe Suggestion", lines=10, interactive=False
)
ingredient_input.change(
validate_inputs, [ingredient_input, quantity_input], add_button
)
quantity_input.change(
validate_inputs, [ingredient_input, quantity_input], add_button
)
add_button.click(
add_ingredient,
[ingredient_input, quantity_input],
[output, ingredient_input, quantity_input],
)
submit_button.click(
submit_to_model,
inputs=model_dropdown,
outputs=model_output,
)
return demo
demo = app()
demo.launch()