#!/usr/bin/env python
# -*- coding: utf-8 -*-

import paramiko
import socket
import sys
import time
import argparse
import getpass
import logging
import select
import json
import os
from threading import Thread

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class SSHClient:
    def __init__(self, server, port, username, password=None, key_file=None, timeout=30):
        """
        Initialize SSH client
        
        Args:
            server (str): SSH server hostname or IP address
            port (int): SSH server port
            username (str): SSH username
            password (str, optional): SSH password
            key_file (str, optional): Path to private key file
            timeout (int, optional): Connection timeout in seconds
        """
        self.server = server
        self.port = port
        self.username = username
        self.password = password
        self.key_file = key_file
        self.client = None
        self.timeout = timeout
        self.transport = None
        
    def connect(self):
        """
        Establish connection to SSH server
        
        Returns:
            bool: True if connection successful, False otherwise
        """
        try:
            logger.info(f"尝试连接到 {self.server}:{self.port}...")
            self.client = paramiko.SSHClient()
            self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            
            connect_kwargs = {
                'hostname': self.server,
                'port': self.port,
                'username': self.username,
                'timeout': self.timeout,
                'allow_agent': False,
                'look_for_keys': False
            }
            
            if self.password:
                connect_kwargs['password'] = self.password
            elif self.key_file:
                connect_kwargs['key_filename'] = self.key_file
            
            # 尝试连接    
            self.client.connect(**connect_kwargs)
            self.transport = self.client.get_transport()
            
            # 设置保活
            if self.transport:
                self.transport.set_keepalive(60)  # 每60秒发送保活包
                
            logger.info(f"成功连接到 {self.server}:{self.port} 用户名: {self.username}")
            return True
            
        except paramiko.AuthenticationException:
            logger.error("认证失败,请检查用户名和密码")
            return False
        except paramiko.SSHException as e:
            logger.error(f"SSH连接错误: {e}")
            return False
        except socket.timeout:
            logger.error(f"连接到 {self.server}:{self.port} 超时。请检查服务器地址和防火墙设置。")
            return False
        except socket.error as e:
            logger.error(f"socket错误: {e}")
            return False
        except Exception as e:
            logger.error(f"连接到SSH服务器时出错: {e}")
            import traceback
            logger.debug(traceback.format_exc())
            return False
    
    def setup_port_forward(self, remote_host, remote_port, local_port):
        """
        Set up port forwarding from server to client (remote to local)
        
        Args:
            remote_host (str): Remote host to connect to from the SSH server
            remote_port (int): Remote port to connect to
            local_port (int): Local port to forward to
            
        Returns:
            bool: True if port forwarding set up successfully, False otherwise
        """
        try:
            # 确保传输层已经准备好
            if not self.transport or not self.transport.is_active():
                logger.error("SSH传输层未激活,无法设置端口转发")
                return False

            # 使用reverse_forward_tunnel方法来建立从服务器到客户端的转发
            try:
                logger.info(f"尝试请求端口转发: {remote_host}:{remote_port}")
                self.transport.request_port_forward(remote_host, remote_port)
            except paramiko.SSHException as e:
                error_msg = str(e).lower()
                if "forwarding request denied" in error_msg or "addressnotpermitted" in error_msg:
                    logger.error(f"端口转发请求被拒绝: {e}")
                    logger.info("BvSshServer端口转发问题排查: ")
                    logger.info("1. 检查用户权限: 确认SSH用户账户是否有端口转发权限")
                    logger.info("2. 尝试使用不同的远程端口: 有些端口可能被禁止转发")
                    logger.info("3. 查看服务器日志: 可能有更多关于拒绝原因的信息")
                    logger.info("4. 检查是否有其他应用已经占用了该端口")
                    logger.info("5. 尝试使用其他绑定地址,如 'localhost' 而不是 '127.0.0.1'")
                    return False
                else:
                    raise
                    
            logger.info(f"设置端口转发: {remote_host}:{remote_port} -> localhost:{local_port}")
            
            # 创建一个监听线程来处理转发的连接
            class ForwardServer(Thread):
                def __init__(self, transport, remote_host, remote_port, local_port):
                    Thread.__init__(self)
                    self.transport = transport
                    self.remote_host = remote_host
                    self.remote_port = remote_port
                    self.local_port = local_port
                    self.daemon = True
                
                def run(self):
                    while True:
                        try:
                            chan = self.transport.accept(1000)
                            if chan is None:
                                continue
                            
                            # 建立从通道到本地端口的连接
                            thr = Thread(target=self.handler, args=(chan,))
                            thr.daemon = True
                            thr.start()
                        except Exception as e:
                            if self.transport.is_active():
                                logger.error(f"转发通道接收错误: {e}")
                            else:
                                break
                
                def handler(self, chan):
                    try:
                        sock = socket.socket()
                        try:
                            sock.connect(('127.0.0.1', self.local_port))
                        except ConnectionRefusedError:
                            logger.error(f"连接本地端口 {self.local_port} 被拒绝,请确保本地服务正在运行")
                            chan.close()
                            return
                            
                        logger.info(f"转发连接 {self.remote_host}:{self.remote_port} -> localhost:{self.local_port}")
                        
                        # 双向数据传输
                        while True:
                            r, w, x = select.select([sock, chan], [], [])
                            if sock in r:
                                data = sock.recv(1024)
                                if len(data) == 0:
                                    break
                                chan.send(data)
                            if chan in r:
                                data = chan.recv(1024)
                                if len(data) == 0:
                                    break
                                sock.send(data)
                    except Exception as e:
                        logger.error(f"转发处理错误: {e}")
                    finally:
                        try:
                            sock.close()
                            chan.close()
                        except:
                            pass
            
            # 启动转发服务器
            forward_server = ForwardServer(self.transport, remote_host, remote_port, local_port)
            forward_server.start()
            
            return True
            
        except Exception as e:
            logger.error(f"设置端口转发时出错: {e}")
            # 检查是否包含 BvSshServer 特定的错误信息
            error_str = str(e)
            if "AddressNotPermitted" in error_str or "<parameters" in error_str:
                logger.error("检测到 BvSshServer 特有的错误格式")
                logger.info("你需要检查 BvSshServer 的用户配置,确认你的用户有权限进行端口转发")
                logger.info("如果你有访问 BvSshServer 配置的权限,请检查用户配置中的端口转发设置")
            import traceback
            logger.debug(traceback.format_exc())
            return False
            
    def close(self):
        """Close the SSH connection"""
        if self.client:
            self.client.close()
            logger.info("SSH连接已关闭")

