1. ##############################################################################
    2. class EqualLR:
    3. def __init__(self, name):
    4. self.name = name
    5. def compute_weight(self, module):
    6. weight = getattr(module, self.name + '_orig')
    7. fan_in = weight.data.size(1) * weight.data[0][0].numel()
    8. return weight * sqrt(2 / fan_in)
    9. @staticmethod
    10. def apply(module, name):
    11. fn = EqualLR(name)
    12. weight = getattr(module, name)
    13. del module._parameters[name]
    14. module.register_parameter(name + '_orig', nn.Parameter(weight.data))
    15. module.register_forward_pre_hook(fn)
    16. return fn
    17. def __call__(self, module, input):
    18. weight = self.compute_weight(module)
    19. setattr(module, self.name, weight)
    20. def equal_lr(module, name='weight'):
    21. EqualLR.apply(module, name)
    22. return module