|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
import os |
|
|
|
|
|
MODEL_ID = "alaamostafa/Microsoft-Phi-2" |
|
BASE_MODEL_ID = "microsoft/phi-2" |
|
|
|
|
|
device = "cpu" |
|
print(f"Using device: {device}") |
|
os.makedirs("offload_dir", exist_ok=True) |
|
|
|
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
|
|
|
|
|
print("Loading base model...") |
|
try: |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
BASE_MODEL_ID, |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True, |
|
offload_folder="offload_dir" |
|
) |
|
|
|
|
|
print(f"Loading adapter from {MODEL_ID}...") |
|
model = PeftModel.from_pretrained( |
|
base_model, |
|
MODEL_ID, |
|
offload_folder="offload_dir" |
|
) |
|
|
|
print("Model loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
|
|
error_message = f"Failed to load model: {str(e)}\n\nThis Space may need a GPU to run properly." |
|
|
|
def generate_text( |
|
prompt, |
|
max_length=256, |
|
temperature=0.7, |
|
top_p=0.9, |
|
top_k=40, |
|
repetition_penalty=1.1 |
|
): |
|
"""Generate text based on prompt with the fine-tuned model""" |
|
try: |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=temperature > 0, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_text |
|
except Exception as e: |
|
return f"Error generating text: {str(e)}" |
|
|
|
|
|
css = """ |
|
.gradio-container {max-width: 800px !important} |
|
.gr-prose code {white-space: pre-wrap !important} |
|
""" |
|
|
|
title = "Neuroscience Fine-tuned Phi-2 Model (CPU Version)" |
|
description = """ |
|
This is a fine-tuned version of Microsoft's Phi-2 model, adapted specifically for neuroscience domain content. |
|
⚠️ **Note: This model is running on CPU which means responses will be slower.** ⚠️ |
|
|
|
For best performance: |
|
- Keep your prompts focused and clear |
|
- Use shorter maximum length settings (128-256) |
|
- Be patient as generation can take 30+ seconds |
|
|
|
**Example prompts:** |
|
- Recent advances in neuroimaging suggest that |
|
- The role of dopamine in learning and memory involves |
|
- Explain the concept of neuroplasticity in simple terms |
|
- What are the key differences between neurons and glial cells? |
|
""" |
|
|
|
|
|
if 'error_message' in locals(): |
|
|
|
demo = gr.Interface( |
|
fn=lambda x: error_message, |
|
inputs=gr.Textbox(label="This model cannot be loaded on CPU"), |
|
outputs=gr.Textbox(), |
|
title=title, |
|
description=description |
|
) |
|
else: |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox( |
|
label="Enter your prompt", |
|
placeholder="Recent advances in neuroscience suggest that", |
|
lines=5 |
|
) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Generate", variant="primary") |
|
clear_btn = gr.Button("Clear") |
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
max_length = gr.Slider( |
|
minimum=64, maximum=512, value=256, step=64, |
|
label="Maximum Length (lower is faster on CPU)" |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.0, maximum=1.5, value=0.7, step=0.1, |
|
label="Temperature (0 = deterministic, 0.7 = creative, 1.5 = random)" |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.1, maximum=1.0, value=0.9, step=0.1, |
|
label="Top-p (nucleus sampling)" |
|
) |
|
top_k = gr.Slider( |
|
minimum=1, maximum=100, value=40, step=1, |
|
label="Top-k" |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=1.0, maximum=2.0, value=1.1, step=0.1, |
|
label="Repetition Penalty" |
|
) |
|
|
|
with gr.Column(): |
|
output = gr.Textbox( |
|
label="Generated Text", |
|
lines=20 |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=generate_text, |
|
inputs=[prompt, max_length, temperature, top_p, top_k, repetition_penalty], |
|
outputs=output |
|
) |
|
clear_btn.click( |
|
fn=lambda: ("", None), |
|
inputs=None, |
|
outputs=[prompt, output] |
|
) |
|
|
|
|
|
examples = [ |
|
["Recent advances in neuroimaging suggest that"], |
|
["The role of dopamine in learning and memory involves"], |
|
["Explain the concept of neuroplasticity in simple terms"], |
|
["What are the key differences between neurons and glial cells?"] |
|
] |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=prompt |
|
) |
|
|
|
|
|
demo.launch() |