alaamostafa commited on
Commit
4f06d80
·
verified ·
1 Parent(s): afba7e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
5
+ import os
6
+
7
+ # Set up model parameters
8
+ MODEL_ID = "alaamostafa/Microsoft-Phi-2"
9
+ BASE_MODEL_ID = "microsoft/phi-2"
10
+
11
+ # Force CPU usage and set up offload directory
12
+ device = "cpu"
13
+ print(f"Using device: {device}")
14
+ os.makedirs("offload_dir", exist_ok=True)
15
+
16
+ # Disable bitsandbytes for CPU usage
17
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
18
+
19
+ # Load the tokenizer
20
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
21
+
22
+ # Load base model with simple CPU configuration, avoiding device_map and 8-bit loading
23
+ print("Loading base model...")
24
+ try:
25
+ base_model = AutoModelForCausalLM.from_pretrained(
26
+ BASE_MODEL_ID,
27
+ torch_dtype=torch.float32, # Use float32 for CPU
28
+ trust_remote_code=True,
29
+ low_cpu_mem_usage=True, # Optimize for lower memory usage
30
+ offload_folder="offload_dir" # Set offload directory
31
+ )
32
+
33
+ # Load the fine-tuned adapter
34
+ print(f"Loading adapter from {MODEL_ID}...")
35
+ model = PeftModel.from_pretrained(
36
+ base_model,
37
+ MODEL_ID,
38
+ offload_folder="offload_dir"
39
+ )
40
+
41
+ print("Model loaded successfully!")
42
+ except Exception as e:
43
+ print(f"Error loading model: {e}")
44
+ # Create a placeholder error message for the UI
45
+ error_message = f"Failed to load model: {str(e)}\n\nThis Space may need a GPU to run properly."
46
+
47
+ def generate_text(
48
+ prompt,
49
+ max_length=256, # Reduced for CPU
50
+ temperature=0.7,
51
+ top_p=0.9,
52
+ top_k=40,
53
+ repetition_penalty=1.1
54
+ ):
55
+ """Generate text based on prompt with the fine-tuned model"""
56
+ try:
57
+ # Prepare input
58
+ inputs = tokenizer(prompt, return_tensors="pt")
59
+
60
+ # Generate text
61
+ with torch.no_grad():
62
+ outputs = model.generate(
63
+ **inputs,
64
+ max_length=max_length,
65
+ temperature=temperature,
66
+ top_p=top_p,
67
+ top_k=top_k,
68
+ repetition_penalty=repetition_penalty,
69
+ do_sample=temperature > 0,
70
+ pad_token_id=tokenizer.eos_token_id
71
+ )
72
+
73
+ # Decode and return the generated text
74
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+ return generated_text
76
+ except Exception as e:
77
+ return f"Error generating text: {str(e)}"
78
+
79
+ # Create the Gradio interface
80
+ css = """
81
+ .gradio-container {max-width: 800px !important}
82
+ .gr-prose code {white-space: pre-wrap !important}
83
+ """
84
+
85
+ title = "Neuroscience Fine-tuned Phi-2 Model (CPU Version)"
86
+ description = """
87
+ This is a fine-tuned version of Microsoft's Phi-2 model, adapted specifically for neuroscience domain content.
88
+ ⚠️ **Note: This model is running on CPU which means responses will be slower.** ⚠️
89
+
90
+ For best performance:
91
+ - Keep your prompts focused and clear
92
+ - Use shorter maximum length settings (128-256)
93
+ - Be patient as generation can take 30+ seconds
94
+
95
+ **Example prompts:**
96
+ - Recent advances in neuroimaging suggest that
97
+ - The role of dopamine in learning and memory involves
98
+ - Explain the concept of neuroplasticity in simple terms
99
+ - What are the key differences between neurons and glial cells?
100
+ """
101
+
102
+ # Check if model loaded successfully
103
+ if 'error_message' in locals():
104
+ # Simple error interface
105
+ demo = gr.Interface(
106
+ fn=lambda x: error_message,
107
+ inputs=gr.Textbox(label="This model cannot be loaded on CPU"),
108
+ outputs=gr.Textbox(),
109
+ title=title,
110
+ description=description
111
+ )
112
+ else:
113
+ # Full interface
114
+ with gr.Blocks(css=css) as demo:
115
+ gr.Markdown(f"# {title}")
116
+ gr.Markdown(description)
117
+
118
+ with gr.Row():
119
+ with gr.Column():
120
+ prompt = gr.Textbox(
121
+ label="Enter your prompt",
122
+ placeholder="Recent advances in neuroscience suggest that",
123
+ lines=5
124
+ )
125
+
126
+ with gr.Row():
127
+ submit_btn = gr.Button("Generate", variant="primary")
128
+ clear_btn = gr.Button("Clear")
129
+
130
+ with gr.Accordion("Advanced Options", open=False):
131
+ max_length = gr.Slider(
132
+ minimum=64, maximum=512, value=256, step=64,
133
+ label="Maximum Length (lower is faster on CPU)"
134
+ )
135
+ temperature = gr.Slider(
136
+ minimum=0.0, maximum=1.5, value=0.7, step=0.1,
137
+ label="Temperature (0 = deterministic, 0.7 = creative, 1.5 = random)"
138
+ )
139
+ top_p = gr.Slider(
140
+ minimum=0.1, maximum=1.0, value=0.9, step=0.1,
141
+ label="Top-p (nucleus sampling)"
142
+ )
143
+ top_k = gr.Slider(
144
+ minimum=1, maximum=100, value=40, step=1,
145
+ label="Top-k"
146
+ )
147
+ repetition_penalty = gr.Slider(
148
+ minimum=1.0, maximum=2.0, value=1.1, step=0.1,
149
+ label="Repetition Penalty"
150
+ )
151
+
152
+ with gr.Column():
153
+ output = gr.Textbox(
154
+ label="Generated Text",
155
+ lines=20
156
+ )
157
+
158
+ # Set up event handlers
159
+ submit_btn.click(
160
+ fn=generate_text,
161
+ inputs=[prompt, max_length, temperature, top_p, top_k, repetition_penalty],
162
+ outputs=output
163
+ )
164
+ clear_btn.click(
165
+ fn=lambda: ("", None),
166
+ inputs=None,
167
+ outputs=[prompt, output]
168
+ )
169
+
170
+ # Example prompts
171
+ examples = [
172
+ ["Recent advances in neuroimaging suggest that"],
173
+ ["The role of dopamine in learning and memory involves"],
174
+ ["Explain the concept of neuroplasticity in simple terms"],
175
+ ["What are the key differences between neurons and glial cells?"]
176
+ ]
177
+
178
+ gr.Examples(
179
+ examples=examples,
180
+ inputs=prompt
181
+ )
182
+
183
+ # Launch the app
184
+ demo.launch()