lwm / README.md
wi-lab's picture
Update README.md
08a8088 verified
|
raw
history blame
10.9 kB
# πŸ“‘ **LWM: Large Wireless Model**
**[πŸš€ Click here to try the Interactive Demo!](https://huggingface.co/spaces/wi-lab/lwm-interactive-demo)**
Welcome to **LWM** (Large Wireless Model) β€” a pre-trained model designed for processing and feature extraction from wireless communication datasets, particularly the **DeepMIMO** dataset. This guide provides step-by-step instructions to set up your environment, install the required packages, clone the repository, load data, and perform inference using LWM.
---
## πŸ›  **How to Use**
### 1. **Install Conda**
First, ensure that you have a package manager like **Conda** installed to manage your Python environments and packages. You can install **Conda** via **Anaconda** or **Miniconda**.
- **Anaconda** includes a comprehensive scientific package suite. Download it [here](https://www.anaconda.com/products/distribution).
- **Miniconda** is a lightweight version that includes only Conda and Python. Download it [here](https://docs.conda.io/en/latest/miniconda.html).
Once installed, you can use Conda to manage environments.
---
### 2. **Create a New Environment**
After installing Conda, follow these steps to create a new environment and install the required packages.
#### **Step 1: Create a new environment**
Create a new environment named `lwm_env`:
```bash
conda create -n lwm_env
```
#### **Step 2: Activate the environment**
Activate the environment:
```bash
conda activate lwm_env
```
---
### 3. **Install Required Packages**
Once the environment is activated, install the necessary packages.
#### **Install CUDA-enabled PyTorch**
While inference runs efficiently on CPU, you may require a GPU for training downstream tasks. Follow the instructions below to install CUDA-enabled PyTorch. Be sure to adjust the `pytorch-cuda` version according to your system's specifications.
```bash
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
```
> **Note:** If you encounter issues installing CUDA-enabled PyTorch, verify your CUDA version compatibility. It might also be due to conflicting installation attemptsβ€”try a fresh environment.
#### **Install Other Required Packages via Conda Forge**
```bash
conda install python numpy pandas matplotlib tqdm -c conda-forge
```
#### **Install DeepMIMOv3 with pip**
```bash
pip install DeepMIMOv3
```
---
### 4. **Clone the Dataset Scenarios**
The following functions will help you clone specific dataset scenarios from a repository:
```python
import subprocess
import os
# Function to clone a specific dataset scenario folder
def clone_dataset_scenario(scenario_name, repo_url, model_repo_dir="./LWM", scenarios_dir="scenarios"):
# Create the scenarios directory if it doesn't exist
scenarios_path = os.path.join(model_repo_dir, scenarios_dir)
if not os.path.exists(scenarios_path):
os.makedirs(scenarios_path)
scenario_path = os.path.join(scenarios_path, scenario_name)
# Initialize sparse checkout for the dataset repository
if not os.path.exists(os.path.join(scenarios_path, ".git")):
print(f"Initializing sparse checkout in {scenarios_path}...")
subprocess.run(["git", "clone", "--sparse", repo_url, "."], cwd=scenarios_path, check=True)
subprocess.run(["git", "sparse-checkout", "init", "--cone"], cwd=scenarios_path, check=True)
subprocess.run(["git", "lfs", "install"], cwd=scenarios_path, check=True) # Install Git LFS if needed
# Add the requested scenario folder to sparse checkout
print(f"Adding {scenario_name} to sparse checkout...")
subprocess.run(["git", "sparse-checkout", "add", scenario_name], cwd=scenarios_path, check=True)
# Pull large files if needed (using Git LFS)
subprocess.run(["git", "lfs", "pull"], cwd=scenarios_path, check=True)
print(f"Successfully cloned {scenario_name} into {scenarios_path}.")
```
---
### 5. **Clone the Model Repository**
Now, clone the **LWM** model repository to your local system.
```bash
# Step 1: Clone the model repository (if not already cloned)
model_repo_url = "https://huggingface.co/wi-lab/lwm"
model_repo_dir = "./LWM"
if not os.path.exists(model_repo_dir):
print(f"Cloning model repository from {model_repo_url}...")
subprocess.run(["git", "clone", model_repo_url, model_repo_dir], check=True)
```
---
### 6. **Clone the Desired Dataset Scenarios**
You can now clone specific scenarios from the DeepMIMO dataset, as detailed in the table below:
πŸ“Š **Dataset Overview**
| πŸ“Š **Dataset** | πŸ™οΈ **City** | πŸ‘₯ **Number of Users** | πŸ”— **DeepMIMO Page** |
|----------------|----------------------|------------------------|------------------------------------------------------------------------------------------------------------|
| Dataset 0 | πŸŒ† Denver | 1354 | [DeepMIMO City Scenario 18](https://www.deepmimo.net/scenarios/deepmimo-city-scenario18/) |
| Dataset 1 | πŸ™οΈ Indianapolis | 3248 | [DeepMIMO City Scenario 15](https://www.deepmimo.net/scenarios/deepmimo-city-scenario15/) |
| Dataset 2 | πŸŒ‡ Oklahoma | 3455 | [DeepMIMO City Scenario 19](https://www.deepmimo.net/scenarios/deepmimo-city-scenario19/) |
| Dataset 3 | πŸŒ† Fort Worth | 1902 | [DeepMIMO City Scenario 12](https://www.deepmimo.net/scenarios/deepmimo-city-scenario12/) |
| Dataset 4 | πŸŒ‰ Santa Clara | 2689 | [DeepMIMO City Scenario 11](https://www.deepmimo.net/scenarios/deepmimo-city-scenario11/) |
| Dataset 5 | πŸŒ… San Diego | 2192 | [DeepMIMO City Scenario 7](https://www.deepmimo.net/scenarios/deepmimo-city-scenario7/) |
#### **Clone the Scenarios:**
```python
dataset_repo_url = "https://huggingface.co/datasets/wi-lab/lwm" # Base URL for dataset repo
scenario_names = np.array([
"city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
"city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
])
scenario_idxs = np.array([0, 1, 2, 3, 4, 5]) # Select the scenario indexes
selected_scenario_names = scenario_names[scenario_idxs]
# Clone the requested scenarios
clone_dataset_scenarios(selected_scenario_names, dataset_repo_url, model_repo_dir)
```
---
### 7. **Change the Working Directory to LWM**
```bash
if os.path.exists(model_repo_dir):
os.chdir(model_repo_dir)
print(f"Changed working directory to {os.getcwd()}")
else:
print(f"Directory {model_repo_dir} does not exist. Please check if the repository is cloned properly.")
```
---
### 8. **Tokenize and Load the Model**
Before we dive into tokenizing the dataset and loading the model, let's understand how the tokenization process is adapted to the wireless communication context. In this case, **tokenization** refers to segmenting each wireless channel into patches, similar to how Vision Transformers (ViTs) work with images. Each wireless channel is structured as a \(32 \times 32\) matrix, where rows represent antennas and columns represent subcarriers.
The tokenization process involves **dividing the channel matrix into patches**, with each patch containing information from 16 consecutive subcarriers. These patches are then **embedded** into a 64-dimensional space, providing the Transformer with a richer context for each patch. In this process, **positional encodings** are added to preserve the structural relationships within the channel, ensuring the Transformer captures both spatial and frequency dependencies.
If you choose to apply **Masked Channel Modeling (MCM)** during inference (by setting `gen_raw=False`), LWM will mask certain patches, as it did during pre-training. However, for standard inference, masking isn't necessary unless you want to test LWM's resilience to noisy inputs.
Now, let's move on to tokenize the dataset and load the pre-trained LWM model.
```python
from input_preprocess import tokenizer
from lwm_model import lwm
import torch
preprocessed_chs = tokenizer(
selected_scenario_names=selected_scenario_names, # Selects predefined DeepMIMOv3 scenarios. Set to None to load your own dataset.
manual_data=None, # If using a custom dataset, ensure it is a wireless channel dataset of size (N,32,32) based on the settings provided above.
gen_raw=True # Set gen_raw=False to apply masked channel modeling (MCM), as used in LWM pre-training. For inference, masking is unnecessary unless you want to evaluate LWM's ability to handle noisy inputs.
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading the LWM model on {device}...")
model = lwm.from_pretrained(device=device)
```
With this setup, you're ready to pass your tokenized wireless channels through the pre-trained model, extracting rich, context-aware embeddings that are ready for use in downstream tasks.
---
### 9. **Perform Inference**
Before running the inference, it's important to understand the benefits of the different embedding types. The **CLS embeddings (cls_emb)** provide a highly compressed, holistic view of the entire wireless channel, making them ideal for tasks requiring a general understanding, such as classification or high-level decision-making. On the other hand, **channel embeddings (channel_emb)** capture detailed spatial and frequency information from the wireless channel, making them more suitable for complex tasks like beamforming or channel prediction.
You can now perform inference on the preprocessed data using the LWM model.
```python
from inference import lwm_inference, create_raw_dataset
input_types = ['cls_emb', 'channel_emb', 'raw']
selected_input_type = input_types[1] # Change the index to select LWM CLS embeddings, LWM channel embeddings, or the original input channels.
if selected_input_type in ['cls_emb', 'channel_emb']:
dataset = lwm_inference(preprocessed_chs, selected_input_type, model, device)
else:
dataset = create_raw_dataset(preprocessed_chs, device)
```
By selecting either `cls_emb` or `channel_emb`, you leverage the pre-trained model's rich feature extraction capabilities to transform raw channels into highly informative embeddings. If you prefer to work with the original raw data, you can choose the `raw` input type.
---
### 10. **Explore the Interactive Demo**
To experience **LWM** interactively, visit our demo hosted on Hugging Face Spaces:
[**Try the Interactive Demo!**](https://huggingface.co/spaces/wi-lab/lwm-interactive-demo)
---
You're now ready to explore the power of **LWM** in wireless communications! Start processing datasets and generate high-quality embeddings to advance your research or applications.