import paramiko import numpy as np import io, os from PIL import Image import requests import json import random from .constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_LOG, SSH_MSCOCO ssh_client = None sftp_client = None def create_ssh_client(server, port, user, password): global ssh_client, sftp_client ssh_client = paramiko.SSHClient() ssh_client.load_system_host_keys() ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh_client.connect(server, port, user, password) transport = ssh_client.get_transport() transport.set_keepalive(60) sftp_client = ssh_client.open_sftp() def is_connected(): global ssh_client, sftp_client if ssh_client is None or sftp_client is None: return False # 检查SSH连接是否正常 if not ssh_client.get_transport().is_active(): return False # 检查SFTP连接是否正常 try: sftp_client.listdir('.') # 尝试列出根目录 except Exception as e: print(f"Error checking SFTP connection: {e}") return False return True def get_image_from_url(image_url): response = requests.get(image_url) response.raise_for_status() # success return Image.open(io.BytesIO(response.content)) def get_random_mscoco_prompt(): global sftp_client if not is_connected(): create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) num = random.randint(0, 2999) file = "{}.txt".format(num) remote_file_path = os.path.join(SSH_MSCOCO, file) with sftp_client.file(remote_file_path, 'r') as f: content = f.read().decode('utf-8') print(f"Content of {file}:") print("\n") return content def create_remote_directory(remote_directory): global ssh_client if not is_connected(): create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) stdin, stdout, stderr = ssh_client.exec_command(f'mkdir -p {SSH_LOG}/{remote_directory}') error = stderr.read().decode('utf-8') if error: print(f"Error: {error}") else: print(f"Directory {remote_directory} created successfully.") return f'{SSH_LOG}/{remote_directory}' def upload_ssh_all(states, output_dir, data, data_path): global sftp_client if not is_connected(): create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) output_file_list = [] image_list = [] for i in range(len(states)): output_file = os.path.join(output_dir, f"{i}.jpg") output_file_list.append(output_file) image_list.append(states[i].output) # with sftp_client as sftp: for i in range(len(output_file_list)): if isinstance(image_list[i], str): print("get url image") image_list[i] = get_image_from_url(image_list[i]) with io.BytesIO() as image_byte_stream: image_list[i].save(image_byte_stream, format='JPEG') image_byte_stream.seek(0) sftp_client.putfo(image_byte_stream, output_file_list[i]) print(f"Successfully uploaded image to {output_file_list[i]}") json_data = json.dumps(data, indent=4) with io.BytesIO(json_data.encode('utf-8')) as json_byte_stream: sftp_client.putfo(json_byte_stream, data_path) print(f"Successfully uploaded JSON data to {data_path}") # create_ssh_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD)