phi2_finetune / app.py
sagar007's picture
Update app.py
66c498d verified
raw
history blame
2.81 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
import torch
# --- 1. Check CUDA Availability and Set Device ---
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using device: {device} ({torch.cuda.get_device_name(0)})")
else:
print("CUDA is not available. Falling back to CPU.")
device = torch.device("cpu")
# --- 2. Load Tokenizer (with error handling) ---
MODEL_PATH = "sagar007/phi2_25k"
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
except Exception as e:
print(f"Error loading tokenizer: {e}")
exit()
# --- 3. Load Base Model (Optimized for GPU) ---
try:
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
torch_dtype=torch.float16, # Use float16 on GPU for efficiency
device_map="auto", # Automatically distribute model across GPUs
trust_remote_code=True
)
except Exception as e:
print(f"Error loading base model: {e}")
exit()
# --- 4. Load PEFT Model (Optimized for GPU) ---
try:
peft_config = PeftConfig.from_pretrained(MODEL_PATH)
model = PeftModel.from_pretrained(base_model, MODEL_PATH)
except Exception as e:
print(f"Error loading PEFT model: {e}")
exit()
# Move model to the GPU
model.to(device)
model.eval()
# --- 5. Generation Function (Optimized for GPU) ---
def generate_response(instruction, max_length=512):
prompt = f"Instruction: {instruction}\nResponse:"
try:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("Response:")[1].strip()
except Exception as e:
print(f"Error during generation: {e}")
return "Error during response generation."
# --- 6. Gradio Interface ---
def chatbot(message, history):
response = generate_response(message)
return response
demo = gr.ChatInterface(
chatbot,
title="Fine-tuned Phi-2 Chatbot (GPU)",
description="This is a chatbot using a fine-tuned version of the Phi-2 model, running on GPU.",
theme="default",
examples=[
"Explain the concept of machine learning.",
"Write a short story about a robot learning to paint.",
"What are some effective ways to reduce stress?",
],
cache_examples=False, # You can enable caching now
)
if __name__ == "__main__":
demo.launch()