bwang0911 edwardjross commited on
Commit
54da5c0
1 Parent(s): 120c5ad

Enable saving in SentenceTransformers by adding get_config_dict (#32)

Browse files

- Enable saving in SentenceTransformers by adding get_config_dict (7e1ab6246d86fe02a992b477256aa9f12deb4aea)


Co-authored-by: Edward Ross <[email protected]>

Files changed (1) hide show
  1. custom_st.py +3 -0
custom_st.py CHANGED
@@ -160,6 +160,9 @@ class Transformer(nn.Module):
160
  )
161
  return output
162
 
 
 
 
163
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
164
  self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
165
  self.tokenizer.save_pretrained(output_path)
 
160
  )
161
  return output
162
 
163
+ def get_config_dict(self) -> dict[str, Any]:
164
+ return {key: self.__dict__[key] for key in self.config_keys}
165
+
166
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
167
  self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
168
  self.tokenizer.save_pretrained(output_path)