#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
#pragma once
#include <c10/core/DispatchKey.h>
#include <c10/macros/Export.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <c10/util/llvmMathExtras.h>
#include <array>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <ostream>
#include <string>
#include <type_traits>

C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")

namespace c10 {

struct FunctionalityOffsetAndMask {
  // empty constructor shouldn't be used; only needed to initialize
  // the array before populating it.
  FunctionalityOffsetAndMask() = default;
  FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
      : offset(offset), mask(mask) {}
  // This needs to big enough to cover the size of the operator table.
  uint16_t offset{};
  // See Note [No More Than 16 Backends]
  // This mask needs to be big enough to mask all of the backend bits.
  // We probably don't ever want to have more than 16 backend bits, so uint16_t
  // should be enough.
  uint16_t mask{};
};
static_assert(
    c10::num_runtime_entries < 65536,
    "The dispatcher currently only supports up to 2^16 runtime entries");

C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
initializeFunctionalityOffsetsAndMasks();

C10_ALWAYS_INLINE static const std::
    array<FunctionalityOffsetAndMask, num_functionality_keys>&
    offsetsAndMasks() {
  static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
  return offsets_and_masks_;
}

// A representation of a set of DispatchKeys. A DispatchKeySet contains both
// "functionality" bits and "backend bits", and every tensor holds its own
// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
// keyset on every input tensor, or’ing them together, and dispatching to a
// specific piece of functionality. The functionality bits are *ordered*. When
// multiple functionality bits are set, we use the highest priority
// functionality. Similarly, multiple backend bits can theoretically be set if
// you call an operator with multiple tensors from difference devices (e.g. CPU
// and CUDA), although support for mixed device dispatch is limited (the only
// kernels that gracefully handle mixed device inputs for now are cuda kernels
// that take in a scalar cpu tensor).

// A representation of a set of DispatchKeys.  A tensor may have multiple
// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
// DispatchKeySet specifies what type ids apply.  The internal representation is
// as a 64-bit bit set (this means only 64 tensor type ids are supported).
//
// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
// "what is the highest priority DispatchKey in the set"?  (The set itself is
// not ordered; two sets with the same ids will always have the ids ordered in
// the same way.)
//
// Note [DispatchKeySet Internal Representation]
// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
// that get passed around at runtime.
// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
// and individual dispatch keys.
//
// First: why do we have this distinction, and why not map every dispatch key
// directly to a bit? This is mostly because we have several types of
// functionalities that different backends would like to customize. For example,
// we have:
// - "Dense":     CPU, CUDA, XLA, ... (~12 keys)
// - "Sparse":    SparseCPU, SparseCUDA, ...
// - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ...
// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
// - "Autograd":  AutogradCPU, AutogradCUDA, Autograd XLA, ...
// The problem is that total number of keys grows quadratically with [#
// backends] x [# functionalities], making it very difficult to map each key
// directly to a bit in a bitset without dramatically increasing the size of the
// bitset over time.
//
// The two enums (BackendComponent and DispatchKey) can be divided roughly into
// 5 categories.
//
// (1) "Building block" keys
//    (a) backends: Everything in the BackendComponent enum (e.g. CPUBit,
//    CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys
//    (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense)
// (2) "Runtime" keys
//    (a) "non-customizable backends" (e.g. FPGA)
//    (b) "non-customizable functionalities" (e.g. Functionalize)
//    (c) "per-backend instances of customizable functionalities" (e.g. CPU,
//    SparseCPU, AutogradCPU)
// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
//
// (1) Building block keys always correspond to individual bits in a
// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
// runtime keys. e.g.
//     auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
//     DispatchKey::Dense});
//     // The keyset has the runtime dense-cpu key.
//     dense_cpu_ks.has(DispatchKey::CPU);
//     // And it contains the building block keys too.
//     dense_cpu_ks.has(DispatchKey::CPUBit);
//     dense_cpu_ks.has(DispatchKey::Dense);
//
// Not every backend and not every functionality counts as a "building block
// key". This is mostly to give us more levers to pull in the design space.
// Backend keys and functionality keys that count as "building blocks" will
// contribute to a full cross product of functionality that can be overridden.
//
// For example, right now we have at least 12 "backend" building
// blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality"
// building blocks (Dense, Sparse, SparseCsr, Quantized,
// AutogradFunctionality, ...). These keys together allow every
// dispatcher operator to be customized in up to 12*4 different
// ways. Each of those requires a slot in the operator table of every
// dispatcher operator.  Not every piece of functionality necessarily
// needs to be customizable per-backend, and not every backend
// necessarily needs to be able to customize every type of
// functionality.
//
//
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
// dispatch table, and you can directly register kernels to a runtime dispatch
// key.
//
// For per-backend functionalities like "Dense" or "AutogradFunctionality",
// you can think of the corresponding runtime dispatch keys as "instances" of
// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
// runtime instances of the "Dense" building block key.

// (2a) and (2b) are represented identically in the DispatchKeySet logic:
// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
// customizable per backend.
//   In order to do so, we'd need to promote it to a per-backend functionality
//   "building block" key.
// - non-customizable backends (e.g. FPGA) can NOT customize existing
// functionality like Sparse, Autograd, etc.
//   In order to do so, we'd need to promote it to a backend "building block"
//   key.
//
// In both cases, these keys directly correspond to runtime slots in the
// operator table.
//
//
// (3) "Alias" keys
// See Note [Alias Dispatch Keys]
//
// Final note: for anyone making future changes to the Dispatcher +
// DispatchKeySet internals, there's a closed PR with a basic
// python-implementation of the Dispatcher that might be useful in quickly
// testing out and validating changes. See it at
// https://github.com/pytorch/pytorch/pull/68743

// An undefined tensor is one with an empty tensor type set.
class DispatchKeySet final {
 public:
  enum Full { FULL };
  enum FullAfter { FULL_AFTER };
  enum Raw { RAW };

