ThinkDeeper / app.py
Abrak's picture
Asking Copilot to fix a runtime error
c93e54f verified
raw
history blame
3.92 kB
# app.py
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
import torch
import os
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model name
PRETRAINED_MODEL = "distilgpt2"
MODEL_DIR = "./fine_tuned_model"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
def fine_tune_model(files):
# Combine uploaded files into one text
if not files:
return "No files uploaded."
text_data = ""
for file in files:
text = file.decode('utf-8')
text_data += text + "\n"
# Save combined text to a file
with open("train.txt", "w") as f:
f.write(text_data)
# Create dataset
dataset = TextDataset(
tokenizer=tokenizer,
file_path="train.txt",
block_size=128
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=False,
)
# Load pre-trained model
model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL)
model.to(device)
# Set training arguments
training_args = TrainingArguments(
output_dir=MODEL_DIR,
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=4,
save_steps=500,
save_total_limit=2,
logging_steps=100,
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
# Fine-tune model
trainer.train()
# Save the model
trainer.save_model(MODEL_DIR)
tokenizer.save_pretrained(MODEL_DIR)
return "Fine-tuning completed successfully!"
def generate_response(prompt, temperature, max_length, top_p):
# Load fine-tuned model if available
if os.path.exists(MODEL_DIR):
model = AutoModelForCausalLM.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
else:
model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL)
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
model.to(device)
# Encode prompt
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
# Generate output
output = model.generate(
input_ids,
do_sample=True,
max_length=int(max_length),
temperature=float(temperature),
top_p=float(top_p),
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response
# Build Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# πŸš€ Language Model Fine-Tuner and Chatbot")
with gr.Tab("Fine-Tune Model"):
gr.Markdown("## πŸ“š Fine-Tune the Model with Your Documents")
file_inputs = gr.File(label="Upload Text Files", file_count="multiple", type="binary")
fine_tune_button = gr.Button("Start Fine-Tuning")
fine_tune_status = gr.Textbox(label="Status", interactive=False)
fine_tune_button.click(fine_tune_model, inputs=file_inputs, outputs=fine_tune_status)
with gr.Tab("Chat with Model"):
gr.Markdown("## πŸ’¬ Chat with the Fine-Tuned Model")
user_input = gr.Textbox(label="Your Message")
with gr.Accordion("Advanced Parameters", open=False):
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
max_length = gr.Slider(20, 200, value=100, step=10, label="Max Length")
top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top-p")
generate_button = gr.Button("Generate Response")
bot_response = gr.Textbox(label="Model Response", interactive=False)
generate_button.click(generate_response, inputs=[user_input, temperature, max_length, top_p], outputs=bot_response)
demo.launch()