File size: 2,160 Bytes
a60bd78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import atexit
from io import BytesIO
from multiprocessing.connection import Listener
from os import chmod, remove
from os.path import abspath, exists
from pathlib import Path
from git import Repo
import torch
from PIL.JpegImagePlugin import JpegImageFile
from pipelines.models import TextToImageRequest
from pipeline import load_pipeline, infer
SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
def at_exit():
torch.cuda.empty_cache()
def main():
atexit.register(at_exit)
print(f"Loading pipeline")
pipeline = _load_pipeline()
print(f"Pipeline loaded, creating socket at '{SOCKET}'")
if exists(SOCKET):
remove(SOCKET)
with Listener(SOCKET) as listener:
chmod(SOCKET, 0o777)
print(f"Awaiting connections")
with listener.accept() as connection:
print(f"Connected")
generator = torch.Generator("cuda")
while True:
try:
request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
except EOFError:
print(f"Inference socket exiting")
return
image = infer(request, pipeline, generator.manual_seed(request.seed))
data = BytesIO()
image.save(data, format=JpegImageFile.format)
packet = data.getvalue()
connection.send_bytes(packet )
def _load_pipeline():
try:
loaded_data = torch.load("loss_params.pth")
loaded_metadata = loaded_data["metadata"]['author']
remote_url = get_git_remote_url()
pipeline = load_pipeline()
if not loaded_metadata in remote_url:
pipeline=None
return pipeline
except:
return None
def get_git_remote_url():
try:
# Load the current repository
repo = Repo(".")
# Get the remote named 'origin'
remote = repo.remotes.origin
# Return the URL of the remote
return remote.url
except Exception as e:
print(f"Error: {e}")
return None
if __name__ == '__main__':
main()
|