#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/util/irange.h>

namespace at::native {
// input tensors are non-zero dim and non-empty
template <typename T1, typename T2, typename Function>

void tensor_dim_apply3(
    const Tensor& self,
    Tensor& values,
    Tensor& indices,
    int64_t dim,
    Function func) {
  int ndims = self.dim();
  int tensor_dim_apply_has_finished = 0;
  std::vector<int64_t> counter(ndims, 0);
  const T1* self_data = self.const_data_ptr<T1>();
  T1* values_data = values.data_ptr<T1>();
  T2* indices_data = indices.data_ptr<T2>();
  int64_t self_stride = self.stride(dim);
  int64_t values_stride = values.stride(dim);
  int64_t indices_stride = indices.stride(dim);
  int self_dim_size = self.size(dim);

  while (!tensor_dim_apply_has_finished) {
    func(
        self_data,
        values_data,
        indices_data,
        self_dim_size,
        self_stride,
        values_stride,
        indices_stride);
    if (ndims == 1) {
      break;
    }
    for (const auto dim_i : c10::irange(ndims)) {
      if (dim_i == dim) {
        if (dim_i == (ndims - 1)) {
          tensor_dim_apply_has_finished = 1;
          break;
        }
        continue;
      }
      counter[dim_i]++;
      self_data += self.stride(dim_i);
      values_data += values.stride(dim_i);
      indices_data += indices.stride(dim_i);

      if (counter[dim_i] == self.size(dim_i)) {
        if (dim_i == ndims - 1) {
          tensor_dim_apply_has_finished = 1;
          break;
        } else {
          self_data -= counter[dim_i] * self.stride(dim_i);
          values_data -= counter[dim_i] * values.stride(dim_i);
          indices_data -= counter[dim_i] * indices.stride(dim_i);
          counter[dim_i] = 0;
        }
      } else {
        break;
      }
    }
  }
}
} // namespace at::native

#else
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
