YouTip LogoYouTip

Pytorch Torch Index_Reduce

PyTorch torch.index_reduce Function

\\n\\nPytorch torch Reference Manual Pytorch torch Reference Manual\\n\\ntorch.index_reduce is a PyTorch function used to aggregate values from a source tensor to specified index positions. It aggregates values from source along the specified dimension dim at indices specified by index.\\n\\n

Function Definition

\\n\\n
torch.index_reduce(input, dim, index, source, reduce='mean', *, include_self=True)
\\n\\nParameters:\\n\\n
    \\n
  • input (Tensor): The input tensor.
  • \\n
  • dim (int): The dimension of the index.
  • \\n
  • index (Tensor): A 1D integer tensor specifying positions to aggregate values to.
  • \\n
  • source (Tensor): Source tensor containing values to aggregate.
  • \\n
  • reduce (str): Aggregation method, options include 'mean', 'prod', 'amax', 'amin'. Default is 'mean'.
  • \\n
  • include_self (bool, optional): Whether to include original values at index positions in aggregation. Default is True.
  • \\n
\\n\\nReturns:\\n\\n
    \\n
  • torch.Tensor: Returns the aggregated tensor.
  • \\n
\\n\\n

Examples

\\n\\n

Example

\\n\\n
import torch\\n\\n# Creating Input Tensor\\n\\ninput= torch.randn(3,3)\\n\\n# Creating Index and Source\\n\\n index = torch.tensor([0,0,0])\\n\\n source = torch.tensor([1.0,2.0,3.0])\\n\\n# Use Mean Aggregation\\n\\n output = torch.index_reduce(input, dim=0, index=index, source=source,reduce='mean')\\n\\nprint("Input:")\\n\\nprint(input)\\n\\nprint("nIndex:", index)\\n\\nprint("Source:", source)\\n\\nprint("nmean Aggregated Result:")\\n\\nprint(output)
\\n\\nOutput:\\n\\n
Input: tensor([[ 0.3456, -0.1234, 0.5678], [-0.5678, 0.1234, -0.6789], [ 0.7890, -0.3456, 0.1234]])Index: tensor([0, 0, 0])Source: tensor([1., 2., 3.]) mean Aggregated Result: tensor([[ 2.5237, 0.2931, 0.3374], [-0.5678, 0.1234, -0.6789], [ 0.7890, -0.3456, 0.1234]])
\\n\\n

Example

\\n\\n
import torch\\n\\n# Test Different Aggregation Methods\\n\\ninput= torch.ones(3)\\n\\n index = torch.tensor([0,0,0])\\n\\n source = torch.tensor([2.0,4.0,6.0])\\n\\n# prod Aggregation\\n\\n output_prod = torch.index_reduce(input,0, index, source,reduce='prod')\\n\\nprint("prod Aggregation:", output_prod)\\n\\n# amax Aggregation\\n\\n output_max = torch.index_reduce(input,0, index, source,reduce='amax')\\n\\nprint("amax Aggregation:", output_max)\\n\\n# amin Aggregation\\n\\n output_min = torch.index_reduce(input,0, index, source,reduce='amin')\\n\\nprint("amin Aggregation:", output_min)
\\n\\nOutput:\\n\\n
prod Aggregation: tensor([48., 1., 1.]) amax Aggregation: tensor([7., 1., 1.]) amin Aggregation: tensor([3., 1., 1.])
\\n\\n

Example

\\n\\n
import torch\\n\\n# include_self Parameter\\n\\ninput= torch.tensor([1.0,10.0,100.0])\\n\\n index = torch.tensor([0,0])\\n\\n source = torch.tensor([5.0,5.0])\\n\\n# Include self values (Default)\\n\\n output1 = torch.index_reduce(input,0, index, source,reduce='mean', include_self=True)\\n\\nprint("include_self=True:", output1)\\n\\n# Exclude self values\\n\\n output2 = torch.index_reduce(input,0, index, source,reduce='mean', include_self=False)\\n\\nprint("include_self=False:", output2)
\\n\\nOutput:\\n\\n
include_self=True: tensor([ 3.6667, 10.0000, 100.0000]) include_self=False: tensor([ 5., 10., 100.])
\\n\\n

Example

\\n\\n
import torch\\n\\n# 2DTensor application\\n\\ninput= torch.zeros(3,4)\\n\\n index = torch.tensor([0,2,2])\\n\\n source = torch.randn(3,4)\\n\\n# mean Aggregation\\n\\n output = torch.index_reduce(input, dim=0, index=index, source=source,reduce='mean')\\n\\nprint("InputShape:",input.shape)\\n\\nprint("Index:", index)\\n\\nprint("Source Shape:", source.shape)\\n\\nprint("Result Shape:", output.shape)\\n\\nprint("nResult:")\\n\\nprint(output)
\\n\\n

Note: torch.index_reduce does not modify the original input tensor but returns a new tensor. Multiple indices can point to the same position, and values will be aggregated according to the specified method. include_self=False is useful in scenarios where you want to exclude original values and only aggregate newly added values.

\\n\\nPytorch torch Reference Manual Pytorch torch Reference Manual
← Pytorch Torch Inference_ModePytorch Torch Index_Add β†’