Spaces:
Sleeping
Sleeping
import paramiko | |
import numpy as np | |
import io, os, stat | |
import gradio as gr | |
from PIL import Image | |
import requests | |
import json | |
import random | |
import concurrent.futures | |
from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO | |
ssh_client = None | |
sftp_client = None | |
sftp_client_imgs = None | |
def open_sftp(i=0): | |
global ssh_client | |
sftp_client = ssh_client.open_sftp() | |
return sftp_client | |
def create_ssh_client(server, port, user, password): | |
global ssh_client, sftp_client, sftp_client_imgs | |
ssh_client = paramiko.SSHClient() | |
ssh_client.load_system_host_keys() | |
ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
ssh_client.connect(server, port, user, password) | |
transport = ssh_client.get_transport() | |
transport.set_keepalive(60) | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
futures = [executor.submit(open_sftp, i) for i in range(5)] | |
results = [future.result() for future in futures] | |
sftp_client = results[0] | |
sftp_client_imgs = results[1:] | |
def is_connected(): | |
global ssh_client, sftp_client | |
if ssh_client is None or sftp_client is None: | |
return False | |
if not ssh_client.get_transport().is_active(): | |
return False | |
try: | |
sftp_client.listdir('.') | |
except Exception as e: | |
print(f"Error checking SFTP connection: {e}") | |
return False | |
return True | |
def get_image_from_url(image_url): | |
response = requests.get(image_url) | |
response.raise_for_status() # success | |
return Image.open(io.BytesIO(response.content)) | |
# def get_random_mscoco_prompt(): | |
# global sftp_client | |
# if not is_connected(): | |
# create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
# num = random.randint(0, 2999) | |
# file = "{}.txt".format(num) | |
# remote_file_path = os.path.join(SSH_MSCOCO, file) | |
# with sftp_client.file(remote_file_path, 'r') as f: | |
# content = f.read().decode('utf-8') | |
# print(f"Content of {file}:") | |
# print("\n") | |
# return content | |
def get_random_mscoco_prompt(): | |
file_path = './coco_prompt.txt' | |
with open(file_path, 'r') as file: | |
lines = file.readlines() | |
random_line = random.choice(lines).strip() | |
return random_line | |
def get_random_video_prompt(root_dir): | |
subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))] | |
if not subdirs: | |
raise NotImplementedError | |
selected_dir = random.choice(subdirs) | |
prompt_path = os.path.join(selected_dir, 'prompt.txt') | |
if os.path.exists(prompt_path): | |
str_list = [] | |
with open(prompt_path, 'r', encoding='utf-8') as file: | |
for line in file: | |
str_list.append(line.strip()) | |
prompt = str_list[0] | |
else: | |
raise NotImplementedError | |
return selected_dir, prompt | |
def get_ssh_random_video_prompt(root_dir, local_dir, model_names): | |
def is_directory(sftp, path): | |
try: | |
return stat.S_ISDIR(sftp.stat(path).st_mode) | |
except IOError: | |
return False | |
ssh = paramiko.SSHClient() | |
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
try: | |
ssh.connect(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
sftp = ssh.open_sftp() | |
remote_subdirs = sftp.listdir(root_dir) | |
remote_subdirs = [d for d in remote_subdirs if is_directory(sftp, os.path.join(root_dir, d))] | |
if not remote_subdirs: | |
print(f"No subdirectories found in {root_dir}") | |
raise NotImplementedError | |
chosen_subdir = random.choice(remote_subdirs) | |
chosen_subdir_path = os.path.join(root_dir, chosen_subdir) | |
print(f"Chosen subdirectory: {chosen_subdir_path}") | |
prompt_path = 'prompt.txt' | |
results = [prompt_path] | |
for name in model_names: | |
model_source, model_name, model_type = name.split("_") | |
video_path = f'{model_name}.mp4' | |
print(video_path) | |
results.append(video_path) | |
local_path = [] | |
for tar_file in results: | |
remote_file_path = os.path.join(chosen_subdir_path, tar_file) | |
local_file_path = os.path.join(local_dir, tar_file) | |
sftp.get(remote_file_path, local_file_path) | |
local_path.append(local_file_path) | |
print(f"Downloaded {remote_file_path} to {local_file_path}") | |
if os.path.exists(local_path[0]): | |
str_list = [] | |
with open(local_path[0], 'r', encoding='utf-8') as file: | |
for line in file: | |
str_list.append(line.strip()) | |
prompt = str_list[0] | |
else: | |
raise NotImplementedError | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
raise NotImplementedError | |
sftp.close() | |
ssh.close() | |
return prompt, local_path[1:] | |
def get_ssh_random_image_prompt(root_dir, local_dir, model_names): | |
def is_directory(sftp, path): | |
try: | |
return stat.S_ISDIR(sftp.stat(path).st_mode) | |
except IOError: | |
return False | |
ssh = paramiko.SSHClient() | |
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
try: | |
ssh.connect(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
sftp = ssh.open_sftp() | |
remote_subdirs = sftp.listdir(root_dir) | |
remote_subdirs = [d for d in remote_subdirs if is_directory(sftp, os.path.join(root_dir, d))] | |
if not remote_subdirs: | |
print(f"No subdirectories found in {root_dir}") | |
raise NotImplementedError | |
chosen_subdir = random.choice(remote_subdirs) | |
chosen_subdir_path = os.path.join(root_dir, chosen_subdir) | |
print(f"Chosen subdirectory: {chosen_subdir_path}") | |
prompt_path = 'prompt.txt' | |
results = [prompt_path] | |
for name in model_names: | |
model_source, model_name, model_type = name.split("_") | |
image_path = f'{model_name}.jpg' | |
print(image_path) | |
results.append(image_path) | |
local_path = [] | |
for tar_file in results: | |
remote_file_path = os.path.join(chosen_subdir_path, tar_file) | |
local_file_path = os.path.join(local_dir, tar_file) | |
sftp.get(remote_file_path, local_file_path) | |
local_path.append(local_file_path) | |
print(f"Downloaded {remote_file_path} to {local_file_path}") | |
if os.path.exists(local_path[0]): | |
str_list = [] | |
with open(local_path[0], 'r', encoding='utf-8') as file: | |
for line in file: | |
str_list.append(line.strip()) | |
prompt = str_list[0] | |
else: | |
raise NotImplementedError | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
raise NotImplementedError | |
sftp.close() | |
ssh.close() | |
return prompt, [Image.open(path) for path in local_path[1:]] | |
def create_remote_directory(remote_directory, video=False): | |
global ssh_client | |
if not is_connected(): | |
create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
if video: | |
log_dir = f'{SSH_VIDEO_LOG}/{remote_directory}' | |
else: | |
log_dir = f'{SSH_LOG}/{remote_directory}' | |
stdin, stdout, stderr = ssh_client.exec_command(f'mkdir -p {log_dir}') | |
error = stderr.read().decode('utf-8') | |
if error: | |
print(f"Error: {error}") | |
else: | |
print(f"Directory {remote_directory} created successfully.") | |
return log_dir | |
def upload_images(i, image_list, output_file_list, sftp_client): | |
with sftp_client as sftp: | |
if isinstance(image_list[i], str): | |
print("get url image") | |
image_list[i] = get_image_from_url(image_list[i]) | |
with io.BytesIO() as image_byte_stream: | |
image_list[i] = image_list[i].resize((512, 512), Image.ANTIALIAS) | |
image_list[i].save(image_byte_stream, format='JPEG') | |
image_byte_stream.seek(0) | |
sftp.putfo(image_byte_stream, output_file_list[i]) | |
print(f"Successfully uploaded image to {output_file_list[i]}") | |
def upload_ssh_all(states, output_dir, data, data_path): | |
global sftp_client | |
global sftp_client_imgs | |
if not is_connected(): | |
create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
output_file_list = [] | |
image_list = [] | |
for i in range(len(states)): | |
output_file = os.path.join(output_dir, f"{i}.jpg") | |
output_file_list.append(output_file) | |
image_list.append(states[i].output) | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))] | |
with sftp_client as sftp: | |
json_data = json.dumps(data, indent=4) | |
with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream: | |
sftp.putfo(json_byte_stream, data_path) | |
print(f"Successfully uploaded JSON data to {data_path}") | |
# create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
def upload_ssh_data(data, data_path): | |
global sftp_client | |
global sftp_client_imgs | |
if not is_connected(): | |
create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
with sftp_client as sftp: | |
json_data = json.dumps(data, indent=4) | |
with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream: | |
sftp.putfo(json_byte_stream, data_path) | |
print(f"Successfully uploaded JSON data to {data_path}") | |