minwoosun commited on
Commit
4481b1f
·
verified ·
1 Parent(s): 8bc0d41

Modularize code

Browse files
Files changed (1) hide show
  1. app.py +93 -191
app.py CHANGED
@@ -13,248 +13,150 @@ from sklearn.linear_model import LogisticRegression
13
  from huggingface_hub import hf_hub_download
14
 
15
 
16
- def load_and_predict_with_classifier(x, model_path, output_path, save):
17
-
18
- # Load the model parameters from the JSON file
19
  with open(model_path, 'r') as f:
20
  model_params = json.load(f)
21
-
22
- # Reconstruct the logistic regression model
23
- model_loaded = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
24
- model_loaded.coef_ = np.array(model_params["coef"])
25
- model_loaded.intercept_ = np.array(model_params["intercept"])
26
- model_loaded.classes_ = np.array(model_params["classes"])
27
-
28
- # output predictions
29
- y_pred = model_loaded.predict(x)
30
-
31
- # Convert the array to a Pandas DataFrame
 
 
 
 
 
 
 
 
 
32
  if save:
33
- df = pd.DataFrame(y_pred, columns=["predicted_cell_type"])
34
- df.to_csv(output_path, index=False, header=False)
35
-
36
  return y_pred
37
 
38
-
39
  def plot_umap(adata):
40
-
41
  labels = pd.Categorical(adata.obs["cell_type"])
42
-
43
  reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
44
  embedding = reducer.fit_transform(adata.obsm["X_uce"])
45
 
46
  plt.figure(figsize=(10, 8))
47
-
48
- # Create the scatter plot
49
  scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6)
50
 
51
- # Create a legend
52
- handles = []
53
- for i, cell_type in enumerate(labels.categories):
54
- handles.append(plt.Line2D([0], [0], marker='o', color='w', label=cell_type,
55
- markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10))
56
-
57
  plt.legend(handles=handles, title='Cell Type')
58
  plt.title('UMAP projection of the data')
59
  plt.xlabel('UMAP1')
60
  plt.ylabel('UMAP2')
61
-
62
- # Save plot to a BytesIO object
63
  buf = BytesIO()
64
  plt.savefig(buf, format='png')
65
  buf.seek(0)
66
-
67
- # Read the image from BytesIO object
68
  img = plt.imread(buf, format='png')
69
-
70
  return img
71
 
72
-
73
  def toggle_file_input(default_dataset):
 
74
  if default_dataset != "None":
75
- return gr.update(interactive=False) # Disable the file input if a default dataset is selected
76
  else:
77
- return gr.update(interactive=True) # Enable the file input if no default dataset is selected
78
-
79
-
80
- def main(input_file_path, species, default_dataset):
81
-
82
- # Get the current working directory
83
- current_working_directory = os.getcwd()
84
-
85
- # Print the current working directory
86
- print("Current Working Directory:", current_working_directory)
87
-
88
- # clone and cd into UCE repo
89
- os.system('git clone https://github.com/minwoosun/UCE.git')
90
- os.chdir('/home/user/app/UCE')
91
-
92
- # Get the current working directory
93
- current_working_directory = os.getcwd()
94
-
95
- # Print the current working directory
96
- print("Current Working Directory:", current_working_directory)
97
-
98
- # Specify the path to the directory you want to add
99
- new_directory = "/home/user/app/UCE"
100
-
101
- # Add the directory to the Python path
102
- sys.path.append(new_directory)
103
-
104
- # Set default dataset path
105
- default_dataset_1_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad")
106
- default_dataset_2_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad")
107
-
108
- # If the user selects a default dataset, use that instead of the uploaded file
109
- if default_dataset == "PBMC 100 cells":
110
- input_file_path = default_dataset_1_path
111
- elif default_dataset == "PBMC 1000 cells":
112
- input_file_path = default_dataset_2_path
113
-
114
- ##############
115
- # UCE #
116
- ##############
117
- from evaluate import AnndataProcessor
118
- from accelerate import Accelerator
119
 
