#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
/*
The Python Imaging Library (PIL) is

    Copyright © 1997-2011 by Secret Labs AB
    Copyright © 1995-2011 by Fredrik Lundh

Pillow is the friendly PIL fork. It is

    Copyright © 2010-2022 by Alex Clark and contributors

Like PIL, Pillow is licensed under the open source HPND License
*/

// This code is heavily inspired from PILLOW-SIMD's implementation:
// https://github.com/uploadcare/pillow-simd/blob/simd/master/src/libImaging/Resample.c

#pragma once
#ifdef CPU_CAPABILITY_AVX2
// TODO: This file only supports AVX2. We could split the AVX kernels into
// smaller logical blocks in order to port them into the Vec.h logic. This would
// allow to support other vectorization architectures and perhaps also support
// the non-vectorized fallback (we'd need to make sure it's not slower than the
// current fallback).

#include <ATen/core/Tensor.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <c10/util/irange.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#endif


namespace {

inline __m128i mm_cvtsi32_si128(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
  int32_t v;
  if (i32_aligned) {
    v = *(const int32_t*)ptr;
  } else {
    std::memcpy(&v, ptr, 4);
  }
  return _mm_cvtsi32_si128(v);
}

inline __m128i mm_cvtepu8_epi32(const uint8_t* C10_RESTRICT ptr, bool i32_aligned) {
  return _mm_cvtepu8_epi32(mm_cvtsi32_si128(ptr, i32_aligned));
}

inline void _write_endline_rgb_as_uint32(
    uint8_t* C10_RESTRICT output,
    uint32_t data
) {
  // data is (R G B X), output is (X1 X2 X3 | R1 B1 G1 R2 ...)
  // Here we explicitly set X as R1
  uint8_t* data_ptr = reinterpret_cast<uint8_t*>(&data);
  data_ptr[3] = output[3];
  std::memcpy(output, data_ptr, 4);
}

at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
  // Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
  // RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
  // into as 32 bits. This generalizes to num_channels <= 4 and also works for
  // non-channels_last tensors.

  const uint8_t* packed = (const uint8_t*)packed_tensor.const_data_ptr<uint8_t>();
  auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
  auto num_channels = packed_tensor.size(0);

  constexpr int rgba_size = 4;
  auto unpacked_tensor = at::empty({rgba_size, packed_tensor.size(1), packed_tensor.size(2)}, at::CPU(at::kByte));
  uint8_t* unpacked = (uint8_t*) unpacked_tensor.data_ptr<uint8_t>();

  auto stride_i = packed_tensor.stride(2);
  auto stride_j = packed_tensor.stride(0);

  for (const auto i : c10::irange(num_pixels)) {
    for (const auto j : c10::irange(rgba_size)) {
      unpacked[rgba_size * i + j] = (j < num_channels) ? packed[stride_i * i + stride_j * j] : 0;
    }
  }
  return unpacked_tensor;
}

void pack_rgb(
    const at::Tensor& unpacked_tensor, // IN
    const at::Tensor& packed_tensor // OUT
) {
  // Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout.

  uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
  uint8_t* packed = (uint8_t*)packed_tensor.data_ptr<uint8_t>();
  auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
  auto num_channels = packed_tensor.size(0);

  auto unpacked_increment = unpacked_tensor.size(0);
  auto packed_increment = packed_tensor.stride(2);
  auto packed_stride = packed_tensor.stride(0);

  TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4);

  for ([[maybe_unused]] const auto i : c10::irange(num_pixels)) {
    for (const auto j : c10::irange(num_channels)) {
      packed[j * packed_stride] = unpacked[j];
    }
    unpacked += unpacked_increment;
    packed += packed_increment;
  }
}

void ImagingResampleHorizontalConvolution8u4x(
    uint8_t* C10_RESTRICT lineOut0,
    uint8_t* C10_RESTRICT lineOut1,
    uint8_t* C10_RESTRICT lineOut2,
    uint8_t* C10_RESTRICT lineOut3,
    int64_t out_xsize,
    const uint8_t* C10_RESTRICT lineIn0,
    const uint8_t* C10_RESTRICT lineIn1,
    const uint8_t* C10_RESTRICT lineIn2,
    const uint8_t* C10_RESTRICT lineIn3,
    int64_t in_xsize,
    const int64_t* idx_ptr_xmin,
    const int64_t* idx_ptr_size,
    const int16_t* kk,
    int kmax,
    unsigned int coefs_precision,
    int64_t num_channels,
    bool is_last_line);

void ImagingResampleHorizontalConvolution8u(
    uint8_t* C10_RESTRICT lineOut,
    int64_t out_xsize,
    const uint8_t* C10_RESTRICT lineIn,
    int64_t in_xsize,
    const int64_t* idx_ptr_xmin,
    const int64_t* idx_ptr_size,
    const int16_t* kk,
    int kmax,
    unsigned int coefs_precision,
    int64_t num_channels,
    bool is_last_line);

void ImagingResampleVerticalConvolution8u(
    uint8_t* C10_RESTRICT lineOut,
    const uint8_t* C10_RESTRICT lineIn,
    int64_t xsize,
    int64_t ids_min,
    int64_t ids_size,
    const int16_t* k,
    unsigned int coefs_precision,
    int64_t num_channels);

template<int num_channels>
void ImagingResampleHorizontal(
    const at::Tensor & unpacked_output,
    const at::Tensor & unpacked_input,
    int ksize,
    const std::vector<at::Tensor>& horiz_indices_weights,
    unsigned int horiz_weights_precision) {

  // Interpolation horizontal pass: we compute x-axis (image width) interpolation outputs.

  // Input data is stored as
  //   input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
  // Weights are float values computed for each output pixel and rescaled to uint16:
  //   weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
  // We want to compute the output as following:
  //   output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
  // where
  //   oR[yoffset + i] = r[yoffset + xmin[i]] * w[i, 0] + ... + r[yoffset + xmin[i] + K-1] * w[i, K-1]
  //   oG[yoffset + i] = g[yoffset + xmin[i]] * w[i, 0] + ... + g[yoffset + xmin[i] + K-1] * w[i, K-1]
  //   oB[yoffset + i] = b[yoffset + xmin[i]] * w[i, 0] + ... + b[yoffset + xmin[i] + K-1] * w[i, K-1]
  //

  // TODO: we may want to merge that into the fallback code (currently called
  // basic_loop_aa_horizontal<uint8_t>)
  // Although this may not be needed if / when we port all this code to use
  // Vec.h since this would potentially give us another fall-back implem

  const int16_t* kk = (int16_t*)(horiz_indices_weights[3].const_data_ptr<double>());

  auto xout = unpacked_output.size(2);
  auto yout = unpacked_output.size(1);
  auto xin = unpacked_input.size(2);
  TORCH_INTERNAL_ASSERT(num_channels == unpacked_input.size(0));

  const int64_t* idx_ptr_xmin = horiz_indices_weights[0].const_data_ptr<int64_t>();
  const int64_t* idx_ptr_size = horiz_indices_weights[1].const_data_ptr<int64_t>();

  uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
  const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();

  int64_t yy = 0;
  auto xout_stride = xout * num_channels;
  auto xin_stride = xin * num_channels;
  for (; yy < yout - 3; yy += 4) {
    ImagingResampleHorizontalConvolution8u4x(
        unpacked_output_p + yy * xout_stride,
        unpacked_output_p + (yy + 1) * xout_stride,
        unpacked_output_p + (yy + 2) * xout_stride,
        unpacked_output_p + (yy + 3) * xout_stride,
        xout,
        unpacked_input_p + yy * xin_stride,
        unpacked_input_p + (yy + 1) * xin_stride,
        unpacked_input_p + (yy + 2) * xin_stride,
        unpacked_input_p + (yy + 3) * xin_stride,
        xin,
        idx_ptr_xmin,
        idx_ptr_size,
        kk,
        ksize,
        horiz_weights_precision,
        num_channels,
        yy + 3 == yout - 1);
  }
  for (; yy < yout; yy++) {
    ImagingResampleHorizontalConvolution8u(
        unpacked_output_p + yy * xout_stride,
        xout,
        unpacked_input_p + yy * xin_stride,
        xin,
        idx_ptr_xmin,
        idx_ptr_size,
        kk,
        ksize,
        horiz_weights_precision,
        num_channels,
        yy == yout - 1);
  }
}

