alaamostafa commited on
Commit
2951f6b
Β·
verified Β·
1 Parent(s): 66810bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
5
+ import os
6
+
7
+ # Set up model parameters
8
+ MODEL_ID = "alaamostafa/Microsoft-Phi-2"
9
+ BASE_MODEL_ID = "microsoft/phi-2"
10
+
11
+ # Check if CUDA is available and set device accordingly
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {device}")
14
+
15
+ # Load the tokenizer
16
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
17
+
18
+ # Load base model with appropriate dtype based on available hardware
19
+ print("Loading base model...")
20
+ base_model = AutoModelForCausalLM.from_pretrained(
21
+ BASE_MODEL_ID,
22
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
23
+ trust_remote_code=True,
24
+ device_map="auto"
25
+ )
26
+
27
+ # Load the fine-tuned adapter
28
+ print(f"Loading adapter from {MODEL_ID}...")
29
+ model = PeftModel.from_pretrained(
30
+ base_model,
31
+ MODEL_ID,
32
+ device_map="auto"
33
+ )
34
+
35
+ print("Model loaded successfully!")
36
+
37
+ def generate_text(
38
+ prompt,
39
+ max_length=512,
40
+ temperature=0.7,
41
+ top_p=0.9,
42
+ top_k=40,
43
+ repetition_penalty=1.1
44
+ ):
45
+ """Generate text based on prompt with the fine-tuned model"""
46
+ # Prepare input
47
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
48
+
49
+ # Generate text
50
+ with torch.no_grad():
51
+ outputs = model.generate(
52
+ **inputs,
53
+ max_length=max_length,
54
+ temperature=temperature,
55
+ top_p=top_p,
56
+ top_k=top_k,
57
+ repetition_penalty=repetition_penalty,
58
+ do_sample=temperature > 0,
59
+ pad_token_id=tokenizer.eos_token_id
60
+ )
61
+
62
+ # Decode and return the generated text
63
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
+ return generated_text
65
+
66
+ # Create the Gradio interface
67
+ css = """
68
+ .gradio-container {max-width: 800px !important}
69
+ .gr-prose code {white-space: pre-wrap !important}
70
+ """
71
+
72
+ title = "Neuroscience Fine-tuned Phi-2 Model"
73
+ description = """
74
+ This is a fine-tuned version of Microsoft's Phi-2 model, adapted specifically for neuroscience domain content.
75
+ Use this interface to interact with the model and see how it handles neuroscience-related queries.
76
+
77
+ **Example prompts:**
78
+ - Recent advances in neuroimaging suggest that
79
+ - The role of dopamine in learning and memory involves
80
+ - Explain the concept of neuroplasticity in simple terms
81
+ - What are the key differences between neurons and glial cells?
82
+ """
83
+
84
+ with gr.Blocks(css=css) as demo:
85
+ gr.Markdown(f"# {title}")
86
+ gr.Markdown(description)
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ prompt = gr.Textbox(
91
+ label="Enter your prompt",
92
+ placeholder="Recent advances in neuroscience suggest that",
93
+ lines=5
94
+ )
95
+
96
+ with gr.Row():
97
+ submit_btn = gr.Button("Generate", variant="primary")
98
+ clear_btn = gr.Button("Clear")
99
+
100
+ with gr.Accordion("Advanced Options", open=False):
101
+ max_length = gr.Slider(
102
+ minimum=64, maximum=1024, value=512, step=64,
103
+ label="Maximum Length"
104
+ )
105
+ temperature = gr.Slider(
106
+ minimum=0.0, maximum=1.5, value=0.7, step=0.1,
107
+ label="Temperature (0 = deterministic, 0.7 = creative, 1.5 = random)"
108
+ )
109
+ top_p = gr.Slider(
110
+ minimum=0.1, maximum=1.0, value=0.9, step=0.1,
111
+ label="Top-p (nucleus sampling)"
112
+ )
113
+ top_k = gr.Slider(
114
+ minimum=1, maximum=100, value=40, step=1,
115
+ label="Top-k"
116
+ )
117
+ repetition_penalty = gr.Slider(
118
+ minimum=1.0, maximum=2.0, value=1.1, step=0.1,
119
+ label="Repetition Penalty"
120
+ )
121
+
122
+ with gr.Column():
123
+ output = gr.Textbox(
124
+ label="Generated Text",
125
+ lines=20
126
+ )
127
+
128
+ # Set up event handlers
129
+ submit_btn.click(
130
+ fn=generate_text,
131
+ inputs=[prompt, max_length, temperature, top_p, top_k, repetition_penalty],
132
+ outputs=output
133
+ )
134
+ clear_btn.click(
135
+ fn=lambda: ("", None),
136
+ inputs=None,
137
+ outputs=[prompt, output]
138
+ )
139
+
140
+ # Example prompts
141
+ examples = [
142
+ ["Recent advances in neuroimaging suggest that"],
143
+ ["The role of dopamine in learning and memory involves"],
144
+ ["Explain the concept of neuroplasticity in simple terms"],
145
+ ["What are the key differences between neurons and glial cells?"]
146
+ ]
147
+
148
+ gr.Examples(
149
+ examples=examples,
150
+ inputs=prompt
151
+ )
152
+
153
+ # Launch the app
154
+ demo.launch()