lwm / README.md
Sadjad Alikhani
Create README.md
cd7cb8b verified
|
raw
history blame
4.59 kB

LWM: Large Wireless Model

This repository contains the implementation of LWM (Large Wireless Model), a pre-trained model for processing and extracting features from wireless communication datasets, specifically DeepMIMO. The instructions below will help you load DeepMIMO data, use the LWM model and weights, tokenize DeepMIMO scenario data, and generate either raw channels or the inferred LWM CLS or channel embeddings.

How to Use

Step-by-Step Guide

  1. Clone the Repository

    Clone the Hugging Face repository to your local machine using the following code:

    import subprocess
    import os
    import sys
    import importlib.util
    import torch
    
    # Hugging Face public repository URL
    repo_url = "https://huggingface.co/sadjadalikhani/LWM"
    
    # Directory where the repo will be cloned
    clone_dir = "./LWM"
    
    # Step 1: Clone the repository if it hasn't been cloned already
    if not os.path.exists(clone_dir):
        print(f"Cloning repository from {repo_url} into {clone_dir}...")
        result = subprocess.run(["git", "clone", repo_url, clone_dir], capture_output=True, text=True)
    
        if result.returncode != 0:
            print(f"Error cloning repository: {result.stderr}")
            sys.exit(1)  # Exit on failure
        print(f"Repository cloned successfully into {clone_dir}")
    else:
        print(f"Repository already cloned into {clone_dir}")
    
    # Step 2: Add the cloned directory to Python path
    sys.path.append(clone_dir)
    
    # Step 3: Dynamic module import and function exposure
    def import_functions_from_file(module_name, file_path):
        try:
            spec = importlib.util.spec_from_file_location(module_name, file_path)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
    
            # Extract functions from the module and make them globally accessible
            for function_name in dir(module):
                if callable(getattr(module, function_name)) and not function_name.startswith("__"):
                    globals()[function_name] = getattr(module, function_name)
    
            return module
        except FileNotFoundError:
            print(f"Error: {file_path} not found!")
            sys.exit(1)
    
    # Step 4: Import necessary functions
    import_functions_from_file("lwm_model", os.path.join(clone_dir, "lwm_model.py"))
    import_functions_from_file("inference", os.path.join(clone_dir, "inference.py"))
    import_functions_from_file("load_data", os.path.join(clone_dir, "load_data.py"))
    import_functions_from_file("input_preprocess", os.path.join(clone_dir, "input_preprocess.py"))
    print("All required functions imported successfully.")
    
  2. Load the LWM Model

    After cloning the repository, you can load the LWM model with the following code:

    # Step 5: Load the LWM model (with flexibility for the device)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Loading the LWM model on {device}...")
    model = LWM.from_pretrained(device=device)
    
  3. Load the DeepMIMO Dataset

    Load the DeepMIMO dataset with this code:

    # Step 6: Load dataset (direct call, no module prefix)
    print("Loading DeepMIMO dataset...")
    deepmimo_data = load_DeepMIMO_data()
    
  4. Tokenize the DeepMIMO Dataset

    Tokenize the loaded dataset. You can choose the scenario indices to select specific scenarios from DeepMIMO:

    # Step 7: Tokenize the dataset (direct call, no module prefix)
    scenario_idxs = torch.arange(1)  # Adjust the number of scenarios you want
    print("Tokenizing the dataset...")
    preprocessed_chs = tokenizer(deepmimo_data, scenario_idxs, gen_raw=True)
    
  5. Generate the Dataset for Inference

    Choose the type of data you want to generate from the tokenized dataset, such as cls_emb, channel_emb, or raw:

    # Step 8: Generate the dataset for inference (direct call, no module prefix)
    input_type = ['cls_emb', 'channel_emb', 'raw'][1]  # Modify input type as needed
    dataset = dataset_gen(preprocessed_chs, input_type, model)
    
  6. Print Results

    Finally, you can print the results and check the shape of the generated dataset:

    # Step 9: Print results
    print(f"Dataset generated with shape: {dataset.shape}")
    print("Inference completed successfully.")
    

Requirements

  • Python 3.x
  • PyTorch
  • Git

Ensure you have the necessary libraries installed before running the script.