  // NB: default constructor representation as zero is MANDATORY as
  // use of DispatchKeySet in TLS requires this.
  constexpr DispatchKeySet() = default;

  constexpr DispatchKeySet(Full /*unused*/)
      : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}

  constexpr DispatchKeySet(FullAfter /*unused*/, DispatchKey t)
      // LSB after t are OK, but not t itself.
      // "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
      // Quantized > Dense). But backends don't really have an ordering.
      // Therefore, we're enforcing that FullAfter can only be used on
      // "functionality" keys.
      : repr_(
            (1ULL
             << (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
                 1)) -
            1) {
    *this = add(DispatchKey::PythonDispatcher);
  }

  // Public version of DispatchKeySet(uint64_t) API; external users
  // must be explicit when they do this!
  constexpr DispatchKeySet(Raw /*unused*/, uint64_t x) : repr_(x) {}

  constexpr explicit DispatchKeySet(BackendComponent k) {
    if (k == BackendComponent::InvalidBit) {
      repr_ = 0;
    } else {
      repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
    }
  }

  constexpr explicit DispatchKeySet(DispatchKey k) {
    // NOLINTNEXTLINE(bugprone-branch-clone)
    if (k == DispatchKey::Undefined) {
      // Case 1: handle Undefined specifically
      repr_ = 0;
    } else if (k <= DispatchKey::EndOfFunctionalityKeys) {
      // Case 2: handle "functionality-only" keys
      // These keys have a functionality bit set, but no backend bits
      // These can technically be either:
      // - valid runtime keys (e.g. DispatchKey::AutogradOther,
      // DispatchKey::FuncTorchBatched, etc)
      // - "building block" keys that aren't actual runtime keys (e.g.
      // DispatchKey::Dense or Sparse)
      uint64_t functionality_val = 1ULL
          << (num_backends + static_cast<uint8_t>(k) - 1);
      repr_ = functionality_val;
    } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
      // Case 3: "runtime" keys that have a functionality bit AND a backend bit.
      // First compute which bit to flip for the functionality.
      auto functionality_k = toFunctionalityKey(k);
      // The - 1 is because Undefined is technically a "functionality" that
      // doesn't show up in the bitset. So e.g. Dense is technically the second
      // functionality, but the lowest functionality bit.
      uint64_t functionality_val = 1ULL
          << (num_backends + static_cast<uint8_t>(functionality_k) - 1);

      // then compute which bit to flip for the backend
      // Case 4a: handle the runtime instances of "per-backend functionality"
      // keys For example, given DispatchKey::CPU, we should set:
      // - the Dense functionality bit
      // - the CPUBit backend bit
      // first compute which bit to flip for the backend
      auto backend_k = toBackendComponent(k);
      uint64_t backend_val = backend_k == BackendComponent::InvalidBit
          ? 0
          : 1ULL << (static_cast<uint8_t>(backend_k) - 1);
      repr_ = functionality_val + backend_val;
    } else {
      // At this point, we should have covered every case except for alias keys.
      // Technically it would be possible to add alias dispatch keys to a
      // DispatchKeySet, but the semantics are a little confusing and this
      // currently isn't needed anywhere.
      repr_ = 0;
    }
  }

  constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
    uint64_t repr = 0;
    for (auto k : ks) {
      repr |= DispatchKeySet(k).repr_;
    }
    return repr;
  }

  constexpr uint64_t backend_bits_to_repr(
      std::initializer_list<BackendComponent> ks) {
    uint64_t repr = 0;
    for (auto k : ks) {
      repr |= DispatchKeySet(k).repr_;
    }
    return repr;
  }

  explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
      : repr_(keys_to_repr(ks)) {}

  explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
      // Note: for some reason, putting this logic directly in the constructor
      // appears to fail to compile on CUDA 10.1.
      // See an example internal failure at
      // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
      : repr_(backend_bits_to_repr(ks)) {}

  // Test if a DispatchKey is in the set
  inline bool has(DispatchKey t) const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
    return has_all(DispatchKeySet(t));
  }
  constexpr bool has_backend(BackendComponent t) const {
    return has_all(DispatchKeySet(t));
  }

  // Test if a DispatchKey is in the set
  // Given a DispatchKeySet of functionality keys and (potentially) backend
  // keys, tests if all of them are in the current set.
  constexpr bool has_all(DispatchKeySet ks) const {
    return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
  }

  // Given a DispatchKeySet of functionality keys and (potentially) backend
  // keys, tests if any of them are in the current set. This could technically
  // be pretty easily implemented using has(). It is strictly a perf
  // optimization though. There are many places in the code base where we want
  // to test for multiple functionality keys together. HOWEVER, runtime
  // per-backend functionality keys aren't allowed to be used with this
  // function, because you can end up with weird results. e.g.
  // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
  // would return true.
  inline bool has_any(DispatchKeySet ks) const {
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        // Either there are no backend bits in the input keyset
        ((ks.repr_ & full_backend_mask) == 0) ||
        // or there are no per-backend-functionality bits
        // See [Note: Per-Backend Functionality Dispatch Keys]
        ((ks &
          DispatchKeySet({
                             DispatchKey::Dense,
                             DispatchKey::Quantized,
                             DispatchKey::Sparse,
                             DispatchKey::SparseCsr,
                             DispatchKey::AutogradFunctionality,
                         })
              .repr_) == 0));
    return static_cast<bool>((repr_ & ks.repr_) != 0);
  }
  // Test if DispatchKeySet is a superset of ks.
  bool isSupersetOf(DispatchKeySet ks) const {
    return (repr_ & ks.repr_) == ks.repr_;
  }
  // Perform set union
  constexpr DispatchKeySet operator|(DispatchKeySet other) const {
    return DispatchKeySet(repr_ | other.repr_);
  }
  // Perform set intersection
  constexpr DispatchKeySet operator&(DispatchKeySet other) const {
    return DispatchKeySet(repr_ & other.repr_);
  }
  // Compute the set difference self - other,
  // but ONLY for the functionality keys.
  // Any backend bits set on self will remain unchanged.
  // See Note [Removing keys from DispatchKeySet Only Affects Functionality
  // Keys]
  constexpr DispatchKeySet operator-(DispatchKeySet other) const {
    return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
  }

  // Compute self ^ other
  constexpr DispatchKeySet operator^(DispatchKeySet other) const {
    return DispatchKeySet(repr_ ^ other.repr_);
  }
  bool operator==(DispatchKeySet other) const {
    return repr_ == other.repr_;
  }
  bool operator!=(DispatchKeySet other) const {
    return repr_ != other.repr_;
  }
  // Add a DispatchKey to the DispatchKey set.  Does NOT mutate,
  // returns the extended DispatchKeySet!
  [[nodiscard]] constexpr DispatchKeySet add(DispatchKey t) const {
    return *this | DispatchKeySet(t);
  }
  [[nodiscard]] constexpr DispatchKeySet add(DispatchKeySet ks) const {
    return *this | ks;
  }

  // Remove a DispatchKey from the DispatchKey set.
  // This is generally not an operation you should be doing
  // (it's used to implement the printing overload, operator<<)
  //
  // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
  // Only functionality bits are allowed to be removed from a keyset.
  // For now, we're only allowing removal of "functionality bits" from the
  // keyset, which is specifically needed by the fallthrough key calculation
  // logic. Why is removing backend bits problematic? Consider this example:
  //
  // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
  // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
  // DispatchKeySet([DispatchKey.CPU,
  // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
  //
  // What do we want to happen?
  // Technically, we'd like it to be true that after removal,
  // the first keyset still has the CUDA dispatch key while the second doesn't.
  // Unfortunately there's no way to represent that, because the two keysets are
  // represented the same way internally: functionality bits: Autograd, Dense
  // backend bits: CPU, CUDA
  //
  // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
  // bit from the bitset.
  [[nodiscard]] constexpr DispatchKeySet remove(DispatchKey t) const {
    return DispatchKeySet(
        repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
  }
  // You're allowed to remove a backend bit from a DispatchKeySet,
  // but you have to be explicit about it (remove_backend() instead of
  // remove()).
  constexpr DispatchKeySet remove_backend(BackendComponent b) const {
    return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_));
  }
  // Is the set empty?  (AKA undefined tensor)
  bool empty() const {
    return repr_ == 0;
  }
  uint64_t raw_repr() const {
    return repr_;
  }

  static DispatchKeySet from_raw_repr(uint64_t x) {
    return DispatchKeySet(RAW, x);
  }

  DispatchKey highestFunctionalityKey() const {
    auto functionality_idx = indexOfHighestBit();
    // This means that none of the functionality bits were set.
    if (functionality_idx < num_backends)
      return DispatchKey::Undefined;
    // The first num_backend bits in the keyset don't correspond to real
    // dispatch keys.
    return static_cast<DispatchKey>(functionality_idx - num_backends);
  }

  // This is similar like toBackendComponent(DispatchKey), but less restrictive.
  // toBackendComponent() errors out if the key that it was passed has no
  // backend bits, which is useful for error checking. We need a version of that
  // here that can also handle "fake" backends like FPGA, because they need to
  // map to the AutogradOther key. For those backends, we return
  // BackendComponent::InvalidBit.
  BackendComponent highestBackendKey() const {
    // mask to mask out functionality bits
    auto backend_idx =
        DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
    // all zeros across the backend bits means that no backend bits are set.
    if (backend_idx == 0)
      return BackendComponent::InvalidBit;
    return static_cast<BackendComponent>(backend_idx);
  }

  // returns the DispatchKey of highest priority in the set.
  DispatchKey highestPriorityTypeId() const {
    auto functionality_k = highestFunctionalityKey();
    if (isPerBackendFunctionalityKey(functionality_k)) {
      return toRuntimePerBackendFunctionalityKey(
          functionality_k, highestBackendKey());
    }
    return functionality_k;
  }

  // Returns the index of the most-significant bit in the keyset.
  // This is used to as part of the calculation into the operator table to get:
  // - the highest "functionality" bit in the keyset.
  // - the highest "backend" bit in the keyset.
  uint8_t indexOfHighestBit() const {
    return 64 - llvm::countLeadingZeros(repr_);
  }