void ImagingResampleVertical(
    const at::Tensor & unpacked_output,
    const at::Tensor & unpacked_input,
    int ksize,
    const std::vector<at::Tensor>& vert_indices_weights,
    unsigned int vert_weights_precision) {

  // Interpolation vertical pass: we compute y-axis interpolation outputs.
  // Input data is stored as
  //   input = [r[0], g[0], b[0], a[0], r[1], g[1], b[1], a[1], r[2], g[2], b[2], a[2], ...]
  // Weights are float values computed for each output pixel and rescaled to uint16:
  //   weights[i] = [w[i, 0], w[i, 1], ..., w[i, K-1]]
  // We want to compute the output as following:
  //   output = [oR[0], oG[0], oB[0], oA[0], oR[1], oG[1], oB[1], oA[1], ...]
  // where
  //   oR[xoffset + i] = r[xoffset + ymin[i]] * w[i, 0] + ... + r[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
  //   oG[xoffset + i] = g[xoffset + ymin[i]] * w[i, 0] + ... + g[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]
  //   oB[xoffset + i] = b[xoffset + ymin[i]] * w[i, 0] + ... + b[xoffset + ymin[i] + (K-1) * xsize] * w[i, K-1]

  // TODO: we may want to merge that into the fallback code (currently called
  // basic_loop_aa_vertical<uint8_t>)
  // Although this may not be needed if / when we port all this code to use
  // Vec.h since this would potentially give us another fall-back implem
  const int16_t* kk = (int16_t*)(vert_indices_weights[3].const_data_ptr<double>());

  const int64_t* idx_ptr_xmin = vert_indices_weights[0].const_data_ptr<int64_t>();
  const int64_t* idx_ptr_size = vert_indices_weights[1].const_data_ptr<int64_t>();

  uint8_t* unpacked_output_p = unpacked_output.data_ptr<uint8_t>();
  const uint8_t* unpacked_input_p = unpacked_input.const_data_ptr<uint8_t>();

  auto xout = unpacked_output.size(2);
  auto yout = unpacked_output.size(1);
  const auto num_channels = unpacked_input.size(0);
  TORCH_INTERNAL_ASSERT(num_channels == unpacked_output.size(0));

  auto xout_stride = xout * num_channels;
  for (const auto yy : c10::irange(yout)) {
    const auto* k = &kk[yy * ksize];
    auto ids_min = idx_ptr_xmin[yy];
    auto ids_size = idx_ptr_size[yy];
    ImagingResampleVerticalConvolution8u(
        unpacked_output_p + yy * xout_stride,
        unpacked_input_p,
        xout,
        ids_min,
        ids_size,
        k,
        vert_weights_precision,
        num_channels);
  }
}

// This is the only public entry point in this file.  It supports bilinear or bicubic
// mode for uint8 dtype when C <= 4, with or without antialias. The
// implem is based on PIL-SIMD.
// Its equivalent implementation (fallback) for when AVX isn't supported or when
// C > 4 is separable_upsample_generic_Nd_kernel_impl()  There are a bunch of
// future improvement that can be done: look for the TODOs in this file.
// For details on how the weights are computed and how the multiplications are
// run on int (instead of float weights), see
// [ Weights computation for uint8_t and multiplication trick ]
// For details on how the AVX kernels are implemented, see
// https://gist.github.com/NicolasHug/47c97d731f05eaad5694c173849b86f5
// See also [ Support for antialias=False as a subcase of antialias=True ] to
// learn more about how the antialias=False case is computed. The same holds
// here: all these kernels are general enough to handle an arbitrary number of
// weights, but when aa=False they could be optimized further.
template <typename scale_type, class F>
void upsample_avx_bilinear_bicubic_uint8(
    const at::Tensor& input_,
    const at::Tensor& output,
    bool align_corners,
    const scale_type& scales,
    bool antialias) {
  auto batch_size = input_.size(0);
  auto num_channels = input_.size(1);
  auto xin = input_.size(3);
  auto yin = input_.size(2);
  auto xout = output.size(3);
  auto yout = output.size(2);

  if (xin == xout && yin == yout) {
    output.copy_(input_);
    return;
  }

  at::Tensor input = input_;
  if (!(input.is_contiguous() || input.is_contiguous(at::MemoryFormat::ChannelsLast))) {
    // If input is not contiguous with memory format channels first or channels last,
    // we explicitly convert the input to contiguous channels last memory format.
    // This simplifies the rest of the code and let us assume that the format is only contiguous channels first or channels last,
    // Most tensors going through this `if` block won't need to go through unpacking, but those having C < 3 may
    // have to (this means 2 copies are made). We could avoid the extra copy by handling non-contiguous input
    // directly within unpack_rgb() and pack_rgb(), but initial attempts showed that this is fairly complex.
    input = input.contiguous(at::MemoryFormat::ChannelsLast);
  }

  auto need_horizontal = xout != xin;
  auto need_vertical = yout != yin;

  int ksize_horiz, ksize_vert;
  std::vector<at::Tensor> horiz_indices_weights, vert_indices_weights;
  unsigned int horiz_weights_precision, vert_weights_precision;

  bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
  bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast);

  if (need_horizontal) {
    int interp_dim = 3;
    auto stride = skip_unpacking ? num_channels : 4;
    std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
        F::compute_index_ranges_int16_weights(
            /*input_size=*/xin,
            /*output_size=*/xout,
            /*stride=*/stride,
            /*ndims=*/4,
            /*reshape_dim=*/interp_dim,
            /*align_corners=*/align_corners,
            /*opt_scale=*/scales[interp_dim - 2],
            /*antialias=*/antialias,
            /*align_i32=*/true);
  }

  if (need_vertical) {
    int interp_dim = 2;
    auto stride = skip_unpacking ? num_channels * xout : 4 * xout;
    std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
        F::compute_index_ranges_int16_weights(
            /*input_size=*/yin,
            /*output_size=*/yout,
            /*stride=*/stride,
            /*ndims=*/4,
            /*reshape_dim=*/interp_dim,
            /*align_corners=*/align_corners,
            /*opt_scale=*/scales[interp_dim - 2],
            /*antialias=*/antialias,
            /*align_i32=*/true);
  }

  at::Tensor buffer_horiz, buffer_vert;
  // Minor optimization: we can avoid allocating an extra buffer if we're performing
  // horizontal-only or vertical-only interpolation, and if the tensor doesn't
  // need repacking
  if (need_horizontal && (need_vertical || !skip_packing)) {
    auto c = skip_unpacking ? num_channels : 4;
    buffer_horiz = at::empty({c, yin, xout}, input.options());
  }
  if (need_vertical && !skip_packing) {
    auto c = skip_unpacking ? num_channels : 4;
    buffer_vert = at::empty({c, yout, xout}, input.options());
  }

  for (const auto i : c10::irange(batch_size)) {

    at::Tensor unpacked_input = skip_unpacking ? input[i] : unpack_rgb(input[i]);
    at::Tensor unpacked_output;

    if (need_horizontal) {
      at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];

      if (skip_unpacking && num_channels == 3) {
        ImagingResampleHorizontal<3>(
          unpacked_output_temp,
          unpacked_input,
          ksize_horiz,
          horiz_indices_weights,
          horiz_weights_precision);
      } else {
        ImagingResampleHorizontal<4>(
            unpacked_output_temp,
            unpacked_input,
            ksize_horiz,
            horiz_indices_weights,
            horiz_weights_precision);
      }
      unpacked_output = unpacked_input = unpacked_output_temp;
    }
    if (need_vertical) {
      unpacked_output = skip_packing ? output[i] : buffer_vert;

      ImagingResampleVertical(
          unpacked_output,
          unpacked_input,
          ksize_vert,
          vert_indices_weights,
          vert_weights_precision
      );
    }

    TORCH_INTERNAL_ASSERT(unpacked_output.defined());

    if (!skip_packing) {
      pack_rgb(unpacked_output, output[i]);
    }
  }
}

