squeeze(arg)
表示若第arg维的维度值为1,则去掉该维度,否则tensor不变。(即若tensor.shape()[arg] == 1,则去掉该维度)
例如:
一个维度为2x1x2x1x2的tensor,不用去想它长什么样儿,squeeze(0)就是不变,squeeze(1)就是变成2x2x1x2。(0是从最左边的维度算起的)
1 | 2, 1, 2, 1, 2) x = torch.zeros( |
unsqueeze(arg)
与squeeze(arg)函数作用相反,表示在第arg维增加一个维度为1的维度。
啥意思呢?
比如一个tensor的shape为3x3,那么unsqueeze(0)就是变成1x3x3,unsqueeze(1)就是变成3x1x3.
再如下面这个官方的例子,得看好几眼才能看明白怎么回事。
其实可以这样理解:x的shape为:4,unsqueeze(0)就是把shape变成1x4;unsqueeze(1)就是把shape变成4x1。
1 | 1, 2, 3, 4]) x = torch.tensor([ |
参考:
[1] https://pytorch.org/docs/1.11/generated/torch.unsqueeze.html#torch.unsqueeze
[2] https://www.cnblogs.com/sbj123456789/p/9231571.html