sd_depth_regression_v2

Image Regression Model

This model was trained with Image Regression Model Trainer. It takes an image as input and outputs a float value.

from ImageRegression import predict
predict(repo_id='BrownEnergy/sd_depth_regression_v2',image_path='image.jpg')

Dataset

Dataset: BrownEnergy/secchi_depth
Value Column: 'sd_depth'
Train Test Split: 0.05


Training

Base Model: google/vit-base-patch16-224
Epochs: 10
Learning Rate: 0.0001


Usage

Download

git clone https://github.com/TonyAssi/ImageRegression.git
cd ImageRegression

Installation

pip install -r requirements.txt

Import

from ImageRegression import train_model, upload_model, predict

Inference (Prediction)

  • repo_id 🤗 repo id of the model
  • image_path path to image
predict(repo_id='BrownEnergy/sd_depth_regression_v2',
        image_path='image.jpg')

The first time this function is called it'll download the safetensor model. Subsequent function calls will run faster.

Train Model

  • dataset_id 🤗 dataset id
  • value_column_name column name of prediction values in dataset
  • test_split test split of the train/test split
  • output_dir the directory where the checkpoints will be saved
  • num_train_epochs training epochs
  • learning_rate learning rate
train_model(dataset_id='BrownEnergy/secchi_depth',
            value_column_name='sd_depth',
            test_split=0.05,
            output_dir='./results',
            num_train_epochs=10,
            learning_rate=0.0001)

The trainer will save the checkpoints in the output_dir location. The model.safetensors are the trained weights you'll use for inference (predicton).

Upload Model

This function will upload your model to the 🤗 Hub.

  • model_id the name of the model id
  • token go here to create a new 🤗 token
  • checkpoint_dir checkpoint folder that will be uploaded
upload_model(model_id='sd_depth_regression_v2',
             token='YOUR_HF_TOKEN',
             checkpoint_dir='./results/checkpoint-940')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Model tree for BrownEnergy/sd_depth_regression_v2

Finetuned
(529)
this model

Dataset used to train BrownEnergy/sd_depth_regression_v2