|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
choices_base_models = { |
|
'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'meta-llama/Llama-3.2-3B-Instruct', |
|
'groloch/gemma-2-2b-it-PromptEnhancing': 'google/gemma-2-2b-it', |
|
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'Qwen/Qwen2.5-3B-Instruct', |
|
'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct' |
|
} |
|
|
|
choices_gen_token = { |
|
'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'assistant', |
|
'groloch/gemma-2-2b-it-PromptEnhancing': 'model', |
|
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'assistant', |
|
'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct' |
|
} |
|
|
|
previous_choice = '' |
|
|
|
model = None |
|
tokenizer = None |
|
|
|
|
|
def load_model(adapter_repo_id: str): |
|
global model, tokenizer |
|
base_repo_id = choices_base_models[adapter_repo_id] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_repo_id) |
|
model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16) |
|
|
|
model.load_adapter(adapter_repo_id) |
|
|
|
def generate(prompt_to_enhance: str, |
|
choice: str, |
|
max_tokens: float, |
|
temperature: float, |
|
top_p: float, |
|
repetition_penalty: float |
|
): |
|
if prompt_to_enhance is None or prompt_to_enhance == '': |
|
raise gr.Error('Please enter a prompt') |
|
global previous_choice |
|
|
|
if choice != previous_choice: |
|
previous_choice = choice |
|
load_model(choice) |
|
|
|
chat = [ |
|
{'role' : 'user', 'content': prompt_to_enhance} |
|
] |
|
|
|
prompt = tokenizer.apply_chat_template(chat, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
return_tensors='pt') |
|
|
|
encoding = tokenizer(prompt, return_tensors="pt") |
|
|
|
generation_config = model.generation_config |
|
generation_config.do_sample = True |
|
generation_config.max_new_tokens = int(max_tokens) |
|
generation_config.temperature = float(temperature) |
|
generation_config.top_p = float(top_p) |
|
generation_config.num_return_sequences = 1 |
|
generation_config.pad_token_id = tokenizer.eos_token_id |
|
generation_config.eos_token_id = tokenizer.eos_token_id |
|
generation_config.repetition_penalty = float(repetition_penalty) |
|
|
|
with torch.inference_mode(): |
|
outputs = model.generate( |
|
input_ids=encoding.input_ids, |
|
attention_mask=encoding.attention_mask, |
|
generation_config=generation_config |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True).split(choices_gen_token[choice])[-1] |
|
|
|
|
|
|
|
|
|
|
|
model_choice = gr.Dropdown( |
|
label='Model choice', |
|
choices=['groloch/Llama-3.2-3B-Instruct-PromptEnhancing', |
|
'groloch/gemma-2-2b-it-PromptEnhancing', |
|
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing', |
|
'groloch/Ministral-3b-instruct-PromptEnhancing' |
|
], |
|
value='groloch/Llama-3.2-3B-Instruct-PromptEnhancing' |
|
) |
|
input_prompt = gr.Text( |
|
label='Prompt to enhance' |
|
) |
|
|
|
|
|
|
|
|
|
input_max_tokens = gr.Number( |
|
label='Max generated tokens', |
|
value=64, |
|
minimum=16, |
|
maximum=128 |
|
) |
|
input_temperature = gr.Number( |
|
label='Temperature', |
|
value=0.3, |
|
minimum=0.0, |
|
maximum=1.5, |
|
step=0.05 |
|
) |
|
input_top_p = gr.Number( |
|
label='Top p', |
|
value=0.9, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05 |
|
) |
|
input_repetition_penalty = gr.Number( |
|
label='Repetition penalty', |
|
value=2.0, |
|
minimum=0.0, |
|
maximum=5.0, |
|
step=0.1 |
|
) |
|
|
|
demo = gr.Interface( |
|
generate, |
|
title='Prompt Enhancing Playground', |
|
description='This space is a tool to compare the different prompt enhancing model I have finetuned. \ |
|
Feel free to experiment as you want !', |
|
inputs=[input_prompt, model_choice], |
|
additional_inputs=[input_max_tokens, |
|
input_temperature, |
|
input_top_p, |
|
input_repetition_penalty |
|
], |
|
outputs=['text'] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|