minwoosun commited on
Commit
53dffc9
1 Parent(s): 4481b1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -95
app.py CHANGED
@@ -13,150 +13,236 @@ from sklearn.linear_model import LogisticRegression
13
  from huggingface_hub import hf_hub_download
14
 
15
 
16
- def load_model_params(model_path):
17
- """Load model parameters from a JSON file."""
 
18
  with open(model_path, 'r') as f:
19
  model_params = json.load(f)
20
- return model_params
21
-
22
- def reconstruct_classifier(model_params):
23
- """Reconstruct the logistic regression model from parameters."""
24
- model = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
25
- model.coef_ = np.array(model_params["coef"])
26
- model.intercept_ = np.array(model_params["intercept"])
27
- model.classes_ = np.array(model_params["classes"])
28
- return model
29
-
30
- def save_predictions(y_pred, output_path):
31
- """Save predictions to a CSV file."""
32
- df = pd.DataFrame(y_pred, columns=["predicted_cell_type"])
33
- df.to_csv(output_path, index=False, header=False)
34
-
35
- def load_and_predict_with_classifier(x, model_path, output_path, save=False):
36
- """Load model, predict, and optionally save predictions."""
37
- model_params = load_model_params(model_path)
38
- model = reconstruct_classifier(model_params)
39
- y_pred = model.predict(x)
40
  if save:
41
- save_predictions(y_pred, output_path)
 
 
42
  return y_pred
43
 
 
44
  def plot_umap(adata):
45
- """Generate a UMAP plot from the provided AnnData object."""
46
  labels = pd.Categorical(adata.obs["cell_type"])
 
47
  reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
48
  embedding = reducer.fit_transform(adata.obsm["X_uce"])
49
 
50
  plt.figure(figsize=(10, 8))
 
 
51
  scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6)
52
 
53
- handles = [
54
- plt.Line2D([0], [0], marker='o', color='w', label=cell_type,
55
- markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10)
56
- for i, cell_type in enumerate(labels.categories)
57
- ]
 
58
  plt.legend(handles=handles, title='Cell Type')
59
  plt.title('UMAP projection of the data')
60
  plt.xlabel('UMAP1')
61
  plt.ylabel('UMAP2')
62
-
 
63
  buf = BytesIO()
64
  plt.savefig(buf, format='png')
65
  buf.seek(0)
 
 
66
  img = plt.imread(buf, format='png')
 
67
  return img
68
 
 
69
  def toggle_file_input(default_dataset):
70
- """Toggle file input based on dataset selection."""
71
  if default_dataset != "None":
72
- return gr.update(interactive=False)
73
  else:
74
- return gr.update(interactive=True)
75
 
76
- def run_uce_model(input_file_path, model_dir, model_loc):
77
- """Run UCE model on the provided AnnData file."""
78
- command = [
79
- sys.executable,
80
- os.path.join(model_dir, 'eval_single_anndata.py'),
81
- '--adata_path', input_file_path,
82
- '--dir', model_dir,
83
- '--model_loc', model_loc
84
- ]
85
- subprocess.run(command, check=True)
86
 
87
  def main(input_file_path, species, default_dataset):
88
- """Main function to execute the demo logic."""
89
 
90
- # Clone the UCE repository and set paths
91
- repo_url = 'https://github.com/minwoosun/UCE.git'
92
- repo_dir = '/home/user/app/UCE'
93
- if not os.path.exists(repo_dir):
94
- subprocess.run(['git', 'clone', repo_url], check=True)
95
 
96
- sys.path.append(repo_dir)
97
-
98
- # Handle default datasets
99
- default_dataset_paths = {
100
- "PBMC 100 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad"),
101
- "PBMC 1000 cells": hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad"),
102
- }
 
 
103
 
104
- if default_dataset in default_dataset_paths:
105
- input_file_path = default_dataset_paths[default_dataset]
106
 
107
- # Run UCE model
108
- run_uce_model(input_file_path, repo_dir, 'minwoosun/uce-100m')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- # Load UCE embeddings and perform classification
111
- adata = sc.read_h5ad(os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"))
 
 
112
  x = adata.obsm['X_uce']
113
 
114
- model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json")
115
- pred_file = os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_predictions.csv")
116
  y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True)
