barreloflube commited on
Commit
c7554bf
1 Parent(s): 350c67c

Refactor get_pipe function to handle different request types and pipeline configurations

Browse files
Files changed (1) hide show
  1. src/tasks/images/sd.py +15 -7
src/tasks/images/sd.py CHANGED
@@ -111,14 +111,22 @@ def get_pipe(request: SDReq | SDImg2ImgReq | SDInpaintReq):
111
  if request.controlnet_config:
112
  pipe_args["controlnet"] = controlnet
113
 
114
- if isinstance(request, SDReq):
115
- pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
116
- elif isinstance(request, SDImg2ImgReq):
117
- pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
118
- elif isinstance(request, SDInpaintReq):
119
- pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
 
 
 
 
 
 
 
 
120
  else:
121
- raise ValueError(f"Unknown request type: {type(request)}")
122
 
123
  return pipe_args
124
 
 
111
  if request.controlnet_config:
112
  pipe_args["controlnet"] = controlnet
113
 
114
+ if not request.photomaker_images:
115
+ if isinstance(request, SDReq):
116
+ pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
117
+ elif isinstance(request, SDImg2ImgReq):
118
+ pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
119
+ elif isinstance(request, SDInpaintReq):
120
+ pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
121
+ else:
122
+ raise ValueError(f"Unknown request type: {type(request)}")
123
+ elif isinstance(request, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
124
+ if request.controlnet_config:
125
+ pipe_args['pipeline'] = PhotoMakerStableDiffusionXLControlNetPipeline.from_pipe(**pipe_args)
126
+ else:
127
+ pipe_args['pipeline'] = PhotoMakerStableDiffusionXLPipeline.from_pipe(**pipe_args)
128
  else:
129
+ raise ValueError(f"Invalid request type: {type(request)}")
130
 
131
  return pipe_args
132