torch.gather()



Torch Gather



1. torch.gather()

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor

公式ウェブサイトでの説明は、より抽象的な、特定の軸に沿って値を集約することです。例に直接行きましょう。

x = torch.tensor([[1,2,3], [4,5,6]]) y = torch.tensor([[0, 1],[1,2]]) print(x) print(y) x.gather(1,y)

画像
dim = 1であるため、値は列の方向に取得されます。 yの最初の行は0と1なので、xの最初の行の列0と列1の要素を取得し、yの2番目の行は1と2を取得してから、2番目の最初の列と列2を取得します。 x要素の行。これは新しいテンソルを取得します。
dim = 0の場合の状況を見てみましょう。



x = torch.tensor([[1, 2, 3], [4, 5, 6]]) y = torch.tensor([[0, 1, 1]]) print(x) print(y) x.gather(0, y)

画像
出力を読んだ後、誰もが言うまでもなくそれを理解していると思います。 yの値は0、1、1であるため、xの各列の0、1、および1行の要素を取得できます。