117
-
118
- # Generate UMAP plot
 
 
119
  img = plot_umap(adata)
 
 
120
 
121
- return img, os.path.join(repo_dir, f"{os.path.splitext(os.path.basename(input_file_path))[0]}_uce_adata.h5ad"), pred_file
122
 
123
- # Gradio UI
124
 
125
- def create_demo():
126
- """Create and launch the Gradio demo."""
127
-
128
  with gr.Blocks() as demo:
129
- gr.Markdown("""
 
130
  <div style="text-align:center; margin-bottom:20px;">
131
- <h1>UCE 100M Demo 🦠</h1>
132
- <h2>Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</h2>
133
- <div style="margin-top:10px;">
134
- <a href="https://github.com/minwoosun/UCE"><img src="https://badges.aleen42.com/src/github.svg" alt="GitHub"></a>
135
- <a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1"><img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper"></a>
136
- <a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
137
- </div>
138
- <p>Upload a `.h5ad` single cell gene expression file or select the species to generate UMAP projections and download the embeddings.</p>
139
  </div>
140
- """)
141
-
142
- # Inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below")
144
- species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
145
- default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset")
146
-
147
- default_dataset_input.change(toggle_file_input, inputs=[default_dataset_input], outputs=[file_input])
 
 
 
 
 
 
 
 
 
148
 
149
- # Outputs
150
- run_button = gr.Button("Run")
151
  with gr.Row():
152
- image_output = gr.Image(type="numpy", label="UMAP of UCE Embeddings")
153
  file_output = gr.File(label="Download embeddings")
154
  pred_output = gr.File(label="Download predictions")
 
 
 
 
155
 
156
- # Run the function on button click
157
- run_button.click(fn=main, inputs=[file_input, species_input, default_dataset_input], outputs=[image_output, file_output, pred_output])
 
 
 
 
158
 
159
- demo.launch()
160
 
161
- if __name__ == "__main__":
162
- create_demo()
 
13
  from huggingface_hub import hf_hub_download
14
 
15
 
16
+ def load_and_predict_with_classifier(x, model_path, output_path, save):
17
+
18
+ # Load the model parameters from the JSON file
19
  with open(model_path, 'r') as f:
20
  model_params = json.load(f)
21
+
22
+ # Reconstruct the logistic regression model
23
+ model_loaded = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
24
+ model_loaded.coef_ = np.array(model_params["coef"])
25
+ model_loaded.intercept_ = np.array(model_params["intercept"])
26
+ model_loaded.classes_ = np.array(model_params["classes"])
27
+
28
+ # output predictions
29
+ y_pred = model_loaded.predict(x)
30
+
31
+ # Convert the array to a Pandas DataFrame
 
 
 
 
 
 
 
 
 
32
  if save:
33
+ df = pd.DataFrame(y_pred, columns=["predicted_cell_type"])
34
+ df.to_csv(output_path, index=False, header=False)
35
+
36
  return y_pred
37
 
38
+
39
  def plot_umap(adata):
40
+
41
  labels = pd.Categorical(adata.obs["cell_type"])
42
+
43
  reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)
44
  embedding = reducer.fit_transform(adata.obsm["X_uce"])
45
 
46
  plt.figure(figsize=(10, 8))
47
+
48
+ # Create the scatter plot
49
  scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels.codes, cmap='Set1', s=50, alpha=0.6)
50
 
51
+ # Create a legend
52
+ handles = []
53
+ for i, cell_type in enumerate(labels.categories):
54
+ handles.append(plt.Line2D([0], [0], marker='o', color='w', label=cell_type,
55
+ markerfacecolor=plt.cm.Set1(i / len(labels.categories)), markersize=10))
56
+
57
  plt.legend(handles=handles, title='Cell Type')
58
  plt.title('UMAP projection of the data')
59
  plt.xlabel('UMAP1')
60
  plt.ylabel('UMAP2')
