import torch 通常是我们写PyTorch代码的第一行代码,这行代码实际上会调用package中__init__.py
本文将介绍torch/init.py都做了什么【见torch/init.py】。

加载扩展模块

  1. [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L41]
  2. if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \
  3. platform.system() != 'Windows':
  4. # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
  5. # few circumstances:
  6. #
  7. # 1. You're in a build environment (e.g., fbcode) where
  8. # libtorch_global_deps is not available, but you still need
  9. # to get mkl to link in with RTLD_GLOBAL or it will just
  10. # not work.
  11. #
  12. # 2. You're trying to run PyTorch under UBSAN and you need
  13. # to ensure that only one copy of libtorch is loaded, so
  14. # vptr checks work properly
  15. #
  16. # If you're using this setting, you must verify that all the libraries
  17. # you load consistently use the same libstdc++, or you may have
  18. # mysterious segfaults.
  19. #
  20. import os as _dl_flags
  21. if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):
  22. try:
  23. # next try if DLFCN exists
  24. import DLFCN as _dl_flags # type: ignore
  25. except ImportError:
  26. # as a last attempt, use compile-time constants
  27. import torch._dl as _dl_flags # type: ignore
  28. old_flags = sys.getdlopenflags()
  29. sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
  30. from torch._C import *
  31. sys.setdlopenflags(old_flags)
  32. del old_flags
  33. del _dl_flags
  34. else:
  35. # Easy way. You want this most of the time, because it will prevent
  36. # C++ symbols from libtorch clobbering C++ symbols from other
  37. # libraries, leading to mysterious segfaults.
  38. #
  39. # If building in an environment where libtorch_global_deps isn't available
  40. # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
  41. # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
  42. #
  43. # See Note [Global dependencies]
  44. if USE_GLOBAL_DEPS:
  45. _load_global_deps()
  46. from torch._C import *
  47. # Appease the type checker; ordinarily this binding is inserted by the
  48. # torch._C module initialization code in C
  49. if TYPE_CHECKING:
  50. import torch._C as _C
  51. # Check to see if we can load C extensions, and if not provide some guidance
  52. # on what the problem might be.
  53. try:
  54. # _initExtension is chosen (arbitrarily) as a sentinel.
  55. from torch._C import _initExtension
  56. except ImportError:
  57. import torch._C as _C_for_compiled_check
  58. ...
  59. raise # If __file__ is not None the cause is unknown, so just re-raise.

torch._C是PyTorch的C或C++扩展模块,将在后续文章中详细介绍。

Define basic utilities

  1. [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L226]
  2. ################################################################################
  3. # Define basic utilities
  4. ################################################################################
  5. def typename(o):
  6. """返回o的类型"""
  7. ...
  8. def is_tensor(obj):
  9. """判断obj是否是PyTorch tensor"""
  10. ...
  11. def is_storage(obj):
  12. """判断obj是否是storage类型"""
  13. ...
  14. def set_default_tensor_type(t):
  15. """设置默认tensor类型"""
  16. ...
  17. def set_default_dtype(d):
  18. ...
  19. def set_deterministic(d):
  20. """设置PyTorch操作是否必须是deterministic"""
  21. ...
  22. def is_deterministic():
  23. """判断global deterministic flag是否为True"""
  24. ...

定义Storage和Tensor类

  1. [https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L408]
  2. ################################################################################
  3. # Define Storage and Tensor classes
  4. ################################################################################
  5. class DoubleStorage(_C.DoubleStorageBase, _StorageBase):
  6. pass
  7. class FloatStorage(_C.FloatStorageBase, _StorageBase):
  8. pass
  9. ...

初始化扩展