Spaces:
Runtime error
Runtime error
Ensemble!
Browse files
app.py
CHANGED
@@ -45,7 +45,7 @@ class Sequence:
|
|
45 |
max_length = 100
|
46 |
padded_ie = pad_sequences([ie], maxlen=max_length, padding='post', truncating='post')
|
47 |
all_ohe = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + [0]*(100-21))
|
48 |
-
return to_categorical(np.array([padded_ie[0], all_ohe]))[:1]
|
49 |
|
50 |
|
51 |
def residual_block(data, filters, d_rate):
|
@@ -96,20 +96,59 @@ def get_model():
|
|
96 |
|
97 |
return model2
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
model = get_model()
|
101 |
mappings_path = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'prot_mappings.json'))
|
102 |
with open(mappings_path) as f:
|
103 |
prot_mappings = json.load(f)
|
104 |
|
105 |
def greet(Amino_Acid_Sequence):
|
106 |
-
processed_seq = Sequence.prepare(Amino_Acid_Sequence)
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
fam_asc = prot_mappings['id2fam_asc'][str(idx)]
|
110 |
fam_id = prot_mappings['fam_asc2fam_id'][fam_asc]
|
111 |
gc.collect()
|
112 |
-
return f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
115 |
-
iface.launch()
|
|
|
45 |
max_length = 100
|
46 |
padded_ie = pad_sequences([ie], maxlen=max_length, padding='post', truncating='post')
|
47 |
all_ohe = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + [0]*(100-21))
|
48 |
+
return padded_ie, to_categorical(np.array([padded_ie[0], all_ohe]))[:1]
|
49 |
|
50 |
|
51 |
def residual_block(data, filters, d_rate):
|
|
|
96 |
|
97 |
return model2
|
98 |
|
99 |
+
def get_lstm_model():
|
100 |
+
x_input = Input(shape=(100,))
|
101 |
+
emb = Embedding(21, 128, input_length=100)(x_input)
|
102 |
+
bi_rnn = Bidirectional(LSTM(64, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01), bias_regularizer=l2(0.01)))(emb)
|
103 |
+
# bi_rnn = CuDNNLSTM(64, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01), bias_regularizer=l2(0.01))(emb)
|
104 |
+
x = Dropout(0.3)(bi_rnn)
|
105 |
+
|
106 |
+
# softmax classifier
|
107 |
+
x_output = Dense(1000, activation='softmax')(x)
|
108 |
+
|
109 |
+
model1 = Model(inputs=x_input, outputs=x_output)
|
110 |
+
model1.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
|
111 |
+
weights = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'model1.h5'))
|
112 |
+
model1.load_weights(weights)
|
113 |
+
return model1
|
114 |
+
|
115 |
+
|
116 |
+
cnn_model = get_model()
|
117 |
+
lstm_model = get_lstm_model()
|
118 |
|
|
|
119 |
mappings_path = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'prot_mappings.json'))
|
120 |
with open(mappings_path) as f:
|
121 |
prot_mappings = json.load(f)
|
122 |
|
123 |
def greet(Amino_Acid_Sequence):
|
124 |
+
padded_seq, processed_seq = Sequence.prepare(Amino_Acid_Sequence)
|
125 |
+
cnn_raw_prediction = cnn_model.predict(processed_seq)[0]
|
126 |
+
lstm_raw_prediction = lstm_model.predict(padded_seq)[0]
|
127 |
+
joined_prediction = cnn_raw_prediction*0.7 + lstm_raw_prediction*0.3
|
128 |
+
cnn_idx = cnn_raw_prediction.argmax()
|
129 |
+
lstm_idx = lstm_raw_prediction.argmax()
|
130 |
+
idx = joined_prediction.argmax()
|
131 |
+
cnn_fam_asc = prot_mappings['id2fam_asc'][str(cnn_idx)]
|
132 |
+
cnn_fam_id = prot_mappings['fam_asc2fam_id'][cnn_fam_asc]
|
133 |
+
lstm_fam_asc = prot_mappings['id2fam_asc'][str(lstm_idx)]
|
134 |
+
lstm_fam_id = prot_mappings['fam_asc2fam_id'][lstm_fam_asc]
|
135 |
fam_asc = prot_mappings['id2fam_asc'][str(idx)]
|
136 |
fam_id = prot_mappings['fam_asc2fam_id'][fam_asc]
|
137 |
gc.collect()
|
138 |
+
return f"""
|
139 |
+
Input is {Amino_Acid_Sequence}.
|
140 |
+
Processed input is:
|
141 |
+
{processed_seq}
|
142 |
+
|
143 |
+
CNN says: Family Accession={cnn_fam_asc} and ID={cnn_fam_id}
|
144 |
+
LSTM says: Family Accession={lstm_fam_asc} and ID={lstm_fam_id}
|
145 |
+
|
146 |
+
0.7 * cnn and 0.3 * lstm ensemble model makes prediction which maps to:
|
147 |
+
Family Accession={fam_asc} and ID={fam_id}
|
148 |
+
|
149 |
+
Raw Joined Prediction:
|
150 |
+
{joined_prediction}
|
151 |
+
"""
|
152 |
|
153 |
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
154 |
+
iface.launch()
|