import torch
通常是我们写PyTorch代码的第一行代码,这行代码实际上会调用package中__init__.py
。
本文将介绍torch/init.py都做了什么【见torch/init.py】。
加载扩展模块
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L41]
if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \
platform.system() != 'Windows':
# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
# few circumstances:
#
# 1. You're in a build environment (e.g., fbcode) where
# libtorch_global_deps is not available, but you still need
# to get mkl to link in with RTLD_GLOBAL or it will just
# not work.
#
# 2. You're trying to run PyTorch under UBSAN and you need
# to ensure that only one copy of libtorch is loaded, so
# vptr checks work properly
#
# If you're using this setting, you must verify that all the libraries
# you load consistently use the same libstdc++, or you may have
# mysterious segfaults.
#
import os as _dl_flags
if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):
try:
# next try if DLFCN exists
import DLFCN as _dl_flags # type: ignore
except ImportError:
# as a last attempt, use compile-time constants
import torch._dl as _dl_flags # type: ignore
old_flags = sys.getdlopenflags()
sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
from torch._C import *
sys.setdlopenflags(old_flags)
del old_flags
del _dl_flags
else:
# Easy way. You want this most of the time, because it will prevent
# C++ symbols from libtorch clobbering C++ symbols from other
# libraries, leading to mysterious segfaults.
#
# If building in an environment where libtorch_global_deps isn't available
# like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
# want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
#
# See Note [Global dependencies]
if USE_GLOBAL_DEPS:
_load_global_deps()
from torch._C import *
# Appease the type checker; ordinarily this binding is inserted by the
# torch._C module initialization code in C
if TYPE_CHECKING:
import torch._C as _C
# Check to see if we can load C extensions, and if not provide some guidance
# on what the problem might be.
try:
# _initExtension is chosen (arbitrarily) as a sentinel.
from torch._C import _initExtension
except ImportError:
import torch._C as _C_for_compiled_check
...
raise # If __file__ is not None the cause is unknown, so just re-raise.
torch._C是PyTorch的C或C++扩展模块,将在后续文章中详细介绍。
Define basic utilities
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L226]
################################################################################
# Define basic utilities
################################################################################
def typename(o):
"""返回o的类型"""
...
def is_tensor(obj):
"""判断obj是否是PyTorch tensor"""
...
def is_storage(obj):
"""判断obj是否是storage类型"""
...
def set_default_tensor_type(t):
"""设置默认tensor类型"""
...
def set_default_dtype(d):
...
def set_deterministic(d):
"""设置PyTorch操作是否必须是deterministic"""
...
def is_deterministic():
"""判断global deterministic flag是否为True"""
...
定义Storage和Tensor类
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L408]
################################################################################
# Define Storage and Tensor classes
################################################################################
class DoubleStorage(_C.DoubleStorageBase, _StorageBase):
pass
class FloatStorage(_C.FloatStorageBase, _StorageBase):
pass
...