lwm / README.md
Sadjad Alikhani
Update README.md
4eb476a verified
|
raw
history blame
9.88 kB

πŸ“‘ LWM: Large Wireless Model

πŸš€ Click here to try the Interactive Demo!

Welcome to the LWM (Large Wireless Model) repository! This project hosts a pre-trained model designed to process and extract features from wireless communication datasets, specifically the DeepMIMO dataset. Follow the instructions below to set up your environment, install the required packages, clone the repository, load the data, and perform inference with LWM.


πŸ›  How to Use for Beginners

1. Install Conda or Mamba (via Miniforge)

First, you need to have a package manager like Conda or Mamba (a faster alternative) installed to manage your Python environments and packages.

Option A: Install Conda

If you prefer to use Conda, you can download and install Anaconda or Miniconda.

  • Anaconda includes a full scientific package suite, but it is larger in size. Download it here.
  • Miniconda is a lightweight version that only includes Conda and Python. Download it here.

Option B: Install Mamba (via Miniforge)

Mamba is a much faster alternative to Conda. You can install Mamba by installing Miniforge.

  • Miniforge is a smaller, community-based installer for Conda that includes Mamba. Download it here.

After installation, you can use conda or mamba for environment management. The commands will be the same except for replacing conda with mamba.


2. Create a New Environment

Once you have Conda or Mamba installed, follow these steps to create a new environment and install the necessary packages.

Step 1: Create a new environment

You can create a new environment called lwm_env (or any other name) with Python 3.9 or any required version:

# If you're using Conda:
conda create -n lwm_env python=3.9

# If you're using Mamba:
mamba create -n lwm_env python=3.9

Step 2: Activate the environment

Activate the environment you just created:

# For both Conda and Mamba:
conda activate lwm_env

3. Clone the Repository

After setting up the environment, clone the Hugging Face repository to your local machine using the following Python code:

import subprocess
import os
import sys
import importlib.util
import torch

# Hugging Face public repository URL
repo_url = "https://huggingface.co/sadjadalikhani/LWM"

# Directory where the repo will be cloned
clone_dir = "./LWM"

# Step 1: Clone the repository if it hasn't been cloned already
if not os.path.exists(clone_dir):
    print(f"Cloning repository from {repo_url} into {clone_dir}...")
    result = subprocess.run(["git", "clone", repo_url, clone_dir], capture_output=True, text=True)

    if result.returncode != 0:
        print(f"Error cloning repository: {result.stderr}")
        sys.exit(1)
    print(f"Repository cloned successfully into {clone_dir}")
else:
    print(f"Repository already cloned into {clone_dir}")

# Step 2: Add the cloned directory to Python path
sys.path.append(clone_dir)

# Step 3: Import necessary functions
def import_functions_from_file(module_name, file_path):
    try:
        spec = importlib.util.spec_from_file_location(module_name, file_path)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)

        for function_name in dir(module):
            if callable(getattr(module, function_name)) and not function_name.startswith("__"):
                globals()[function_name] = getattr(module, function_name)
        return module
    except FileNotFoundError:
        print(f"Error: {file_path} not found!")
        sys.exit(1)

# Step 4: Import functions from the repository
import_functions_from_file("lwm_model", os.path.join(clone_dir, "lwm_model.py"))
import_functions_from_file("inference", os.path.join(clone_dir, "inference.py"))
import_functions_from_file("load_data", os.path.join(clone_dir, "load_data.py"))
import_functions_from_file("input_preprocess", os.path.join(clone_dir, "input_preprocess.py"))
print("All required functions imported successfully.")

4. Install Required Packages

Install the necessary packages inside your new environment.

# If you're using Conda:
conda install pytorch torchvision torchaudio -c pytorch
pip install -r requirements.txt

# If you're using Mamba:
mamba install pytorch torchvision torchaudio -c pytorch
pip install -r requirements.txt

This will install PyTorch, Torchvision, and other required dependencies from the requirements.txt file in the cloned repository.


