Update handler.py
Browse files- handler.py +7 -7
handler.py
CHANGED
@@ -10,23 +10,23 @@ import logging
|
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
|
12 |
# Access the environment variable
|
13 |
-
|
14 |
-
if not
|
15 |
-
raise ValueError("
|
16 |
|
17 |
# Log the token for debugging (remove this in production)
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
-
logging.info(
|
20 |
|
21 |
# Load model and pipeline
|
22 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
23 |
-
model_dir, torch_dtype=torch.bfloat16, use_auth_token=
|
24 |
)
|
25 |
self.pipe = FluxControlNetPipeline.from_pretrained(
|
26 |
"black-forest-labs/FLUX.1-dev",
|
27 |
controlnet=self.controlnet,
|
28 |
torch_dtype=torch.bfloat16,
|
29 |
-
use_auth_token=
|
30 |
)
|
31 |
self.pipe.to("cuda")
|
32 |
|
@@ -67,4 +67,4 @@ if __name__ == "__main__":
|
|
67 |
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
|
68 |
handler = EndpointHandler()
|
69 |
output = handler.inference(data)
|
70 |
-
print(output)
|
|
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
|
12 |
# Access the environment variable
|
13 |
+
HF_TOKEN = os.getenv('HF_TOKEN')
|
14 |
+
if not HF_TOKEN:
|
15 |
+
raise ValueError("HF_TOKEN environment variable is not set")
|
16 |
|
17 |
# Log the token for debugging (remove this in production)
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
+
logging.info("Using HF_TOKEN")
|
20 |
|
21 |
# Load model and pipeline
|
22 |
self.controlnet = FluxControlNetModel.from_pretrained(
|
23 |
+
model_dir, torch_dtype=torch.bfloat16, use_auth_token=HF_TOKEN
|
24 |
)
|
25 |
self.pipe = FluxControlNetPipeline.from_pretrained(
|
26 |
"black-forest-labs/FLUX.1-dev",
|
27 |
controlnet=self.controlnet,
|
28 |
torch_dtype=torch.bfloat16,
|
29 |
+
use_auth_token=HF_TOKEN
|
30 |
)
|
31 |
self.pipe.to("cuda")
|
32 |
|
|
|
67 |
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
|
68 |
handler = EndpointHandler()
|
69 |
output = handler.inference(data)
|
70 |
+
print(output)
|