Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel | |
import os | |
# Set up model parameters | |
MODEL_ID = "alaamostafa/Microsoft-Phi-2" | |
BASE_MODEL_ID = "microsoft/phi-2" | |
# Check if CUDA is available and set device accordingly | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) | |
# Load base model with appropriate dtype based on available hardware | |
print("Loading base model...") | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL_ID, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
# Load the fine-tuned adapter | |
print(f"Loading adapter from {MODEL_ID}...") | |
model = PeftModel.from_pretrained( | |
base_model, | |
MODEL_ID, | |
device_map="auto" | |
) | |
print("Model loaded successfully!") | |
def generate_text( | |
prompt, | |
max_length=512, | |
temperature=0.7, | |
top_p=0.9, | |
top_k=40, | |
repetition_penalty=1.1 | |
): | |
"""Generate text based on prompt with the fine-tuned model""" | |
# Prepare input | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Generate text | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
do_sample=temperature > 0, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and return the generated text | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return generated_text | |
# Create the Gradio interface | |
css = """ | |
.gradio-container {max-width: 800px !important} | |
.gr-prose code {white-space: pre-wrap !important} | |
""" | |
title = "Neuroscience Fine-tuned Phi-2 Model" | |
description = """ | |
This is a fine-tuned version of Microsoft's Phi-2 model, adapted specifically for neuroscience domain content. | |
Use this interface to interact with the model and see how it handles neuroscience-related queries. | |
**Example prompts:** | |
- Recent advances in neuroimaging suggest that | |
- The role of dopamine in learning and memory involves | |
- Explain the concept of neuroplasticity in simple terms | |
- What are the key differences between neurons and glial cells? | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="Recent advances in neuroscience suggest that", | |
lines=5 | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Generate", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Accordion("Advanced Options", open=False): | |
max_length = gr.Slider( | |
minimum=64, maximum=1024, value=512, step=64, | |
label="Maximum Length" | |
) | |
temperature = gr.Slider( | |
minimum=0.0, maximum=1.5, value=0.7, step=0.1, | |
label="Temperature (0 = deterministic, 0.7 = creative, 1.5 = random)" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.9, step=0.1, | |
label="Top-p (nucleus sampling)" | |
) | |
top_k = gr.Slider( | |
minimum=1, maximum=100, value=40, step=1, | |
label="Top-k" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=1.0, maximum=2.0, value=1.1, step=0.1, | |
label="Repetition Penalty" | |
) | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Generated Text", | |
lines=20 | |
) | |
# Set up event handlers | |
submit_btn.click( | |
fn=generate_text, | |
inputs=[prompt, max_length, temperature, top_p, top_k, repetition_penalty], | |
outputs=output | |
) | |
clear_btn.click( | |
fn=lambda: ("", None), | |
inputs=None, | |
outputs=[prompt, output] | |
) | |
# Example prompts | |
examples = [ | |
["Recent advances in neuroimaging suggest that"], | |
["The role of dopamine in learning and memory involves"], | |
["Explain the concept of neuroplasticity in simple terms"], | |
["What are the key differences between neurons and glial cells?"] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=prompt | |
) | |
# Launch the app | |
demo.launch() |