#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
  // [Note: Trimmed Mobile Dispatch Keys]
  /**
   * The method below maps the dispatch key in the enum DispatchKey to an
   * integer index in the dispatchTable_ array in OperatorEntry. The array
   * is trimmed for mobile to reduce peak memory usage since it's
   * unnecessary to reserve additional space for dispatch keys that will
   * never be used on mobile.
   */
  int getDispatchTableIndexForDispatchKeySet() const {
    auto dk = highestPriorityTypeId();
    switch (dk) {
      case DispatchKey::Undefined:
        return 0;
      case DispatchKey::CPU:
        return 1;
      case DispatchKey::QuantizedCPU:
        return 2;
      case DispatchKey::SparseCPU:
        return 3;
      case DispatchKey::BackendSelect:
        return 4;
      case DispatchKey::ADInplaceOrView:
        return 5;
      case DispatchKey::AutogradOther:
        return 6;
      case DispatchKey::AutogradCPU:
        return 7;
      default:
        return -1;
    }
  }
#else
  // returns the index in the operator table of highest priority key in the the
  // keyset Note that we could in theory implement this using
  // highestPriorityTypeId(), but this code is very hotpath and we can do it
  // faster without it.
  int getDispatchTableIndexForDispatchKeySet() const {
    auto functionality_idx =
        DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
    auto offset_and_mask = offsetsAndMasks()[functionality_idx];
    // Mask the functionality bits out first, then right-shift by 1.
    // right-shifting by 1 because everything is zero-indexed.
    // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
    // give us an offset of 1, etc.
    auto backend_idx =
        DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
    return offset_and_mask.offset + backend_idx;
  }
