fffiloni commited on
Commit
9c9498f
·
verified ·
1 Parent(s): f047660

Migrated from GitHub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/CBS_2.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/MLB_1.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/Sunny_1.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/Titanic_1.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Fiona Ryan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ORIGINAL_README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gaze-LLE
2
+ <div style="text-align:center;">
3
+ <img src="./assets/the_office.png" height="100"/>
4
+ <img src="./assets/MLB_1.gif" height="100"/>
5
+ <img src="./assets/succession.png" height="100"/>
6
+ <img src="./assets/CBS_2.gif" height="100"/>
7
+ </div>
8
+
9
+ [Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders](https://arxiv.org/abs/2412.09586) \
10
+ [Fiona Ryan](https://fkryan.github.io/), Ajay Bati, [Sangmin Lee](https://sites.google.com/view/sangmin-lee), [Daniel Bolya](https://dbolya.github.io/), [Judy Hoffman](https://faculty.cc.gatech.edu/~judy/)\*, [James M. Rehg](https://rehg.org/)\*
11
+
12
+
13
+ This is the official implementation for Gaze-LLE, a transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!
14
+
15
+ <div style="text-align:center;">
16
+ <img src="./assets/gazelle_arch.png" height="200"/>
17
+ </div>
18
+
19
+
20
+ ## Installation
21
+
22
+ Clone this repo, then create the virtual environment.
23
+ ```
24
+ conda env create -f environment.yml
25
+ conda activate gazelle
26
+ pip install -e .
27
+ ```
28
+ If your system supports it, consider installing [xformers](https://github.com/facebookresearch/xformers) to speed up attention computation.
29
+ ```
30
+ pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
31
+ ```
32
+
33
+ ## Pretrained Models
34
+
35
+ We provide the following pretrained models for download.
36
+ | Name | Backbone type | Backbone name | Training data | Checkpoint |
37
+ | ---- | ------------- | ------------- |-------------- | ---------- |
38
+ | ```gazelle_dinov2_vitb14``` | DINOv2 ViT-B | ```dinov2_vitb14```| GazeFollow | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitb14.pt) |
39
+ | ```gazelle_dinov2_vitl14``` | DINOv2 ViT-L | ```dinov2_vitl14``` | GazeFollow | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14.pt) |
40
+ | ```gazelle_dinov2_vitb14_inout``` | DINOv2 ViT-B | ```dinov2_vitb14``` | Gazefollow -> VideoAttentionTarget | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitb14_inout.pt) |
41
+ | ```gazelle_large_vitl14_inout``` | DINOv2-ViT-L | ```dinov2_vitl14``` | GazeFollow -> VideoAttentionTarget | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14_inout.pt) |
42
+
43
+
44
+ Note that our Gaze-LLE checkpoints contain only the gaze decoder weights - the DINOv2 backbone weights are downloaded from ```facebookresearch/dinov2``` on PyTorch Hub when the Gaze-LLE model is created in our code.
45
+
46
+ The GazeFollow-trained models output a spatial heatmap of gaze locations over the scene with values in range ```[0,1]```, where 1 represents the highest probability of the location being a gaze target. The models that are additionally finetuned on VideoAttentionTarget also predict a in/out of frame gaze score in range ```[0,1]``` where 1 represents the person's gaze target being in the frame.
47
+
48
+ ### PyTorch Hub
49
+
50
+ The models are also available on PyTorch Hub for easy use without installing from source.
51
+ ```
52
+ model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitb14')
53
+ model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14')
54
+ model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitb14_inout')
55
+ model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
56
+ ```
57
+
58
+
59
+ ## Usage
60
+ ### Colab Demo Notebook
61
+ Check out our [Demo Notebook](https://colab.research.google.com/drive/1TSoyFvNs1-au9kjOZN_fo5ebdzngSPDq?usp=sharing) on Google Colab for how to detect gaze for all people in an image.
62
+
63
+ ### Gaze Prediction
64
+ Gaze-LLE is set up for multi-person inference (e.g. for a single image, GazeLLE encodes the scene only once and then uses the features to predict the gaze of multiple people in the image). The input is a batch of image tensors and a list of bounding boxes for each image representing the heads of the people to predict gaze for in each image. The bounding boxes are tuples of form ```(xmin, ymin, xmax, ymax)``` and are in ```[0,1]``` normalized image coordinates. Below we show how to perform inference for a single person in a single image.
65
+ ```
66
+ from PIL import Image
67
+ import torch
68
+ from gazelle.model import get_gazelle_model
69
+
70
+ model, transform = get_gazelle_model("gazelle_dinov2_vitl14_inout")
71
+ model.load_gazelle_state_dict(torch.load("/path/to/checkpoint.pt", weights_only=True))
72
+ model.eval()
73
+
74
+ device = "cuda" if torch.cuda.is_available() else "cpu"
75
+ model.to(device)
76
+
77
+ image = Image.open("path/to/image.png").convert("RGB")
78
+ input = {
79
+ "images": transform(image).unsqueeze(dim=0).to(device), # tensor of shape [1, 3, 448, 448]
80
+ "bboxes": [[(0.1, 0.2, 0.5, 0.7)]] # list of lists of bbox tuples
81
+ }
82
+
83
+ with torch.no_grad():
84
+ output = model(input)
85
+ predicted_heatmap = output["heatmap"][0][0] # access prediction for first person in first image. Tensor of size [64, 64]
86
+ predicted_inout = output["inout"][0][0] # in/out of frame score (1 = in frame) (output["inout"] will be None for non-inout models)
87
+ ```
88
+ We empirically find that Gaze-LLE is effective without a bounding box input for scenes with just one person. However, providing a bounding box can improve results, and is necessary for scenes with multiple people to specify which person's gaze to estimate. To inference without a bounding box, use None in place of a bounding box tuple in the bbox list (e.g. ```input["bboxes"] = [[None]]``` in the example above).
89
+
90
+
91
+ We also provide a function to visualize the predicted heatmap for an image.
92
+ ```
93
+ import matplotlib.pyplot as plt
94
+ from gazelle.utils import visualize_heatmap
95
+
96
+ viz = visualize_heatmap(image, predicted_heatmap)
97
+ plt.imshow(viz)
98
+ plt.show()
99
+ ```
100
+
101
+
102
+ ## Evaluate
103
+ We provide evaluation scripts for GazeFollow and VideoAttentionTarget below to reproduce our results from our checkpoints.
104
+ ### GazeFollow
105
+ Download the GazeFollow dataset [here](https://github.com/ejcgt/attention-target-detection?tab=readme-ov-file#dataset). We provide a preprocessing script ```data_prep/preprocess_gazefollow.py```, which preprocesses and compiles the annotations into a JSON file for each split within the dataset folder. Run the preprocessing script as
106
+ ```
107
+ python data_prep/preprocess_gazefollow.py --data_path /path/to/gazefollow/data_new
108
+ ```
109
+ Download the pretrained model checkpoints above and use ```--model_name``` and ```ckpt_path``` to specify the model type and checkpoint for evaluation.
110
+
111
+ ```
112
+ python scripts/eval_gazefollow.py
113
+ --data_path /path/to/gazefollow/data_new \
114
+ --model_name gazelle_dinov2_vitl14 \
115
+ --ckpt_path /path/to/checkpoint.pt \
116
+ --batch_size 128
117
+ ```
118
+
119
+
120
+ ### VideoAttentionTarget
121
+ Download the VideoAttentionTarget dataset [here](https://github.com/ejcgt/attention-target-detection?tab=readme-ov-file#dataset-1). We provide a preprocessing script ```data_prep/preprocess_vat.py```, which preprocesses and compiles the annotations into a JSON file for each split within the dataset folder. Run the preprocessing script as
122
+ ```
123
+ python data_prep/preprocess_gazefollow.py --data_path /path/to/videoattentiontarget
124
+ ```
125
+ Download the pretrained model checkpoints above and use ```--model_name``` and ```ckpt_path``` to specify the model type and checkpoint for evaluation.
126
+ ```
127
+ python scripts/eval_vat.py
128
+ --data_path /path/to/videoattentiontarget \
129
+ --model_name gazelle_dinov2_vitl14_inout \
130
+ --ckpt_path /path/to/checkpoint.pt \
131
+ --batch_size 64
132
+ ```
133
+
134
+ ## Citation
135
+
136
+ ```
137
+ @article{ryan2024gazelle,
138
+ author = {Ryan, Fiona and Bati, Ajay and Lee, Sangmin and Bolya, Daniel and Hoffman, Judy and Rehg, James M},
139
+ title = {Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders},
140
+ journal = {arXiv preprint arXiv:2412.09586},
141
+ year = {2024},
142
+ }
143
+ ```
144
+
145
+ ## References
146
+
147
+ - Our models are built on top of pretrained DINOv2 models from PyTorch Hub ([Github repo](https://github.com/facebookresearch/dinov2)).
148
+
149
+ - Our GazeFollow and VideoAttentionTarget preprocessing code is based on [Detecting Attended Targets in Video](https://github.com/ejcgt/attention-target-detection).
150
+
151
+ - We use [PyTorch Image Models (timm)](https://github.com/huggingface/pytorch-image-models) for our transformer implementation.
152
+
153
+ - We use [xFormers](https://github.com/facebookresearch/xformers) for efficient multi-head attention.
assets/CBS_2.gif ADDED

Git LFS Details

  • SHA256: 350d577f58dc36b436cd0e900d92228b5a88c4cc35d2e4527f6262dc886dcb96
  • Pointer size: 134 Bytes
  • Size of remote file: 101 MB
assets/MLB_1.gif ADDED

Git LFS Details

  • SHA256: 39ae696054f546f13ae72cc72ceb88404a76a550a34f2aaedc19a7206a2bdfbb
  • Pointer size: 133 Bytes
  • Size of remote file: 26.7 MB
assets/Sunny_1.gif ADDED

Git LFS Details

  • SHA256: 4364a39dedd8d92f8a08ff08c3a4d80c1e03cd0bd2e22bf3e9a2782f9d9f74e1
  • Pointer size: 132 Bytes
  • Size of remote file: 7.58 MB
assets/Titanic_1.gif ADDED

Git LFS Details

  • SHA256: ad54e4362747c94225fe29a2beb2e9fcb10bf92b0c58f0a54dfdf007146064ff
  • Pointer size: 133 Bytes
  • Size of remote file: 18.2 MB
assets/gazelle_arch.png ADDED
assets/succession.png ADDED
assets/the_office.png ADDED
data_prep/preprocess_gazefollow.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import json
4
+ from PIL import Image
5
+ import argparse
6
+
7
+ # preprocessing adapted from https://github.com/ejcgt/attention-target-detection/blob/master/dataset.py
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--data_path", type=str, default="./data/gazefollow")
11
+ args = parser.parse_args()
12
+
13
+
14
+ def main(DATA_PATH):
15
+
16
+ # TRAIN
17
+ train_csv_path = os.path.join(DATA_PATH, "train_annotations_release.txt")
18
+ column_names = ['path', 'idx', 'body_bbox_x', 'body_bbox_y', 'body_bbox_w', 'body_bbox_h', 'eye_x', 'eye_y',
19
+ 'gaze_x', 'gaze_y', 'bbox_x_min', 'bbox_y_min', 'bbox_x_max', 'bbox_y_max', 'inout', 'source', 'meta']
20
+ df = pd.read_csv(train_csv_path, header=None, names=column_names, index_col=False)
21
+ df = df[df['inout'] != -1]
22
+ df = df.groupby("path").agg(list) # aggregate over frames
23
+
24
+ multiperson_ex = 0
25
+ TRAIN_FRAMES = []
26
+ for path, row in df.iterrows():
27
+ img_path = os.path.join(DATA_PATH, path)
28
+ img = Image.open(img_path)
29
+ width, height = img.size
30
+
31
+ num_people = len(row['idx'])
32
+ if num_people > 1:
33
+ multiperson_ex += 1
34
+ heads = []
35
+ crop_constraint_xs = []
36
+ crop_constraint_ys = []
37
+
38
+ for i in range(num_people):
39
+ xmin, ymin, xmax, ymax = row['bbox_x_min'][i], row['bbox_y_min'][i], row['bbox_x_max'][i], row['bbox_y_max'][i]
40
+ gazex = row['gaze_x'][i] * float(width)
41
+ gazey = row['gaze_y'][i] * float(height)
42
+ gazex_norm = row['gaze_x'][i]
43
+ gazey_norm = row['gaze_y'][i]
44
+
45
+
46
+ if xmin > xmax:
47
+ temp = xmin
48
+ xmin = xmax
49
+ xmax = temp
50
+ if ymin > ymax:
51
+ temp = ymin
52
+ ymin = ymax
53
+ ymax = temp
54
+
55
+ # move in out of frame bbox annotations
56
+ xmin = max(xmin, 0)
57
+ ymin = max(ymin, 0)
58
+ xmax = min(xmax, width)
59
+ ymax = min(ymax, height)
60
+
61
+ # precalculate feasible crop region (containing bbox and gaze target)
62
+ crop_xmin = min(xmin, gazex)
63
+ crop_ymin = min(ymin, gazey)
64
+ crop_xmax = max(xmax, gazex)
65
+ crop_ymax = max(ymax, gazey)
66
+ crop_constraint_xs.extend([crop_xmin, crop_xmax])
67
+ crop_constraint_ys.extend([crop_ymin, crop_ymax])
68
+
69
+ heads.append({
70
+ 'bbox': [xmin, ymin, xmax, ymax],
71
+ 'bbox_norm': [xmin / float(width), ymin / float(height), xmax / float(width), xmax / float(height)],
72
+ 'inout': row['inout'][i],
73
+ 'gazex': [gazex], # convert to list for consistency with multi-annotation format
74
+ 'gazey': [gazey],
75
+ 'gazex_norm': [gazex_norm],
76
+ 'gazey_norm': [gazey_norm],
77
+ 'crop_region': [crop_xmin, crop_ymin, crop_xmax, crop_ymax],
78
+ 'crop_region_norm': [crop_xmin / float(width), crop_ymin / float(height), crop_xmin / float(width), crop_ymax / float(height)],
79
+ 'head_id': i
80
+ })
81
+ TRAIN_FRAMES.append({
82
+ 'path': path,
83
+ 'heads': heads,
84
+ 'num_heads': num_people,
85
+ 'width': width,
86
+ 'height': height,
87
+ 'crop_region': [min(crop_constraint_xs), min(crop_constraint_ys), max(crop_constraint_xs), max(crop_constraint_ys)],
88
+ })
89
+
90
+ print("Train set: {} frames, {} multi-person".format(len(TRAIN_FRAMES), multiperson_ex))
91
+ out_file = open(os.path.join(DATA_PATH, "train_preprocessed.json"), "w")
92
+ json.dump(TRAIN_FRAMES, out_file)
93
+
94
+ # TEST
95
+ test_csv_path = os.path.join(DATA_PATH, "test_annotations_release.txt")
96
+ column_names = ['path', 'idx', 'body_bbox_x', 'body_bbox_y', 'body_bbox_w', 'body_bbox_h', 'eye_x', 'eye_y',
97
+ 'gaze_x', 'gaze_y', 'bbox_x_min', 'bbox_y_min', 'bbox_x_max', 'bbox_y_max', 'source', 'meta']
98
+ df = pd.read_csv(test_csv_path, header=None, names=column_names, index_col=False)
99
+
100
+ TEST_FRAME_DICT = {}
101
+ df = df.groupby(["path", "eye_x"]).agg(list) # aggregate over frames
102
+ for id, row in df.iterrows(): # aggregate by frame
103
+ path, _ = id
104
+ if path in TEST_FRAME_DICT.keys():
105
+ TEST_FRAME_DICT[path].append(row)
106
+ else:
107
+ TEST_FRAME_DICT[path] = [row]
108
+
109
+ multiperson_ex = 0
110
+ TEST_FRAMES = []
111
+ for path in TEST_FRAME_DICT.keys():
112
+ img_path = os.path.join(DATA_PATH, path)
113
+ img = Image.open(img_path)
114
+ width, height = img.size
115
+
116
+ item = TEST_FRAME_DICT[path]
117
+ num_people = len(item)
118
+ heads = []
119
+ crop_constraint_xs = []
120
+ crop_constraint_ys = []
121
+
122
+ for i in range(num_people):
123
+ row = item[i]
124
+ assert(row['bbox_x_min'].count(row['bbox_x_min'][0]) == len(row['bbox_x_min'])) # quick check that all bboxes are equivalent
125
+ xmin, ymin, xmax, ymax = row['bbox_x_min'][0], row['bbox_y_min'][0], row['bbox_x_max'][0], row['bbox_y_max'][0]
126
+
127
+ if xmin > xmax:
128
+ temp = xmin
129
+ xmin = xmax
130
+ xmax = temp
131
+ if ymin > ymax:
132
+ temp = ymin
133
+ ymin = ymax
134
+ ymax = temp
135
+
136
+ # move in out of frame bbox annotations
137
+ xmin = max(xmin, 0)
138
+ ymin = max(ymin, 0)
139
+ xmax = min(xmax, width)
140
+ ymax = min(ymax, height)
141
+
142
+ gazex_norm = [x for x in row['gaze_x']]
143
+ gazey_norm = [y for y in row['gaze_y']]
144
+ gazex = [x * float(width) for x in row['gaze_x']]
145
+ gazey = [y * float(height) for y in row['gaze_y']]
146
+
147
+ # precalculate feasible crop region (containing bbox and gaze target)
148
+ crop_xmin = min(xmin, *gazex)
149
+ crop_ymin = min(ymin, *gazey)
150
+ crop_xmax = max(xmax, *gazex)
151
+ crop_ymax = max(ymax, *gazey)
152
+ crop_constraint_xs.extend([crop_xmin, crop_xmax])
153
+ crop_constraint_ys.extend([crop_ymin, crop_ymax])
154
+
155
+ heads.append({
156
+ 'bbox': [xmin, ymin, xmax, ymax],
157
+ 'bbox_norm': [xmin / float(width), ymin / float(height), xmax / float(width), ymax / float(height)],
158
+ 'gazex': gazex,
159
+ 'gazey': gazey,
160
+ 'gazex_norm': gazex_norm,
161
+ 'gazey_norm': gazey_norm,
162
+ 'inout': 1, # all test frames are in frame
163
+ 'num_annot': len(gazex),
164
+ 'crop_region': [crop_xmin, crop_ymin, crop_xmax, crop_ymax],
165
+ 'crop_region_norm': [crop_xmin / float(width), crop_ymin / float(height), crop_xmax / float(width), crop_ymax / float(height)],
166
+ 'head_id': i
167
+ })
168
+
169
+ # visualize_heads(img_path, heads)
170
+ TEST_FRAMES.append({
171
+ 'path': path,
172
+ 'heads': heads,
173
+ 'num_heads': num_people,
174
+ 'width': width,
175
+ 'height': height,
176
+ 'crop_region': [min(crop_constraint_xs), min(crop_constraint_ys), max(crop_constraint_xs), max(crop_constraint_ys)],
177
+ })
178
+ if num_people > 1:
179
+ multiperson_ex += 1
180
+
181
+ print("Test set: {} frames, {} multi-person".format(len(TEST_FRAMES), multiperson_ex))
182
+ out_file = open(os.path.join(DATA_PATH, "test_preprocessed.json"), "w")
183
+ json.dump(TEST_FRAMES, out_file)
184
+
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main(args.data_path)
data_prep/preprocess_vat.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ from functools import reduce
4
+ import os
5
+ import pandas as pd
6
+ import json
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--data_path", type=str, default="./data/videoattentiontarget")
12
+ args = parser.parse_args()
13
+
14
+ # preprocessing adapted from https://github.com/ejcgt/attention-target-detection/blob/master/dataset.py
15
+
16
+ def merge_dfs(ls):
17
+ for i, df in enumerate(ls): # give columns unique names
18
+ df.columns = [col if col == "path" else f"{col}_df{i}" for col in df.columns]
19
+ merged_df = reduce(
20
+ lambda left, right: pd.merge(left, right, on=["path"], how="outer"), ls
21
+ )
22
+ merged_df = merged_df.sort_values(by=["path"])
23
+ merged_df = merged_df.reset_index(drop=True)
24
+ return merged_df
25
+
26
+ def smooth_by_conv(window_size, df, col):
27
+ """Temporal smoothing on labels to match original VideoAttTarget evaluation.
28
+ Adapted from https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/utils/myutils.py"""
29
+ values = df[col].values
30
+ padded_track = np.concatenate([values[0].repeat(window_size // 2), values, values[-1].repeat(window_size // 2)])
31
+ smoothed_signals = np.convolve(
32
+ padded_track.squeeze(), np.ones(window_size) / window_size, mode="valid"
33
+ )
34
+ return smoothed_signals
35
+
36
+ def smooth_df(window_size, df):
37
+ df["xmin"] = smooth_by_conv(window_size, df, "xmin")
38
+ df["ymin"] = smooth_by_conv(window_size, df, "ymin")
39
+ df["xmax"] = smooth_by_conv(window_size, df, "xmax")
40
+ df["ymax"] = smooth_by_conv(window_size, df, "ymax")
41
+ return df
42
+
43
+
44
+ def main(PATH):
45
+ # preprocess by sequence and person track
46
+ splits = ["train", "test"]
47
+
48
+ for split in splits:
49
+ sequences = []
50
+ max_num_ppl = 0
51
+ seq_idx = 0
52
+ for seq_path in glob.glob(
53
+ os.path.join(PATH, "annotations", split, "*", "*")
54
+ ):
55
+ seq_img_path = os.path.join("images", *seq_path.split("/")[-2:]
56
+ )
57
+ sample_image = os.path.join(PATH, seq_img_path, os.listdir(os.path.join(PATH, seq_img_path))[0])
58
+ width, height = Image.open(sample_image).size
59
+ seq_dict = {"path": seq_img_path, "width": width, "height": height}
60
+ frames = []
61
+ person_files = glob.glob(os.path.join(seq_path, "*"))
62
+ num_ppl = len(person_files)
63
+ if num_ppl > max_num_ppl:
64
+ max_num_ppl = num_ppl
65
+ person_dfs = [
66
+ pd.read_csv(
67
+ file,
68
+ header=None,
69
+ index_col=False,
70
+ names=["path", "xmin", "ymin", "xmax", "ymax", "gazex", "gazey"],
71
+ )
72
+ for file in person_files
73
+ ]
74
+ # moving-avg smoothing to match original benchmark's evaluation
75
+ window_size = 11
76
+ person_dfs = [smooth_df(window_size, df) for df in person_dfs]
77
+ merged_df = merge_dfs(person_dfs) # merge annotations per person for same frames
78
+ for frame_idx, row in merged_df.iterrows():
79
+ frame_dict = {
80
+ "path": os.path.join(seq_img_path, row["path"]),
81
+ "heads": [],
82
+ }
83
+ p_idx = 0
84
+ for i in range(1, num_ppl * 6 + 1, 6):
85
+ if not np.isnan(row.iloc[i]): # if it's nan lack of continuity (one person leaving the frame for a period of time)
86
+ xmin, ymin, xmax, ymax, gazex, gazey = row[i: i+6].values.tolist()
87
+ # match original benchmark's preprocessing of annotations
88
+ if gazex >=0 and gazey < 0:
89
+ gazey = 0
90
+ elif gazey >=0 and gazex < 0:
91
+ gazex = 0
92
+ inout = int(gazex >= 0 and gazey >= 0)
93
+ frame_dict["heads"].append({
94
+ "bbox": [xmin, ymin, xmax, ymax],
95
+ "bbox_norm": [xmin / float(width), ymin / float(height), xmax / float(width), ymax / float(height)],
96
+ "gazex": [gazex],
97
+ "gazex_norm": [gazex / float(width)],
98
+ "gazey": [gazey],
99
+ "gazey_norm": [gazey / float(height)],
100
+ "inout": inout
101
+ })
102
+ p_idx = p_idx + 1
103
+
104
+ frames.append(frame_dict)
105
+ seq_dict["frames"] = frames
106
+ sequences.append(seq_dict)
107
+ seq_idx += 1
108
+
109
+ print("{} max people per image {}".format(split, max_num_ppl))
110
+ print("{} num unique video sequences {}".format(split, len(sequences)))
111
+
112
+ out_file = open(os.path.join(PATH, "{}_preprocessed.json".format(split)), "w")
113
+ json.dump(sequences, out_file)
114
+
115
+ if __name__ == "__main__":
116
+ main(args.data_path)
environment.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: gazelle
2
+ channels:
3
+ - nvidia
4
+ - pytorch
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - python=3.9
9
+ - pytorch=2.5.1
10
+ - torchvision=0.20.1
11
+ - torchaudio=2.5.1
12
+ - pytorch-cuda=11.8
13
+ - timm
14
+ - scikit-learn
15
+ - matplotlib
16
+ - pandas
gazelle/backbone.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+
6
+ # Abstract Backbone class
7
+ class Backbone(nn.Module, ABC):
8
+ def __init__(self):
9
+ super(Backbone, self).__init__()
10
+
11
+ @abstractmethod
12
+ def forward(self, x):
13
+ pass
14
+
15
+ @abstractmethod
16
+ def get_dimension(self):
17
+ pass
18
+
19
+ @abstractmethod
20
+ def get_out_size(self, in_size):
21
+ pass
22
+
23
+ def get_transform(self):
24
+ pass
25
+
26
+
27
+ # Official DINOv2 backbones from torch hub (https://github.com/facebookresearch/dinov2#pretrained-backbones-via-pytorch-hub)
28
+ class DinoV2Backbone(Backbone):
29
+ def __init__(self, model_name):
30
+ super(DinoV2Backbone, self).__init__()
31
+ self.model = torch.hub.load('facebookresearch/dinov2', model_name)
32
+
33
+ def forward(self, x):
34
+ b, c, h, w = x.shape
35
+ out_h, out_w = self.get_out_size((h, w))
36
+ x = self.model.forward_features(x)['x_norm_patchtokens']
37
+ x = x.view(x.size(0), out_h, out_w, -1).permute(0, 3, 1, 2) # "b (out_h out_w) c -> b c out_h out_w"
38
+ return x
39
+
40
+ def get_dimension(self):
41
+ return self.model.embed_dim
42
+
43
+ def get_out_size(self, in_size):
44
+ h, w = in_size
45
+ return (h // self.model.patch_size, w // self.model.patch_size)
46
+
47
+ def get_transform(self, in_size):
48
+ return transforms.Compose([
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(
51
+ mean=[0.485,0.456,0.406],
52
+ std=[0.229,0.224,0.225]
53
+ ),
54
+ transforms.Resize(in_size),
55
+ ])
gazelle/model.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ from timm.models.vision_transformer import Block
5
+ import math
6
+
7
+ import gazelle.utils as utils
8
+ from gazelle.backbone import DinoV2Backbone
9
+
10
+
11
+ class GazeLLE(nn.Module):
12
+ def __init__(self, backbone, inout=False, dim=256, num_layers=3, in_size=(448, 448), out_size=(64, 64)):
13
+ super().__init__()
14
+ self.backbone = backbone
15
+ self.dim = dim
16
+ self.num_layers = num_layers
17
+ self.featmap_h, self.featmap_w = backbone.get_out_size(in_size)
18
+ self.in_size = in_size
19
+ self.out_size = out_size
20
+ self.inout = inout
21
+
22
+ self.linear = nn.Conv2d(backbone.get_dimension(), self.dim, 1)
23
+ self.register_buffer("pos_embed", positionalencoding2d(self.dim, self.featmap_h, self.featmap_w).squeeze(dim=0).squeeze(dim=0))
24
+ self.transformer = nn.Sequential(*[
25
+ Block(
26
+ dim=self.dim,
27
+ num_heads=8,
28
+ mlp_ratio=4,
29
+ drop_path=0.1)
30
+ for i in range(num_layers)
31
+ ])
32
+ self.heatmap_head = nn.Sequential(
33
+ nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2),
34
+ nn.Conv2d(dim, 1, kernel_size=1, bias=False),
35
+ nn.Sigmoid()
36
+ )
37
+ self.head_token = nn.Embedding(1, self.dim)
38
+ if self.inout:
39
+ self.inout_head = nn.Sequential(
40
+ nn.Linear(self.dim, 128),
41
+ nn.ReLU(),
42
+ nn.Dropout(0.1),
43
+ nn.Linear(128, 1),
44
+ nn.Sigmoid()
45
+ )
46
+ self.inout_token = nn.Embedding(1, self.dim)
47
+
48
+ def forward(self, input):
49
+ # input["images"]: [B, 3, H, W] tensor of images
50
+ # input["bboxes"]: list of lists of bbox tuples [[(xmin, ymin, xmax, ymax)]] per image in normalized image coords
51
+
52
+ num_ppl_per_img = [len(bbox_list) for bbox_list in input["bboxes"]]
53
+ x = self.backbone.forward(input["images"])
54
+ x = self.linear(x)
55
+ x = x + self.pos_embed
56
+ x = utils.repeat_tensors(x, num_ppl_per_img) # repeat image features along people dimension per image
57
+ head_maps = torch.cat(self.get_input_head_maps(input["bboxes"]), dim=0).to(x.device) # [sum(N_p), 32, 32]
58
+ head_map_embeddings = head_maps.unsqueeze(dim=1) * self.head_token.weight.unsqueeze(-1).unsqueeze(-1)
59
+ x = x + head_map_embeddings
60
+ x = x.flatten(start_dim=2).permute(0, 2, 1) # "b c h w -> b (h w) c"
61
+
62
+ if self.inout:
63
+ x = torch.cat([self.inout_token.weight.unsqueeze(dim=0).repeat(x.shape[0], 1, 1), x], dim=1)
64
+
65
+ x = self.transformer(x)
66
+
67
+ if self.inout:
68
+ inout_tokens = x[:, 0, :]
69
+ inout_preds = self.inout_head(inout_tokens).squeeze(dim=-1)
70
+ inout_preds = utils.split_tensors(inout_preds, num_ppl_per_img)
71
+ x = x[:, 1:, :] # slice off inout tokens from scene tokens
72
+
73
+ x = x.reshape(x.shape[0], self.featmap_h, self.featmap_w, x.shape[2]).permute(0, 3, 1, 2) # b (h w) c -> b c h w
74
+ x = self.heatmap_head(x).squeeze(dim=1)
75
+ x = torchvision.transforms.functional.resize(x, self.out_size)
76
+ heatmap_preds = utils.split_tensors(x, num_ppl_per_img) # resplit per image
77
+
78
+ return {"heatmap": heatmap_preds, "inout": inout_preds if self.inout else None}
79
+
80
+ def get_input_head_maps(self, bboxes):
81
+ # bboxes: [[(xmin, ymin, xmax, ymax)]] - list of list of head bboxes per image
82
+ head_maps = []
83
+ for bbox_list in bboxes:
84
+ img_head_maps = []
85
+ for bbox in bbox_list:
86
+ if bbox is None: # no bbox provided, use empty head map
87
+ img_head_maps.append(torch.zeros(self.featmap_h, self.featmap_w))
88
+ else:
89
+ xmin, ymin, xmax, ymax = bbox
90
+ width, height = self.featmap_w, self.featmap_h
91
+ xmin = round(xmin * width)
92
+ ymin = round(ymin * height)
93
+ xmax = round(xmax * width)
94
+ ymax = round(ymax * height)
95
+ head_map = torch.zeros((height, width))
96
+ head_map[ymin:ymax, xmin:xmax] = 1
97
+ img_head_maps.append(head_map)
98
+ head_maps.append(torch.stack(img_head_maps))
99
+ return head_maps
100
+
101
+ def get_gazelle_state_dict(self, include_backbone=False):
102
+ if include_backbone:
103
+ return self.state_dict()
104
+ else:
105
+ return {k: v for k, v in self.state_dict().items() if not k.startswith("backbone")}
106
+
107
+ def load_gazelle_state_dict(self, ckpt_state_dict, include_backbone=False):
108
+ current_state_dict = self.state_dict()
109
+ keys1 = current_state_dict.keys()
110
+ keys2 = ckpt_state_dict.keys()
111
+
112
+ if not include_backbone:
113
+ keys1 = set([k for k in keys1 if not k.startswith("backbone")])
114
+ keys2 = set([k for k in keys2 if not k.startswith("backbone")])
115
+ else:
116
+ keys1 = set(keys1)
117
+ keys2 = set(keys2)
118
+
119
+ if len(keys2 - keys1) > 0:
120
+ print("WARNING unused keys in provided state dict: ", keys2 - keys1)
121
+ if len(keys1 - keys2) > 0:
122
+ print("WARNING provided state dict does not have values for keys: ", keys1 - keys2)
123
+
124
+ for k in list(keys1 & keys2):
125
+ current_state_dict[k] = ckpt_state_dict[k]
126
+
127
+ self.load_state_dict(current_state_dict, strict=False)
128
+
129
+
130
+ # From https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
131
+ def positionalencoding2d(d_model, height, width):
132
+ """
133
+ :param d_model: dimension of the model
134
+ :param height: height of the positions
135
+ :param width: width of the positions
136
+ :return: d_model*height*width position matrix
137
+ """
138
+ if d_model % 4 != 0:
139
+ raise ValueError("Cannot use sin/cos positional encoding with "
140
+ "odd dimension (got dim={:d})".format(d_model))
141
+ pe = torch.zeros(d_model, height, width)
142
+ # Each dimension use half of d_model
143
+ d_model = int(d_model / 2)
144
+ div_term = torch.exp(torch.arange(0., d_model, 2) *
145
+ -(math.log(10000.0) / d_model))
146
+ pos_w = torch.arange(0., width).unsqueeze(1)
147
+ pos_h = torch.arange(0., height).unsqueeze(1)
148
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
149
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
150
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
151
+ pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
152
+
153
+ return pe
154
+
155
+
156
+ # models
157
+ def get_gazelle_model(model_name):
158
+ factory = {
159
+ "gazelle_dinov2_vitb14": gazelle_dinov2_vitb14,
160
+ "gazelle_dinov2_vitl14": gazelle_dinov2_vitl14,
161
+ "gazelle_dinov2_vitb14_inout": gazelle_dinov2_vitb14_inout,
162
+ "gazelle_dinov2_vitl14_inout": gazelle_dinov2_vitl14_inout,
163
+ }
164
+ assert model_name in factory.keys(), "invalid model name"
165
+ return factory[model_name]()
166
+
167
+ def gazelle_dinov2_vitb14():
168
+ backbone = DinoV2Backbone('dinov2_vitb14')
169
+ transform = backbone.get_transform((448, 448))
170
+ model = GazeLLE(backbone)
171
+ return model, transform
172
+
173
+ def gazelle_dinov2_vitl14():
174
+ backbone = DinoV2Backbone('dinov2_vitl14')
175
+ transform = backbone.get_transform((448, 448))
176
+ model = GazeLLE(backbone)
177
+ return model, transform
178
+
179
+ def gazelle_dinov2_vitb14_inout():
180
+ backbone = DinoV2Backbone('dinov2_vitb14')
181
+ transform = backbone.get_transform((448, 448))
182
+ model = GazeLLE(backbone, inout=True)
183
+ return model, transform
184
+
185
+ def gazelle_dinov2_vitl14_inout():
186
+ backbone = DinoV2Backbone('dinov2_vitl14')
187
+ transform = backbone.get_transform((448, 448))
188
+ model = GazeLLE(backbone, inout=True)
189
+ return model, transform
gazelle/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ def repeat_tensors(tensor, repeat_counts):
7
+ repeated_tensors = [tensor[i:i+1].repeat(repeat, *[1] * (tensor.ndim - 1)) for i, repeat in enumerate(repeat_counts)]
8
+ return torch.cat(repeated_tensors, dim=0)
9
+
10
+ def split_tensors(tensor, split_counts):
11
+ indices = torch.cumsum(torch.tensor([0] + split_counts), dim=0)
12
+ return [tensor[indices[i]:indices[i+1]] for i in range(len(split_counts))]
13
+
14
+ def visualize_heatmap(pil_image, heatmap, bbox=None):
15
+ if isinstance(heatmap, torch.Tensor):
16
+ heatmap = heatmap.detach().cpu().numpy()
17
+ heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
18
+ heatmap = plt.cm.jet(np.array(heatmap) / 255.)
19
+ heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
20
+ heatmap = Image.fromarray(heatmap).convert("RGBA")
21
+ heatmap.putalpha(128)
22
+ overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)
23
+
24
+ if bbox is not None:
25
+ width, height = pil_image.size
26
+ xmin, ymin, xmax, ymax = bbox
27
+ draw = ImageDraw.Draw(overlay_image)
28
+ draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="green", width=3)
29
+ return overlay_image
30
+
31
+ def stack_and_pad(tensor_list):
32
+ max_size = max([t.shape[0] for t in tensor_list])
33
+ padded_list = []
34
+ for t in tensor_list:
35
+ if t.shape[0] == max_size:
36
+ padded_list.append(t)
37
+ else:
38
+ padded_list.append(torch.cat([t, torch.zeros(max_size - t.shape[0], *t.shape[1:])], dim=0))
39
+ return torch.stack(padded_list)
hubconf.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ['torch', 'timm']
2
+
3
+ import torch
4
+ from gazelle.model import get_gazelle_model
5
+
6
+ def gazelle_dinov2_vitb14():
7
+ model, transform = get_gazelle_model('gazelle_dinov2_vitb14')
8
+ ckpt_path = "https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitb14_hub.pt"
9
+ model.load_gazelle_state_dict(torch.hub.load_state_dict_from_url(ckpt_path))
10
+ return model, transform
11
+
12
+ def gazelle_dinov2_vitl14():
13
+ model, transform = get_gazelle_model('gazelle_dinov2_vitl14')
14
+ ckpt_path = "https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14.pt"
15
+ model.load_gazelle_state_dict(torch.hub.load_state_dict_from_url(ckpt_path))
16
+ return model, transform
17
+
18
+ def gazelle_dinov2_vitb14_inout():
19
+ model, transform = get_gazelle_model('gazelle_dinov2_vitb14_inout')
20
+ ckpt_path = "https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitb14_inout.pt"
21
+ model.load_gazelle_state_dict(torch.hub.load_state_dict_from_url(ckpt_path))
22
+ return model, transform
23
+
24
+ def gazelle_dinov2_vitl14_inout():
25
+ model, transform = get_gazelle_model('gazelle_dinov2_vitl14_inout')
26
+ ckpt_path = "https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14_inout.pt"
27
+ model.load_gazelle_state_dict(torch.hub.load_state_dict_from_url(ckpt_path))
28
+ return model, transform
scripts/eval_gazefollow.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from PIL import Image
4
+ import json
5
+ import os
6
+ import numpy as np
7
+ from sklearn.metrics import roc_auc_score
8
+ from tqdm import tqdm
9
+
10
+ from gazelle.model import get_gazelle_model
11
+ from gazelle.model import GazeLLE
12
+ from gazelle.backbone import DinoV2Backbone
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument("--data_path", type=str, default="./data/gazefollow")
16
+ parser.add_argument("--model_name", type=str, default="gazelle_dinov2_vitl14_inout")
17
+ parser.add_argument("--ckpt_path", type=str, default="./checkpoints/gazelle_dinov2_vitl14_inout.pt")
18
+ parser.add_argument("--batch_size", type=int, default=128)
19
+ args = parser.parse_args()
20
+
21
+ class GazeFollow(torch.utils.data.Dataset):
22
+ def __init__(self, path, img_transform):
23
+ self.images = json.load(open(os.path.join(path, "test_preprocessed.json"), "rb"))
24
+ self.path = path
25
+ self.transform = img_transform
26
+
27
+ def __getitem__(self, idx):
28
+ item = self.images[idx]
29
+ image = self.transform(Image.open(os.path.join(self.path, item['path'])).convert("RGB"))
30
+ height = item['height']
31
+ width = item['width']
32
+ bboxes = [head['bbox_norm'] for head in item['heads']]
33
+ gazex = [head['gazex_norm'] for head in item['heads']]
34
+ gazey = [head['gazey_norm'] for head in item['heads']]
35
+
36
+ return image, bboxes, gazex, gazey, height, width
37
+
38
+ def __len__(self):
39
+ return len(self.images)
40
+
41
+ def collate(batch):
42
+ images, bboxes, gazex, gazey, height, width = zip(*batch)
43
+ return torch.stack(images), list(bboxes), list(gazex), list(gazey), list(height), list(width)
44
+
45
+ # GazeFollow calculates AUC using original image size with GT (x,y) coordinates set to 1 and everything else as 0
46
+ # References:
47
+ # https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/eval_on_gazefollow.py#L78
48
+ # https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/utils/imutils.py#L67
49
+ # https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/utils/evaluation.py#L7
50
+ def gazefollow_auc(heatmap, gt_gazex, gt_gazey, height, width):
51
+ target_map = np.zeros((height, width))
52
+ for point in zip(gt_gazex, gt_gazey):
53
+ if point[0] >= 0:
54
+ x, y = map(int, [point[0]*float(width), point[1]*float(height)])
55
+ x = min(x, width - 1)
56
+ y = min(y, height - 1)
57
+ target_map[y, x] = 1
58
+ resized_heatmap = torch.nn.functional.interpolate(heatmap.unsqueeze(dim=0).unsqueeze(dim=0), (height, width), mode='bilinear').squeeze()
59
+ auc = roc_auc_score(target_map.flatten(), resized_heatmap.cpu().flatten())
60
+
61
+ return auc
62
+
63
+ # Reference: https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/eval_on_gazefollow.py#L81
64
+ def gazefollow_l2(heatmap, gt_gazex, gt_gazey):
65
+ argmax = heatmap.flatten().argmax().item()
66
+ pred_y, pred_x = np.unravel_index(argmax, (64, 64))
67
+ pred_x = pred_x / 64.
68
+ pred_y = pred_y / 64.
69
+
70
+ gazex = np.array(gt_gazex)
71
+ gazey = np.array(gt_gazey)
72
+
73
+ avg_l2 = np.sqrt((pred_x - gazex.mean())**2 + (pred_y - gazey.mean())**2)
74
+ all_l2s = np.sqrt((pred_x - gazex)**2 + (pred_y - gazey)**2)
75
+ min_l2 = all_l2s.min().item()
76
+
77
+ return avg_l2, min_l2
78
+
79
+
80
+ @torch.no_grad()
81
+ def main():
82
+ device = "cuda" if torch.cuda.is_available() else "cpu"
83
+ print("Running on {}".format(device))
84
+
85
+ model, transform = get_gazelle_model(args.model_name)
86
+ model.load_gazelle_state_dict(torch.load(args.ckpt_path, weights_only=True))
87
+ model.to(device)
88
+ model.eval()
89
+
90
+ dataset = GazeFollow(args.data_path, transform)
91
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
92
+
93
+ aucs = []
94
+ min_l2s = []
95
+ avg_l2s = []
96
+
97
+ for _, (images, bboxes, gazex, gazey, height, width) in tqdm(enumerate(dataloader), desc="Evaluating", total=len(dataloader)):
98
+ preds = model.forward({"images": images.to(device), "bboxes": bboxes})
99
+
100
+ # eval each instance (head)
101
+ for i in range(images.shape[0]): # per image
102
+ for j in range(len(bboxes[i])): # per head
103
+ auc = gazefollow_auc(preds['heatmap'][i][j], gazex[i][j], gazey[i][j], height[i], width[i])
104
+ avg_l2, min_l2 = gazefollow_l2(preds['heatmap'][i][j], gazex[i][j], gazey[i][j])
105
+ aucs.append(auc)
106
+ avg_l2s.append(avg_l2)
107
+ min_l2s.append(min_l2)
108
+
109
+ print("AUC: {}".format(np.array(aucs).mean()))
110
+ print("Avg L2: {}".format(np.array(avg_l2s).mean()))
111
+ print("Min L2: {}".format(np.array(min_l2s).mean()))
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
scripts/eval_vat.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from PIL import Image
4
+ import json
5
+ import os
6
+ import numpy as np
7
+ from sklearn.metrics import roc_auc_score, average_precision_score
8
+ from tqdm import tqdm
9
+
10
+ from gazelle.model import get_gazelle_model
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--data_path", type=str, default="./data/videoattentiontarget")
14
+ parser.add_argument("--model_name", type=str, default="gazelle_dinov2_vitl14_inout")
15
+ parser.add_argument("--ckpt_path", type=str, default="./checkpoints/gazelle_dinov2_vitl14_inout.pt")
16
+ parser.add_argument("--batch_size", type=int, default=64)
17
+ args = parser.parse_args()
18
+
19
+ class VideoAttentionTarget(torch.utils.data.Dataset):
20
+ def __init__(self, path, img_transform):
21
+ self.sequences = json.load(open(os.path.join(path, "test_preprocessed.json"), "rb"))
22
+ self.frames = []
23
+ for i in range(len(self.sequences)):
24
+ for j in range(len(self.sequences[i]['frames'])):
25
+ self.frames.append((i, j))
26
+ self.path = path
27
+ self.transform = img_transform
28
+
29
+ def __getitem__(self, idx):
30
+ seq_idx, frame_idx = self.frames[idx]
31
+ seq = self.sequences[seq_idx]
32
+ frame = seq['frames'][frame_idx]
33
+ image = self.transform(Image.open(os.path.join(self.path, frame['path'])).convert("RGB"))
34
+ bboxes = [head['bbox_norm'] for head in frame['heads']]
35
+ gazex = [head['gazex_norm'] for head in frame['heads']]
36
+ gazey = [head['gazey_norm'] for head in frame['heads']]
37
+ inout = [head['inout'] for head in frame['heads']]
38
+
39
+ return image, bboxes, gazex, gazey, inout
40
+
41
+ def __len__(self):
42
+ return len(self.frames)
43
+
44
+ def collate(batch):
45
+ images, bboxes, gazex, gazey, inout = zip(*batch)
46
+ return torch.stack(images), list(bboxes), list(gazex), list(gazey), list(inout)
47
+
48
+ # VideoAttentionTarget calculates AUC on 64x64 heatmap, defining a rectangular tolerance region of 6*(sigma=3) + 1 (uses 2D Gaussian code but binary thresholds > 0 resulting in rectangle)
49
+ # References:
50
+ # https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/eval_on_videoatttarget.py#L106
51
+ # https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/utils/imutils.py#L31
52
+ def vat_auc(heatmap, gt_gazex, gt_gazey):
53
+ res = 64
54
+ sigma = 3
55
+ assert heatmap.shape[0] == res and heatmap.shape[1] == res
56
+ target_map = np.zeros((res, res))
57
+ gazex = gt_gazex * res
58
+ gazey = gt_gazey * res
59
+ ul = [max(0, int(gazex - 3 * sigma)), max(0, int(gazey - 3 * sigma))]
60
+ br = [min(int(gazex + 3 * sigma + 1), res-1), min(int(gazey + 3 * sigma + 1), res-1)]
61
+ target_map[ul[1]:br[1], ul[0]:br[0]] = 1
62
+ auc = roc_auc_score(target_map.flatten(), heatmap.cpu().flatten())
63
+ return auc
64
+
65
+ # Reference: https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/eval_on_videoatttarget.py#L118
66
+ def vat_l2(heatmap, gt_gazex, gt_gazey):
67
+ argmax = heatmap.flatten().argmax().item()
68
+ pred_y, pred_x = np.unravel_index(argmax, (64, 64))
69
+ pred_x = pred_x / 64.
70
+ pred_y = pred_y / 64.
71
+
72
+ l2 = np.sqrt((pred_x - gt_gazex)**2 + (pred_y - gt_gazey)**2)
73
+
74
+ return l2
75
+
76
+
77
+ @torch.no_grad()
78
+ def main():
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ print("Running on {}".format(device))
81
+
82
+ model, transform = get_gazelle_model(args.model_name)
83
+ model.load_gazelle_state_dict(torch.load(args.ckpt_path, weights_only=True))
84
+ model.to(device)
85
+ model.eval()
86
+
87
+ dataset = VideoAttentionTarget(args.data_path, transform)
88
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
89
+
90
+ aucs = []
91
+ l2s = []
92
+ inout_preds = []
93
+ inout_gts = []
94
+
95
+ for _, (images, bboxes, gazex, gazey, inout) in tqdm(enumerate(dataloader), desc="Evaluating", total=len(dataloader)):
96
+ preds = model.forward({"images": images.to(device), "bboxes": bboxes})
97
+
98
+ # eval each instance (head)
99
+ for i in range(images.shape[0]): # per image
100
+ for j in range(len(bboxes[i])): # per head
101
+ if inout[i][j] == 1: # in frame
102
+ auc = vat_auc(preds['heatmap'][i][j], gazex[i][j][0], gazey[i][j][0])
103
+ l2 = vat_l2(preds['heatmap'][i][j], gazex[i][j][0], gazey[i][j][0])
104
+ aucs.append(auc)
105
+ l2s.append(l2)
106
+ inout_preds.append(preds['inout'][i][j].item())
107
+ inout_gts.append(inout[i][j])
108
+
109
+
110
+ print("AUC: {}".format(np.array(aucs).mean()))
111
+ print("Avg L2: {}".format(np.array(l2s).mean()))
112
+ print("Inout AP: {}".format(average_precision_score(inout_gts, inout_preds)))
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
setup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import setuptools
2
+
3
+ setuptools.setup(
4
+ name="gazelle",
5
+ version="0.0.1",
6
+ author="Fiona Ryan",
7
+ description="Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders",
8
+ packages=setuptools.find_packages()
9
+ )