一、PyTorch小土堆快速入门
:::info
- pycharm
- 控制台
- jupyter notebook
:::
1、pytorch数据加载
a、Dataset实战
```python from torch.utils.data import Dataset from PIL import Image import os
class ReadData(Dataset): #ReadData继承的是Dataset类,很重要 def init(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(root_dir, label_dir) #生成路径 self.img_path = os.listdir(self.path) #文件名称列表
def __getitem__(self, item):
img_name = self.img_path[item]
img_item_path = os.path.join(self.path, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
if name == “main“: ant_root_dir = “data_set/train” ant_label_dir = “ants” ants_dataset = ReadData(ant_root_dir, ant_label_dir)
bee_root_dir = "data_set/train"
bee_label_dir = "bees"
bees_dataset = ReadData(bee_root_dir, bee_label_dir)
dataset = ants_dataset + bees_dataset
print(len(dataset), len(ants_dataset), len(bees_dataset))
img_ant,ant=dataset[123]
img_ant.show()
img_bee, bee = dataset[124]
img_bee.show()
> <a name="mTB2y"></a>
# 2022.5.22,第一周
```python
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(100):
writer.add_scalar("y=2x", 2*i, i)
writer.close()
#tensorboard --logdir=logs
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
image_path = "data_set/train/ants/24335309_c5ea483bb8.jpg"
img = Image.open(image_path)
img_array = np.array(img) # 图片转化为np.array格式
print(img_array.shape, type(img_array))
writer.add_image("train", img_array, 1, dataformats="hwc") # 1为步骤
# for i in range(100):
# writer.add_scalar("y=2x", 2*i, i)
writer.close()