import numpy as np
import shutil
import torch
import os
import io
import logging
from collections import defaultdict


from torch.nn import BatchNorm2d, SyncBatchNorm

def param_group_no_wd(model):
    pgroup_no_wd = []
    names_no_wd = []
    pgroup_normal = []

    type2num = defaultdict(lambda : 0)
    for name,m in model.named_modules():
        if isinstance(m, torch.nn.Conv2d):
            if m.bias is not None:
                pgroup_no_wd.append(m.bias)
                names_no_wd.append(name+'.bias')
                type2num[m.__class__.__name__+'.bias'] += 1
        elif isinstance(m, torch.nn.Linear):
            if m.bias is not None:
                pgroup_no_wd.append(m.bias)
                names_no_wd.append(name+'.bias')
                type2num[m.__class__.__name__+'.bias'] += 1
        elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d) or isinstance(m, SyncBatchNorm):
            if m.weight is not None:
                pgroup_no_wd.append(m.weight)
                names_no_wd.append(name+'.weight')
                type2num[m.__class__.__name__+'.weight'] += 1
            if m.bias is not None:
                pgroup_no_wd.append(m.bias)
                names_no_wd.append(name+'.bias')
                type2num[m.__class__.__name__+'.bias'] += 1

    for name,p in model.named_parameters():
        if not name in names_no_wd:
            pgroup_normal.append(p)

    return [{'params': pgroup_normal}, {'params': pgroup_no_wd, 'weight_decay': 0.0}], type2num

def param_group_fc(model):
    logits_w_id = id(model.module.logits.weight)
    fc_group = []
    normal_group = []
    for p in model.parameters():
        if id(p) == logits_w_id:
            fc_group.append(p)
        else:
            normal_group.append(p)
    param_group = [{'params': fc_group}, {'params': normal_group}]

    return param_group

def param_group_multitask(model):
    backbone_group = []
    neck_group = []
    decoder_group = []
    other_group = []
    for name, p in model.named_parameters():
        if 'module.backbone_module' in name:
            backbone_group.append(p)
        elif 'module.neck_module' in name:
            neck_group.append(p)
        elif 'module.decoder_module' in name:
            decoder_group.append(p)
        else:
            other_group.append(p)

    if len(other_group) > 0:
        param_group = [{'params': backbone_group}, {'params': neck_group}, \
                        {'params': decoder_group}, {'params', other_group}]
    else:
        param_group = [{'params': backbone_group}, {'params': neck_group}, \
                        {'params': decoder_group}]
    return param_group