import torch 通常是我们写PyTorch代码的第一行代码,这行代码实际上会调用package中__init__.py。
本文将介绍torch/init.py都做了什么【见torch/init.py】。
加载扩展模块
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L41]if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \platform.system() != 'Windows':# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a# few circumstances:## 1. You're in a build environment (e.g., fbcode) where# libtorch_global_deps is not available, but you still need# to get mkl to link in with RTLD_GLOBAL or it will just# not work.## 2. You're trying to run PyTorch under UBSAN and you need# to ensure that only one copy of libtorch is loaded, so# vptr checks work properly## If you're using this setting, you must verify that all the libraries# you load consistently use the same libstdc++, or you may have# mysterious segfaults.#import os as _dl_flagsif not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):try:# next try if DLFCN existsimport DLFCN as _dl_flags # type: ignoreexcept ImportError:# as a last attempt, use compile-time constantsimport torch._dl as _dl_flags # type: ignoreold_flags = sys.getdlopenflags()sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)from torch._C import *sys.setdlopenflags(old_flags)del old_flagsdel _dl_flagselse:# Easy way. You want this most of the time, because it will prevent# C++ symbols from libtorch clobbering C++ symbols from other# libraries, leading to mysterious segfaults.## If building in an environment where libtorch_global_deps isn't available# like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will# want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False## See Note [Global dependencies]if USE_GLOBAL_DEPS:_load_global_deps()from torch._C import *# Appease the type checker; ordinarily this binding is inserted by the# torch._C module initialization code in Cif TYPE_CHECKING:import torch._C as _C# Check to see if we can load C extensions, and if not provide some guidance# on what the problem might be.try:# _initExtension is chosen (arbitrarily) as a sentinel.from torch._C import _initExtensionexcept ImportError:import torch._C as _C_for_compiled_check...raise # If __file__ is not None the cause is unknown, so just re-raise.
torch._C是PyTorch的C或C++扩展模块,将在后续文章中详细介绍。
Define basic utilities
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L226]################################################################################# Define basic utilities################################################################################def typename(o):"""返回o的类型"""...def is_tensor(obj):"""判断obj是否是PyTorch tensor"""...def is_storage(obj):"""判断obj是否是storage类型"""...def set_default_tensor_type(t):"""设置默认tensor类型"""...def set_default_dtype(d):...def set_deterministic(d):"""设置PyTorch操作是否必须是deterministic"""...def is_deterministic():"""判断global deterministic flag是否为True"""...
定义Storage和Tensor类
[https://github.com/pytorch/pytorch/blob/v1.7.0/torch/__init__.py#L408]################################################################################# Define Storage and Tensor classes################################################################################class DoubleStorage(_C.DoubleStorageBase, _StorageBase):passclass FloatStorage(_C.FloatStorageBase, _StorageBase):pass...
