基础篇
安装PyTorch:
使用nvidia-smi 查看自己CUDA的版本
然后前往官网查看自己对应CUDA版本的安装方式
https://pytorch.org/get-started/locally/
安装好后使用import torch和torch.cuda.is_available()来验证是否安装成功,并且是否成功连接上GPU
安装 jupyter
使用pip install jupyter和pip install jupyterlab进行安装
jupyter Notebook 是以网页的形式打开,可以在网页页面中直接编写代码和运行代码,代码的运行结果也会直接在代码块下显示。比如在编程过程中需要编写说明文档,可以在同一个页面中直接编写,便于及时说明、解释。
而 Jupyter Lab 可以看做是 Jupyter Notebook 的终极进化版,它不但包含了 Jupyter Notebook 所有功能,并且集成了操作终端、打开交互模式、查看 csv 文件及图片等功能。
安装好后使用:
jupyter notebook或者jupyter lab## shift+回车 快速运行
来打开相关网页,打开后点击右上角新建python3,然后输入以下命令进行测试
import torchtorch.__version__
学习NumPy
NumPy 是用于 Python 中科学计算的一个基础包。它提供了一个多维度的数组对象,以及针对数组对象的各种快速操作,例如排序、变换,选择等
首先安装它的包:pip install numpy
numpy的核心在于它的数组:ndarray,全程为N-dimensional array,N 是一个数字,指代维度,例如你常常能听到的 1-D 数组、2-D 数组或者更高维度的数组。
特性1,创建时指定多大,那么它的大小就是多大无法再进行修改
特性2,ndarray中的数组元素类型必须一致
特性3,Ndarray针对数组进行了特殊优化,效率大幅提升,相比python原生列表要快很多
创建ndarray
import numpy as np## asarray是浅拷贝, array是深拷贝ndarray_1 = np.asarray([1,2,3])ndarray_2 = np.asarray([[1, 2, 3], [3, 4, 5]])ndarray_3 = np.array([1,2,3])
ndarray的常用方法
## ndim:数组的轴(维度)ndarray_1.ndimndarray_2.ndim## shape:数组的形状print(ndarray_2.shape)print(ndarray_2.shape[1])## reshape:变换数组形状,变换前与变换后数组的元素个数需要是一样的ndarray_2.reshape((1,6))## size:数组元素的总数print(ndarray_2.size)## dtype:数组内的元素类型(不要意图直接修改dtype来修改数组元素类型)print(ndarray_2.dtype)## 也可以在创建数组时,直接指定数组内元素的类型ndarray_3 = np.asarray([[1, 2], [3, 4]], dtype=float)print(ndarray_3)## astype:转换数组类型(是创建一个新的数组而不是修改以前的数组)ndarray_3.astype('int32')
## np.ones() 与 np.zeros()来创建特殊数组,只需要传入数组形状即可nparray5 = np.ones(shape=(2,3))nparray6 = np.zeros(shape=(3,2))## nparray数组的乘法是乘以所有的成员print(nparray5 * 0.5)## np.arange([start, ]stop, [step, ]dtype=None)来创建一个顺序的数组## 左闭右开区间,不包括endnparray7 = np.arange(0, 10, 1,dtype = float)print(nparray7)print(nparray7.reshape(5, 2))## np.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)创建一个数组## 具体就是创建一个从开始数值到结束数值的等差数列。nparray8 = np.linspace(1, 100, 4)print(nparray8)
arrange和linspace十分适合用于做图时的X轴,分别适用于等段长的X轴或者只需要有多少段的情况
import numpy as npimport matplotlib.pyplot as plt## X = np.arange(-50, 51, 2)X = np.linspace(-50, 51, 20)Y = X ** 2plt.plot(X, Y, color='blue')plt.legend()plt.show()
数组的轴
## 数组的轴interest_score = np.random.randint(10, size=(4, 3))print(interest_score)##那么在使用np.sum()、np.max()时候就要先明白0轴,1轴位于那列print(np.sum(interest_score, axis=0))print(np.average(interest_score, axis=0))

二维数组还是比较好理解的,那多维数据该怎么办呢?你有没有发现,其实当 axis=i 时,就是按照第 i 个轴的方向进行计算的,或者可以理解为第 i 个轴的数据将会被折叠或聚合到一起。形状为 (a, b, c) 的数组,沿着 0 轴聚合后,形状变为 (b, c);沿着 1 轴聚合后,形状变为 (a, c);沿着 2 轴聚合后,形状变为 (a, b);更高维数组以此类推。
同理min、mean、argmin(求最小值下标)、argmax(求最大值下标)都有这方便的思想
如何读取一张图片(OpenCV和Pillow)
Pillow读取出来是一个数组,所以需要NumPy转换一下
from PIL import Imageimport numpy as npim = Image.open('jk.jpg')im.sizeim_pillow = np.asarray(im)im_pillow.shape
而OpenCV读取出来本身就是NumPy格式
import cv2im_cv2 = cv2.imread('jk.jpg')type(im_cv2)输出:numpy.ndarrayim_cv2.shape输出:(116, 318, 3)
因为一般的图片都是三信道,分别为RGB,所以数组的最后一位为3
值得注意的是Pillow读取时是按照RGB的顺序读取的,而OpenCV是按照BGR读取的
NumPy中可以使用:来获得数据,“:”代表全部选中的意思
im_pillow[:, :, 0]
所以这句话的意思是取第0层的所有数据
同理可得
im_pillow_c1 = im_pillow[:, :, 0]im_pillow_c2 = im_pillow[:, :, 1]im_pillow_c3 = im_pillow[:, :, 2]
NumPy 数组为我们提供了 np.concatenate((a1, a2, …), axis=0) 方法进行数组拼接,但合并的前提是数组的维度相同才能进行合并
如何修改维度
方法1:np.newaxis配合np.concatenate
np.newaxis 让数组增加一个维度,然后在使用concatenate进行拼接
im_pillow_c1 = im_pillow_c1[:, :, np.newaxis]im_pillow_c1.shape输出:(116, 318, 1)im_pillow_c1_3ch = np.concatenate((im_pillow_c1, zeros), axis=2)im_pillow_c1_3ch.shape
方法二:直接赋值
im_pillow_c2_3ch = np.zeros(im_pillow.shape)im_pillow_c2_3ch[:,:,1] = im_pillow_c2im_pillow_c3_3ch = np.zeros(im_pillow.shape)im_pillow_c3_3ch[:,:,2] = im_pillow_c3
zeros再配合原本数组的shape则可以快速创建为0的数组,再替换为原有的数据即可
最后再进行打印即可
from matplotlib import pyplot as pltplt.subplot(2, 2, 1)plt.title('Origin Image')plt.imshow(im_pillow)plt.axis('off')plt.subplot(2, 2, 2)plt.title('Red Channel')plt.imshow(im_pillow_c1_3ch.astype(np.uint8))plt.axis('off')plt.subplot(2, 2, 3)plt.title('Green Channel')plt.imshow(im_pillow_c2_3ch.astype(np.uint8))plt.axis('off')plt.subplot(2, 2, 4)plt.title('Blue Channel')plt.imshow(im_pillow_c3_3ch.astype(np.uint8))plt.axis('off')plt.savefig('./rgb_pillow.png', dpi=150)

