PyTorch入门:探索Tensor的基本操作(2)
torch.cat()
a = torch.zeros((2,4)) b = torch.ones((2,4)) out = torch.cat((a,b), dim=1) print(out) 运行结果如下: tensor([[0., 0., 0., 0., 1., 1., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.]])
torch.stack():在新的维度进行拼接
a = torch.linspace(1,12,12).view(4,3) b = torch.linspace(12,24,12).view(4,3) print(a,b) out = torch.stack((a,b), dim=1) print(out) print(out.shape) print(out[:,0,:]) #a print(out[:,1,:]) #b 运行结果如下: tensor([[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.], [10., 11., 12.]]) tensor([[12.0000, 13.0909, 14.1818], [15.2727, 16.3636, 17.4545], [18.5455, 19.6364, 20.7273], [21.8182, 22.9091, 24.0000]]) tensor([[[ 1.0000, 2.0000, 3.0000], [12.0000, 13.0909, 14.1818]], [[ 4.0000, 5.0000, 6.0000], [15.2727, 16.3636, 17.4545]], [[ 7.0000, 8.0000, 9.0000], [18.5455, 19.6364, 20.7273]], [[10.0000, 11.0000, 12.0000], [21.8182, 22.9091, 24.0000]]]) torch.Size([4, 2, 3]) tensor([[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.], [10., 11., 12.]]) tensor([[12.0000, 13.0909, 14.1818], [15.2727, 16.3636, 17.4545], [18.5455, 19.6364, 20.7273], [21.8182, 22.9091, 24.0000]])
torch.chunk():无法平均的话最后一块小于平均值
a = torch.rand(3,4) print(a) out = torch.chunk(a, 2, dim = 0) #零维度上平均切成两块 print(out) 运行结果如下: tensor([[0.8622, 0.8813, 0.8506, 0.4134], [0.0779, 0.7562, 0.1696, 0.2435], [0.4689, 0.9242, 0.3887, 0.0454]]) (tensor([[0.8622, 0.8813, 0.8506, 0.4134], [0.0779, 0.7562, 0.1696, 0.2435]]), tensor([[0.4689, 0.9242, 0.3887, 0.0454]]))
torch.split():两种切分方式如下
a = torch.rand(5,4) out = torch.split(a, 2, dim=0) # 2组切一刀 print(out) 结果如下: (tensor([[0.5005, 0.7890, 0.9882, 0.3623], [0.4507, 0.7575, 0.3662, 0.1380]]), tensor([[0.5239, 0.0081, 0.1808, 0.1356], [0.5778, 0.3228, 0.8367, 0.5824]]), tensor([[0.2967, 0.7542, 0.2384, 0.3248]])) out = torch.split(a, [1,2,3], dim=0) #按照list清单来切 结果如下: (tensor([[0.3039, 0.0142, 0.6083, 0.0831]]), tensor([[0.2401, 0.0204, 0.2778, 0.6313]]), tensor([[0.1951, 0.9309, 0.3008, 0.0065], [0.2678, 0.8499, 0.7622, 0.3273], [0.1647, 0.1284, 0.4363, 0.3052]]))
以下是几个易混淆的实例
out = torch.unsqueeze(a,0) #-1表示把维度增加在最后面 out = torch.unbind(a, dim=2) #去除某个维度有点类似于在这个维度上切块操作 print(torch.flip(a, dims=[1,2])) #可以同时对多个维度进行翻转 print(torch.rot90(a,2,dims=[0,1])) #传入2表示逆时针2*90°的旋转,若负数表示顺时针的90°旋转
以下是把一张图片数据转换的实例:
im_data = cv2.imread("test.jpg") print(im_data) cv2.imshow("1", im_data) out = torch.from_numpy(im_data) #把numpy转换成Tensor数据 print(out) out = torch.flip(out, dims=[0]) #翻转0维度即图片的高 data = out.numpy() #把tensor转换成numpy cv2.imshow("2", data) cv2.waitKey(0)
文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。