AutoEncoder for Dimensionality Reduction
Model Description
The AutoEncoder
presented here is a neural network model based on an encoder-decoder architecture. It is designed to learn efficient representations (encodings) of the input data, typically for dimensionality reduction purposes. The encoder compresses the input into a lower-dimensional latent space, while the decoder reconstructs the input data from the latent representation.
This model is flexible and can be configured with different layer types such as linear layers, LSTMs, GRUs, or RNNs, and can handle bidirectional sequence processing. The model is configured to be used with the Hugging Face Transformers library, allowing for easy download and deployment.
Intended Use
This AutoEncoder
is suitable for unsupervised learning tasks where dimensionality reduction or feature learning is desired. Examples include anomaly detection, data compression, and preprocessing for other complex tasks such as feature reduction before classification.
Basic Usage in Python
Here are some simple examples of how to use the AutoEncoder
model in Python:
from transformers import AutoConfig, AutoModel
config = AutoConfig.from_pretrained("amaye15/autoencoder", trust_remote_code = True)
# Let's say you want to change the input_dim and latent_dim
config.input_dim = 1024 # New input dimension
config.latent_dim = 64 # New latent dimension
# Similarly, update other parameters as needed
config.layer_types = 'gru' # Change layer types to 'gru'
config.dropout_rate = 0.2 # Update dropout rate
config.num_layers = 4 # Change the number of layers
config.compression_rate = 0.6 # Update compression rate
config.bidirectional = False # Change to unidirectional
### Change Configuration
model = AutoModel.from_config(config, trust_remote_code = True)
# Example input data (batch_size, seq_len, input_dim)
input_data = torch.rand((32, 10, 784)) # Adjust shape according to your needs
# Perform encoding and decoding
with torch.no_grad(): # Assuming inference only
output = model(input_data)
# The `output` is a dataclass with
output.logits
output.labels
output.hidden_state
output.loss
- Downloads last month
- 7