import paramiko import numpy as np import io, os, stat import gradio as gr from PIL import Image import requests import json import random import concurrent.futures from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_VIDEO_LOG, SSH_MSCOCO ssh_client = None sftp_client = None sftp_client_imgs = None def open_sftp(i=0): global ssh_client sftp_client = ssh_client.open_sftp() return sftp_client def create_ssh_client(server, port, user, password): global ssh_client, sftp_client, sftp_client_imgs ssh_client = paramiko.SSHClient() ssh_client.load_system_host_keys() ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh_client.connect(server, port, user, password) transport = ssh_client.get_transport() transport.set_keepalive(60) with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(open_sftp, i) for i in range(5)] results = [future.result() for future in futures] sftp_client = results[0] sftp_client_imgs = results[1:] def is_connected(): global ssh_client, sftp_client if ssh_client is None or sftp_client is None: return False # 检查SSH连接是否正常 if not ssh_client.get_transport().is_active(): return False # 检查SFTP连接是否正常 try: sftp_client.listdir('.') # 尝试列出根目录 except Exception as e: print(f"Error checking SFTP connection: {e}") return False return True def get_image_from_url(image_url): response = requests.get(image_url) response.raise_for_status() # success return Image.open(io.BytesIO(response.content)) # def get_random_mscoco_prompt(): # global sftp_client # if not is_connected(): # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) # num = random.randint(0, 2999) # file = "{}.txt".format(num) # remote_file_path = os.path.join(SSH_MSCOCO, file) # with sftp_client.file(remote_file_path, 'r') as f: # content = f.read().decode('utf-8') # print(f"Content of {file}:") # print("\n") # return content def get_random_mscoco_prompt(): file_path = './coco_prompt.txt' with open(file_path, 'r') as file: lines = file.readlines() random_line = random.choice(lines).strip() return random_line def get_random_video_prompt(root_dir): subdirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))] if not subdirs: raise NotImplementedError selected_dir = random.choice(subdirs) prompt_path = os.path.join(selected_dir, 'prompt.txt') if os.path.exists(prompt_path): str_list = [] with open(prompt_path, 'r', encoding='utf-8') as file: for line in file: str_list.append(line.strip()) prompt = str_list[0] else: raise NotImplementedError return selected_dir, prompt def get_ssh_random_video_prompt(root_dir, local_dir, model_names): def is_directory(sftp, path): try: return stat.S_ISDIR(sftp.stat(path).st_mode) except IOError: return False ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: ssh.connect(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) sftp = ssh.open_sftp() remote_subdirs = sftp.listdir(root_dir) remote_subdirs = [d for d in remote_subdirs if is_directory(sftp, os.path.join(root_dir, d))] if not remote_subdirs: print(f"No subdirectories found in {root_dir}") raise NotImplementedError chosen_subdir = random.choice(remote_subdirs) chosen_subdir_path = os.path.join(root_dir, chosen_subdir) print(f"Chosen subdirectory: {chosen_subdir_path}") prompt_path = 'prompt.txt' results = [prompt_path] for name in model_names: model_source, model_name, model_type = name.split("_") video_path = f'{model_name}.mp4' print(video_path) results.append(video_path) local_path = [] for tar_file in results: remote_file_path = os.path.join(chosen_subdir_path, tar_file) local_file_path = os.path.join(local_dir, tar_file) sftp.get(remote_file_path, local_file_path) local_path.append(local_file_path) print(f"Downloaded {remote_file_path} to {local_file_path}") if os.path.exists(local_path[0]): str_list = [] with open(local_path[0], 'r', encoding='utf-8') as file: for line in file: str_list.append(line.strip()) prompt = str_list[0] else: raise NotImplementedError except Exception as e: print(f"An error occurred: {e}") raise NotImplementedError sftp.close() ssh.close() return prompt, local_path[1:] def get_ssh_random_image_prompt(root_dir, local_dir, model_names): def is_directory(sftp, path): try: return stat.S_ISDIR(sftp.stat(path).st_mode) except IOError: return False ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) try: ssh.connect(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) sftp = ssh.open_sftp() remote_subdirs = sftp.listdir(root_dir) remote_subdirs = [d for d in remote_subdirs if is_directory(sftp, os.path.join(root_dir, d))] if not remote_subdirs: print(f"No subdirectories found in {root_dir}") raise NotImplementedError chosen_subdir = random.choice(remote_subdirs) chosen_subdir_path = os.path.join(root_dir, chosen_subdir) print(f"Chosen subdirectory: {chosen_subdir_path}") prompt_path = 'prompt.txt' results = [prompt_path] for name in model_names: model_source, model_name, model_type = name.split("_") image_path = f'{model_name}.jpg' print(image_path) results.append(image_path) local_path = [] for tar_file in results: remote_file_path = os.path.join(chosen_subdir_path, tar_file) local_file_path = os.path.join(local_dir, tar_file) sftp.get(remote_file_path, local_file_path) local_path.append(local_file_path) print(f"Downloaded {remote_file_path} to {local_file_path}") if os.path.exists(local_path[0]): str_list = [] with open(local_path[0], 'r', encoding='utf-8') as file: for line in file: str_list.append(line.strip()) prompt = str_list[0] else: raise NotImplementedError except Exception as e: print(f"An error occurred: {e}") raise NotImplementedError sftp.close() ssh.close() return prompt, [Image.open(path) for path in local_path[1:]] def create_remote_directory(remote_directory, video=False): global ssh_client if not is_connected(): create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) if video: log_dir = f'{SSH_VIDEO_LOG}/{remote_directory}' else: log_dir = f'{SSH_LOG}/{remote_directory}' stdin, stdout, stderr = ssh_client.exec_command(f'mkdir -p {log_dir}') error = stderr.read().decode('utf-8') if error: print(f"Error: {error}") else: print(f"Directory {remote_directory} created successfully.") return log_dir def upload_images(i, image_list, output_file_list, sftp_client): with sftp_client as sftp: if isinstance(image_list[i], str): print("get url image") image_list[i] = get_image_from_url(image_list[i]) with io.BytesIO() as image_byte_stream: image_list[i] = image_list[i].resize((512, 512), Image.ANTIALIAS) image_list[i].save(image_byte_stream, format='JPEG') image_byte_stream.seek(0) sftp.putfo(image_byte_stream, output_file_list[i]) print(f"Successfully uploaded image to {output_file_list[i]}") def upload_ssh_all(states, output_dir, data, data_path): global sftp_client global sftp_client_imgs if not is_connected(): create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) output_file_list = [] image_list = [] for i in range(len(states)): output_file = os.path.join(output_dir, f"{i}.jpg") output_file_list.append(output_file) image_list.append(states[i].output) with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(upload_images, i, image_list, output_file_list, sftp_client_imgs[i]) for i in range(len(output_file_list))] with sftp_client as sftp: json_data = json.dumps(data, indent=4) with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream: sftp.putfo(json_byte_stream, data_path) print(f"Successfully uploaded JSON data to {data_path}") # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) def upload_ssh_data(data, data_path): global sftp_client global sftp_client_imgs if not is_connected(): create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) with sftp_client as sftp: json_data = json.dumps(data, indent=4) with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream: sftp.putfo(json_byte_stream, data_path) print(f"Successfully uploaded JSON data to {data_path}")