alaamostafa's picture
Create app.py
4f06d80 verified
raw
history blame
6.21 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
# Set up model parameters
MODEL_ID = "alaamostafa/Microsoft-Phi-2"
BASE_MODEL_ID = "microsoft/phi-2"
# Force CPU usage and set up offload directory
device = "cpu"
print(f"Using device: {device}")
os.makedirs("offload_dir", exist_ok=True)
# Disable bitsandbytes for CPU usage
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
# Load base model with simple CPU configuration, avoiding device_map and 8-bit loading
print("Loading base model...")
try:
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32, # Use float32 for CPU
trust_remote_code=True,
low_cpu_mem_usage=True, # Optimize for lower memory usage
offload_folder="offload_dir" # Set offload directory
)
# Load the fine-tuned adapter
print(f"Loading adapter from {MODEL_ID}...")
model = PeftModel.from_pretrained(
base_model,
MODEL_ID,
offload_folder="offload_dir"
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Create a placeholder error message for the UI
error_message = f"Failed to load model: {str(e)}\n\nThis Space may need a GPU to run properly."
def generate_text(
prompt,
max_length=256, # Reduced for CPU
temperature=0.7,
top_p=0.9,
top_k=40,
repetition_penalty=1.1
):
"""Generate text based on prompt with the fine-tuned model"""
try:
# Prepare input
inputs = tokenizer(prompt, return_tensors="pt")
# Generate text
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=temperature > 0,
pad_token_id=tokenizer.eos_token_id
)
# Decode and return the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
except Exception as e:
return f"Error generating text: {str(e)}"
# Create the Gradio interface
css = """
.gradio-container {max-width: 800px !important}
.gr-prose code {white-space: pre-wrap !important}
"""
title = "Neuroscience Fine-tuned Phi-2 Model (CPU Version)"
description = """
This is a fine-tuned version of Microsoft's Phi-2 model, adapted specifically for neuroscience domain content.
⚠️ **Note: This model is running on CPU which means responses will be slower.** ⚠️
For best performance:
- Keep your prompts focused and clear
- Use shorter maximum length settings (128-256)
- Be patient as generation can take 30+ seconds
**Example prompts:**
- Recent advances in neuroimaging suggest that
- The role of dopamine in learning and memory involves
- Explain the concept of neuroplasticity in simple terms
- What are the key differences between neurons and glial cells?
"""
# Check if model loaded successfully
if 'error_message' in locals():
# Simple error interface
demo = gr.Interface(
fn=lambda x: error_message,
inputs=gr.Textbox(label="This model cannot be loaded on CPU"),
outputs=gr.Textbox(),
title=title,
description=description
)
else:
# Full interface
with gr.Blocks(css=css) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Enter your prompt",
placeholder="Recent advances in neuroscience suggest that",
lines=5
)
with gr.Row():
submit_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Accordion("Advanced Options", open=False):
max_length = gr.Slider(
minimum=64, maximum=512, value=256, step=64,
label="Maximum Length (lower is faster on CPU)"
)
temperature = gr.Slider(
minimum=0.0, maximum=1.5, value=0.7, step=0.1,
label="Temperature (0 = deterministic, 0.7 = creative, 1.5 = random)"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.9, step=0.1,
label="Top-p (nucleus sampling)"
)
top_k = gr.Slider(
minimum=1, maximum=100, value=40, step=1,
label="Top-k"
)
repetition_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.1, step=0.1,
label="Repetition Penalty"
)
with gr.Column():
output = gr.Textbox(
label="Generated Text",
lines=20
)
# Set up event handlers
submit_btn.click(
fn=generate_text,
inputs=[prompt, max_length, temperature, top_p, top_k, repetition_penalty],
outputs=output
)
clear_btn.click(
fn=lambda: ("", None),
inputs=None,
outputs=[prompt, output]
)
# Example prompts
examples = [
["Recent advances in neuroimaging suggest that"],
["The role of dopamine in learning and memory involves"],
["Explain the concept of neuroplasticity in simple terms"],
["What are the key differences between neurons and glial cells?"]
]
gr.Examples(
examples=examples,
inputs=prompt
)
# Launch the app
demo.launch()