#endif

  // returns the "index" of the highest priority backend in the keyset.
  // This is pretty similar to getBackendKey(), but:
  // - It's hotpath code (part of the runtime bitset calculation)
  // - I's returns an integer index, not an enum value
  // - Everything is shifted to the right by 1.
  //   BackendComponent::InvalidBit is technically the lowest enum value,
  //   but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
  //   etc.
  uint64_t getBackendIndex() const {
    return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
  }

 private:
  constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
  uint64_t repr_ = 0;

 public:
  // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
  // in the set. The iterator is only invalidated by the destruction of the
  // underlying DispatchKeySet as the iterator stores a pointer to the raw
  // representation of the DispatchKeySet. Note: When we encounter a per-backend
  // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
  // in the keyset, for that functionality. For example, if the next
  // functionality key to iterate over is Autograd, and the backend bits in the
  // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
  // then the next two keys we return will be DispatchKey::AutogradCPU,
  // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
  // CUDA in DispatchKey.h).
  class iterator {
   public:
    using self_type = iterator;
    using iterator_category = std::input_iterator_tag;
    using value_type = DispatchKey;
    using difference_type = ptrdiff_t;
    using reference = value_type&;
    using pointer = value_type*;
    // final mask value should mask out the entire keyset
    static const uint8_t end_iter_mask_val =
        num_backends + num_functionality_keys;
    // final key value should be the last DispatchKey
    static const uint8_t end_iter_key_val = num_functionality_keys;

    // current_dispatchkey_idx_ will iterate through all functionality bits.
    // current_backendcomponent_idx_ will iterate through all backend bits.
    explicit iterator(
        const uint64_t* data_ptr,
        uint8_t next_functionality = num_backends,
        uint8_t next_backend = 0)
        : data_ptr_(data_ptr),
          next_functionality_(next_functionality),
          next_backend_(next_backend),
          // These are in an invalid state at construction time, and set by the
          // first increment call
          current_dispatchkey_idx_(end_iter_key_val),
          current_backendcomponent_idx_(end_iter_key_val) {
      // Go to the first key in the set
      TORCH_INTERNAL_ASSERT(
          next_functionality_ >= num_backends,
          "num_backends=",
          static_cast<uint32_t>(num_backends),
          "next_functionality_=",
          static_cast<uint32_t>(next_functionality_));
      ++(*this);
    }

    C10_API self_type& operator++();

    self_type operator++(int) {
      self_type previous_iterator = *this;
      ++(*this);
      return previous_iterator;
    }

    bool operator==(const self_type& rhs) const {
      return next_functionality_ == rhs.next_functionality_ &&
          current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
          next_backend_ == rhs.next_backend_ &&
          current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
    }
    bool operator!=(const self_type& rhs) const {
      return next_functionality_ != rhs.next_functionality_ ||
          current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
          next_backend_ != rhs.next_backend_ ||
          current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
    }
    DispatchKey operator*() const {
      auto functionality_key =
          static_cast<DispatchKey>(current_dispatchkey_idx_);
      if (isPerBackendFunctionalityKey(functionality_key)) {
        auto next_key = toRuntimePerBackendFunctionalityKey(
            functionality_key,
            static_cast<BackendComponent>(current_backendcomponent_idx_));
        // We expect all of the Dense, Sparse, Quantized, and Autograd keys to
        // be ordered the same way with respect to their backends
        TORCH_INTERNAL_ASSERT(
            toBackendComponent(next_key) ==
                static_cast<BackendComponent>(current_backendcomponent_idx_),
            "Tried to map functionality key ",
            toString(functionality_key),
            " and backend bit ",
            toString(
                static_cast<BackendComponent>(current_backendcomponent_idx_)),
            " to a runtime key, but ended up with ",
            toString(next_key),
            ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
            " Please double check that enum for inconsistencies.");
        return next_key;
      } else {
        return functionality_key;
      }
    }

   private:
    const uint64_t* data_ptr_;
    uint8_t next_functionality_;
    uint8_t next_backend_;
    uint8_t current_dispatchkey_idx_;
    uint8_t current_backendcomponent_idx_;
  };

 public:
  // Returns iterator to the first key in the set. If no keys are in the
  // set, then will return the end iterator.
  iterator begin() const {
    return iterator(&repr_);
  }

  // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
  // this as the end iterator.
  iterator end() const {
    return iterator(&repr_, iterator::end_iter_mask_val);
  }
};