61
+
62
+ # Save plot to a BytesIO object
63
  buf = BytesIO()
64
  plt.savefig(buf, format='png')
65
  buf.seek(0)
66
+
67
+ # Read the image from BytesIO object
68
  img = plt.imread(buf, format='png')
69
+
70
  return img
71
 
72
+
73
  def toggle_file_input(default_dataset):
 
74
  if default_dataset != "None":
75
+ return gr.update(interactive=False) # Disable the file input if a default dataset is selected
76
  else:
77
+ return gr.update(interactive=True) # Enable the file input if no default dataset is selected
78
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def main(input_file_path, species, default_dataset):
 
81
 
82
+ # Get the current working directory
83
+ current_working_directory = os.getcwd()
 
 
 
84
 
85
+ # Print the current working directory
86
+ print("Current Working Directory:", current_working_directory)
87
+
88
+ # clone and cd into UCE repo
89
+ os.system('git clone https://github.com/minwoosun/UCE.git')
90
+ os.chdir('/home/user/app/UCE')
91
+
92
+ # Get the current working directory
93
+ current_working_directory = os.getcwd()
94
 
95
+ # Print the current working directory
96
+ print("Current Working Directory:", current_working_directory)
97
 
98
+ # Specify the path to the directory you want to add
99
+ new_directory = "/home/user/app/UCE"
100
+
101
+ # Add the directory to the Python path
102
+ sys.path.append(new_directory)
103
+
104
+ # Set default dataset path
105
+ default_dataset_1_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="100_pbmcs_proc_subset.h5ad")
106
+ default_dataset_2_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="1k_pbmcs_proc_subset.h5ad")
107
+
108
+ # If the user selects a default dataset, use that instead of the uploaded file
109
+ if default_dataset == "PBMC 100 cells":
110
+ input_file_path = default_dataset_1_path
111
+ elif default_dataset == "PBMC 1000 cells":
112
+ input_file_path = default_dataset_2_path
113
+
114
+ ##############
115
+ # UCE #
116
+ ##############
117
+ from evaluate import AnndataProcessor
118
+ from accelerate import Accelerator
119
+
120
+ dir_path = '/home/user/app/UCE/'
121
+ model_loc = 'minwoosun/uce-100m'
122
+
123
+ print(input_file_path)
124
+ print(dir_path)
125
+ print(model_loc)
126
+
127
+ # Construct the command
128
+ command = [
129
+ 'python',
130
+ '/home/user/app/UCE/eval_single_anndata.py',
131
+ '--adata_path', input_file_path,
132
+ '--dir', dir_path,
133
+ '--model_loc', model_loc
134
+ ]
135
+
136
+ # Print the command for debugging
137
+ print("Running command:", command)
138
+
139
+ print("---> RUNNING UCE")
140
+ result = subprocess.run(command, capture_output=True, text=True, check=True)
141
+ print(result.stdout)
142
+ print(result.stderr)
143
+ print("---> FINSIH UCE")
144
+
145
+ ################################
146
+ # Cell-type classification #
147
+ ################################
148
+
149
+ # Set output file path
150
+ file_name_with_ext = os.path.basename(input_file_path)
151
+ file_name = os.path.splitext(file_name_with_ext)[0]
152
+ pred_file = "/home/user/app/UCE/" + f"{file_name}_predictions.csv"
153
+ model_path = hf_hub_download(repo_id="minwoosun/uce-misc", filename="tabula_sapiens_v1_logistic_regression_model_weights.json")
154
 
155
+ file_name_with_ext = os.path.basename(input_file_path)
156
+ file_name = os.path.splitext(file_name_with_ext)[0]
157
+ output_file = "/home/user/app/UCE/" + f"{file_name}_uce_adata.h5ad"
158
+ adata = sc.read_h5ad(output_file)
159
  x = adata.obsm['X_uce']
160
 
 
 
161
  y_pred = load_and_predict_with_classifier(x, model_path, pred_file, save=True)
162
+
163
+ ##############
164
+ # UMAP #
165
+ ##############
166
  img = plot_umap(adata)
167
+
168
+ return img, output_file, pred_file
169
 
 
170
 
