CM2000112 / internals /util /lora_style.py
jayparmr's picture
Upload 118 files
19b3da3
raw
history blame
5.3 kB
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Union
import boto3
import torch
from lora_diffusion import patch_pipe, tune_lora_scale
from pydash import chain
from internals.data.dataAccessor import getStyles
from internals.util.commons import download_file
class LoraStyle:
class LoraPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
@torch.inference_mode()
def patch(self):
path = self.__style["path"]
if str(path).endswith((".pt", ".safetensors")):
patch_pipe(self.pipe, self.__style["path"])
tune_lora_scale(self.pipe.unet, self.__style["weight"])
tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
def kwargs(self):
return {}
def cleanup(self):
tune_lora_scale(self.pipe.unet, 0.0)
tune_lora_scale(self.pipe.text_encoder, 0.0)
pass
class EmptyLoraPatcher:
def __init__(self, pipe):
self.pipe = pipe
def patch(self):
"Patch will act as cleanup, to tune down any corrupted lora"
self.cleanup()
pass
def kwargs(self):
return {}
def cleanup(self):
tune_lora_scale(self.pipe.unet, 0.0)
tune_lora_scale(self.pipe.text_encoder, 0.0)
pass
def load(self, model_dir: str):
self.model = model_dir
self.fetch_styles()
def fetch_styles(self):
model_dir = self.model
result = getStyles()
if result is not None:
self.__styles = self.__parse_styles(model_dir, result["data"])
else:
self.__styles = self.__get_default_styles(model_dir)
self.__verify()
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
if key in self.__styles:
style = self.__styles[key]
return f"{', '.join(style['text'])}, {prompt}"
return prompt
def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]:
if key in self.__styles:
style = self.__styles[key]
return self.LoraPatcher(pipe, style)
return self.EmptyLoraPatcher(pipe)
def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict:
styles = {}
download_dir = Path(Path.home() / ".cache" / "lora")
download_dir.mkdir(exist_ok=True)
data = chain(data).uniq_by(lambda x: x["tag"]).value()
for item in data:
if item["attributes"] is not None:
attr = json.loads(item["attributes"])
if "path" in attr:
file_path = Path(download_dir / attr["path"].split("/")[-1])
if not file_path.exists():
s3_uri = attr["path"]
download_file(s3_uri, file_path)
styles[item["tag"]] = {
"path": str(file_path),
"weight": attr["weight"],
"type": attr["type"],
"text": attr["text"],
"negativePrompt": attr["negativePrompt"],
}
if len(styles) == 0:
return self.__get_default_styles(model_dir)
return styles
def __get_default_styles(self, model_dir: str) -> Dict:
return {
"nq6akX1CIp": {
"path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors",
"text": ["nq6akX1CIp style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"ghibli": {
"path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
"text": ["ghibli style"],
"weight": 1,
"negativePrompt": [""],
"type": "custom",
},
"eQAmnK2kB2": {
"path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors",
"text": ["eQAmnK2kB2 style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"to8contrast": {
"path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
"text": ["to8contrast style"],
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"sfrrfz8vge": {
"path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors",
"text": ["sfrrfz8vge style"],
"weight": 1.2,
"negativePrompt": [""],
"type": "custom",
},
}
def __verify(self):
"A method to verify if lora exists within the required path otherwise throw error"
for item in self.__styles.keys():
if not os.path.exists(self.__styles[item]["path"]):
raise Exception(
"Lora style model "
+ item
+ " not found at path: "
+ self.__styles[item]["path"]
)