C10_API std::string toString(DispatchKeySet /*ts*/);
C10_API std::ostream& operator<<(std::ostream& /*os*/, DispatchKeySet /*ts*/);

inline int getDispatchTableIndexForDispatchKey(DispatchKey k) {
  return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
}

// Alias key DispatchKey::Autograd maps to
// (autograd_dispatch_keyset x full_backend_mask)
// NB: keys in this set also get associated with CompositeImplicitAutograd
//
// Note [autograd_dispatch_keyset Does Not Include Backend Bits]
// We don't want to include any backend bits (BackendComponent::CPUBit, etc)
// directly in autograd_dispatch_keyset.
// Why? keysets like autograd_dispatch_keyset are commonly used to remove
// autograd keys from a DispatchKeySet throughout the code base. However, you
// are only allowed to remove functionality bits from a keyset, not backend
// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
// Keys] for details. To be consistent and avoid confusion, we're explicitly
// setting up autograd_dispatch_keyset to not have any backend bits.
constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
    DispatchKey::AutogradFunctionality,
    DispatchKey::AutogradOther,
    DispatchKey::AutogradNestedTensor,
});

constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
    DispatchKey::AutocastCPU,
    DispatchKey::AutocastMPS,
    DispatchKey::AutocastCUDA,
    DispatchKey::AutocastXPU,
    DispatchKey::AutocastIPU,
    DispatchKey::AutocastHPU,
    DispatchKey::AutocastXLA,
    DispatchKey::AutocastPrivateUse1,
    DispatchKey::AutocastMTIA,
    DispatchKey::AutocastMAIA,
});

