mindspore打卡第几天 DDPM 之Unet 网络解析markdown版本
mindspore打卡第几天 DDPM 之Unet 网络解析markdown版本
A:
为啥DDPM的unet网络的下采样这部分的channel是从20 32 64 128这样上升的?从U形结构看不应该是下降的
{Block1 --> block2 --> Res(attn)-- >dowmsample}×3
B:
他是在weight和hight上是下降的,通道数是上升
在上采样部分反过来,weight和hight变大,通道数最后回到3
条件U-Net
我们已经定义了所有的构建块(位置嵌入、ResNet/ConvNeXT块、Attention和组归一化),现在需要定义整个神经网络了。请记住,网络 ϵ θ ( x t , t ) \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) ϵθ(xt,t) 的工作是接收一批噪声图像+噪声水平,并输出添加到输入中的噪声。
更具体的:
网络获取了一批(batch_size, num_channels, height, width)形状的噪声图像和一批(batch_size, 1)形状的噪音水平作为输入,并返回(batch_size, num_channels, height, width)形状的张量。
网络构建过程如下:
-
首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置
-
接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成
-
在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织
-
接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成
-
最后,应用ResNet/ConvNeXT块,然后应用卷积层
最终,神经网络将层堆叠起来,就像它们是乐高积木一样(但重要的是了解它们是如何工作的)。
class Unet(nn.Cell): def __init__( self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, with_time_emb=True, convnext_mult=2, ): super().__init__() self.channels = channels init_dim = default(init_dim, dim // 3 * 2) self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) block_klass = partial(ConvNextBlock, mult=convnext_mult) if with_time_emb: time_dim = dim * 4 self.time_mlp = nn.SequentialCell( SinusoidalPositionEmbeddings(dim), nn.Dense(dim, time_dim), nn.GELU(), nn.Dense(time_dim, time_dim), ) else: time_dim = None self.time_mlp = None self.downs = nn.CellList([]) self.ups = nn.CellList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.CellList( [ block_klass(dim_in, dim_out, time_emb_dim=time_dim), block_klass(dim_out, dim_out, time_emb_dim=time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) ) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append( nn.CellList( [ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), block_klass(dim_in, dim_in, time_emb_dim=time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Upsample(dim_in) if not is_last else nn.Identity(), ] ) ) out_dim = default(out_dim, channels) self.final_conv = nn.SequentialCell( block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1) ) def construct(self, x, time): x = self.init_conv(x) t = self.time_mlp(time) if exists(self.time_mlp) else None h = [] for block1, block2, attn, downsample in self.downs: x = block1(x, t) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_attn(x) x = self.mid_block2(x, t) len_h = len(h) - 1 for block1, block2, attn, upsample in self.ups: x = ops.concat((x, h[len_h]), 1) len_h -= 1 x = block1(x, t) x = block2(x, t) x = attn(x) x = upsample(x) return self.final_conv(x)
import mindspore as ms from mindspore.common.initializer import Normal # 参数定义 image_side_length = 32 # 图像的宽和高的像素数 channels = 3 # 图像通道数,这里假设处理的是RGB图像 batch_size = 2 # 批次大小 # 定义 Unet模型 # 注意:此处的dim应该根据模型设计具体指定,但基于您的代码,我们保持原样 unet_model = Unet(dim=image_side_length, channels=channels, dim_mults=(1, 2, 4,)) # 构建输入数据 x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal()) x.shape # 显示数据形状 print(x) # 打印数据(显示初始化后的随机值)
[[[[ 1.22990236e-02 9.65940859e-03 -5.95777121e-04 ... -1.09354462e-02 2.30002552e-02 -5.25823655e-03] [ 1.35805225e-02 1.16471937e-02 -1.20973922e-02 ... -1.13204606e-02 -1.91520341e-02 -1.09745166e-03] [-4.65569133e-03 1.33861918e-02 -1.60518996e-02 ... 4.18792450e-04 9.22567211e-03 4.44417645e-04] ... [ 3.40697076e-03 4.53335233e-03 5.73999388e-03 ... 4.67619160e-03 -8.16432573e-03 -1.39179081e-02] [-9.07978602e-03 -6.43689744e-03 1.32928183e-02 ... 4.21820907e-03 -1.05559649e-02 8.33686162e-03] [ 2.96656298e-03 -7.44550209e-03 5.52403228e-03 ... -2.09826510e-03 2.17068940e-02 2.28530783e-02]] [[-2.34551495e-03 7.68061494e-03 8.63175746e-03 ... -5.62175177e-03 -9.85390134e-03 -4.08322597e-03] [ 1.30044697e-02 -9.87336412e-03 2.55680992e-03 ... 1.21581517e-02 1.10829184e-02 -1.09381862e-02] [-1.09032113e-02 1.25320591e-02 -9.15124733e-03 ... -8.42134352e-04 -3.48115107e-03 -8.12307373e-03] ... [-1.22983279e-02 2.11556954e-03 -1.63072231e-03 ... -8.83890502e-03 2.00234205e-02 -2.91514886e-03] [-4.95374482e-03 -1.51413877e-03 6.57585217e-03 ... 1.93616766e-02 -3.65696964e-03 -1.76955778e-02] [ 8.47856048e-03 9.17020999e-03 -5.66793000e-03 ... -2.92802905e-03 -5.98460436e-03 8.32138583e-03]] [[ 1.00378189e-02 -2.43024575e-03 2.11097375e-02 ... -6.47504721e-03 -1.47426147e-02 7.38033140e-03] [-3.09416349e-03 -3.46184568e-03 -7.74018466e-03 ... 1.19950040e-03 3.14799254e-04 -7.95779750e-03] [ 3.98837449e-03 2.33123749e-02 1.63442008e-02 ... 1.05365906e-02 -1.44729228e-03 1.90633966e-03] ... [-1.76522471e-02 9.42215510e-03 -9.92319733e-03 ... -8.83952528e-03 -1.18930812e-03 -8.53374321e-03] [ 2.51283534e-02 -1.38457380e-02 1.32035371e-02 ... 1.66724548e-02 -9.26751085e-03 1.42328264e-02] [-3.69384699e-03 6.09130273e-03 -2.94976344e-04 ... 7.72336172e-03 -3.75742209e-03 -3.17590404e-03]]] [[[-2.92081665e-03 -1.39991604e-02 -8.93703103e-03 ... 1.51352473e-02 3.90937366e-03 2.66693830e-02] [-2.27847677e-02 3.63694108e-03 2.70780316e-03 ... -8.13330431e-03 -4.17956570e-03 1.22072157e-02] [-1.24624427e-02 4.75015305e-03 2.68556597e-03 ... 6.48784591e-03 -6.09957753e-03 4.85362951e-03] ... [-3.67846363e-03 -9.81856976e-03 -7.40657933e-03 ... 1.95454084e-03 1.80558003e-02 4.30267537e-03] [-2.47061905e-02 1.53471017e-03 -2.55961739e-03 ... -6.16029697e-03 -1.19128199e-02 7.23672146e-03] [-9.77169070e-03 -5.93968621e-03 -1.16010886e-02 ... 1.13449963e-02 7.74116023e-03 -8.25872459e-03]] [[ 2.42574494e-02 -1.59016773e-02 4.60586976e-03 ... -1.27300173e-02 -2.08083801e-02 1.20891845e-02] [ 4.98928130e-03 1.58587005e-02 -1.17553072e-02 ... -4.57813032e-03 2.66204093e-04 -1.80527139e-02] [ 9.97055881e-03 2.07035127e-03 -7.31401029e-04 ... 1.80852767e-02 -2.09929375e-03 4.49541025e-04] ... [-8.71989876e-04 7.75372284e-03 3.14102072e-05 ... 6.37980178e-04 -1.68553423e-02 -4.13572555e-03] [ 6.12246012e-03 -1.88669516e-03 1.50548946e-02 ... 9.18534491e-03 1.46157937e-02 5.96544426e-03] [-5.24167530e-03 2.64895801e-03 7.25612324e-03 ... -5.48065547e-03 -2.98001780e-03 -7.99621455e-03]] [[ 1.18518099e-02 1.00414380e-02 -3.00463289e-03 ... -3.48429219e-03 1.21912286e-02 -8.21612682e-03] [ 9.25556850e-03 -1.57560236e-04 7.71128759e-03 ... 3.91136715e-03 1.56383701e-02 8.09505815e-04] [ 4.79864981e-03 1.88933630e-02 1.73798949e-02 ... 5.97322173e-03 4.30198200e-03 1.52684944e-02] ... [-9.37487371e-03 5.54391975e-03 4.64118691e-03 ... 6.41342625e-03 1.36971334e-03 -1.25444317e-02] [-4.26448090e-03 7.79700419e-03 2.39845295e-03 ... -1.18866842e-02 3.74738523e-03 1.07039241e-02] [-1.02939839e-02 7.36899953e-03 -2.00587343e-02 ... -1.10042403e-02 -1.42604960e-02 -1.37462756e-02]]]]
x.shape # 显示数据形状
(2, 3, 32, 32)
dim=image_side_length channels=channels dim_mults=(1, 2, 4,)
init_dim=None out_dim=None # dim_mults=(1, 2, 4, 8) channels=3 with_time_emb=True convnext_mult=2
init_dim = default(init_dim, dim // 3 * 2) dim,init_dim
(32, 20)
channels
3
init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True) init_conv
Conv2d
dim, dim_mults,init_dim
(32, (1, 2, 4), 20)
(lambda m: dim * m, dim_mults)
(, (1, 2, 4))
dims = [init_dim, *map(lambda m: dim * m, dim_mults)] dims
[20, 32, 64, 128]
zip(dims[:-1], dims[1:])
dims[:-1], dims[1:]
([20, 32, 64], [32, 64, 128])
in_out = list(zip(dims[:-1], dims[1:])) in_out
[(20, 32), (32, 64), (64, 128)]
ConvNextBlock,convnext_mult
(__main__.ConvNextBlock, 2)
block_klass = partial(ConvNextBlock, mult=convnext_mult) ##传入ConvNextBlock的第一个参数mult=convnext_mult block_klass
functools.partial(, mult=2)
with_time_emb
True
time_dim = dim * 4 time_dim,dim
(128, 32)
time_mlp = nn.SequentialCell( SinusoidalPositionEmbeddings(dim), nn.Dense(dim, time_dim), nn.GELU(), nn.Dense(time_dim, time_dim), ) time_mlp
SequentialCell
downs = nn.CellList([]) ups = nn.CellList([]) ups
CellList
num_resolutions = len(in_out) num_resolutions
3
for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) print(ind,":",is_last)
0 : False 1 : False 2 : True
dim_in, dim_out, time_dim ###把每个时间步编码为128维度
(64, 128, 128)
in_out
[(20, 32), (32, 64), (64, 128)]
for ind, (dim_in, dim_out) in enumerate(in_out): print(dim_in, dim_out) is_last = ind >= (num_resolutions - 1) downs.append( nn.CellList( [ block_klass(dim_in, dim_out, time_emb_dim=time_dim), block_klass(dim_out, dim_out, time_emb_dim=time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) )
20 32 32 64 64 128
downs
CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Conv2d > (1): CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Conv2d > (2): CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Identity > >
mid_dim = dims[-1] mid_dim
128
mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
mid_block1
ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity >
mid_attn
Residual (norm): GroupNorm > >
mid_block2
ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity >
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): print(dim_in, dim_out) is_last = ind >= (num_resolutions - 1) print(is_last)
64 128 False 32 64 False
dim_in
32
LinearAttention(dim_in)
LinearAttention >
class LinearAttention(nn.Cell): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False) self.to_out = nn.SequentialCell( nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True), LayerNorm(dim) ) self.map = ops.Map() self.partial = ops.Partial() def construct(self, x): b, _, h, w = x.shape qkv = self.to_qkv(x).chunk(3, 1) q, k, v = self.map(self.partial(rearrange, self.heads), qkv) q = ops.softmax(q, -2) k = ops.softmax(k, -1) q = q * self.scale v = v / (h * w) # 'b h d n, b h e n -> b h d e' context = ops.bmm(k, v.swapaxes(2, 3)) # 'b h d e, b h d n -> b h e n' out = ops.bmm(context.swapaxes(2, 3), q) out = out.reshape((b, -1, h, w)) return self.to_out(out)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) ups.append( nn.CellList( [ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), block_klass(dim_in, dim_in, time_emb_dim=time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Upsample(dim_in) if not is_last else nn.Identity(), ] ) )
ups
CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Conv2dTranspose > (1): CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Conv2dTranspose > >
out_dim = default(out_dim, channels) out_dim
3
final_conv = nn.SequentialCell( block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1) ) final_conv
SequentialCell (res_conv): Identity > (1): Conv2d >
x time=5
x.shape
(2, 3, 32, 32)
print(x)
[[[[ 1.22990236e-02 9.65940859e-03 -5.95777121e-04 ... -1.09354462e-02 2.30002552e-02 -5.25823655e-03] [ 1.35805225e-02 1.16471937e-02 -1.20973922e-02 ... -1.13204606e-02 -1.91520341e-02 -1.09745166e-03] [-4.65569133e-03 1.33861918e-02 -1.60518996e-02 ... 4.18792450e-04 9.22567211e-03 4.44417645e-04] ... [ 3.40697076e-03 4.53335233e-03 5.73999388e-03 ... 4.67619160e-03 -8.16432573e-03 -1.39179081e-02] [-9.07978602e-03 -6.43689744e-03 1.32928183e-02 ... 4.21820907e-03 -1.05559649e-02 8.33686162e-03] [ 2.96656298e-03 -7.44550209e-03 5.52403228e-03 ... -2.09826510e-03 2.17068940e-02 2.28530783e-02]] [[-2.34551495e-03 7.68061494e-03 8.63175746e-03 ... -5.62175177e-03 -9.85390134e-03 -4.08322597e-03] [ 1.30044697e-02 -9.87336412e-03 2.55680992e-03 ... 1.21581517e-02 1.10829184e-02 -1.09381862e-02] [-1.09032113e-02 1.25320591e-02 -9.15124733e-03 ... -8.42134352e-04 -3.48115107e-03 -8.12307373e-03] ... [-1.22983279e-02 2.11556954e-03 -1.63072231e-03 ... -8.83890502e-03 2.00234205e-02 -2.91514886e-03] [-4.95374482e-03 -1.51413877e-03 6.57585217e-03 ... 1.93616766e-02 -3.65696964e-03 -1.76955778e-02] [ 8.47856048e-03 9.17020999e-03 -5.66793000e-03 ... -2.92802905e-03 -5.98460436e-03 8.32138583e-03]] [[ 1.00378189e-02 -2.43024575e-03 2.11097375e-02 ... -6.47504721e-03 -1.47426147e-02 7.38033140e-03] [-3.09416349e-03 -3.46184568e-03 -7.74018466e-03 ... 1.19950040e-03 3.14799254e-04 -7.95779750e-03] [ 3.98837449e-03 2.33123749e-02 1.63442008e-02 ... 1.05365906e-02 -1.44729228e-03 1.90633966e-03] ... [-1.76522471e-02 9.42215510e-03 -9.92319733e-03 ... -8.83952528e-03 -1.18930812e-03 -8.53374321e-03] [ 2.51283534e-02 -1.38457380e-02 1.32035371e-02 ... 1.66724548e-02 -9.26751085e-03 1.42328264e-02] [-3.69384699e-03 6.09130273e-03 -2.94976344e-04 ... 7.72336172e-03 -3.75742209e-03 -3.17590404e-03]]] [[[-2.92081665e-03 -1.39991604e-02 -8.93703103e-03 ... 1.51352473e-02 3.90937366e-03 2.66693830e-02] [-2.27847677e-02 3.63694108e-03 2.70780316e-03 ... -8.13330431e-03 -4.17956570e-03 1.22072157e-02] [-1.24624427e-02 4.75015305e-03 2.68556597e-03 ... 6.48784591e-03 -6.09957753e-03 4.85362951e-03] ... [-3.67846363e-03 -9.81856976e-03 -7.40657933e-03 ... 1.95454084e-03 1.80558003e-02 4.30267537e-03] [-2.47061905e-02 1.53471017e-03 -2.55961739e-03 ... -6.16029697e-03 -1.19128199e-02 7.23672146e-03] [-9.77169070e-03 -5.93968621e-03 -1.16010886e-02 ... 1.13449963e-02 7.74116023e-03 -8.25872459e-03]] [[ 2.42574494e-02 -1.59016773e-02 4.60586976e-03 ... -1.27300173e-02 -2.08083801e-02 1.20891845e-02] [ 4.98928130e-03 1.58587005e-02 -1.17553072e-02 ... -4.57813032e-03 2.66204093e-04 -1.80527139e-02] [ 9.97055881e-03 2.07035127e-03 -7.31401029e-04 ... 1.80852767e-02 -2.09929375e-03 4.49541025e-04] ... [-8.71989876e-04 7.75372284e-03 3.14102072e-05 ... 6.37980178e-04 -1.68553423e-02 -4.13572555e-03] [ 6.12246012e-03 -1.88669516e-03 1.50548946e-02 ... 9.18534491e-03 1.46157937e-02 5.96544426e-03] [-5.24167530e-03 2.64895801e-03 7.25612324e-03 ... -5.48065547e-03 -2.98001780e-03 -7.99621455e-03]] [[ 1.18518099e-02 1.00414380e-02 -3.00463289e-03 ... -3.48429219e-03 1.21912286e-02 -8.21612682e-03] [ 9.25556850e-03 -1.57560236e-04 7.71128759e-03 ... 3.91136715e-03 1.56383701e-02 8.09505815e-04] [ 4.79864981e-03 1.88933630e-02 1.73798949e-02 ... 5.97322173e-03 4.30198200e-03 1.52684944e-02] ... [-9.37487371e-03 5.54391975e-03 4.64118691e-03 ... 6.41342625e-03 1.36971334e-03 -1.25444317e-02] [-4.26448090e-03 7.79700419e-03 2.39845295e-03 ... -1.18866842e-02 3.74738523e-03 1.07039241e-02] [-1.02939839e-02 7.36899953e-03 -2.00587343e-02 ... -1.10042403e-02 -1.42604960e-02 -1.37462756e-02]]]]
init_conv
Conv2d
x = init_conv(x) print(x)
[[[[ 0.04106301 0.04308875 0.03753628 ... 0.03978373 0.04020362 0.03793241] [ 0.04131693 0.04225756 0.03624208 ... 0.04158005 0.04384746 0.0374498 ] [ 0.03251923 0.04382189 0.03682814 ... 0.04343156 0.03790728 0.03667009] ... [ 0.03987011 0.04261249 0.03721504 ... 0.03754282 0.03530194 0.04190454] [ 0.03995614 0.04259038 0.04231969 ... 0.03937387 0.03802945 0.03542861] [ 0.03724413 0.03895703 0.03808391 ... 0.04210365 0.03843816 0.03887339]] [[-0.07880677 -0.081793 -0.08021648 ... -0.07915598 -0.08803446 -0.07824855] [-0.07975532 -0.0827108 -0.08153103 ... -0.08920732 -0.08202183 -0.07717112] [-0.080375 -0.08304221 -0.07943083 ... -0.08371484 -0.07717931 -0.07678773] ... [-0.07925861 -0.07035945 -0.07607639 ... -0.08380341 -0.08219168 -0.08388805] [-0.07702561 -0.07861231 -0.08642116 ... -0.08342467 -0.07647635 -0.08471077] [-0.08218312 -0.08206419 -0.0820056 ... -0.0710914 -0.08050337 -0.08665174]] [[-0.05770465 -0.05971097 -0.06042907 ... -0.05981689 -0.05351909 -0.06045758] [-0.06280089 -0.06072729 -0.06125656 ... -0.06167236 -0.05607811 -0.06504007] [-0.05974955 -0.06224146 -0.05134789 ... -0.06194806 -0.05703649 -0.05661972] ... [-0.0587245 -0.06006888 -0.06369887 ... -0.0509633 -0.05987025 -0.05689852] [-0.05888586 -0.06178844 -0.06245932 ... -0.06076533 -0.05802548 -0.06169396] [-0.05935856 -0.05726556 -0.05836396 ... -0.06468105 -0.05601557 -0.05411654]] ... [[-0.06885257 -0.06496602 -0.07227325 ... -0.06768468 -0.07973982 -0.06684067] [-0.06921483 -0.07310341 -0.07145415 ... -0.07373261 -0.06769554 -0.06564213] [-0.07235637 -0.08390911 -0.06977317 ... -0.06690352 -0.06286541 -0.06959118] ... [-0.07000594 -0.06508094 -0.06877656 ... -0.07407243 -0.07690564 -0.06396648] [-0.07082649 -0.07268029 -0.07315704 ... -0.06758922 -0.06662212 -0.06855071] [-0.07199463 -0.06999994 -0.07014568 ... -0.06523817 -0.07094447 -0.07466151]] [[ 0.06986891 0.06941634 0.06439675 ... 0.06187101 0.0675493 0.07306495] [ 0.07319107 0.07266034 0.05997508 ... 0.06689761 0.06815154 0.0660945 ] [ 0.07065641 0.05923657 0.06411441 ... 0.06652149 0.07088953 0.07194202] ... [ 0.06848673 0.07591817 0.0726023 ... 0.06602401 0.06890585 0.07259338] [ 0.07433689 0.06679939 0.06691605 ... 0.0667197 0.07184143 0.06983658] [ 0.06431621 0.07212089 0.06723586 ... 0.06868842 0.07140361 0.06901537]] [[ 0.02097934 0.01210438 0.01431934 ... 0.01505992 0.01852277 0.01381299] [ 0.02296696 0.0177606 0.01976403 ... 0.02147305 0.02210259 0.02313221] [ 0.02124698 0.02709681 0.02910981 ... 0.01016832 0.02212639 0.01957588] ... [ 0.02289374 0.01311012 0.01578637 ... 0.01931083 0.01555186 0.0208313 ] [ 0.01390727 0.02096656 0.01745579 ... 0.01781181 0.02211875 0.01568411] [ 0.02439262 0.01495296 0.01968778 ... 0.02193322 0.01783368 0.0176824 ]]] [[[ 0.04223164 0.03378314 0.03601065 ... 0.03959855 0.03485664 0.04071919] [ 0.0318558 0.05363872 0.03783617 ... 0.04385335 0.0496259 0.03691863] [ 0.03818982 0.03180957 0.04072122 ... 0.03430039 0.03384047 0.03837577] ... [ 0.0398777 0.03721025 0.03533046 ... 0.04020133 0.03928016 0.04710523] [ 0.04118172 0.03496882 0.03100736 ... 0.03642647 0.03914004 0.0371574 ] [ 0.04037436 0.04040184 0.04165599 ... 0.04403537 0.03254044 0.04335065]] [[-0.08552387 -0.07319534 -0.08021338 ... -0.07858572 -0.07166487 -0.08406518] [-0.07923919 -0.08566054 -0.08015955 ... -0.08471547 -0.0847266 -0.08085599] [-0.08489675 -0.09258271 -0.08831957 ... -0.09042192 -0.08426952 -0.0808774 ] ... [-0.08036023 -0.07413588 -0.07989521 ... -0.07935498 -0.08571334 -0.08329107] [-0.07644836 -0.07608277 -0.08767064 ... -0.08434241 -0.08071237 -0.0839122 ] [-0.07979399 -0.08087463 -0.08673595 ... -0.08414597 -0.08045428 -0.07299927]] [[-0.06060546 -0.05453672 -0.06102112 ... -0.05194974 -0.0567053 -0.06273571] [-0.06276039 -0.05693425 -0.04725159 ... -0.06214722 -0.06443968 -0.05762123] [-0.05252658 -0.06019294 -0.06137866 ... -0.04910715 -0.06131132 -0.06036767] ... [-0.06173272 -0.05464447 -0.05099018 ... -0.06136036 -0.06400239 -0.06106843] [-0.05803053 -0.05994222 -0.06404369 ... -0.04949801 -0.05738675 -0.06158596] [-0.05899998 -0.06198164 -0.05937162 ... -0.06379396 -0.06430338 -0.06287489]] ... [[-0.07173873 -0.0707745 -0.06975999 ... -0.07155637 -0.06534318 -0.07189398] [-0.06730746 -0.07013785 -0.06751848 ... -0.07264671 -0.07705939 -0.07342067] [-0.07058413 -0.07025788 -0.06871852 ... -0.06887744 -0.06563742 -0.07028291] ... [-0.07809374 -0.06778216 -0.06392691 ... -0.06867532 -0.07118014 -0.07647338] [-0.07219965 -0.07040192 -0.0732589 ... -0.07633238 -0.0752567 -0.0702922 ] [-0.06984755 -0.07723872 -0.06846898 ... -0.06786713 -0.06702175 -0.07062964]] [[ 0.06747851 0.06883495 0.06797507 ... 0.06853593 0.06575806 0.06841848] [ 0.06514458 0.06994057 0.06866109 ... 0.06339982 0.06309478 0.06588745] [ 0.06701669 0.0691862 0.06725767 ... 0.06696404 0.07045414 0.07060774] ... [ 0.07085218 0.0809648 0.06841429 ... 0.06838602 0.06918488 0.07014886] [ 0.07304276 0.07134987 0.07214254 ... 0.07656243 0.07136226 0.06578355] [ 0.06968872 0.07193028 0.06518821 ... 0.07004035 0.06891351 0.06959624]] [[ 0.02295301 0.01347421 0.02212771 ... 0.02214386 0.01323562 0.02334489] [ 0.01578862 0.01825874 0.01307945 ... 0.0216907 0.02719616 0.02306023] [ 0.01491401 0.01406 0.02918804 ... 0.02165697 0.01733657 0.01930147] ... [ 0.01614039 0.019646 0.02148937 ... 0.00664111 0.01888491 0.02413018] [ 0.01757734 0.01567486 0.01912338 ... 0.02099028 0.01717271 0.01547725] [ 0.01756483 0.020161 0.01650484 ... 0.01933268 0.0167334 0.01855144]]]]
x.shape
(2, 20, 32, 32)
#unet_model.init_conv(x) ####调用实例化后类的方法!!!!! 好像是失败的
import numpy as np from mindspore import Tensor # 定义时间步的起始值、步数以及步长(默认为1,即每个时间步增加1) start = 0 num_steps = 10 step = 1 # 生成线性递增的时间步长序列 t = Tensor(np.arange(start, start + num_steps * step, step), dtype=ms.int32) print("线性递增的时间步长序列:", t) # time_mlp # SequentialCell time_mlp(t) ###这里正确了
线性递增的时间步长序列: [0 1 2 3 4 5 6 7 8 9] Tensor(shape=[10, 128], dtype=Float32, value= [[ 1.89717814e-01, 1.14008449e-02, 3.33061777e-02 ... 1.43985003e-01, 3.92933972e-02, -1.06829256e-02], [ 1.93240538e-01, 1.78442001e-02, 6.77158684e-02 ... 1.36301309e-01, 7.64560923e-02, -1.50307640e-02], [ 1.83035284e-01, 2.44393535e-02, 8.70461762e-02 ... 1.38745904e-01, 1.33171901e-01, -3.85175534e-02], ... [ 1.28773689e-01, 1.91335917e-01, -9.48226005e-02 ... 8.54851380e-02, 1.52098373e-01, 2.03581899e-02], [ 1.22549936e-01, 1.48201510e-01, -8.17623138e-02 ... 4.44053262e-02, 9.75183249e-02, 3.97774130e-02], [ 1.07752994e-01, 1.08763084e-01, -7.05250949e-02 ... 4.34711799e-02, 6.16942756e-02, 1.67786255e-02]])
import numpy as np from mindspore import Tensor # 定义时间步的起始值、步数以及步长(默认为1,即每个时间步增加1) start = 0 num_steps = 5 step = 1 # 生成线性递增的时间步长序列 t = Tensor(np.arange(start, start + num_steps * step, step), dtype=ms.int32) print("线性递增的时间步长序列:", t) # time_mlp # SequentialCell time_mlp(t) ###这里正确了
线性递增的时间步长序列: [0 1 2 3 4] Tensor(shape=[5, 128], dtype=Float32, value= [[ 1.89717814e-01, 1.14008449e-02, 3.33061777e-02 ... 1.43985003e-01, 3.92933972e-02, -1.06829256e-02], [ 1.93240538e-01, 1.78442001e-02, 6.77158684e-02 ... 1.36301309e-01, 7.64560923e-02, -1.50307640e-02], [ 1.83035284e-01, 2.44393535e-02, 8.70461762e-02 ... 1.38745904e-01, 1.33171901e-01, -3.85175534e-02], [ 1.60673216e-01, 5.66724911e-02, 6.34887069e-02 ... 1.58708930e-01, 1.91956162e-01, -7.78784081e-02], [ 1.33557051e-01, 1.20526701e-01, 5.27272746e-03 ... 1.78237736e-01, 2.34055102e-01, -9.55123529e-02]])
t.shape
(5,)
看起来你希望构造一个形状为(5,)的张量t,这通常表示一个包含5个元素的一维向量。在Python中使用MindSpore库,你可以很容易地创建这样一个张量。这里有几个创建此形状张量的例子,包括初始化为特定值或随机值:
初始化为零
import mindspore as ms # 创建一个全零的张量,形状为(5,) t = ms.Tensor.zeros(5, dtype=ms.float32) print(t.shape) # 输出: (5,)
初始化为一
t = ms.Tensor.ones(5, dtype=ms.float32) print(t.shape) # 输出: (5,)
初始化为单位序列(例如,[0, 1, 2, 3, 4])
import numpy as np t = ms.Tensor(np.arange(5), dtype=ms.float32) print(t.shape) # 输出: (5,)
初始化为随机值
t = ms.Tensor(np.random.rand(5), dtype=ms.float32) print(t.shape) # 输出: (5,)
使用特定值填充
value = 3.14 # 例如,使用π作为填充值 t = ms.Tensor(np.full(5, value), dtype=ms.float32) print(t.shape) # 输出: (5,)
以上任一代码块都可以创建一个形状为(5,)的张量t,根据你的具体需求选择合适的初始化方式。
h = []
downs # 3#套有3个cellist元素的一个celllist
CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Conv2d > (1): CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Conv2d > (2): CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Identity > >
for downsample in downs: print(downsample)
CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Conv2d > CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Conv2d > CellList (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > (1): ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > (2): Residual > (norm): GroupNorm > > (3): Identity >
for block1, block2, attn, downsample in downs: print("aaaaaaaaaas11BL1") print(block1) print("aaaaaaaaaasBL2") print( block2) print("aaaaaaaaaasAT 残差 attn") print(attn) print("aaaaaaaaaasAT 下采样") print( downsample)
aaaaaaaaaas11BL1 ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > aaaaaaaaaasBL2 ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > aaaaaaaaaasAT 残差 attn Residual > (norm): GroupNorm > > aaaaaaaaaasAT 下采样 Conv2d aaaaaaaaaas11BL1 ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > aaaaaaaaaasBL2 ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > aaaaaaaaaasAT 残差 attn Residual > (norm): GroupNorm > > aaaaaaaaaasAT 下采样 Conv2d aaaaaaaaaas11BL1 ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > aaaaaaaaaasBL2 ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > aaaaaaaaaasAT 残差 attn Residual > (norm): GroupNorm > > aaaaaaaaaasAT 下采样 Identity
dim_in, dim_out,dim
(32, 64, 32)
因为循环这个[(20, 32), (32, 64), (64, 128)] 所以down 有3个元素 nn.cell
convnext_mult
2
#dims = [init_dim, *map(lambda m: dim * m, dim_mults)] dims
[20, 32, 64, 128]
i=0 for block1, block2, attn, downsample in downs: i=i+1 print("--------",i) print("BL1块1:",block1,"BL2块2:", block2, "ATTT残差注意力:",attn, "DOWN下采样:",downsample) # block_klass(dim_in, dim_out, time_emb_dim=time_dim), #time_dim=128 dim=128 # block_klass(dim_out, dim_out, time_emb_dim=time_dim), # Residual(PreNorm(dim_out, LinearAttention(dim_out))), # Downsample(dim_out) if not is_last else nn.Identity(), block_klass = partial(ConvNextBlock, mult=convnext_mult) ##传入ConvNextBlock的第一个参数mult=convnext_mult block_klass # class ConvNextBlock(nn.Cell): # def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): # super().__init__() # self.mlp = ( # nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim)) # if exists(time_emb_dim) # else None # ) # self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad") # self.net = nn.SequentialCell( # nn.GroupNorm(1, dim) if norm else nn.Identity(), # nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"), # nn.GELU(), # nn.GroupNorm(1, dim_out * mult), # nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"), # ) # self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() # def construct(self, x, time_emb=None): # h = self.ds_conv(x) # if exists(self.mlp) and exists(time_emb): # assert exists(time_emb), "time embedding must be passed in" # condition = self.mlp(time_emb) # condition = condition.expand_dims(-1).expand_dims(-1) # h = h + condition # h = self.net(h) # return h + self.res_conv(x) ### 第一层BL1 #input_channels=20, output_channels=32 (20,32) ### 第一层BL2 #Conv2d DOWN下采样: Conv2d -------- 2 BL1块1: ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > BL2块2: ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > ATTT残差注意力: Residual > (norm): GroupNorm > > DOWN下采样: Conv2d -------- 3 BL1块1: ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > BL2块2: ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > ATTT残差注意力: Residual > (norm): GroupNorm > > DOWN下采样: Identity functools.partial(, mult=2)
x.shape
(2, 20, 32, 32)
t
Tensor(shape=[5], dtype=Int32, value= [0, 1, 2, 3, 4])
t=time_mlp(t) ###这里正确了 t
Tensor(shape=[5, 128], dtype=Float32, value= [[ 1.89717814e-01, 1.14008449e-02, 3.33061777e-02 ... 1.43985003e-01, 3.92933972e-02, -1.06829256e-02], [ 1.93240538e-01, 1.78442001e-02, 6.77158684e-02 ... 1.36301309e-01, 7.64560923e-02, -1.50307640e-02], [ 1.83035284e-01, 2.44393535e-02, 8.70461762e-02 ... 1.38745904e-01, 1.33171901e-01, -3.85175534e-02], [ 1.60673216e-01, 5.66724911e-02, 6.34887069e-02 ... 1.58708930e-01, 1.91956162e-01, -7.78784081e-02], [ 1.33557051e-01, 1.20526701e-01, 5.27272746e-03 ... 1.78237736e-01, 2.34055102e-01, -9.55123529e-02]])
# 选取第一行 # 选取第一行 new_t = t[0:1, :] # 或者直接 t[0:1] 也可以 = t[0:1, :] # 或者直接 t[0:1] 也可以 new_t.shape
(1, 128)
t=new_t
class ConvNextBlock(nn.Cell): def __init__(self, dim=20, dim_out=32, *, time_emb_dim=128, mult=2, norm=True): super().__init__() self.mlp = ( nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim)) if exists(time_emb_dim) else None ) self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad") self.net = nn.SequentialCell( nn.GroupNorm(1, dim) if norm else nn.Identity(), nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"), nn.GELU(), nn.GroupNorm(1, dim_out * mult), nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"), ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def construct(self, x, time_emb=None): h = self.ds_conv(x) if exists(self.mlp) and exists(time_emb): assert exists(time_emb), "time embedding must be passed in" condition = self.mlp(time_emb) condition = condition.expand_dims(-1).expand_dims(-1) h = h + condition h = self.net(h) return h + self.res_conv(x)
BL1=ConvNextBlock() BL1
ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d >
x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal()) x.shape # 显示数据形状 x = init_conv(x) t=new_t x.shape,t.shape
((2, 20, 32, 32), (1, 128))
BL1(x, t)
- Tensor(shape=[2, 32, 32, 32], dtype=Float32, value= [[[[ 4.79678512e-01, 6.18093789e-01, 4.11911160e-01 ... 2.45554969e-01, 2.87007272e-01, 5.84303178e-02], [ 3.35714668e-01, 5.81144929e-01, 3.02089810e-01 ... 7.34893382e-02, -2.16056317e-01, -2.40196183e-01], [ 6.84010506e-01, 1.10967433e+00, 7.40820885e-01 ... 4.42677915e-01, -5.60586452e-02, -1.65627971e-01], ... [ 7.63499856e-01, 1.13718486e+00, 9.87868309e-01 ... 6.83883190e-01, 1.12375900e-01, -3.52420285e-02], [ 7.06175983e-01, 1.14337587e+00, 9.23162937e-01 ... 6.31385088e-01, 1.49176538e-01, -3.98113094e-02], [ 1.76896825e-01, 4.82230633e-01, 4.89957243e-01 ... 4.00457889e-01, -2.55417023e-02, 1.21514685e-01]], [[ 2.81717628e-01, -9.01857391e-04, -8.76490697e-02 ... -7.88121223e-02, -1.26668364e-01, -1.73759460e-01], [ 1.92831263e-01, -1.43926665e-01, -1.48099199e-01 ... -2.68793881e-01, -1.00307748e-01, -1.20211102e-01], [ 1.39493421e-01, 2.17211112e-01, 1.45897210e-01 ... 1.56540543e-01, 1.65525198e-01, 5.83062395e-02], ... [ 6.92536831e-02, 5.68503514e-02, 3.68858390e-02 ... 8.93135369e-02, 1.07637540e-01, 4.47027944e-02], [ 1.12979397e-01, 2.12710798e-01, 5.37276417e-02 ... 1.01731792e-01, -4.49074507e-02, -2.13617496e-02], [ 7.63944685e-02, 1.35763273e-01, 1.16834700e-01 ... 1.85339376e-01, 1.34029865e-01, 2.19782546e-01]], [[ 1.30130202e-01, 3.29991281e-01, 4.25871283e-01 ... 3.08504313e-01, 4.61269379e-01, 1.57225683e-01], [ 4.05929983e-01, 7.34413862e-01, 8.77515614e-01 ... 7.71579146e-01, 9.44401443e-01, 5.34572124e-01], [-3.62766758e-02, 3.63330275e-01, 4.08758491e-01 ... 3.54813248e-01, 5.34208298e-01, 2.87856668e-01], ... [-1.26003683e-01, 2.56464094e-01, 3.78679991e-01 ... 4.59082156e-01, 5.85478425e-01, 2.87693620e-01], [-4.88909632e-02, 2.99566031e-01, 3.99350137e-01 ... 4.61422205e-01, 4.17674005e-01, 7.26606846e-02], [-2.39267662e-01, -3.09484214e-01, -2.71493912e-01 ... -7.24366307e-02, -1.12498447e-01, -1.38472736e-01]], ... [[-1.59212112e-01, 7.95098245e-02, -1.85586754e-02 ... -2.23550811e-01, -2.70033002e-01, -2.44036630e-01], [-1.60980105e-01, 4.35432374e-01, 5.99099815e-01 ... 4.49325353e-01, 4.23938036e-01, 3.12254220e-01], [-3.05827290e-01, 7.60348588e-02, 2.39996284e-01 ... 2.72248350e-02, -1.65684037e-02, -1.06293596e-01], ... [-2.98552692e-01, 5.39370701e-02, 2.43080735e-01 ... 1.28992423e-01, 6.57526404e-02, -7.48077184e-02], [-2.99177408e-01, -1.83835357e-01, -3.95203456e-02 ... -4.43453342e-02, -1.39597371e-01, -2.18513533e-01], [ 1.12659251e-02, 2.71181725e-02, 8.25900063e-02 ... 1.92151055e-01, 2.09751755e-01, 4.28373404e-02]], [[-1.12641305e-01, 2.13623658e-01, 7.53313601e-02 ... -1.21324155e-02, -1.53158829e-01, -4.77597833e-01], [-3.84328097e-01, 6.15204051e-02, 1.25263743e-02 ... -7.10279793e-02, -2.77535737e-01, -3.76104653e-01], [-4.07613993e-01, 1.12187594e-01, 3.72527242e-02 ... -2.06387043e-02, -1.46990225e-01, -2.87585199e-01], ... [-3.61820847e-01, 9.99522805e-02, -9.61808674e-03 ... -4.89163473e-02, -1.65467933e-01, -3.17837149e-01], [-6.31257832e-01, -2.93515027e-01, -3.12220454e-01 ... -1.29600003e-01, -1.40924498e-01, -1.52635127e-01], [-5.44437885e-01, -2.92856932e-01, -2.64693975e-01 ... -1.66876107e-01, -1.02364674e-01, 1.51942633e-02]], [[-4.31133687e-01, -6.47886336e-01, -8.31129193e-01 ... -8.14953923e-01, -8.53494108e-01, -5.33654928e-01], [-1.00464332e+00, -1.09428477e+00, -1.38921976e+00 ... -1.53568864e+00, -1.31930661e+00, -4.56116676e-01], [-1.13990593e+00, -1.03481460e+00, -1.49022770e+00 ... -1.62232399e+00, -1.40410137e+00, -5.24185598e-01], ... [-1.26637757e+00, -1.16348124e+00, -1.51403701e+00 ... -1.62057757e+00, -1.49686539e+00, -5.62209725e-01], [-1.11411476e+00, -9.28978443e-01, -1.17068827e+00 ... -1.24776149e+00, -1.03560197e+00, -3.00330132e-01], [-8.52251232e-01, -7.45468676e-01, -9.34467912e-01 ... -1.00492966e+00, -8.28293085e-01, -2.36020580e-01]]], [[[ 4.66426373e-01, 6.25464499e-01, 3.93321365e-01 ... 2.30096430e-01, 3.03178400e-01, 5.51086180e-02], [ 3.32810491e-01, 6.17641032e-01, 3.06311995e-01 ... 1.02875866e-01, -1.90033853e-01, -2.67078996e-01], [ 6.84333742e-01, 1.09874713e+00, 7.69674480e-01 ... 4.62564647e-01, -6.67065307e-02, -2.36097589e-01], ... [ 7.33127177e-01, 1.17721725e+00, 9.89053011e-01 ... 7.15648472e-01, 1.17136240e-01, -4.71793935e-02], [ 6.87817514e-01, 1.09633350e+00, 8.85757685e-01 ... 5.86604059e-01, 9.45525989e-02, -3.36224921e-02], [ 1.95047542e-01, 4.68445808e-01, 4.64000225e-01 ... 3.75145793e-01, 1.80484354e-03, 1.13696203e-01]], [[ 2.65101492e-01, 2.46687196e-02, -1.07584536e-01 ... -1.03970490e-01, -8.17846432e-02, -1.53097644e-01], [ 1.52385280e-01, -8.83764774e-02, -1.62100300e-01 ... -2.18001902e-01, -9.41001922e-02, -1.19305871e-01], [ 1.61403462e-01, 2.30408147e-01, 1.57331020e-01 ... 1.96940184e-01, 1.30461589e-01, 6.52605519e-02], ... [ 5.33300415e-02, 1.22396260e-01, -1.88096687e-02 ... 1.05915010e-01, 1.53571054e-01, 2.45359484e-02], [ 9.76279452e-02, 1.82655305e-01, 1.09691672e-01 ... 1.40925452e-01, 8.01324844e-04, 1.88996084e-03], [ 7.78941736e-02, 1.48156703e-01, 1.39126211e-01 ... 2.34847367e-01, 1.08238310e-01, 2.09336147e-01]], [[ 1.17656320e-01, 3.43433738e-01, 4.39827234e-01 ... 3.09850901e-01, 4.53984410e-01, 1.49862975e-01], [ 3.95844698e-01, 7.22729802e-01, 8.56524229e-01 ... 8.03788304e-01, 9.29986835e-01, 5.35356879e-01], [-5.63383885e-02, 3.74548256e-01, 3.96855712e-01 ... 3.82491171e-01, 5.69522381e-01, 3.37262630e-01], ... [-1.52335018e-01, 2.06072912e-01, 3.62504959e-01 ... 4.47174758e-01, 5.80285132e-01, 2.62391001e-01], [-3.93561423e-02, 3.46813500e-01, 3.57039988e-01 ... 4.35864031e-01, 4.56840277e-01, 7.46745914e-02], [-2.53383636e-01, -2.85494059e-01, -2.50772089e-01 ... -1.19937055e-01, -9.26215500e-02, -1.42144099e-01]], ... [[-1.72060549e-01, 7.35142902e-02, 1.05164722e-02 ... -2.21164092e-01, -2.66059518e-01, -2.46203467e-01], [-1.51937515e-01, 4.78927851e-01, 5.69894075e-01 ... 4.43681806e-01, 4.60492224e-01, 2.69292653e-01], [-3.13104421e-01, 1.40309155e-01, 2.25837916e-01 ... 3.81425694e-02, 7.96409026e-02, -1.00592285e-01], ... [-3.12582165e-01, 3.55643854e-02, 1.99504092e-01 ... 1.73697099e-01, 5.84969185e-02, -7.21051544e-02], [-2.97579527e-01, -1.40844122e-01, -5.89616746e-02 ... -2.86213737e-02, -1.22039340e-01, -2.13227138e-01], [ 3.26608960e-03, 4.80151782e-03, 5.54511398e-02 ... 1.92409024e-01, 1.99357480e-01, 3.34331095e-02]], [[-1.12870112e-01, 2.09549189e-01, 9.65655223e-02 ... -4.27889600e-02, -1.44986391e-01, -4.36559677e-01], [-3.58288437e-01, 9.11962241e-02, -8.71371478e-04 ... -4.59572896e-02, -2.30747938e-01, -4.01585191e-01], [-3.93885374e-01, 1.43090501e-01, -1.07427724e-02 ... -2.74238884e-02, -1.51127338e-01, -3.17271531e-01], ... [-3.39975446e-01, 6.32377267e-02, -4.78150323e-02 ... -9.10668075e-02, -1.39780402e-01, -2.90815294e-01], [-6.53750360e-01, -2.34141201e-01, -2.83103675e-01 ... -9.91634205e-02, -1.61574319e-01, -1.63588241e-01], [-5.44618607e-01, -2.84289837e-01, -2.75803030e-01 ... -1.71445966e-01, -1.29518926e-01, 9.64298844e-04]], [[-4.26712006e-01, -6.10979497e-01, -8.42772007e-01 ... -8.18627119e-01, -8.19367886e-01, -5.47874212e-01], [-1.00363958e+00, -1.09864676e+00, -1.45736146e+00 ... -1.57554007e+00, -1.27784061e+00, -5.02054691e-01], [-1.13520634e+00, -1.06120992e+00, -1.46519005e+00 ... -1.58347833e+00, -1.37640107e+00, -5.30683100e-01], ... [-1.25580835e+00, -1.18095815e+00, -1.52926147e+00 ... -1.63596940e+00, -1.44870484e+00, -5.40451765e-01], [-1.12258959e+00, -9.22193348e-01, -1.14803100e+00 ... -1.24472821e+00, -1.02205288e+00, -3.08310509e-01], [-8.60808074e-01, -7.58719683e-01, -9.61336493e-01 ... -9.92724419e-01, -8.70862603e-01, -2.35298872e-01]]]])
x = ms.Tensor(shape=(batch_size, channels, image_side_length, image_side_length), dtype=ms.float32, init=Normal()) x.shape # 显示数据形状 x = init_conv(x) t=new_t x.shape,t.shape ##需要放到一个格子里面才能运算成功 for block1, block2, attn, downsample in downs: ###事实循环3次,每次有这四个变量 x = block1(x, t) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x)
/
x.shape
(2, 128, 8, 8)
len(h)
3
h ##[x0=2 32 32 32 x1=2 64 16 16 x3=2 128 8 8]
[Tensor(shape=[2, 32, 32, 32], dtype=Float32, value= Tensor(shape=[2, 64, 16, 16], dtype=Float32, value= Tensor(shape=[2, 128, 8, 8], dtype=Float32, value=
len_h = len(h) - 1 len_h
2
h[len_h].shape,x.shape ##最后的一个downsample的维度没有变化
((2, 128, 8, 8), (2, 128, 8, 8))
ops.concat((x, h[len_h]), 1)
- Tensor(shape=[2, 256, 8, 8], dtype=Float32,
这段代码使用了MindSpore框架中的ops.concat函数来执行张量拼接操作。下面是对该代码片段的详细解析:
-
ops.concat: 这是MindSpore操作库中的一个函数,用于沿着指定维度拼接一个张量列表。这里的"ops"是MindSpore中操作(operations)的简写,用于访问各种数学和数组操作。
-
(x, h[len_h]): 这是一个包含两个张量的元组,它们是要被拼接的输入。其中:
- x 是一个张量。
- h[len_h] 表示从列表或数组h中获取索引为len_h的元素。这通常意味着取h的最后一个元素,如果len_h是h的长度的话。不过,确切的行为依据len_h的具体值而定,如果len_h是动态计算的结果或者代表序列的长度,则它可能不是简单地指最后一个元素,而是某个特定位置的元素。
-
, 1): 这个参数指定了拼接操作应该沿着第1个维度进行。在MindSpore和其他类似的深度学习库中,维度计数通常从0开始,所以1表示第二个维度。这意味着x和h[len_h]将在它们的第二个维度上被连接起来,生成一个新的张量,其中这两个输入张量的相应列被串联在一起。
综上所述,这段代码的作用是将张量x和序列h中的最后一个元素(或索引为len_h的元素)在第二个维度上进行拼接,从而生成一个新的张量。这样的操作常见于循环神经网络(RNNs)等模型中,用于更新隐藏状态或组合不同来源的信息。
x = mid_block1(x, t) x = mid_attn(x) x = mid_block2(x, t) x.shape
\ (2, 128, 8, 8)
len_h = len(h) - 1 len_h
2
i=0 for block1, block2, attn, upsample in ups: ##ups 只有2个元素 i=i+1 print("--------",i) print("BL1块1:",block1,"BL2块2:", block2, "ATTT残差注意力:",attn, "UP上采样:",upsample) # block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), # block_klass(dim_in, dim_in, time_emb_dim=time_dim), # Residual(PreNorm(dim_in, LinearAttention(dim_in))), # Upsample(dim_in) if not is_last else nn.Identity(), ### 第一层BL1 #(res_conv): Conv2d (norm): GroupNorm > > UP上采样: Conv2dTranspose -------- 2 BL1块1: ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d > BL2块2: ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Identity > ATTT残差注意力: Residual > (norm): GroupNorm > > UP上采样: Conv2dTranspose
class ConvNextBlock(nn.Cell): def __init__(self, dim=256, dim_out=64, *, time_emb_dim=128, mult=2, norm=True): super().__init__() self.mlp = ( nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim)) if exists(time_emb_dim) else None ) self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad") self.net = nn.SequentialCell( nn.GroupNorm(1, dim) if norm else nn.Identity(), nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"), nn.GELU(), nn.GroupNorm(1, dim_out * mult), nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"), ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def construct(self, x, time_emb=None): h = self.ds_conv(x) if exists(self.mlp) and exists(time_emb): assert exists(time_emb), "time embedding must be passed in" condition = self.mlp(time_emb) condition = condition.expand_dims(-1).expand_dims(-1) h = h + condition h = self.net(h) return h + self.res_conv(x)
BL1=ConvNextBlock() BL1
ConvNextBlock (ds_conv): Conv2d (net): SequentialCell (res_conv): Conv2d >
for block1, block2, attn, upsample in ups: x = ops.concat((x, h[len_h]), 1) len_h -= 1 x = block1(x, t) x = block2(x, t) x = attn(x) x = upsample(x) x.shape
(2, 32, 32, 32)
rx=final_conv(x) rx.shape
(2, 3, 32, 32)
def construct(self, x, time): x = self.init_conv(x) t = self.time_mlp(time) if exists(self.time_mlp) else None h = [] for block1, block2, attn, downsample in self.downs: x = block1(x, t) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x) x = self.mid_block1(x, t) x = self.mid_attn(x) x = self.mid_block2(x, t) len_h = len(h) - 1 for block1, block2, attn, upsample in self.ups: ###因为up只有2个元素 down 有3个元素 但是我们在这里只是循环2次 并没有取出h[0] 就是downn 最开始的那个(20,32) x = ops.concat((x, h[len_h]), 1) ##这步是有啥作用?就是传说的skip connect 或者所谓的残差? len_h -= 1 x = block1(x, t) x = block2(x, t) x = attn(x) x = upsample(x) return self.final_conv(x)
这段代码定义了一个基于U-Net架构的模型,主要应用于图像处理、图像生成或分割任务中,特别是在需要保留细节信息同时捕捉上下文特征的场景下。此模型通过编码器-解码器结构,结合跳跃连接(skip connections)来实现这一点。下面是逐部分的解析:
初始化与时间嵌入
- self.init_conv(x):对输入x应用初始卷积层,开始特征提取。
- t = self.time_mlp(time):如果模型设计中包含时间相关的处理(常用于时序数据或在生成模型中引入时间条件),则通过多层感知机(MLP)处理时间信号time,得到时间嵌入t。
编码器路径(Downsampling)
- 循环遍历self.downs中的模块(每个模块包含两个卷积块block1, block2、一个注意力模块attn和一个下采样模块downsample):
- 两个卷积块分别应用特征变换,并可选择性地结合时间嵌入t。
- 应用注意力机制模块attn增强特征表示。
- 将当前特征图x添加到列表h中作为跳跃连接的存储。
- 使用下采样模块减小空间尺寸,增加深度。
中间块(Bottleneck)
- 应用一系列中间层变换,包括两个卷积块和一个注意力模块,进一步提炼特征。
解码器路径(Upsampling)
- 反向循环遍历self.ups中的模块,与编码器部分相对应,但包含上采样操作:
- ops.concat((x, h[len_h]), 1):这是关键的跳跃连接步骤,将当前解码器层的输出x与对应编码器层的特征图h[len_h]沿通道维度(维度1)拼接,从而传递并合并局部细节信息。
- 减少len_h以在下一轮迭代中获取上一层的特征图。
- 继续应用两个卷积块和注意力模块,以及上采样操作,逐步增加空间尺寸并整合信息。
输出
- 最后,通过self.final_conv(x)应用最终的卷积层,生成输出特征图或直接预测像素级结果。
跳跃连接(Skip Connections)的作用
- 跳跃连接(在这里体现为特征图的拼接)有助于解决梯度消失问题,使得网络能更有效地学习细节信息。
- 它允许低级特征(保留了更多细节信息)与高级特征(提供了更多上下文信息)在解码阶段融合,这对于恢复输入的精细结构至关重要,尤其是在图像生成和分割任务中。
- 因此,这种设计不仅有助于保持对输入细节的精确重构,还能促进生成内容的高保真度和清晰度。
final_conv ##恢复到3个channel
SequentialCell (res_conv): Identity > (1): Conv2d >
- 最后,通过self.final_conv(x)应用最终的卷积层,生成输出特征图或直接预测像素级结果。
- 反向循环遍历self.ups中的模块,与编码器部分相对应,但包含上采样操作:
- 应用一系列中间层变换,包括两个卷积块和一个注意力模块,进一步提炼特征。
- 循环遍历self.downs中的模块(每个模块包含两个卷积块block1, block2、一个注意力模块attn和一个下采样模块downsample):
-