void ImagingResampleHorizontalConvolution8u4x(
    uint8_t* C10_RESTRICT lineOut0,
    uint8_t* C10_RESTRICT lineOut1,
    uint8_t* C10_RESTRICT lineOut2,
    uint8_t* C10_RESTRICT lineOut3,
    int64_t out_xsize,
    const uint8_t* C10_RESTRICT lineIn0,
    const uint8_t* C10_RESTRICT lineIn1,
    const uint8_t* C10_RESTRICT lineIn2,
    const uint8_t* C10_RESTRICT lineIn3,
    int64_t in_xsize,
    const int64_t* idx_ptr_xmin,
    const int64_t* idx_ptr_size,
    const int16_t* kk,
    int kmax,
    unsigned int coefs_precision,
    int64_t num_channels,
    bool is_last_line) {

  // Interpolation horizontal pass processing together 4 vertical lines.
  // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
  //   we can encode 4 values as a single uint32 value.
  // - We split the size of weight vector for a given output index as a sum:
  //   ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
  // - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
  // in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").

  // Define shuffling masks (low/high) for num_channels 4 and 3
  // Mask low casts lower half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA:
  //   [r1 g1 b1 a1  r2 g2 b2 a2  ... | R1 G1 B1 A1  R2 G2 B2 A2 ... ] ->
  //   [r1 0 r2 0  g1 0 g2 0  b1 0 b2 0  a1 0 a2 0 | R1 0 R2 0  G1 0 G2 0  B1 0 B2 0  A1 0 A2 0]
  // Mask high casts upper half of each lane to epi16 and reorder RGBARGBA -> RRGGBBAA::
  //   [ ... r3 g3 b3 a3  r4 g4 b4 a4 | ... R3 G3 B3 A3  R4 G4 B4 A4 ] ->
  //   [r3 0 r4 0  g3 0 g4 0  b3 0 b4 0  a3 0 a4 0 | R3 0 R4 0  G3 0 G4 0  B3 0 B4 0  A3 0 A4 0]

  const auto mask_low_c4 = _mm256_set_epi8(
      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
  const auto mask_high_c4 = _mm256_set_epi8(
      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
  const auto mask_low_c3 = _mm256_set_epi8(
      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
  const auto mask_high_c3 = _mm256_set_epi8(
      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);

  const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
  const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;

  const auto stride = num_channels * sizeof(uint8_t);

  TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);

  // out_xsize = output width, out_x = output x index
  // ids_min is the input offset index corresponding to out_x
  // ids_size is the interpolation size for out_x

  // Let's precompute ids_size limits for block 4 and block 2.
  //
  // In block 4 (4 means we process 4 weight values together), we read input data
  // with _mm_loadu_si128, i.e. 16 bytes, per one line:
  // lineIn0 + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
  // --> i <= ids_size - 16.0 / stride
  // Strict boundary:
  // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
  // Soft boundary for reading inside the buffer except its boundaries:
  // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
  // RGBA: b4_delta = b4_delta_soft = 3
  // RGB : b4_delta = 5
  // RGB : b4_delta_soft = 4
  const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4);

  // In block 2 (2 means we process 2 weights values together), we read input data
  // with _mm_loadl_epi64, i.e. 8 bytes, per one line:
  // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
  // --> i <= ids_size - 8.0 / stride
  // Strict boundary:
  // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
  // Soft boundary for reading inside the buffer except its boundaries:
  // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
  // RGBA: b2_delta = b2_delta_soft = 1
  // RGB : b2_delta = 2
  // RGB : b2_delta_soft = 1
  const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1);

  const auto max_out_x_strided = out_xsize * stride;
  const auto max_in_x_strided = in_xsize * stride;

  const auto zero = _mm256_setzero_si256();
  const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1));

  for (const auto out_x : c10::irange(out_xsize)) {
    const auto ids_min = idx_ptr_xmin[out_x];
    const auto ids_size = idx_ptr_size[out_x];
    const auto * k = &kk[out_x * kmax];
    int64_t i = 0;

    auto sss0 = initial;
    auto sss1 = initial;

    const auto * lineIn0_min = lineIn0 + ids_min;
    const auto * lineIn1_min = lineIn1 + ids_min;
    const auto * lineIn2_min = lineIn2 + ids_min;
    const auto * lineIn3_min = lineIn3 + ids_min;

    // block 4
    for (; i < ids_size - b4_delta; i += 4) {
      // Load 4 values from weight vector
      // mmk0 = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ...]
      // mmk1 = [wl_2 wh_2 wl_3 wh_3  wl_2 wh_2 wl_3 wh_3  ...]
      const auto mmk0 = _mm256_set1_epi32(*(int32_t*)&k[i]);
      const auto mmk1 = _mm256_set1_epi32(*(int32_t*)&k[i + 2]);

      // RGBA: Load 8 pixels (4 per line) from input lines 0 and 1:
      // source = [
      //   r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
      //   R0 G0 B0 A0  R1 G1 B1 A1  R2 G2 B2 A2  R3 G3 B3 A3
      // ]
      // RGB: Load 10 pixels (5 per line)
      // source = [
      //   r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
      //   R0 G0 B0 R1  G1 B1 R2 G2  B2 R3 G3 B3  R4 G4 B4 R5
      // ]
      auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
          _mm_loadu_si128((__m128i *) (lineIn0_min + stride * i))),
          _mm_loadu_si128((__m128i *) (lineIn1_min + stride * i)), 1);

      // Apply mask_low:
      // RGBA:
      //   [r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  a0 0 a1 0 | R0 0 R1 0  G0 0 G1 0  B0 0 B1 0  A0 0 A1 0]
      // RGB:
      //   [r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  0 0 0 0 | R0 0 R1 0  G0 0 G1 0  B0 0 B1 0  0 0 0 0]
      auto pix1 = _mm256_shuffle_epi8(source, mask_low);
      // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk0));

      // Apply mask_high:
      // RGBA:
      //   [r2 0 r3 0  g2 0 g3 0  b2 0 b3 0  a2 0 a3 0 | R2 0 R3 0  G2 0 G3 0  B2 0 B3 0  A2 0 A3 0]
      // RGB:
      //   [r2 0 r3 0  g2 0 g3 0  b2 0 b3 0  0 0 0 0 | R2 0 R3 0  G2 0 G3 0  B2 0 B3 0  0 0 0 0]
      auto pix2 = _mm256_shuffle_epi8(source, mask_high);
      // Compute output value as C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix2, mmk1));

      // Same as above to next lines 2 and 3:
      auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
          _mm_loadu_si128((__m128i *) (lineIn2_min + stride * i))),
          _mm_loadu_si128((__m128i *) (lineIn3_min + stride * i)), 1);
      auto pix3 = _mm256_shuffle_epi8(source2, mask_low);
      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix3, mmk0));
      auto pix4 = _mm256_shuffle_epi8(source2, mask_high);
      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix4, mmk1));
    }

    // block 2
    for (; i < ids_size - b2_delta; i += 2) {
      // Load 2 values from weight vector
      // mmk = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ...]
      const auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);

      // Load 4 pixels (2 per line) from input lines 0 and 1:
      // RGBA: source1 = [
      //   r0 g0 b0 a0  r1 g1 b1 a1  0 0 0 0  0 0 0 0
      //   R0 G0 B0 A0  R1 G1 B1 A1  0 0 0 0  0 0 0 0
      // ]
      // RGB: source1 = [
      //   r0 g0 b0 r1  g1 b1 r2  0 0 0 0  0 0 0 0
      //   R0 G0 B0 R1  G1 B1 R2  0 0 0 0  0 0 0 0
      // ]
      auto source1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
          _mm_loadl_epi64((__m128i *) (lineIn0_min + stride * i))),
          _mm_loadl_epi64((__m128i *) (lineIn1_min + stride * i)), 1);
      // Apply mask_low:
      // RGBA:
      //   [r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  a0 0 a1 0 | R0 0 R1 0  G0 0 G1 0  B0 0 B1 0  A0 0 A1 0]
      // RGB:
      //   [r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  0 0 0 0 | R0 0 R1 0  G0 0 G1 0  B0 0 B1 0  0 0 0 0]
      auto pix1 = _mm256_shuffle_epi8(source1, mask_low);
      // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));

      // Same as above for lines 2 and 3:
      auto source2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
          _mm_loadl_epi64((__m128i *) (lineIn2_min + stride * i))),
          _mm_loadl_epi64((__m128i *) (lineIn3_min + stride * i)), 1);
      auto pix2 = _mm256_shuffle_epi8(source2, mask_low);
      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
    }

    // block 1
    const auto i32_aligned = num_channels == 4;
    for (; i < ids_size - 1; i++) {
      // Load 1 value from weight vector
      // mmk = [wl_0 wh_0 0 0  wl_0 wh_0 0 0  ...]
      const auto mmk = _mm256_set1_epi32(k[i]);

      // Load 2 pixels (one per line) from input lines 0 and 1:
      // RGBA: pix1 = [
      //   r0 0 0 0  g0 0 0 0  b0 0 0 0  a0 0 0 0
      //   R0 0 0 0  G0 0 0 0  B0 0 0 0  A0 0 0 0
      // ]
      // RGB: pix1 = [
      //   r0 0 0 0  g0 0 0 0  b0 0 0 0  r1 0 0 0
      //   R0 0 0 0  G0 0 0 0  B0 0 0 0  R1 0 0 0
      // ]
      auto pix1 = _mm256_inserti128_si256(_mm256_castsi128_si256(
          mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
          mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
      // Compute output value as C += w0 * C0 for each channel in 32-bit precision
      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));

      // Same as above for lines 2 and 3
      auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(
          mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned)),
          mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned), 1);
      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
    }

    if (i == ids_size - 1) {
      // last element
      auto mmk = _mm256_set1_epi32(k[i]);
      // For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes
      // lines 0, 1 and 2 won't go out of allocated memory bounds
      auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256(
          mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
          mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix, mmk));

      auto p0 = mm_cvtepu8_epi32(lineIn2_min + stride * i, i32_aligned);
      __m128i p1;
      if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
        uint8_t input[4];
        std::memcpy(input, lineIn3_min + stride * i, 3);
        p1 = mm_cvtepu8_epi32(input, true);
      } else {
        p1 = mm_cvtepu8_epi32(lineIn3_min + stride * i, i32_aligned);
      }
      auto pix2 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1);
      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));
    }

    // Convert fixed point values back to integers (truncating)
    sss0 = _mm256_srai_epi32(sss0, coefs_precision);
    sss1 = _mm256_srai_epi32(sss1, coefs_precision);
    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
    sss0 = _mm256_packs_epi32(sss0, zero);
    sss1 = _mm256_packs_epi32(sss1, zero);
    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
    // (a a b b c c d d) -> (a b c d 0 0 0 0)
    sss0 = _mm256_packus_epi16(sss0, zero);
    sss1 = _mm256_packus_epi16(sss1, zero);

    // Write the output into single uint32
    // (a b c d) -> x_uint32
    auto o0 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss0));
    auto o1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss0, 1));
    auto o2 = _mm_cvtsi128_si32(_mm256_castsi256_si128(sss1));
    auto o3 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sss1, 1));

    const auto out_x_strided = stride * out_x;

    if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
      // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
      // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
      // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
      // value which was previously computed by another line. In other words, it means that we can not overwrite
      // it by simply writing 4 bytes from the register to the output. We'll do the following:
      //               v----------|
      // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
      // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
      // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
      // Output = [... R G B | R1 G1 B1 R2 ...]

      _write_endline_rgb_as_uint32(lineOut0 + out_x_strided, o0);
      _write_endline_rgb_as_uint32(lineOut1 + out_x_strided, o1);
      _write_endline_rgb_as_uint32(lineOut2 + out_x_strided, o2);

      if (C10_UNLIKELY(is_last_line)) {
        // When we handle the last line, we can not access the next 4 bytes
        // as they are out of memory bounds.
        std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, num_channels);
      } else {
        _write_endline_rgb_as_uint32(lineOut3 + out_x_strided, o3);
      }
    } else if (num_channels == 3) {
      // Memcpy 4-bytes is faster than 3-bytes and here
      // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
      // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
      std::memcpy(lineOut0 + out_x_strided, (uint8_t *) &o0, 4);
      std::memcpy(lineOut1 + out_x_strided, (uint8_t *) &o1, 4);
      std::memcpy(lineOut2 + out_x_strided, (uint8_t *) &o2, 4);
      std::memcpy(lineOut3 + out_x_strided, (uint8_t *) &o3, 4);
    } else {
      // num_channels = 4 -> lineOutX + out_x_strided should be uint32 aligned
      *(uint32_t *)(lineOut0 + out_x_strided) = o0;
      *(uint32_t *)(lineOut1 + out_x_strided) = o1;
      *(uint32_t *)(lineOut2 + out_x_strided) = o2;
      *(uint32_t *)(lineOut3 + out_x_strided) = o3;
    }
  }
}