5. Load the DeepMIMO Dataset

Before proceeding with tokenization and data processing, the DeepMIMO datasetβ€”or any dataset generated using the operational settings outlined belowβ€”must first be loaded. The table below provides a list of available datasets and their respective links for further details:

πŸ“Š Dataset Overview

πŸ“Š Dataset πŸ™οΈ City πŸ‘₯ Number of Users πŸ”— DeepMIMO Page
Dataset 0 πŸŒ† Denver 1354 DeepMIMO City Scenario 18
Dataset 1 πŸ™οΈ Indianapolis 3248 DeepMIMO City Scenario 15
Dataset 2 πŸŒ‡ Oklahoma 3455 DeepMIMO City Scenario 19
Dataset 3 πŸŒ† Fort Worth 1902 DeepMIMO City Scenario 12
Dataset 4 πŸŒ‰ Santa Clara 2689 DeepMIMO City Scenario 11
Dataset 5 πŸŒ… San Diego 2192 DeepMIMO City Scenario 7

It is important to note that these six datasets were not used during the pre-training of the LWM model, and the high-quality embeddings produced are a testament to LWM’s robust generalization capabilities rather than overfitting.

The operational settings below were used in generating the datasets for both the pre-training of LWM and the downstream tasks. If you intend to use custom datasets, please ensure they adhere to these configurations:

Operational Settings:

  • Antennas at BS: 32
  • Antennas at UEs: 1
  • Subcarriers: 32
  • Paths: 20

Load Data Code:

Select and load specific datasets by adjusting the dataset_idxs. In the example below, we select the first two datasets.

# Step 5: Load the DeepMIMO dataset
print("Loading the DeepMIMO dataset...")

# Load the DeepMIMO dataset
deepmimo_data = load_DeepMIMO_data()

# Select datasets to load
dataset_idxs = torch.arange(2)  # Adjust the number of datasets as needed
print("DeepMIMO dataset loaded successfully.")

6. Tokenize the DeepMIMO Dataset

After loading the data, tokenize the selected DeepMIMO datasets. This step prepares the data for the model to process.

Tokenization Code:

# Step 6: Tokenize the dataset
print("Tokenizing the DeepMIMO dataset...")

# Tokenize the loaded datasets
preprocessed_chs = tokenizer(deepmimo_data, dataset_idxs, gen_raw=True)
print("Dataset tokenized successfully.")

7. Load the LWM Model

Once the dataset is tokenized, load the pre-trained LWM model using the following code:

# Step 7: Load the LWM model (with flexibility for the device)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading the LWM model on {device}...")
model = LWM.from_pretrained(device=device)

8. LWM Inference

Once the dataset is tokenized and the model is loaded, generate either raw channels or the inferred LWM embeddings by choosing the input type.

# Step 8: Generate the dataset for inference
input_type = ['cls_emb', 'channel_emb', 'raw'][1]  # Modify input type as needed
dataset = dataset_gen(preprocessed_chs, input_type, model)

You can choose between:

  • cls_emb: LWM CLS token embeddings
  • channel_emb: LWM channel embeddings
  • raw: Raw wireless channel data

  1. Post-processing for Downstream Task

Use the Dataset in Downstream Tasks

Finally, use the generated dataset for your downstream tasks, such as classification, prediction, or analysis.

# Step 9: Print results
print(f"Dataset generated with shape: {dataset.shape}")
print("Inference completed successfully.")

πŸ“‹ Requirements

  • Python 3.x
  • PyTorch
  • Git

Summary of Steps:

  1. Install Conda/Mamba: Install a package manager for environment management.
  2. Create Environment: Use Conda or Mamba to create a new environment.
  3. Clone the Repository: Download the project files from Hugging Face.
  4. Install Packages: Install PyTorch and other dependencies.
  5. Load and Tokenize Data: Load the DeepMIMO dataset and prepare it for the model.
  6. Load Model and Perform Inference: Use the LWM model for generating embeddings or raw channels.