def load_config_from_json(config_file):
    """
    Load SSH configuration from a JSON file
    
    Args:
        config_file (str): Path to the JSON configuration file
        
    Returns:
        dict: Configuration parameters as a dictionary
    """
    try:
        if not os.path.exists(config_file):
            logger.error(f"配置文件 {config_file} 不存在")
            return None
            
        with open(config_file, 'r', encoding='utf-8') as f:
            config = json.load(f)
            
        # 验证必要的配置参数
        required_params = ['server', 'username']
        missing_params = [param for param in required_params if param not in config]
        
        if missing_params:
            logger.error(f"配置文件缺少必要的参数: {', '.join(missing_params)}")
            return None
            
        # 确保端口是整数类型
        if 'port' in config:
            config['port'] = int(config['port'])
        if 'remote_port' in config:
            config['remote_port'] = int(config['remote_port'])
        if 'local_port' in config:
            config['local_port'] = int(config['local_port'])
            
        # 设置默认值
        config.setdefault('port', 22)
        config.setdefault('timeout', 30)
        config.setdefault('remote_host', 'localhost')
        config.setdefault('verbose', False)
        
        # 检查端口转发设置是否存在
        if 'remote_port' not in config or 'local_port' not in config:
            logger.warning("配置文件缺少端口转发设置 (remote_port 和/或 local_port)")
        
        return config
    except json.JSONDecodeError as e:
        logger.error(f"JSON配置文件解析错误: {e}")
        return None
    except Exception as e:
        logger.error(f"无法加载配置文件: {e}")
        import traceback
        logger.debug(traceback.format_exc())
        return None

def create_default_config():
    """
    Create a default configuration dictionary
    
    Returns:
        dict: Default configuration parameters
    """
    return {
        "server": "ssh.example.com",
        "port": 22,
        "username": "your_username",
        "password": "your_password",
        # 如果使用密钥认证,可以删除password参数并添加下面的配置
        # "key_file": "/path/to/your/private_key.pem",
        "timeout": 30,
        "remote_host": "localhost",
        "remote_port": 8080,
        "local_port": 8080,
        "verbose": False
    }