120
- dir_path = '/home/user/app/UCE/'
121
- model_loc = 'minwoosun/uce-100m'
122
-
123
- print(input_file_path)
124
- print(dir_path)
125
- print(model_loc)
126
-
127
- # Construct the command
128
  command = [
129
- 'python',
130
- '/home/user/app/UCE/eval_single_anndata.py',
131
  '--adata_path', input_file_path,
132
- '--dir', dir_path,
133
  '--model_loc', model_loc
134
  ]
135
-
136
- # Print the command for debugging
137
- print("Running command:", command)
138
 
139
- print("---> RUNNING UCE")
140
- result = subprocess.run(command, capture_output=True, text=True, check=True)
141
- print(result.stdout)
142
- print(result.stderr)
143
- print("---> FINSIH UCE")
144
 
145
- ################################
146
- # Cell-type classification #
147
- ################################
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- # Set output file path
150
- file_name_with_ext = os.path.basename(input_file_path)
151
- file_name = os.path.splitext(file_name_with_ext)[0]
152
- pred_file = "/home/user/app/UCE/" + f"{file_name}_predictions.csv"
153
- model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json")
154
 
155
- file_name_with_ext = os.path.basename(input_file_path)
156
- file_name = os.path.splitext(file_name_with_ext)[0]
157
- output_file = "/home/user/app/UCE/" + f"{file_name}_uce_adata.h5ad"
158
- adata = sc.read_h5ad(output_file)
159
  x = adata.obsm['X_uce']
160
 
 
 
161
  y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True)
162
-
163
- ##############
164
- # UMAP #
165
- ##############
166
  img = plot_umap(adata)
167
-
168
- return img, output_file, pred_file
169
 
 
170
 
171
- if __name__ == "__main__":
172
 
 
 
 
173
  with gr.Blocks() as demo:
174
- gr.Markdown(
175
- '''
176
- <div style="text-align:center; margin-bottom:20px;">
177
- <span style="font-size:3em; font-weight:bold;">UCE 100M Demo 🦠</span>
178
- </div>
179
- <div style="text-align:center; margin-bottom:10px;">
180
- <span style="font-size:1.5em; font-weight:bold;">Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</span>
181
- </div>
182
  <div style="text-align:center; margin-bottom:20px;">
183
- <a href="https://github.com/minwoosun/UCE">
184
- <img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-right:10px;">
185
- </a>
186
- <a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1">
187
- <img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper" style="display:inline-block; margin-right:10px;">
188
- </a>
189
- <a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing">
190
- <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" style="display:inline-block; margin-right:10px;">
191
- </a>
192
- </div>
193
- <div style="text-align:left; margin-bottom:20px;">
194
- Upload a `.h5ad` single cell gene expression file and select the species (Human/Mouse).
195
- The demo will generate UMAP projections of the embeddings and allow you to download the embeddings for further analysis.
196
- </div>
197
- <div style="margin-bottom:20px;">
198
- <ol style="list-style:none; padding-left:0;">
199
- <li>1. Upload your `.h5ad` file or select one of the default datasets (subset of 10x pbmc data)</li>
200
- <li>2. Select the species</li>
201
- <li>3. Click "Run" to view the UMAP scatter plot</li>
202
- <li>4. Download the UCE embeddings and predicted cell-types</li>
203
- </ol>
204
- </div>
205
- <div style="text-align:left; line-height:1.8;">
206
- Please consider citing the following paper if you use this tool in your research:
207
  </div>
208
- <div style="text-align:left; line-height:1.8;">
209
- Rosen, Y., Roohani, Y., Agarwal, A., Samotorčan, L., Tabula Sapiens Consortium, Quake, S. R., & Leskovec, J. Universal Cell Embeddings: A Foundation Model for Cell Biology. bioRxiv. https://doi.org/10.1101/2023.11.28.568918
210
- </div>
211
- '''
212
- )
213
-
214
- # download default datasets and assign paths
215
- default_dataset_1_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad")
216
- default_dataset_2_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad")
217
-
218
- # Define Gradio inputs and outputs
219
  file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below")