// See Note [TLS Initialization]
constexpr DispatchKeySet default_included_set = DispatchKeySet({
    DispatchKey::BackendSelect,
    DispatchKey::ADInplaceOrView,
});

constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
    DispatchKey::AutocastCPU,
    DispatchKey::AutocastMPS,
    DispatchKey::AutocastCUDA,
    DispatchKey::AutocastXPU,
    DispatchKey::AutocastIPU,
    DispatchKey::AutocastHPU,
    DispatchKey::AutocastXLA,
    DispatchKey::AutocastPrivateUse1,
    DispatchKey::AutocastMTIA,
    DispatchKey::AutocastMAIA,
});

constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
    autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);

constexpr DispatchKeySet python_ks = DispatchKeySet({
    DispatchKey::Python,
    DispatchKey::PythonTLSSnapshot,
});

constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);

constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr);

constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);

// backend dispatch keys that map to DispatchKey::AutogradOther
// NB: keys in this set also get associated with CompositeImplicitAutograd
constexpr DispatchKeySet autogradother_backends =
    DispatchKeySet(
        // HIP and VE aren't in this list: they now have their own backend bits
        // which means that they can now have their own Autograd keys.
        // Technically, HIP will now redispatch to its own custom AutogradHIP
        // slot in the runtime table.
        {DispatchKey::FPGA,
         DispatchKey::Vulkan,
         DispatchKey::Metal,
         DispatchKey::CustomRNGKeyId,
         DispatchKey::MkldnnCPU,
         // Sparse and Quantized backends also live here.
         DispatchKey::Sparse,
         DispatchKey::SparseCsr,
         DispatchKey::Quantized})
    // Including the backend bits because this keyset is used during op
    // registration, which requires looping over all runtime autogradother
    // backend keys.
    | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);

// The set of dispatch keys that come after autograd
// n.b. this relies on the fact that AutogradOther is currently the lowest
// Autograd key
constexpr DispatchKeySet after_autograd_keyset =
    DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);

// The set of dispatch keys that come after ADInplaceOrView
constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet(
    DispatchKeySet::FULL_AFTER,
    c10::DispatchKey::ADInplaceOrView);

