首先,我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module,也就是模块。
    pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。
    比如下面的网络例子中。net这个模块有两个子模块,分别为Linear(2,4)和Linear(4,8)。函数首先对Linear(2,4)和Linear(4,8)两个子模块调用init_weights函数,即print(m)打印Linear(2,4)和Linear(4,8)两个子模块。然后再对net模块进行同样的操作。如此完成递归地调用。从而完成model.apply(fn)或者net.apply(fn)。

    1. import torch.nn as nn
    2. @torch.no_grad()
    3. def init_weights(m):
    4. print(m)
    5. net = nn.Sequential(nn.Linear(2,4), nn.Linear(4, 8))
    6. print(net)
    7. print('isinstance torch.nn.Module',isinstance(net,torch.nn.Module))
    8. print(' ')
    9. net.apply(init_weights)

    输出

    1. Sequential(
    2. (0): Linear(in_features=2, out_features=4, bias=True)
    3. (1): Linear(in_features=4, out_features=8, bias=True)
    4. )
    5. isinstance torch.nn.Module True
    6. Linear(in_features=2, out_features=4, bias=True)
    7. Linear(in_features=4, out_features=8, bias=True)
    8. Sequential(
    9. (0): Linear(in_features=2, out_features=4, bias=True)
    10. (1): Linear(in_features=4, out_features=8, bias=True)
    11. )