220
- # species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
221
- with gr.Row():
222
- species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
223
- default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset")
224
-
225
- # Attach the `change` event to the dropdown
226
- default_dataset_input.change(
227
- toggle_file_input,
228
- inputs=[default_dataset_input],
229
- outputs=[file_input]
230
- )
231
-
232
- run_button = gr.Button("Run", elem_classes="run-button")
233
 
234
- # Arrange UMAP plot and file output side by side
 
 
 
235
  with gr.Row():
236
- image_output = gr.Image(type="numpy", label="UMAP_of_UCE_Embeddings")
237
  file_output = gr.File(label="Download embeddings")
238
  pred_output = gr.File(label="Download predictions")
239
-
240
- print(image_output)
241
- print(file_output)
242
- print(pred_output)
243
 
244
- # Add the components and link to the function
245
- run_button.click(
246
- fn=main,
247
- inputs=[file_input, species_input, default_dataset_input],
248
- outputs=[image_output, file_output, pred_output]
249
- )
250
-
251
- examples = gr.Examples(
252
- examples=[[default_dataset_1_path, "human", "PBMC 100 cells"],[default_dataset_2_path, "human", "PBMC 1000 cells"]],
253
- inputs=[file_input, species_input, default_dataset_input],
254
- outputs=[image_output, file_output, pred_output],
255
- fn=main,
256
- cache_examples=True
257
- )
258
 
259
  demo.launch()
260
 
 
 
 
13
  from huggingface_hub import hf_hub_download
14
 
15
 
16
+ def load_model_params(model_path):
17
+ """Load model parameters from a JSON file."""
 
18
  with open(model_path, 'r') as f:
19
  model_params = json.load(f)
20
+ return model_params
21
+
22
+ def reconstruct_classifier(model_params):
23
+ """Reconstruct the logistic regression model from parameters."""
24
+ model = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
25
+ model.coef_ = np.array(model_params["coef"])
26
+ model.intercept_ = np.array(model_params["intercept"])
27
+ model.classes_ = np.array(model_params["classes"])
28
+ return model
29
+
30
+ def save_predictions(y_pred, output_path):
31
+ """Save predictions to a CSV file."""
32
+ df = pd.DataFrame(y_pred, columns=["predicted_cell_type"])
33
+ df.to_csv(output_path, index=False, header=False)
34
+
35
+ def load_and_predict_with_classifier(x, model_path, output_path, save=False):
36
+ """Load model, predict, and optionally save predictions."""
37
+ model_params = load_model_params(model_path)
38
+ model = reconstruct_classifier(model_params)
39
+ y_pred = model.predict(x)
40
  if save:
41
+ save_predictions(y_pred, output_path)
 
 
42
  return y_pred
43
 
 
44
  def plot_umap(adata):
45
+ """Generate a UMAP plot from the provided AnnData object."""
46
  labels = pd.Categorical(adata.obs["cell_type"])
 
47
  reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
48
  embedding = reducer.fit_transform(adata.obsm["X_uce"])
49
 
50
  plt.figure(figsize=(10, 8))
 
 
51
  scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6)
52
 
53
+ handles = [
54
+ plt.Line2D([0], [0], marker='o', color='w', label=cell_type,
55
+ markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10)
56
+ for i, cell_type in enumerate(labels.categories)
57
+ ]
 
58
  plt.legend(handles=handles, title='Cell Type')
59
  plt.title('UMAP projection of the data')
60
  plt.xlabel('UMAP1')
61
  plt.ylabel('UMAP2')
62
+
 
63
  buf = BytesIO()
64
  plt.savefig(buf, format='png')