// The set of dispatch keys that come after Functionalize
constexpr DispatchKeySet after_func_keyset =
    DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize)
        .remove(
            // NOTE: we also need to remove ADInplaceOrView from the keyset when
            // redispatching after the func kernels. This is because we're not
            // calling the same op; we originally called an inplace op, and now
            // we aren't. The original key calculation figured out which keys
            // were Fallthrough based on the inplace op. That means that it did
            // not include the ADInPlaceOrView kernel as a fallthrough key.
            // However, we WANT the ADInPlaceOrView kernel to be ignored now
            // that we're calling an out-of-place op. Re-invoking
            // Dispatcher::call would re-run the Fallthrough key calculation and
            // get us that, But at::redispatch is more performant. We can get
            // away with it by explicitly removing the key here.
            c10::DispatchKey::ADInplaceOrView);

constexpr DispatchKeySet backend_bitset_mask =
    DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);

constexpr auto inplace_or_view_ks =
    DispatchKeySet(DispatchKey::ADInplaceOrView);
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA);
constexpr auto autograd_maia_ks = DispatchKeySet(DispatchKey::AutogradMAIA);
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta);
constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS);
constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
constexpr auto autograd_privateuse1_ks =
    DispatchKeySet(DispatchKey::AutogradPrivateUse1);
constexpr auto autograd_privateuse2_ks =
    DispatchKeySet(DispatchKey::AutogradPrivateUse2);
constexpr auto autograd_privateuse3_ks =
    DispatchKeySet(DispatchKey::AutogradPrivateUse3);
constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
constexpr auto autograd_nested =
    DispatchKeySet(DispatchKey::AutogradNestedTensor);
// keyset corresponding to functorch keys that have their own dedicated
// TensorImpl subclass.
constexpr auto functorch_transforms_ks = DispatchKeySet(
    {DispatchKey::FuncTorchBatched,
     DispatchKey::FuncTorchVmapMode,
     DispatchKey::Batched,
     DispatchKey::VmapMode,
     DispatchKey::FuncTorchGradWrapper});

constexpr auto functorch_batched_ks =
    DispatchKeySet({DispatchKey::FuncTorchBatched});

// This keyset has:
// (1) the functionality bits corresponding to backends (dense, sparse,
// quantized) (2) all of the backend bits set
constexpr DispatchKeySet backend_functionality_keys =
    DispatchKeySet({
        DispatchKey::Dense,
        DispatchKey::Quantized,
        DispatchKey::Sparse,
        DispatchKey::SparseCsr,
    }) |
    DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);

struct OpTableOffsetAndMask {
  uint16_t offset;
  uint16_t backend_mask;
};

static_assert(
    num_backends <= 16,
    "Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
    " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");

// true if t is a backend dispatch key
C10_API bool isBackendDispatchKey(DispatchKey t);

// Resolve alias dispatch key to DispatchKeySet if applicable
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);

// Resolve alias dispatch key to DispatchKeySet if applicable,
// and check if k is a part of that set
C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);

// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key
// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);

// Returns a DispatchKeySet of autograd related keys mapped to backend.
// for a given backend key, use the associated autograd key.
// for non-backend keys, use AutogradOther as a default.
// Note: it's convenient and fast to return a default here rather than (say)
// returning an std::optional<DispatchKey>, or throwing. But it makes callers
// responsible for either a) enforcing the invariant that only backend keys
// be passed as arguments, or b) interpreting our return value carefully.
inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
  switch (t) {
    case BackendComponent::CPUBit:
      return inplace_or_view_ks | autograd_cpu_ks;
    case BackendComponent::IPUBit:
      return inplace_or_view_ks | autograd_ipu_ks;
    case BackendComponent::MTIABit:
      return inplace_or_view_ks | autograd_mtia_ks;
    case BackendComponent::MAIABit:
      return inplace_or_view_ks | autograd_maia_ks;
    case BackendComponent::XPUBit:
      return inplace_or_view_ks | autograd_xpu_ks;
    case BackendComponent::CUDABit:
      return inplace_or_view_ks | autograd_cuda_ks;
    case BackendComponent::XLABit:
      return inplace_or_view_ks | autograd_xla_ks;
    case BackendComponent::LazyBit:
      return inplace_or_view_ks | autograd_lazy_ks;
    case BackendComponent::MetaBit:
      return inplace_or_view_ks | autograd_meta_ks;
    case BackendComponent::MPSBit:
      return inplace_or_view_ks | autograd_mps_ks;
    case BackendComponent::HPUBit:
      return inplace_or_view_ks | autograd_hpu_ks;
    case BackendComponent::PrivateUse1Bit:
      return inplace_or_view_ks | autograd_privateuse1_ks;
    case BackendComponent::PrivateUse2Bit:
      return inplace_or_view_ks | autograd_privateuse2_ks;
    case BackendComponent::PrivateUse3Bit:
      return inplace_or_view_ks | autograd_privateuse3_ks;
    default:
      return inplace_or_view_ks | autograd_other_ks;
  }
}