def save_default_config(file_path):
    """
    Save default configuration template to a JSON file
    
    Args:
        file_path (str): Path to save the configuration file
        
    Returns:
        bool: True if successful, False otherwise
    """
    try:
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(create_default_config(), f, indent=4)
        logger.info(f"默认配置模板已保存到 {file_path}")
        return True
    except Exception as e:
        logger.error(f"保存默认配置失败: {e}")
        return False

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='SSH Client with Port Forwarding')
    parser.add_argument('-s', '--server', help='SSH server hostname or IP')
    parser.add_argument('-p', '--port', type=int, help='SSH server port (default: 22)')
    parser.add_argument('-u', '--username', help='SSH username')
    parser.add_argument('-pw', '--password', help='SSH password (will prompt if not provided)')
    parser.add_argument('-k', '--key_file', help='Path to private key file')
    parser.add_argument('-rh', '--remote_host', default='localhost', 
                        help='Remote host to connect to from SSH server (default: localhost)')
    parser.add_argument('-rp', '--remote_port', type=int, 
                        help='Remote port to forward from')
    parser.add_argument('-lp', '--local_port', type=int, 
                        help='Local port to forward to')
    parser.add_argument('-t', '--timeout', type=int, default=30,
                        help='Connection timeout in seconds (default: 30)')
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='Enable verbose logging')
    # 添加JSON配置文件选项
    parser.add_argument('-c', '--config', help='JSON configuration file path')
    parser.add_argument('--create-config', help='Create default configuration template and save to the specified path')
    
    args = parser.parse_args()
    
    # 如果指定了创建配置文件选项
    if args.create_config:
        if save_default_config(args.create_config):
            logger.info("已创建默认配置文件模板,请根据需要修改该文件")
            sys.exit(0)
        else:
            sys.exit(1)
    
    # 加载配置
    config = {}
    
    # 如果指定了配置文件,从配置文件中加载设置
    if args.config:
        config = load_config_from_json(args.config)
        if config is None:
            logger.error("无法加载配置文件,退出程序")
            sys.exit(1)
        logger.info(f"从配置文件加载的配置: server={config['server']}, port={config['port']}, username={config['username']}")
    
    # 命令行参数优先级高于配置文件
    if args.server:
        config['server'] = args.server
    if args.port is not None:  # 修复:只在明确提供端口时覆盖配置
        config['port'] = args.port
    if args.username:
        config['username'] = args.username
    if args.password:
        config['password'] = args.password
    if args.key_file:
        config['key_file'] = args.key_file
    if args.remote_host:
        config['remote_host'] = args.remote_host
    if args.remote_port:
        config['remote_port'] = args.remote_port
    if args.local_port:
        config['local_port'] = args.local_port
    if args.timeout:
        config['timeout'] = args.timeout
    if args.verbose:
        config['verbose'] = True
    
    # 检查必要的参数是否存在
    missing_params = []
    if 'server' not in config:
        missing_params.append('server')
    if 'username' not in config:
        missing_params.append('username')
    if 'remote_port' not in config:
        missing_params.append('remote_port')
    if 'local_port' not in config:
        missing_params.append('local_port')
    
    if missing_params:
        logger.error(f"缺少必要的参数: {', '.join(missing_params)}")
        logger.info("请提供这些参数或使用配置文件")
        parser.print_help()
        sys.exit(1)
    
    # 再次确认端口是整数
    if 'port' in config:
        config['port'] = int(config['port'])
    if 'remote_port' in config:
        config['remote_port'] = int(config['remote_port'])
    if 'local_port' in config:
        config['local_port'] = int(config['local_port'])
    
    logger.info(f"最终使用的配置: server={config['server']}, port={config['port']}, username={config['username']}")
    
    # 设置详细日志级别
    if config.get('verbose', False):
        logger.setLevel(logging.DEBUG)
        logging.getLogger('paramiko').setLevel(logging.DEBUG)
    
    # 如果没有提供密码和密钥文件,则提示输入密码
    password = config.get('password')
    key_file = config.get('key_file')
    if not password and not key_file:
        password = getpass.getpass('SSH Password: ')
        config['password'] = password
    
    # Create SSH client
    ssh_client = SSHClient(
        server=config['server'],
        port=config['port'],
        username=config['username'],
        password=config.get('password'),
        key_file=config.get('key_file'),
        timeout=config.get('timeout', 30)
    )
    
    # Connect to SSH server
    if not ssh_client.connect():
        logger.error("无法连接到SSH服务器,退出程序")
        sys.exit(1)
    
    try:
        # Set up port forwarding
        if not ssh_client.setup_port_forward(
            config.get('remote_host', 'localhost'), 
            config['remote_port'], 
            config['local_port']
        ):
            logger.error("无法设置端口转发,退出程序")
            sys.exit(1)
        
        logger.info(f"端口转发已建立: {config.get('remote_host', 'localhost')}:{config['remote_port']} -> localhost:{config['local_port']}")
        logger.info("按 Ctrl+C 退出...")
        
        # Keep the connection alive with transport keepalives
        while True:
            if not ssh_client.transport or not ssh_client.transport.is_active():
                logger.error("SSH连接已断开,尝试重新连接...")
                if ssh_client.connect():
                    if not ssh_client.setup_port_forward(
                        config.get('remote_host', 'localhost'),
                        config['remote_port'],
                        config['local_port']
                    ):
                        logger.error("无法重新设置端口转发,退出程序")
                        sys.exit(1)
                    logger.info("连接和端口转发已恢复")
                else:
                    logger.error("无法重新连接,退出程序")
                    sys.exit(1)
            time.sleep(5)
            
    except KeyboardInterrupt:
        logger.info("\n正在退出...")
    except Exception as e:
        logger.error(f"发生错误: {e}")
        import traceback
        logger.debug(traceback.format_exc())
    finally:
        ssh_client.close()

if __name__ == "__main__":
    main()