ghostsInTheMachine
commited on
Commit
•
12b0993
1
Parent(s):
c2e328f
Update infer.py
Browse files
infer.py
CHANGED
@@ -47,10 +47,10 @@ def infer_pipe(pipe, image, task_name, seed, device):
|
|
47 |
with torch.no_grad():
|
48 |
with autocast_ctx:
|
49 |
# Convert image to tensor
|
50 |
-
img = np.array(image.convert('RGB')).astype(np.
|
51 |
test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
|
52 |
test_image = test_image / 127.5 - 1.0
|
53 |
-
test_image = test_image.to(device)
|
54 |
|
55 |
# Create task_emb
|
56 |
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
|
|
|
47 |
with torch.no_grad():
|
48 |
with autocast_ctx:
|
49 |
# Convert image to tensor
|
50 |
+
img = np.array(image.convert('RGB')).astype(np.float16)
|
51 |
test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
|
52 |
test_image = test_image / 127.5 - 1.0
|
53 |
+
test_image = test_image.to(device)
|
54 |
|
55 |
# Create task_emb
|
56 |
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
|