// Returns a DispatchKeySet of autocast related keys mapped to backend.
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
  constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
  constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA);
  constexpr auto autocast_maia_ks = DispatchKeySet(DispatchKey::AutocastMAIA);
  constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
  constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
  constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
  constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
  constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
  constexpr auto autocast_privateuse1_ks =
      DispatchKeySet(DispatchKey::AutocastPrivateUse1);
  constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS);
  switch (t) {
    case BackendComponent::CPUBit:
      return autocast_cpu_ks;
    case BackendComponent::MTIABit:
      return autocast_mtia_ks;
    case BackendComponent::MAIABit:
      return autocast_maia_ks;
    case BackendComponent::XPUBit:
      return autocast_xpu_ks;
    case BackendComponent::IPUBit:
      return autocast_ipu_ks;
    case BackendComponent::HPUBit:
      return autocast_hpu_ks;
    case BackendComponent::CUDABit:
      return autocast_cuda_ks;
    case BackendComponent::XLABit:
      return autocast_xla_ks;
    case BackendComponent::PrivateUse1Bit:
      return autocast_privateuse1_ks;
    case BackendComponent::MPSBit:
      return autocast_mps_ks;
    default:
      return DispatchKeySet();
  }
}

// returns the "backend" DispatchKey of highest priority in the set.
// This is basically like highestBackendKey(), except that we have some
// "functionality" bits that correspond to backends (Sparse, Quantized)
inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) {
  return (ks & backend_functionality_keys).highestPriorityTypeId();
}

// This API exists because we have a use case for checking
// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
// in OperatorEntry.cpp but we disallow it in has() API.
C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias);

// Historically, every tensor only had a single DispatchKey, and it was always
// something like CPU, and there wasn't any of this business where TLS
// could cause the DispatchKey of a tensor to change.  But we still have some
// legacy code that is still using DispatchKey for things like instanceof
// checks; if at all possible, refactor the code to stop using DispatchKey in
// those cases.
inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
  // NB: If you add any extra keys that can be stored in TensorImpl on
  // top of existing "backend" keys like CPU/CUDA, you need to add it
  // here.  At the moment, autograd keys and ADInplaceOrView key need this
  // treatment;
  return (s - autograd_dispatch_keyset_with_ADInplaceOrView -
          autocast_dispatch_keyset -
          DispatchKeySet(
              {DispatchKey::Functionalize,
               DispatchKey::PythonTLSSnapshot,
               DispatchKey::FuncTorchGradWrapper,
               DispatchKey::FuncTorchVmapMode,
               DispatchKey::FuncTorchBatched,
               DispatchKey::Python}))
      .highestPriorityTypeId();
}

template <class T>
using is_not_DispatchKeySet = std::negation<std::is_same<DispatchKeySet, T>>;

// Given a function type, constructs a function_traits type that drops the first
// parameter type if the first parameter is of type DispatchKeySet. NB:
// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid
// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through
// the Dispatcher] for details). If at any point in the future we need to expose
// this type to JIT, revisit the usage of this type alias.
template <class FuncType>
using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t<
    typename guts::infer_function_traits_t<FuncType>::return_type,
    typename std::conditional_t<
        std::is_same_v<
            DispatchKeySet,
            typename guts::typelist::head_with_default_t<
                void,
                typename guts::infer_function_traits_t<
                    FuncType>::parameter_types>>,
        guts::typelist::drop_if_nonempty_t<
            typename guts::infer_function_traits_t<FuncType>::parameter_types,
            1>,
        typename guts::infer_function_traits_t<FuncType>::parameter_types>>;
} // namespace c10

C10_DIAGNOSTIC_POP()

#else
#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
#endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
