groloch's picture
Added app
21a5563
raw
history blame
4.23 kB
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
choices_base_models = {
'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'meta-llama/Llama-3.2-3B-Instruct',
'groloch/gemma-2-2b-it-PromptEnhancing': 'google/gemma-2-2b-it',
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'Qwen/Qwen2.5-3B-Instruct',
'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct'
}
choices_gen_token = {
'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'assistant',
'groloch/gemma-2-2b-it-PromptEnhancing': 'model',
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'assistant',
'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct'
}
previous_choice = ''
model = None
tokenizer = None
def load_model(adapter_repo_id: str):
global model, tokenizer
base_repo_id = choices_base_models[adapter_repo_id]
tokenizer = AutoTokenizer.from_pretrained(base_repo_id)
model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16)
model.load_adapter(adapter_repo_id)
def generate(prompt_to_enhance: str,
choice: str,
max_tokens: float,
temperature: float,
top_p: float,
repetition_penalty: float
):
if prompt_to_enhance is None or prompt_to_enhance == '':
raise gr.Error('Please enter a prompt')
global previous_choice
if choice != previous_choice:
previous_choice = choice
load_model(choice)
chat = [
{'role' : 'user', 'content': prompt_to_enhance}
]
prompt = tokenizer.apply_chat_template(chat,
tokenize=False,
add_generation_prompt=True,
return_tensors='pt')
encoding = tokenizer(prompt, return_tensors="pt")
generation_config = model.generation_config
generation_config.do_sample = True
generation_config.max_new_tokens = int(max_tokens)
generation_config.temperature = float(temperature)
generation_config.top_p = float(top_p)
generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.repetition_penalty = float(repetition_penalty)
with torch.inference_mode():
outputs = model.generate(
input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask,
generation_config=generation_config
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).split(choices_gen_token[choice])[-1]
#
# Inputs
#
model_choice = gr.Dropdown(
label='Model choice',
choices=['groloch/Llama-3.2-3B-Instruct-PromptEnhancing',
'groloch/gemma-2-2b-it-PromptEnhancing',
'groloch/Qwen2.5-3B-Instruct-PromptEnhancing',
'groloch/Ministral-3b-instruct-PromptEnhancing'
],
value='groloch/Llama-3.2-3B-Instruct-PromptEnhancing'
)
input_prompt = gr.Text(
label='Prompt to enhance'
)
#
# Additional inputs
#
input_max_tokens = gr.Number(
label='Max generated tokens',
value=64,
minimum=16,
maximum=128
)
input_temperature = gr.Number(
label='Temperature',
value=0.3,
minimum=0.0,
maximum=1.5,
step=0.05
)
input_top_p = gr.Number(
label='Top p',
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05
)
input_repetition_penalty = gr.Number(
label='Repetition penalty',
value=2.0,
minimum=0.0,
maximum=5.0,
step=0.1
)
demo = gr.Interface(
generate,
title='Prompt Enhancing Playground',
description='This space is a tool to compare the different prompt enhancing model I have finetuned. \
Feel free to experiment as you want !',
inputs=[input_prompt, model_choice],
additional_inputs=[input_max_tokens,
input_temperature,
input_top_p,
input_repetition_penalty
],
outputs=['text']
)
if __name__ == "__main__":
demo.launch()