171
+ if __name__ == "__main__":
172
 
 
 
 
173
  with gr.Blocks() as demo:
174
+ gr.Markdown(
175
+ '''
176
  <div style="text-align:center; margin-bottom:20px;">
177
+ <span style="font-size:3em; font-weight:bold;">UCE 100M Demo 🦠</span>
 
 
 
 
 
 
 
178
  </div>
179
+ <div style="text-align:center; margin-bottom:10px;">
180
+ <span style="font-size:1.5em; font-weight:bold;">Universal Cell Embeddings: Zero-Shot Cell-Type Classification in Action!</span>
181
+ </div>
182
+ <div style="text-align:center; margin-bottom:20px;">
183
+ <a href="https://github.com/minwoosun/UCE">
184
+ <img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block; margin-right:10px;">
185
+ </a>
186
+ <a href="https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1">
187
+ <img src="https://img.shields.io/badge/bioRxiv-2023.11.28.568918-green?style=plastic" alt="Paper" style="display:inline-block; margin-right:10px;">
188
+ </a>
189
+ <a href="https://colab.research.google.com/drive/1opud0BVWr76IM8UnGgTomVggui_xC4p0?usp=sharing">
190
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" style="display:inline-block; margin-right:10px;">
191
+ </a>
192
+ </div>
193
+ <div style="text-align:left; margin-bottom:20px;">
194
+ Upload a `.h5ad` single cell gene expression file and select the species (Human/Mouse).
195
+ The demo will generate UMAP projections of the embeddings and allow you to download the embeddings for further analysis.
196
+ </div>
197
+ <div style="margin-bottom:20px;">
198
+ <ol style="list-style:none; padding-left:0;">
199
+ <li>1. Upload your `.h5ad` file or select one of the default datasets (subset of 10x pbmc data)</li>
200
+ <li>2. Select the species</li>
201
+ <li>3. Click "Run" to view the UMAP scatter plot</li>
202
+ <li>4. Download the UCE embeddings and predicted cell-types</li>
203
+ </ol>
204
+ </div>
205
+ <div style="text-align:left; line-height:1.8;">
206
+ Please consider citing the following paper if you use this tool in your research:
207
+ </div>
208
+ <div style="text-align:left; line-height:1.8;">
209
+ Rosen, Y., Roohani, Y., Agarwal, A., Samotorčan, L., Tabula Sapiens Consortium, Quake, S. R., & Leskovec, J. Universal Cell Embeddings: A Foundation Model for Cell Biology. bioRxiv. https://doi.org/10.1101/2023.11.28.568918
210
+ </div>
211
+ '''
212
+ )
213
+
214
+ # Define Gradio inputs and outputs
215
  file_input = gr.File(label="Upload a .h5ad single cell gene expression file or select a default dataset below")
216
+ # species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
217
+ with gr.Row():
218
+ species_input = gr.Dropdown(choices=["human", "mouse"], label="Select species")
219
+ default_dataset_input = gr.Dropdown(choices=["None", "PBMC 100 cells", "PBMC 1000 cells"], label="Select default dataset")
220
+
221
+ # Attach the `change` event to the dropdown
222
+ default_dataset_input.change(
223
+ toggle_file_input,
224
+ inputs=[default_dataset_input],
225
+ outputs=[file_input]
226
+ )
227
+
228
+ run_button = gr.Button("Run", elem_classes="run-button")
229
 
230
+ # Arrange UMAP plot and file output side by side
 
231
  with gr.Row():
232
+ image_output = gr.Image(type="numpy", label="UMAP_of_UCE_Embeddings")
233
  file_output = gr.File(label="Download embeddings")
234
  pred_output = gr.File(label="Download predictions")
235
+
236
+ print(image_output)
237
+ print(file_output)
238
+ print(pred_output)
239
 
240
+ # Add the components and link to the function
241
+ run_button.click(
242
+ fn=main,
243
+ inputs=[file_input, species_input, default_dataset_input],
244
+ outputs=[image_output, file_output, pred_output]
245
+ )
246
 
 
247
 
248
+ demo.launch()