pytorch学习(十三)torch维度变换
包含了flatten、view reshape transpose permute squeeze unsqueeze cat stack;在琢磨一遍之后就比较好理解了。
(图片来源网络,侵删)
1.代码
import torch import numpy as np #生成一个2组3行4列的数据 data = torch.randn((2,3,4)) print("data.shape:\n",data.shape) print("data:\n",data) #对数据进行压平,行以及行以后的维度全部展平 #具体的就是一行一行的排列起来 flatten_data = torch.flatten(data,start_dim=1) print("flatten_data:\n",flatten_data) #可以这样想,原来的数据展成一行,然后每4个数据构成一行,每两行构成一组,最后构成三组 view_data = data.view((3,2,4)) print("view_data:\n",view_data) #和view基本一致的用法和功能 reshape_data = torch.reshape(data,(3,2,4)) print("reshape_data:\n",reshape_data) #transpose每次只能转换两个维度 #表示沿着组的方向, 变成了行,沿组方向也就是[0,0,:] 和[1,0,:]变成了一个组 transpose_data1 = torch.transpose(data,1,0) print("transpose_data1:\n",transpose_data1) #表示沿组的方向变成了列,也就是[:,0,0]变成了一行数据,结合着[:,1,0],[:,2,0]变成了一组数据 transpose_data2 = torch.transpose(data,2,0) print("transpose_data2:\n",transpose_data2) #这个可以同时变换多个维度 #这个要怎么理解呢? #x轴变成了y轴,y轴变成了z轴,z轴变成了x轴 #无论怎么变化,记清楚 x方向是一行,y方向是一列,z方向是一组就可以了 permute_data = data.permute(2,0,1) print("permute_data:\n",permute_data) #在指定的维度上升一个维度 squeeze_data = torch.squeeze(data,dim=0) print("squeeze_data:\n",squeeze_data) #去掉维度为1的维度 unsqueeze_data = torch.unsqueeze(squeeze_data,dim=0) print("unsqueeze_data:\n",unsqueeze_data) # x1 = torch.tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int) x1.shape # torch.Size([2, 3]) print(x1) # x2 x2 = torch.tensor([[12, 22, 32], [22, 32, 42]], dtype=torch.int) x2.shape # torch.Size([2, 3]) print(x2) x3 =torch.cat((x1,x2),0) #按维数0(行)拼接 print("按行拼接:\n",x3) x4 =torch.cat((x1,x2),1) #按维数0(行)拼接 print("按列拼接:\n",x4) #0的时候是升维度 #非0的时候就是升维度后,0和 设置到dim中,其他的按顺序前移动一个 #比如 0 1 2 现在dim=1 那么0跑到1位置,1往前移动一个,变成了 1 0 2 #比如 0 1 2 现在dim=2, 那么0跑到了2位置,1 2 顺势前移动一个位置,变成了 1 2 0 a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) c = torch.stack([a, b], 0) d = torch.stack([a, b], 1) #看来就是把 e = torch.stack([a, b], 2) print("升维度拼接:\n",c) print("拼接-类比与c.permute(1,0,2):\n",d) print("拼接-类似于c.permute(1,2,0):\n",e) print("c11",c.permute(1,0,2)) print("c11",c.permute(1,2,0))
文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。