import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM choices_base_models = { 'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'meta-llama/Llama-3.2-3B-Instruct', 'groloch/gemma-2-2b-it-PromptEnhancing': 'google/gemma-2-2b-it', 'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'Qwen/Qwen2.5-3B-Instruct', 'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct' } choices_gen_token = { 'groloch/Llama-3.2-3B-Instruct-PromptEnhancing': 'assistant', 'groloch/gemma-2-2b-it-PromptEnhancing': 'model', 'groloch/Qwen2.5-3B-Instruct-PromptEnhancing': 'assistant', 'groloch/Ministral-3b-instruct-PromptEnhancing': 'ministral/Ministral-3b-instruct' } previous_choice = '' model = None tokenizer = None def load_model(adapter_repo_id: str): global model, tokenizer base_repo_id = choices_base_models[adapter_repo_id] tokenizer = AutoTokenizer.from_pretrained(base_repo_id) model = AutoModelForCausalLM.from_pretrained(base_repo_id, torch_dtype=torch.bfloat16) model.load_adapter(adapter_repo_id) def generate(prompt_to_enhance: str, choice: str, max_tokens: float, temperature: float, top_p: float, repetition_penalty: float ): if prompt_to_enhance is None or prompt_to_enhance == '': raise gr.Error('Please enter a prompt') global previous_choice if choice != previous_choice: previous_choice = choice load_model(choice) chat = [ {'role' : 'user', 'content': prompt_to_enhance} ] prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, return_tensors='pt') encoding = tokenizer(prompt, return_tensors="pt") generation_config = model.generation_config generation_config.do_sample = True generation_config.max_new_tokens = int(max_tokens) generation_config.temperature = float(temperature) generation_config.top_p = float(top_p) generation_config.num_return_sequences = 1 generation_config.pad_token_id = tokenizer.eos_token_id generation_config.eos_token_id = tokenizer.eos_token_id generation_config.repetition_penalty = float(repetition_penalty) with torch.inference_mode(): outputs = model.generate( input_ids=encoding.input_ids, attention_mask=encoding.attention_mask, generation_config=generation_config ) return tokenizer.decode(outputs[0], skip_special_tokens=True).split(choices_gen_token[choice])[-1] # # Inputs # model_choice = gr.Dropdown( label='Model choice', choices=['groloch/Llama-3.2-3B-Instruct-PromptEnhancing', 'groloch/gemma-2-2b-it-PromptEnhancing', 'groloch/Qwen2.5-3B-Instruct-PromptEnhancing', 'groloch/Ministral-3b-instruct-PromptEnhancing' ], value='groloch/Llama-3.2-3B-Instruct-PromptEnhancing' ) input_prompt = gr.Text( label='Prompt to enhance' ) # # Additional inputs # input_max_tokens = gr.Number( label='Max generated tokens', value=64, minimum=16, maximum=128 ) input_temperature = gr.Number( label='Temperature', value=0.3, minimum=0.0, maximum=1.5, step=0.05 ) input_top_p = gr.Number( label='Top p', value=0.9, minimum=0.0, maximum=1.0, step=0.05 ) input_repetition_penalty = gr.Number( label='Repetition penalty', value=2.0, minimum=0.0, maximum=5.0, step=0.1 ) demo = gr.Interface( generate, title='Prompt Enhancing Playground', description='This space is a tool to compare the different prompt enhancing model I have finetuned. \ Feel free to experiment as you want !', inputs=[input_prompt, model_choice], additional_inputs=[input_max_tokens, input_temperature, input_top_p, input_repetition_penalty ], outputs=['text'] ) if __name__ == "__main__": demo.launch()