65
  buf.seek(0)
 
 
66
  img = plt.imread(buf, format='png')
 
67
  return img
68
 
 
69
  def toggle_file_input(default_dataset):
70
+ """Toggle file input based on dataset selection."""
71
  if default_dataset != "None":
72
+ return gr.update(interactive=False)
73
  else:
74
+ return gr.update(interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def run_uce_model(input_file_path, model_dir, model_loc):
77
+ """Run UCE model on the provided AnnData file."""
 
 
 
 
 
 
78
  command = [
79
+ sys.executable,
80
+ os.path.join(model_dir, 'eval_single_anndata.py'),
81
  '--adata_path', input_file_path,
82
+ '--dir', model_dir,
83
  '--model_loc', model_loc
84
  ]
85
+ subprocess.run(command, check=True)
 
 
86
 
87
+ def main(input_file_path, species, default_dataset):
88
+ """Main function to execute the demo logic."""
 
 
 
89
 
90
+ # Clone the UCE repository and set paths
91
+ repo_url = 'https://github.com/minwoosun/UCE.git'
92
+ repo_dir = '/home/user/app/UCE'
93
+ if not os.path.exists(repo_dir):
94
+ subprocess.run(['git', 'clone', repo_url], check=True)
95
+
96
+ sys.path.append(repo_dir)
97
+
98
+ # Handle default datasets
99
+ default_dataset_paths = {
100
+ "PBMC 100 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad"),
101
+ "PBMC 1000 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad"),
102
+ }
103
+
104
+ if default_dataset in default_dataset_paths:
105
+ input_file_path = default_dataset_paths[default_dataset]
106
 
107
+ # Run UCE model
108
+ run_uce_model(input_file_path, repo_dir, 'minwoosun/uce-100m')
 
 
 
109
 
110
+ # Load UCE embeddings and perform classification
111
+ adata = sc.read_h5ad(os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"))
 
 
112
  x = adata.obsm['X_uce']
113
 
114
+ model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json")
115
+ pred_file = os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_predictions.csv")
116
  y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True)
117
+
118
+ # Generate UMAP plot
 
 
119
  img = plot_umap(adata)
 
 
120
 
121
+ return img, os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"), pred_file
122
 
123
+ # Gradio UI
124
 
125
+ def create_demo():
126
+ """Create and launch the Gradio demo."""
127
+
128
  with gr.Blocks() as demo:
129
+ gr.Markdown("""
 
 
 
 
 
 
 
130
  <div style="text-align:center; margin-bottom:20px;">
131
+ <h1>UCE 100M Demo 🦠</h1>
132
+ <h2>Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</h2>
133
+ <div style="margin-top:10px;">
134
+ <a href="https://github.com/minwoosun/UCE"><img src="https://badges.aleen42.com/src/github.svg" alt="GitHub"></a>
135
+ <a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1"><img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper"></a>
136
+ <a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
137
+ </div>
138
+ <p>Upload a `.h5ad` single cell gene expression file or select the species to generate UMAP projections and download the embeddings.</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  </div>
140
+ """)
141
+
142
+ # Inputs
 
 
 
 
 
 
 
 
143
  file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below")
144
+ species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
145
+ default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset")
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ default_dataset_input.change(toggle_file_input, inputs=[default_dataset_input], outputs=[file_input])
148
+
149
+ # Outputs
150
+ run_button = gr.Button("Run")
151
  with gr.Row():
152
+ image_output = gr.Image(type="numpy", label="UMAP of UCE Embeddings")
153
  file_output = gr.File(label="Download embeddings")
154
  pred_output = gr.File(label="Download predictions")
 
 
 
 
155
 
156
+ # Run the function on button click
157
+ run_button.click(fn=main, inputs=[file_input, species_input, default_dataset_input], outputs=[image_output, file_output, pred_output])
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  demo.launch()
160
 
161
+ if __name__ == "__main__":
162
+ create_demo()