void ImagingResampleHorizontalConvolution8u(
    uint8_t* C10_RESTRICT lineOut,
    int64_t out_xsize,
    const uint8_t* C10_RESTRICT lineIn,
    int64_t in_xsize,
    const int64_t* idx_ptr_xmin,
    const int64_t* idx_ptr_size,
    const int16_t* kk,
    int kmax,
    unsigned int coefs_precision,
    int64_t num_channels,
    bool is_last_line) {

  // Interpolation horizontal pass processing only one vertical line.
  // - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
  //   we can encode 4 values as a single uint32 value.
  // - We split the size of weight vector for a given output index as a sum:
  //   ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
  // - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
  // in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").

  // Define various shuffling masks
  const auto kmask_low = _mm256_set_epi8(
      11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8, 11, 10, 9, 8,
      3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
  const auto kmask_high = _mm256_set_epi8(
      15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12,
      7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4);
  const auto kmask_hl = _mm256_set_epi8(
      7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4,
      3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);

  const auto mask_low_c4 = _mm256_set_epi8(
      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
  const auto mask_high_c4 = _mm256_set_epi8(
      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
  const auto mask_low_c3 = _mm256_set_epi8(
      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0,
      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
  const auto mask_high_c3 = _mm256_set_epi8(
      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6);
  const auto mask_hl_c3 = _mm256_set_epi8(
      -1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
  const auto mask_hl_c4 = _mm256_set_epi8(
      -1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);

  const auto mask_low128_c3 = _mm_set_epi8(
      -1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0);
  const auto mask_low128_c4 = _mm_set_epi8(
      -1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);

  const auto mask_low = (num_channels == 3) ? mask_low_c3 : mask_low_c4;
  const auto mask_high = (num_channels == 3) ? mask_high_c3 : mask_high_c4;
  const auto mask_hl = (num_channels == 3) ? mask_hl_c3 : mask_hl_c4;
  const auto mask_low128 = (num_channels == 3) ? mask_low128_c3 : mask_low128_c4;

  // out_xsize = output width, out_x = output x index
  // ids_min is the input offset index corresponding to out_x
  // ids_size is the interpolation size for out_x

  const auto stride = num_channels * sizeof(uint8_t);
  const auto zero = _mm_setzero_si128();

  TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);

  // Let's precompute ids_size limits for block 8, block 4 and block 2
  //
  // In block 8 (8 means we process 8 weight values together), we read at
  // most 32 bytes input data (16 + 16 bytes for RGBA and 12 + 16 bytes for RGB)
  // lineIn + stride * (i + ids_min) + 32 <= lineIn + stride * (ids_size + ids_min)
  // --> i <= ids_size - 32.0 / stride
  // Strict boundary:
  // --> i < ids_size + 1 - int(ceil(32.0 / stride)) = ids_size - b8_delta
  // Soft boundary for reading inside the buffer except its boundaries:
  // --> i < ids_size + 1 - int(32.0 / stride) = ids_size - b8_delta_soft
  // RGBA: b8_delta = b8_delta_soft = 7
  // RGB : b8_delta = 10
  // RGB : b8_delta_soft = 9
  const auto b8_delta = (stride == 4) ? 7 : (is_last_line ? 10 : 9);

  // In block 4 (4 means we process 4 weight values together), we read
  // 16 bytes of input data.
  // lineIn + stride * (i + ids_min) + 16 <= lineIn0 + stride * (ids_size + ids_min)
  // --> i <= ids_size - 16.0 / stride
  // Strict boundary:
  // --> i < ids_size + 1 - int(ceil(16.0 / stride)) = ids_size - b4_delta
  // Soft boundary for reading inside the buffer except its boundaries:
  // --> i < ids_size + 1 - int(16.0 / stride) = ids_size - b4_delta_soft
  // RGBA: b4_delta = b4_delta_soft = 3
  // RGB : b4_delta = 5
  // RGB : b4_delta_soft = 4
  const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4);

  // In block 2 (2 means we process 2 weight values together), we read
  // 8 bytes of input data.
  // lineIn0 + stride * (i + ids_min) + 8 <= lineIn0 + stride * (ids_size + ids_min)
  // --> i <= ids_size - 8.0 / stride
  // Strict boundary:
  // --> i < ids_size + 1 - int(ceil(8.0 / stride)) = ids_size - b2_delta
  // Soft boundary for reading inside the buffer except its boundaries:
  // --> i < ids_size + 1 - int(8.0 / stride) = ids_size - b2_delta_soft
  // RGBA: b2_delta = b2_delta_soft = 1
  // RGB : b2_delta = 2
  // RGB : b2_delta_soft = 1
  const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1);

  const auto max_out_x_strided = out_xsize * stride;
  const auto max_in_x_strided = in_xsize * stride;

  for (const auto out_x : c10::irange(out_xsize)) {
    __m128i sss;
    const auto ids_min = idx_ptr_xmin[out_x];
    const auto ids_size = idx_ptr_size[out_x];
    const auto * k = &kk[out_x * kmax];
    int64_t i = 0;

    const auto * lineIn_min = lineIn + ids_min;

    if (ids_size < 8) {
      sss = _mm_set1_epi32(1 << (coefs_precision - 1));
    } else {
      // Lower part will be added to higher, use only half of the error
      auto sss256 = _mm256_set1_epi32(1 << (coefs_precision - 2));

      // block 8
      for (; i < ids_size - b8_delta; i += 8) {
        // Load 8 values from weight vector
        auto tmp = _mm_loadu_si128((__m128i*)&k[i]);
        // ksource = [
        //    wl_0 wh_0 wl_1 wh_1  wl_2 wh_2 wl_3 wh_3  wl_4 wh_4 wl_5 wh_5  wl_6 wh_6 wl_7 wh_7
        //    wl_0 wh_0 wl_1 wh_1  wl_2 wh_2 wl_3 wh_3  wl_4 wh_4 wl_5 wh_5  wl_6 wh_6 wl_7 wh_7
        // ]
        auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);

        // RGBA: Load 8 pixels from input:
        // source = [
        //    r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
        //    r4 g4 b4 a4  r5 g5 b5 a5  r6 g6 b6 a6  r7 g7 b7 a7
        // ]
        // RGB: Load 10 pixels from input (however we can process only 8 pixels):
        // source = [
        //    r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
        //    r4 g4 b4 r5  g5 b5 r6 g6  b6 r7 g7 b7  r8 g8 b8 r9
        // ]
        auto source = _mm256_inserti128_si256(_mm256_castsi128_si256(
            _mm_loadu_si128((__m128i *) (lineIn_min + stride * i))),
            _mm_loadu_si128((__m128i *) (lineIn_min + stride * (i + 4))), 1);

        // Extract lower part of each lane, cast to epi16 and reorder RGBARGBA -> RRGGBBAA
        // RGBA: pix1 = [
        //   r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  a0 0 a1 0
        //   r4 0 r5 0  g4 0 g5 0  b4 0 b5 0  a4 0 a5 0
        // ]
        // RGB: pix1 = [
        //   r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  0 0 0 0
        //   r4 0 r5 0  g4 0 g5 0  b4 0 b5 0  0 0 0 0
        // ]
        auto pix1 = _mm256_shuffle_epi8(source, mask_low);
        // mmk1 = [
        //   wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ...  ...
        //   wl_4 wh_4 wl_5 wh_5  wl_4 wh_4 wl_5 wh_5  ...  ...
        // ]
        auto mmk1 = _mm256_shuffle_epi8(ksource, kmask_low);
        // Compute output value as
        //   C += w0 * C0 + w1 * C1
        //   C += w4 * C4 + w5 * C5 for each channel in 32-bit precision
        sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix1, mmk1));

        // Same as above for higher part of each lane
        auto pix2 = _mm256_shuffle_epi8(source, mask_high);
        auto mmk2 = _mm256_shuffle_epi8(ksource, kmask_high);
        // Compute output value as
        //    C += w2 * C2 + w3 * C3
        //    C += w6 * C6 + w7 * C7 for each channel in 32-bit precision
        sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix2, mmk2));
      }

      // block 4
      for (; i < ids_size - b4_delta; i += 4) {
        // Load 4 values from weight vector
        auto tmp = _mm_loadl_epi64((__m128i *) &k[i]);
        // ksource = [
        //    wl_0 wh_0 wl_1 wh_1  wl_2 wh_2 wl_3 wh_3  0 0 0 0  0 0 0 0
        //    wl_0 wh_0 wl_1 wh_1  wl_2 wh_2 wl_3 wh_3  0 0 0 0  0 0 0 0
        // ]
        auto ksource = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);

        // Load pixels from input line
        tmp = _mm_loadu_si128((__m128i *) (lineIn_min + stride * i));
        // RGBA: source = [
        //   r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
        //   r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
        // ]
        // RGB: source = [
        //   r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
        //   r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
        // ]
        auto source = _mm256_insertf128_si256(_mm256_castsi128_si256(tmp), tmp, 1);

        // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
        // RGBA: pix = [
        //   r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  a0 0 a1 0
        //   r2 0 r3 0  g2 0 g3 0  b2 0 b3 0  a2 0 a3 0
        // ]
        // RGB: pix = [
        //   r0 0 r1 0  g0 0 g1 0  b0 0 b1 0  0 0 0 0
        //   r2 0 r3 0  g2 0 g3 0  b2 0 b3 0  0 0 0 0
        // ]
        auto pix = _mm256_shuffle_epi8(source, mask_hl);
        // mmk = [
        //   wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ... ...
        //   wl_2 wh_2 wl_3 wh_3  wl_2 wh_2 wl_3 wh_3  ... ...
        // ]
        auto mmk = _mm256_shuffle_epi8(ksource, kmask_hl);
        // Compute output value as
        //   C += w0 * C0 + w1 * C1
        //   C += w2 * C2 + w3 * C3 for each channel in 32-bit precision
        sss256 = _mm256_add_epi32(sss256, _mm256_madd_epi16(pix, mmk));
      }

      // Sum results between the lanes
      sss = _mm_add_epi32(
          _mm256_extracti128_si256(sss256, 0),
          _mm256_extracti128_si256(sss256, 1));
    }

    // block 2
    for (; i < ids_size - b2_delta; i += 2) {
      // Load 2 values from weight vector
      // mmk = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ...]
      auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);
      // Load pixels from input line
      // RGBA: source = [
      //   r0 g0 b0 a0  r1 g1 b1 a1  0 0 0 0  0 0 0 0
      // ]
      // RGB: source = [
      //   r0 g0 b0 r1  g1 b1 r2 g2  0 0 0 0  0 0 0 0
      // ]
      auto source = _mm_loadl_epi64((__m128i *) (lineIn_min + stride * i));
      // Cast source to epi16 and reorder RGBARGBA -> RRGGBBAA
      auto pix = _mm_shuffle_epi8(source, mask_low128);
      // Compute output value as C += w0 * C0 + w1 * C1 for each channel in 32-bit precision
      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
    }

    // block 1
    const auto i32_aligned = num_channels == 4;
    for (; i < ids_size - 1; i++) {
      // Load 1 value from weight vector
      // mmk = [wl_0 wh_0 0 0  wl_0 wh_0 0 0  ...]
      auto mmk = _mm_set1_epi32(k[i]);
      // Load one pixel from input line
      // RGBA: pix = [
      //   r0 0 0 0  g0 0 0 0  b0 0 0 0  a0 0 0 0
      // ]
      // RGB: pix = [
      //   r0 0 0 0  g0 0 0 0  b0 0 0 0  r1 0 0 0
      // ]
      auto pix = mm_cvtepu8_epi32(lineIn_min + stride * i, i32_aligned);
      // Compute output value as C += w0 * C0 for each channel in 32-bit precision
      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
    }

    if (i == ids_size - 1) {
      // last element
      auto mmk = _mm_set1_epi32(k[i]);
      __m128i pix;
      auto p = lineIn_min + stride * i;
      if (num_channels == 3 && C10_UNLIKELY(is_last_line && ids_min + stride * i + 4 >= max_in_x_strided)) {
        uint8_t input[4];
        std::memcpy(input, p, 3);
        pix = mm_cvtepu8_epi32(input, true);
      } else {
        pix = mm_cvtepu8_epi32(p, i32_aligned);
      }
      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
    }

    // Convert fixed point values back to integers (truncating)
    sss = _mm_srai_epi32(sss, coefs_precision);
    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d 0 0 0 0 0 0 0 0)
    sss = _mm_packs_epi32(sss, zero);
    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
    // (a a b b c c d d) -> (a b c d 0 0 0 0)
    sss = _mm_packus_epi16(sss, zero);
    // Write the output into single uint32
    // (a b c d) -> x_uint32
    auto o = _mm_cvtsi128_si32(sss);
    const auto out_x_strided = stride * out_x;
    if (num_channels == 3 && C10_UNLIKELY(out_x_strided + 4 >= max_out_x_strided)) {
      if (C10_UNLIKELY(is_last_line)) {
        // When we handle the last line, we can not access the next 4 bytes
        // as they are out of memory bounds.
        std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 3);
      } else {
        // Memcpy 4-bytes is faster than 3-bytes and this is a boundary case when we want to write
        // 4 bytes (R G B | X) to the output buffer (X1 X2 X3 | R1).
        // The 4th byte in the register (X) has a garbage value and 4th byte in the output buffer (R1) has a correct
        // value which was previously computed by another line. In other words, it means that we can not overwrite
        // it by simply writing 4 bytes from the register to the output. We'll do the following:
        //               v----------|
        // Output = [... X1 X2 X3 | R1 G1 B1 R2 ...]
        // First, we write R1 value to the 4th byte of (R G B | X) -> (R G B | R1)
        // Second, we write 4 bytes from the register to the output: (X1 X2 X3 | R1) -> (R G B | R1)
        // Output = [... R G B | R1 G1 B1 R2 ...]
        _write_endline_rgb_as_uint32(lineOut + out_x_strided, o);
      }
    } else if (num_channels == 3) {
      // Memcpy 4-bytes is faster than 3-bytes and here
      // we simply write 4 bytes (... R G B X 0 0 0 0 0 ...) where X is a garbage value
      // that we will overwrite on the next iteration: (... R G B R G B X 0 0 ...)
      std::memcpy(lineOut + out_x_strided, (uint8_t *) &o, 4);
    } else {
      // num_channels = 4 -> lineOut + out_x_strided should be uint32 aligned
      *(uint32_t *)(lineOut + out_x_strided) = o;
    }
  }
}

