jonathang commited on
Commit
fc1e518
1 Parent(s): 0bc7b8a
Files changed (1) hide show
  1. app.py +46 -7
app.py CHANGED
@@ -45,7 +45,7 @@ class Sequence:
45
  max_length = 100
46
  padded_ie = pad_sequences([ie], maxlen=max_length, padding='post', truncating='post')
47
  all_ohe = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + [0]*(100-21))
48
- return to_categorical(np.array([padded_ie[0], all_ohe]))[:1]
49
 
50
 
51
  def residual_block(data, filters, d_rate):
@@ -96,20 +96,59 @@ def get_model():
96
 
97
  return model2
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- model = get_model()
101
  mappings_path = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'prot_mappings.json'))
102
  with open(mappings_path) as f:
103
  prot_mappings = json.load(f)
104
 
105
  def greet(Amino_Acid_Sequence):
106
- processed_seq = Sequence.prepare(Amino_Acid_Sequence)
107
- raw_prediction = model.predict(processed_seq)[0]
108
- idx = raw_prediction.argmax()
 
 
 
 
 
 
 
 
109
  fam_asc = prot_mappings['id2fam_asc'][str(idx)]
110
  fam_id = prot_mappings['fam_asc2fam_id'][fam_asc]
111
  gc.collect()
112
- return f"Input is {Amino_Acid_Sequence}.\nProcessed input is:\n{processed_seq}\n\nModel makes prediction which maps to:\nFamily Accession={fam_asc} and ID={fam_id}\n\nRaw Prediction:\n{raw_prediction}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
115
- iface.launch()
 
45
  max_length = 100
46
  padded_ie = pad_sequences([ie], maxlen=max_length, padding='post', truncating='post')
47
  all_ohe = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + [0]*(100-21))
48
+ return padded_ie, to_categorical(np.array([padded_ie[0], all_ohe]))[:1]
49
 
50
 
51
  def residual_block(data, filters, d_rate):
 
96
 
97
  return model2
98
 
99
+ def get_lstm_model():
100
+ x_input = Input(shape=(100,))
101
+ emb = Embedding(21, 128, input_length=100)(x_input)
102
+ bi_rnn = Bidirectional(LSTM(64, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01), bias_regularizer=l2(0.01)))(emb)
103
+ # bi_rnn = CuDNNLSTM(64, kernel_regularizer=l2(0.01), recurrent_regularizer=l2(0.01), bias_regularizer=l2(0.01))(emb)
104
+ x = Dropout(0.3)(bi_rnn)
105
+
106
+ # softmax classifier
107
+ x_output = Dense(1000, activation='softmax')(x)
108
+
109
+ model1 = Model(inputs=x_input, outputs=x_output)
110
+ model1.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
111
+ weights = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'model1.h5'))
112
+ model1.load_weights(weights)
113
+ return model1
114
+
115
+
116
+ cnn_model = get_model()
117
+ lstm_model = get_lstm_model()
118
 
 
119
  mappings_path = cached_download(hf_hub_url("jonathang/Protein_Family_Models", 'prot_mappings.json'))
120
  with open(mappings_path) as f:
121
  prot_mappings = json.load(f)
122
 
123
  def greet(Amino_Acid_Sequence):
124
+ padded_seq, processed_seq = Sequence.prepare(Amino_Acid_Sequence)
125
+ cnn_raw_prediction = cnn_model.predict(processed_seq)[0]
126
+ lstm_raw_prediction = lstm_model.predict(padded_seq)[0]
127
+ joined_prediction = cnn_raw_prediction*0.7 + lstm_raw_prediction*0.3
128
+ cnn_idx = cnn_raw_prediction.argmax()
129
+ lstm_idx = lstm_raw_prediction.argmax()
130
+ idx = joined_prediction.argmax()
131
+ cnn_fam_asc = prot_mappings['id2fam_asc'][str(cnn_idx)]
132
+ cnn_fam_id = prot_mappings['fam_asc2fam_id'][cnn_fam_asc]
133
+ lstm_fam_asc = prot_mappings['id2fam_asc'][str(lstm_idx)]
134
+ lstm_fam_id = prot_mappings['fam_asc2fam_id'][lstm_fam_asc]
135
  fam_asc = prot_mappings['id2fam_asc'][str(idx)]
136
  fam_id = prot_mappings['fam_asc2fam_id'][fam_asc]
137
  gc.collect()
138
+ return f"""
139
+ Input is {Amino_Acid_Sequence}.
140
+ Processed input is:
141
+ {processed_seq}
142
+
143
+ CNN says: Family Accession={cnn_fam_asc} and ID={cnn_fam_id}
144
+ LSTM says: Family Accession={lstm_fam_asc} and ID={lstm_fam_id}
145
+
146
+ 0.7 * cnn and 0.3 * lstm ensemble model makes prediction which maps to:
147
+ Family Accession={fam_asc} and ID={fam_id}
148
+
149
+ Raw Joined Prediction:
150
+ {joined_prediction}
151
+ """
152
 
153
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
154
+ iface.launch()