|
import gradio as gr |
|
|
|
from transformers import TextIteratorStreamer |
|
from threading import Thread |
|
from unsloth import FastLanguageModel |
|
|
|
load_in_4bit = True |
|
|
|
recipe_model_id = "ID2223JR/recipe_model" |
|
lora_model_id = "ID2223JR/lora_model" |
|
|
|
recipe_model, recipe_tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=recipe_model_id, |
|
load_in_4bit=load_in_4bit, |
|
) |
|
FastLanguageModel.for_inference(recipe_model) |
|
|
|
lora_model, lora_tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=lora_model_id, |
|
load_in_4bit=load_in_4bit, |
|
) |
|
FastLanguageModel.for_inference(lora_model) |
|
|
|
ingredients_list = [] |
|
|
|
def add_ingredient(ingredient, quantity): |
|
if ingredient and int(quantity) > 0: |
|
ingredients_list.append(f"{ingredient}, {quantity} grams") |
|
return ( |
|
"\n".join(ingredients_list), |
|
gr.update(value="", interactive=True), |
|
gr.update(value=None, interactive=True), |
|
) |
|
|
|
def validate_inputs(ingredient, quantity): |
|
if ingredient and int(quantity) > 0: |
|
return gr.update(interactive=True) |
|
return gr.update(interactive=False) |
|
|
|
def submit_to_model(selected_model): |
|
if not ingredients_list: |
|
return "Ingredients list is empty! Please add ingredients first." |
|
|
|
if selected_model == "Recipe Model": |
|
model = recipe_model |
|
tokenizer = recipe_tokenizer |
|
elif selected_model == "LoRA Model": |
|
model = lora_model |
|
tokenizer = lora_tokenizer |
|
else: |
|
return "Invalid model selected!" |
|
|
|
prompt = f"Using the following ingredients, suggest a recipe:\n\n" + "\n".join( |
|
ingredients_list |
|
) |
|
ingredients_list.clear() |
|
|
|
messages = [ |
|
{"role": "user", "content": prompt}, |
|
] |
|
inputs = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt", |
|
).to("cuda") |
|
|
|
text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) |
|
|
|
generation_kwargs = dict(inputs=inputs, streamer=text_streamer, use_cache=True, temperature=0.3, min_p=0.1) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
content = "" |
|
for text in text_streamer: |
|
print(text) |
|
content += text |
|
if content.endswith("<|eot_id|>"): |
|
content = content.replace("<|eot_id|>", "") |
|
yield content |
|
|
|
def app(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
ingredient_input = gr.Textbox( |
|
label="Ingredient", placeholder="Enter ingredient name" |
|
) |
|
quantity_input = gr.Number(label="Quantity (grams)", value=None) |
|
|
|
add_button = gr.Button("Add Ingredient", interactive=False) |
|
output = gr.Textbox(label="Ingredients List", lines=10, interactive=False) |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Give me a meal!") |
|
model_dropdown = gr.Dropdown( |
|
label="Choose Model", |
|
choices=["Recipe Model", "LoRA Model"], |
|
value="Recipe Model" |
|
) |
|
|
|
with gr.Row(): |
|
model_output = gr.Textbox( |
|
label="Recipe Suggestion", lines=10, interactive=False |
|
) |
|
|
|
ingredient_input.change( |
|
validate_inputs, [ingredient_input, quantity_input], add_button |
|
) |
|
quantity_input.change( |
|
validate_inputs, [ingredient_input, quantity_input], add_button |
|
) |
|
|
|
add_button.click( |
|
add_ingredient, |
|
[ingredient_input, quantity_input], |
|
[output, ingredient_input, quantity_input], |
|
) |
|
|
|
submit_button.click( |
|
submit_to_model, |
|
inputs=model_dropdown, |
|
outputs=model_output, |
|
) |
|
|
|
return demo |
|
|
|
|
|
demo = app() |
|
demo.launch() |
|
|