void ImagingResampleVerticalConvolution8u(
    uint8_t* C10_RESTRICT lineOut,
    const uint8_t* C10_RESTRICT lineIn,
    int64_t xsize,
    int64_t ids_min,
    int64_t ids_size,
    const int16_t* k,
    unsigned int coefs_precision,
    int64_t num_channels) {

  // Interpolation vertical pass processing one line.
  // - We process x-axis data with blocks of 8, 2 and 1
  // - We split the size of weight vector for a given output index as a sum: K = n * 2 + m.

  // xsize = output width, also equals to input width
  // ids_size = interpolation size
  // ids_min = input y start index
  const auto stride = num_channels * sizeof(uint8_t);

  TORCH_INTERNAL_ASSERT(stride == 3 || stride == 4);

  const int64_t data_size = xsize * stride;
  const int64_t data_stride = stride;
  constexpr auto vec_size = 256 / 8;

  const auto initial = _mm_set1_epi32(1 << (coefs_precision - 1));
  const auto initial_256 = _mm256_set1_epi32(1 << (coefs_precision - 1));
  const auto zero = _mm_setzero_si128();
  const auto zero_256 = _mm256_setzero_si256();

  int64_t j = 0;
  // block 8
  const auto b8_usable_vec_stride = (vec_size / data_stride) * data_stride;
  for (; j < data_size - vec_size; j += b8_usable_vec_stride) {
    auto sss0 = initial_256;
    auto sss1 = initial_256;
    auto sss2 = initial_256;
    auto sss3 = initial_256;
    int64_t i = 0;
    const auto * lineIn_min = lineIn + j + ids_min;

    for (; i < ids_size - 1; i += 2) {
      // Load 2 values from weight vector
      auto mmk = _mm256_set1_epi32(*(int32_t*)&k[i]);

      // RGBA: Load 8 pixels per line
      // source1 = [
      //    r0 g0 b0 a0  r1 g1 b1 a1  r2 g2 b2 a2  r3 g3 b3 a3
      //    r4 g4 b4 a4  r5 g5 b5 a5  r6 g6 b6 a6  r7 g7 b7 a7
      // ]
      // RGB: Load 10 pixels per line (however we can process only 8 pixels):
      // source1 = [
      //    r0 g0 b0 r1  g1 b1 r2 g2  b2 r3 g3 b3  r4 g4 b4 r5
      //    r4 g4 b4 r5  g5 b5 r6 g6  b6 r7 g7 b7  r8 g8 b8 r9
      // ]
      auto source1 =
          _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * i));
      auto source2 =
          _mm256_loadu_si256((__m256i*)(lineIn_min + data_size * (i + 1)));

      // Interleave source1 and source2 from the low half of each 128-bit lane
      // and cast the result to epi16
      // RGBA: pix1 = [
      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  a0 0 A0 0
      //    r1 0 R1 0  g1 0 G1 0  b1 0 B1 0  a1 0 A1 0
      // ]
      // RGB: pix1 = [
      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  0 0 0 0
      //    r1 0 R1 0  g1 0 G1 0  b1 0 B1 0  0 0 0 0
      // ]
      auto source_lo = _mm256_unpacklo_epi8(source1, source2);
      auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
      // Compute output value as
      //   C += w0 * c0 + w1 * C0
      //   C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));

      // RGBA: pix2 = [
      //    r2 0 R2 0  g2 0 G2 0  b2 0 B2 0  a2 0 A2 0
      //    r3 0 R3 0  g3 0 G3 0  b3 0 B3 0  a3 0 A3 0
      // ]
      // RGB: pix2 = [
      //    r2 0 R2 0  g2 0 G2 0  b2 0 B2 0  0 0 0 0
      //    r3 0 R3 0  g3 0 G3 0  b3 0 B3 0  0 0 0 0
      // ]
      auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
      // Compute output value as
      //   C += w0 * c2 + w1 * C2
      //   C += w0 * c3 + w1 * C3 for each channel in 32-bit precision
      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));

      // Same as above for the high half of each 128-bit lane
      auto source_hi = _mm256_unpackhi_epi8(source1, source2);
      auto pix3 = _mm256_unpacklo_epi8(source_hi, zero_256);
      sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
      auto pix4 = _mm256_unpackhi_epi8(source_hi, zero_256);
      sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
    }
    // Same processing as above but with a single weight value
    for (; i < ids_size; i += 1) {
      auto mmk = _mm256_set1_epi32(k[i]);

      auto source1 = _mm256_loadu_si256((__m256i*)(lineIn_min + i * data_size));

      auto source_lo = _mm256_unpacklo_epi8(source1, zero_256);
      auto pix1 = _mm256_unpacklo_epi8(source_lo, zero_256);
      sss0 = _mm256_add_epi32(sss0, _mm256_madd_epi16(pix1, mmk));
      auto pix2 = _mm256_unpackhi_epi8(source_lo, zero_256);
      sss1 = _mm256_add_epi32(sss1, _mm256_madd_epi16(pix2, mmk));

      auto source_hi = _mm256_unpackhi_epi8(source1, zero_256);
      auto pix3 = _mm256_unpacklo_epi8(source_hi, _mm256_setzero_si256());
      sss2 = _mm256_add_epi32(sss2, _mm256_madd_epi16(pix3, mmk));
      auto pix4 = _mm256_unpackhi_epi8(source_hi, _mm256_setzero_si256());
      sss3 = _mm256_add_epi32(sss3, _mm256_madd_epi16(pix4, mmk));
    }
    // Convert fixed point values back to integers (truncating)
    sss0 = _mm256_srai_epi32(sss0, coefs_precision);
    sss1 = _mm256_srai_epi32(sss1, coefs_precision);
    sss2 = _mm256_srai_epi32(sss2, coefs_precision);
    sss3 = _mm256_srai_epi32(sss3, coefs_precision);
    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
    sss0 = _mm256_packs_epi32(sss0, sss1);
    sss2 = _mm256_packs_epi32(sss2, sss3);
    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
    // (a a b b c c d d) -> (a b c d)
    sss0 = _mm256_packus_epi16(sss0, sss2);

    // Stores 32 bytes
    _mm256_storeu_si256((__m256i*)(lineOut + j), sss0);
  }

  // TODO: Do we also need block 4 ???
  // block 2
  const auto b2_usable_vec_stride = (8 / data_stride) * data_stride;
  for (; j < data_size - vec_size / 4; j += b2_usable_vec_stride) {
    auto sss0 = initial;
    auto sss1 = initial;
    int64_t i = 0;
    const auto * lineIn_min = lineIn + j + ids_min;

    for (; i < ids_size - 1; i += 2) {
      // Load 2 values from weight vector
      // mmk = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ... ]
      auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);

      // Load 2 pixels per line
      // RGBA: source1 = [
      //    r0 g0 b0 a0  r1 g1 b1 a1  0 0 0 0  0 0 0 0
      // ]
      // RGB: source1 = [
      //    r0 g0 b0 r1  g1 b1 r2 g2  0 0 0 0  0 0 0 0
      // ]
      auto source1 = _mm_loadl_epi64((__m128i *) (lineIn_min + i * data_size));
      auto source2 = _mm_loadl_epi64((__m128i *) (lineIn_min + (i + 1) * data_size));
      // Interleave source1 and source2 and cast the result to epi16
      // RGBA: pix = [
      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  a0 0 A0 0
      // ]
      // RGB: pix = [
      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  0 0 0 0
      // ]
      auto source = _mm_unpacklo_epi8(source1, source2);
      auto pix = _mm_unpacklo_epi8(source, zero);
      // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
      sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix, mmk));
      // RGBA: pix = [
      //    r1 0 R1 0  g1 0 G1 0  b1 0 B1 0  a1 0 A1 0
      // ]
      // RGB: pix = [
      //    r1 0 R1 0  g1 0 G1 0  b1 0 B1 0  0 0 0 0
      // ]
      pix = _mm_unpackhi_epi8(source, zero);
      // Compute output value as C += w0 * c1 + w1 * C1 for each channel in 32-bit precision
      sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix, mmk));
    }
    // Same processing as above but with a single weight value
    for (; i < ids_size; i += 1) {
      auto mmk = _mm_set1_epi32(k[i]);

      auto source1 = _mm_loadl_epi64((__m128i*) (lineIn_min + i * data_size));

      auto source = _mm_unpacklo_epi8(source1, zero);
      auto pix1 = _mm_unpacklo_epi8(source, zero);
      sss0 = _mm_add_epi32(sss0, _mm_madd_epi16(pix1, mmk));
      auto pix2 = _mm_unpackhi_epi8(source, zero);
      sss1 = _mm_add_epi32(sss1, _mm_madd_epi16(pix2, mmk));
    }
    // Convert fixed point values back to integers (truncating)
    sss0 = _mm_srai_epi32(sss0, coefs_precision);
    sss1 = _mm_srai_epi32(sss1, coefs_precision);
    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
    sss0 = _mm_packs_epi32(sss0, sss1);
    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
    // (a a b b c c d d) -> (a b c d)
    sss0 = _mm_packus_epi16(sss0, sss0);
    // Store 2 pixels to the output
    _mm_storel_epi64((__m128i*)(lineOut + j), sss0);
  }

  // block 1
  const auto b1_usable_vec_stride = (4 / data_stride) * data_stride;
  const auto i32_aligned = num_channels == 4;
  for (; j < data_size - 4; j += b1_usable_vec_stride) {
    auto sss = initial;
    int64_t i = 0;
    const auto * lineIn_min = lineIn + j + ids_min;

    for (; i < ids_size - 1; i += 2) {
      // Load 2 values from weight vector
      // mmk = [wl_0 wh_0 wl_1 wh_1  wl_0 wh_0 wl_1 wh_1  ... ]
      auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);

      // Load one pixel per line
      // RGBA: source1 = [
      //    r0 g0 b0 a0  0 0 0 0  0 0 0 0  0 0 0 0
      // ]
      // RGB: source1 = [
      //    r0 g0 b0 r1  0 0 0 0  0 0 0 0  0 0 0 0
      // ]
      auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
      auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);

      // Interleave source1 and source2 and cast the result to epi16
      // RGBA: pix = [
      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  a0 0 A0 0
      // ]
      // RGB: pix = [
      //    r0 0 R0 0  g0 0 G0 0  b0 0 B0 0  0 0 0 0
      // ]
      auto source = _mm_unpacklo_epi8(source1, source2);
      auto pix = _mm_unpacklo_epi8(source, zero);
      // Compute output value as C += w0 * c0 + w1 * C0 for each channel in 32-bit precision
      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
    }

    for (; i < ids_size; i++) {
      auto mmk = _mm_set1_epi32(k[i]);
      auto pix = mm_cvtepu8_epi32(lineIn_min + i * data_size, i32_aligned);
      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
    }
    sss = _mm_srai_epi32(sss, coefs_precision);
    sss = _mm_packs_epi32(sss, zero);
    sss = _mm_packus_epi16(sss, zero);

    auto o = _mm_cvtsi128_si32(sss);

    // Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3
    // It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data.
    // We also won't go out of bounds of lineOut memory allocation
    std::memcpy(lineOut + j, (uint8_t *) &o, 4);
  }

  for (; j < data_size; j += data_stride) {
    auto sss = initial;
    int64_t i = 0;
    const auto * lineIn_min = lineIn + j + ids_min;
    // For RGBA we can use (ids_size - 1) as tighter limit but for RGB we can read outside memory boundary
    // for the last remaining line
    for (; i < ids_size - 2; i += 2) {
      // Load two coefficients at once
      auto mmk = _mm_set1_epi32(*(int32_t*)&k[i]);

      // Load 2 lines
      auto source1 = mm_cvtsi32_si128(lineIn_min + i * data_size, i32_aligned);
      auto source2 = mm_cvtsi32_si128(lineIn_min + (i + 1) * data_size, i32_aligned);

      auto source = _mm_unpacklo_epi8(source1, source2);
      auto pix = _mm_unpacklo_epi8(source, zero);
      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
    }

    // Same processing as above but with a single weight value
    for (; i < ids_size; i++) {
      auto mmk = _mm_set1_epi32(k[i]);

      const uint8_t * p = lineIn_min + i * data_size;
      __m128i pix;
      // There is no much perf gain using more detailed condition like
      // num_channels == 3 && ids_min + j + data_size * i + 4 >= in_max_size
      // const int64_t in_max_size = data_size * in_ysize;
      if (num_channels == 3) {
        uint8_t input[4];
        std::memcpy(input, p, 3);
        pix = mm_cvtepu8_epi32(input, true);
      } else {
        pix = mm_cvtepu8_epi32(p, true);
      }
      sss = _mm_add_epi32(sss, _mm_madd_epi16(pix, mmk));
    }

    // Convert fixed point values back to integers (truncating)
    sss = _mm_srai_epi32(sss, coefs_precision);
    // Convert packed signed 32-bit integers to packed 16-bit integers using signed saturation
    // (a a a a b b b b c c c c d d d d) -> (a a b b c c d d)
    sss = _mm_packs_epi32(sss, zero);
    // Convert packed signed 16-bit integers to packed 8-bit integers using unsigned saturation
    // (a a b b c c d d) -> (a b c d)
    sss = _mm_packus_epi16(sss, zero);
    // Store one pixel to the output
    auto o = _mm_cvtsi128_si32(sss);
    if (num_channels == 3 && C10_UNLIKELY(j + 4 >= data_size)) {
      std::memcpy(lineOut + j, (uint8_t *) &o, 3);
    } else {
      std::memcpy(lineOut + j, (uint8_t *) &o, 4);
    }
  }
}

} // anonymous namespace
#endif // CPU_CAPABILITY_AVX2

#else
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
