/*
 * Copyright 2014-2023 NVIDIA Corporation.  All rights reserved.
 *
 * NOTICE TO LICENSEE:
 *
 * This source code and/or documentation ("Licensed Deliverables") are
 * subject to NVIDIA intellectual property rights under U.S. and
 * international Copyright laws.
 *
 * These Licensed Deliverables contained herein is PROPRIETARY and
 * CONFIDENTIAL to NVIDIA and is being provided under the terms and
 * conditions of a form of NVIDIA software license agreement by and
 * between NVIDIA and Licensee ("License Agreement") or electronically
 * accepted by Licensee.  Notwithstanding any terms or conditions to
 * the contrary in the License Agreement, reproduction or disclosure
 * of the Licensed Deliverables to any third party without the express
 * written consent of NVIDIA is prohibited.
 *
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
 * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE.  IT IS
 * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
 * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
 * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
 * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
 * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
 * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
 * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
 * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
 * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
 * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
 * OF THESE LICENSED DELIVERABLES.
 *
 * U.S. Government End Users.  These Licensed Deliverables are a
 * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
 * 1995), consisting of "commercial computer software" and "commercial
 * computer software documentation" as such terms are used in 48
 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
 * only as a commercial end item.  Consistent with 48 C.F.R.12.212 and
 * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
 * U.S. Government End Users acquire the Licensed Deliverables with
 * only those rights set forth herein.
 *
 * Any use of the Licensed Deliverables in individual and commercial
 * software must include, in the user documentation and internal
 * comments to the code, the above Disclaimer and U.S. Government End
 * Users Notice.
 */

/*   cudnn_adv : cuDNN's advanced and experimental features.

*/

#if !defined(CUDNN_ADV_H_)
#define CUDNN_ADV_H_

#include <stdint.h>

#include "cudnn_version.h"
#include "cudnn_ops.h"

/* These version numbers are autogenerated, do not edit manually. */
#define CUDNN_ADV_MAJOR 9
#define CUDNN_ADV_MINOR 10
#define CUDNN_ADV_PATCH 2

#if (CUDNN_ADV_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_MINOR != CUDNN_MINOR) || (CUDNN_ADV_PATCH != CUDNN_PATCHLEVEL)
#error Version mismatch in cuDNN ADV INFER!!!
#endif

#if defined(__cplusplus)
extern "C" {
#endif

/* BASIC RNN API */

typedef enum {
    CUDNN_RNN_ALGO_STANDARD               = 0,
    CUDNN_RNN_ALGO_PERSIST_STATIC         = 1,
    CUDNN_RNN_ALGO_PERSIST_DYNAMIC        = 2,
    CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H = 3,
    CUDNN_RNN_ALGO_COUNT                  = 4,
} cudnnRNNAlgo_t;

typedef enum {
    CUDNN_FWD_MODE_INFERENCE = 0,
    CUDNN_FWD_MODE_TRAINING  = 1,
} cudnnForwardMode_t;

typedef enum {
    CUDNN_RNN_RELU = 0, /* basic RNN cell type with ReLu activation */
    CUDNN_RNN_TANH = 1, /* basic RNN cell type with tanh activation */
    CUDNN_LSTM     = 2, /* LSTM with optional recurrent projection and clipping */
    CUDNN_GRU      = 3, /* Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1); */
} cudnnRNNMode_t;

typedef enum {
    CUDNN_RNN_NO_BIAS         = 0, /* rnn cell formulas do not use biases */
    CUDNN_RNN_SINGLE_INP_BIAS = 1, /* rnn cell formulas use one input bias in input GEMM */
    CUDNN_RNN_DOUBLE_BIAS     = 2, /* default, rnn cell formulas use two bias vectors */
    CUDNN_RNN_SINGLE_REC_BIAS = 3  /* rnn cell formulas use one recurrent bias in recurrent GEMM */
} cudnnRNNBiasMode_t;

typedef enum {
    CUDNN_UNIDIRECTIONAL = 0, /* single direction network */
    CUDNN_BIDIRECTIONAL  = 1, /* output concatination at each layer */
} cudnnDirectionMode_t;

typedef enum {
    CUDNN_LINEAR_INPUT = 0, /* adjustable weight matrix in first layer input GEMM */
    CUDNN_SKIP_INPUT   = 1, /* fixed identity matrix in the first layer input GEMM */
} cudnnRNNInputMode_t;

typedef enum {
    CUDNN_RNN_CLIP_NONE   = 0, /* disables LSTM cell clipping */
    CUDNN_RNN_CLIP_MINMAX = 1, /* enables LSTM cell clipping */
} cudnnRNNClipMode_t;

typedef enum {
    CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED   = 0, /* padded, outer stride from one time-step to the next */
    CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED     = 1, /* sequence length sorted and packed as in basic RNN api */
    CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED = 2, /* padded, outer stride from one batch to the next */
} cudnnRNNDataLayout_t;

/* For auxFlags in cudnnSetRNNDescriptor_v8() */
#define CUDNN_RNN_PADDED_IO_DISABLED 0
#define CUDNN_RNN_PADDED_IO_ENABLED (1U << 0)

struct cudnnRNNStruct;
typedef struct cudnnRNNStruct *cudnnRNNDescriptor_t;

struct cudnnRNNDataStruct;
typedef struct cudnnRNNDataStruct *cudnnRNNDataDescriptor_t;

cudnnStatus_t CUDNNWINAPI
cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc);

cudnnStatus_t CUDNNWINAPI
cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc);

/*
 * mathPrec in cudnnSetRNNDescriptor_v8() specifies compute precision.
 * Compute precision is further modified by mathType that sets the
 * preferred option for using NVIDIA Tensor Cores.  dataType specify
 * input/output data type and weight/bias type.
 */

cudnnStatus_t CUDNNWINAPI
cudnnSetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
                         cudnnRNNAlgo_t algo,
                         cudnnRNNMode_t cellMode,
                         cudnnRNNBiasMode_t biasMode,
                         cudnnDirectionMode_t dirMode,
                         cudnnRNNInputMode_t inputMode,
                         cudnnDataType_t dataType,
                         cudnnDataType_t mathPrec,
                         cudnnMathType_t mathType,
                         int32_t inputSize,
                         int32_t hiddenSize,
                         int32_t projSize,
                         int32_t numLayers,
                         cudnnDropoutDescriptor_t dropoutDesc,
                         uint32_t auxFlags);

cudnnStatus_t CUDNNWINAPI
cudnnGetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc,
                         cudnnRNNAlgo_t *algo,
                         cudnnRNNMode_t *cellMode,
                         cudnnRNNBiasMode_t *biasMode,
                         cudnnDirectionMode_t *dirMode,
                         cudnnRNNInputMode_t *inputMode,
                         cudnnDataType_t *dataType,
                         cudnnDataType_t *mathPrec,
                         cudnnMathType_t *mathType,
                         int32_t *inputSize,
                         int32_t *hiddenSize,
                         int32_t *projSize,
                         int32_t *numLayers,
                         cudnnDropoutDescriptor_t *dropoutDesc,
                         uint32_t *auxFlags);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnRNNSetClip_v8(cudnnRNNDescriptor_t rnnDesc,
                   cudnnRNNClipMode_t clipMode,
                   cudnnNanPropagation_t clipNanOpt,
                   double lclip,
                   double rclip);

cudnnStatus_t CUDNNWINAPI
cudnnRNNSetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t clipMode, double lclip, double rclip);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnRNNGetClip_v8(cudnnRNNDescriptor_t rnnDesc,
                   cudnnRNNClipMode_t *clipMode,
                   cudnnNanPropagation_t *clipNanOpt,
                   double *lclip,
                   double *rclip);

cudnnStatus_t CUDNNWINAPI
cudnnRNNGetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t *clipMode, double *lclip, double *rclip);

cudnnStatus_t CUDNNWINAPI
cudnnBuildRNNDynamic(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int miniBatch);

cudnnStatus_t CUDNNWINAPI
cudnnGetRNNTempSpaceSizes(cudnnHandle_t handle,
                          cudnnRNNDescriptor_t rnnDesc,
                          cudnnForwardMode_t fwdMode,
                          cudnnRNNDataDescriptor_t xDesc,
                          size_t *workSpaceSize,
                          size_t *reserveSpaceSize);

cudnnStatus_t CUDNNWINAPI
cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, size_t *weightSpaceSize);

cudnnStatus_t CUDNNWINAPI
cudnnGetRNNWeightParams(cudnnHandle_t handle,
                        cudnnRNNDescriptor_t rnnDesc,
                        int32_t pseudoLayer,
                        size_t weightSpaceSize,
                        const void *weightSpace,
                        int32_t linLayerID,
                        cudnnTensorDescriptor_t mDesc,
                        void **mAddr,
                        cudnnTensorDescriptor_t bDesc,
                        void **bAddr);

cudnnStatus_t CUDNNWINAPI
cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc);

cudnnStatus_t CUDNNWINAPI
cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc);

cudnnStatus_t CUDNNWINAPI
cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
                          cudnnDataType_t dataType,
                          cudnnRNNDataLayout_t layout,
                          int maxSeqLength,
                          int batchSize,
                          int vectorSize,
                          const int seqLengthArray[], /* length of each sequence in the batch */
                          void *paddingFill);         /* symbol for filling padding position in output */

cudnnStatus_t CUDNNWINAPI
cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc,
                          cudnnDataType_t *dataType,
                          cudnnRNNDataLayout_t *layout,
                          int *maxSeqLength,
                          int *batchSize,
                          int *vectorSize,
                          int arrayLengthRequested,
                          int seqLengthArray[],
                          void *paddingFill);

cudnnStatus_t CUDNNWINAPI
cudnnRNNForward(cudnnHandle_t handle,
                cudnnRNNDescriptor_t rnnDesc,
                cudnnForwardMode_t fwdMode,
                const int32_t devSeqLengths[],
                cudnnRNNDataDescriptor_t xDesc,
                const void *x,
                cudnnRNNDataDescriptor_t yDesc,
                void *y,
                cudnnTensorDescriptor_t hDesc,
                const void *hx,
                void *hy,
                cudnnTensorDescriptor_t cDesc,
                const void *cx,
                void *cy,
                size_t weightSpaceSize,
                const void *weightSpace,
                size_t workSpaceSize,
                void *workSpace,
                size_t reserveSpaceSize,
                void *reserveSpace);

/* Sequence data descriptor */

typedef enum {
    CUDNN_SEQDATA_TIME_DIM  = 0, /* index in time */
    CUDNN_SEQDATA_BATCH_DIM = 1, /* index in batch */
    CUDNN_SEQDATA_BEAM_DIM  = 2, /* index in beam */
    CUDNN_SEQDATA_VECT_DIM  = 3  /* index in vector */
} cudnnSeqDataAxis_t;

struct cudnnSeqDataStruct;
typedef struct cudnnSeqDataStruct *cudnnSeqDataDescriptor_t CUDNN_DEPRECATED;

#define CUDNN_SEQDATA_DIM_COUNT 4 /* dimension count */

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc,
                          cudnnDataType_t dataType,
                          int nbDims,
                          const int dimA[],
                          const cudnnSeqDataAxis_t axes[],
                          size_t seqLengthArraySize,
                          const int seqLengthArray[],
                          void *paddingFill);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc,
                          cudnnDataType_t *dataType,
                          int *nbDims,
                          int nbDimsRequested,
                          int dimA[],
                          cudnnSeqDataAxis_t axes[],
                          size_t *seqLengthArraySize,
                          size_t seqLengthSizeRequested,
                          int seqLengthArray[],
                          void *paddingFill);

/* Multihead Attention */

/*
 * Multi-head attention options passed via 'attnMode' in cudnnSetAttnDescriptor().
 * Use the bitwise OR operator to combine several settings listed below.  Additional
 * minor options can be added here w/o changing or introducing new API functions.
 */
#define CUDNN_ATTN_QUERYMAP_ALL_TO_ONE 0         /* multiple Q-s map to a single (K,V) set when beam size > 1 */
#define CUDNN_ATTN_QUERYMAP_ONE_TO_ONE (1U << 0) /* multiple Q-s map to multiple (K,V) sets when beam size > 1 */
#define CUDNN_ATTN_DISABLE_PROJ_BIASES 0         /* no biases in attention input and output projections */
#define CUDNN_ATTN_ENABLE_PROJ_BIASES (1U << 1)  /* use biases in attention input and output projections */

struct cudnnAttnStruct;
typedef struct cudnnAttnStruct *cudnnAttnDescriptor_t CUDNN_DEPRECATED;

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
                       unsigned attnMode,
                       int nHeads,
                       double smScaler,
                       cudnnDataType_t dataType,
                       cudnnDataType_t computePrec,
                       cudnnMathType_t mathType,
                       cudnnDropoutDescriptor_t attnDropoutDesc,
                       cudnnDropoutDescriptor_t postDropoutDesc,
                       int qSize,
                       int kSize,
                       int vSize,
                       int qProjSize,
                       int kProjSize,
                       int vProjSize,
                       int oProjSize,
                       int qoMaxSeqLength,
                       int kvMaxSeqLength,
                       int maxBatchSize,
                       int maxBeamSize);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc,
                       unsigned *attnMode,
                       int *nHeads,
                       double *smScaler,
                       cudnnDataType_t *dataType,
                       cudnnDataType_t *computePrec,
                       cudnnMathType_t *mathType,
                       cudnnDropoutDescriptor_t *attnDropoutDesc,
                       cudnnDropoutDescriptor_t *postDropoutDesc,
                       int *qSize,
                       int *kSize,
                       int *vSize,
                       int *qProjSize,
                       int *kProjSize,
                       int *vProjSize,
                       int *oProjSize,
                       int *qoMaxSeqLength,
                       int *kvMaxSeqLength,
                       int *maxBatchSize,
                       int *maxBeamSize);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle,
                             const cudnnAttnDescriptor_t attnDesc,
                             size_t *weightSizeInBytes,
                             size_t *workSpaceSizeInBytes,
                             size_t *reserveSpaceSizeInBytes);

typedef enum {
    CUDNN_MH_ATTN_Q_WEIGHTS = 0, /* input projection weights for 'queries' */
    CUDNN_MH_ATTN_K_WEIGHTS = 1, /* input projection weights for 'keys' */
    CUDNN_MH_ATTN_V_WEIGHTS = 2, /* input projection weights for 'values' */
    CUDNN_MH_ATTN_O_WEIGHTS = 3, /* output projection weights */
    CUDNN_MH_ATTN_Q_BIASES  = 4, /* input projection bias tensor for 'queries' */
    CUDNN_MH_ATTN_K_BIASES  = 5, /* input projection bias for 'keys' */
    CUDNN_MH_ATTN_V_BIASES  = 6, /* input projection bias for 'values' */
    CUDNN_MH_ATTN_O_BIASES  = 7, /* output projection biases */
} cudnnMultiHeadAttnWeightKind_t;

#define CUDNN_ATTN_WKIND_COUNT 8 /* Number of attention weight/bias tensors */

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle,
                             const cudnnAttnDescriptor_t attnDesc,
                             cudnnMultiHeadAttnWeightKind_t wKind,
                             size_t weightSizeInBytes,
                             const void *weights,
                             cudnnTensorDescriptor_t wDesc,
                             void **wAddr);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnMultiHeadAttnForward(cudnnHandle_t handle,
                          const cudnnAttnDescriptor_t attnDesc,
                          int currIdx,
                          const int loWinIdx[],
                          const int hiWinIdx[],
                          const int devSeqLengthsQO[],
                          const int devSeqLengthsKV[],
                          const cudnnSeqDataDescriptor_t qDesc,
                          const void *queries,
                          const void *residuals,
                          const cudnnSeqDataDescriptor_t kDesc,
                          const void *keys,
                          const cudnnSeqDataDescriptor_t vDesc,
                          const void *values,
                          const cudnnSeqDataDescriptor_t oDesc,
                          void *out,
                          size_t weightSizeInBytes,
                          const void *weights,
                          size_t workSpaceSizeInBytes,
                          void *workSpace,
                          size_t reserveSpaceSizeInBytes,
                          void *reserveSpace);

/*
 * \brief Cross-library version checker.
 * This function is implemented differently in each sub-library. Each sublib
 * checks whether its own version matches that of its dependencies.
 * \returns CUDNN_STATUS_SUCCESS if the version check passes,
 *          CUDNN_STATUS_SUBLIBRARY_VERSION_MISMATCH if the versions are inconsistent.
 */
cudnnStatus_t CUDNNWINAPI
cudnnAdvVersionCheck(void);

typedef enum {
    CUDNN_WGRAD_MODE_ADD = 0, /* add partial gradients to wgrad output buffers */
    CUDNN_WGRAD_MODE_SET = 1, /* write partial gradients to wgrad output buffers */
} cudnnWgradMode_t;

cudnnStatus_t CUDNNWINAPI
cudnnRNNBackwardData_v8(cudnnHandle_t handle,
                        cudnnRNNDescriptor_t rnnDesc,
                        const int32_t devSeqLengths[],
                        cudnnRNNDataDescriptor_t yDesc,
                        const void *y,
                        const void *dy,
                        cudnnRNNDataDescriptor_t xDesc,
                        void *dx,
                        cudnnTensorDescriptor_t hDesc,
                        const void *hx,
                        const void *dhy,
                        void *dhx,
                        cudnnTensorDescriptor_t cDesc,
                        const void *cx,
                        const void *dcy,
                        void *dcx,
                        size_t weightSpaceSize,
                        const void *weightSpace,
                        size_t workSpaceSize,
                        void *workSpace,
                        size_t reserveSpaceSize,
                        void *reserveSpace);

cudnnStatus_t CUDNNWINAPI
cudnnRNNBackwardWeights_v8(cudnnHandle_t handle,
                           cudnnRNNDescriptor_t rnnDesc,
                           cudnnWgradMode_t addGrad,
                           const int32_t devSeqLengths[],
                           cudnnRNNDataDescriptor_t xDesc,
                           const void *x,
                           cudnnTensorDescriptor_t hDesc,
                           const void *hx,
                           cudnnRNNDataDescriptor_t yDesc,
                           const void *y,
                           size_t weightSpaceSize,
                           void *dweightSpace,
                           size_t workSpaceSize,
                           void *workSpace,
                           size_t reserveSpaceSize,
                           void *reserveSpace);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle,
                               const cudnnAttnDescriptor_t attnDesc,
                               const int loWinIdx[],
                               const int hiWinIdx[],
                               const int devSeqLengthsDQDO[],
                               const int devSeqLengthsDKDV[],
                               const cudnnSeqDataDescriptor_t doDesc,
                               const void *dout,
                               const cudnnSeqDataDescriptor_t dqDesc,
                               void *dqueries,
                               const void *queries,
                               const cudnnSeqDataDescriptor_t dkDesc,
                               void *dkeys,
                               const void *keys,
                               const cudnnSeqDataDescriptor_t dvDesc,
                               void *dvalues,
                               const void *values,
                               size_t weightSizeInBytes,
                               const void *weights,
                               size_t workSpaceSizeInBytes,
                               void *workSpace,
                               size_t reserveSpaceSizeInBytes,
                               void *reserveSpace);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle,
                                  const cudnnAttnDescriptor_t attnDesc,
                                  cudnnWgradMode_t addGrad,
                                  const cudnnSeqDataDescriptor_t qDesc,
                                  const void *queries,
                                  const cudnnSeqDataDescriptor_t kDesc,
                                  const void *keys,
                                  const cudnnSeqDataDescriptor_t vDesc,
                                  const void *values,
                                  const cudnnSeqDataDescriptor_t doDesc,
                                  const void *dout,
                                  size_t weightSizeInBytes,
                                  const void *weights,
                                  void *dweights,
                                  size_t workSpaceSizeInBytes,
                                  void *workSpace,
                                  size_t reserveSpaceSizeInBytes,
                                  void *reserveSpace);

/*
 * CTC (Connectionist Temporal Classification) loss descriptor create/destory/set/get functions
 */
/* Input normalization mode for loss function */
typedef enum {
    CUDNN_LOSS_NORMALIZATION_NONE    = 0,
    CUDNN_LOSS_NORMALIZATION_SOFTMAX = 1,
} cudnnLossNormalizationMode_t;

cudnnStatus_t CUDNNWINAPI
cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
                            cudnnDataType_t compType,
                            cudnnLossNormalizationMode_t normMode,
                            cudnnNanPropagation_t gradMode);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnSetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
                             cudnnDataType_t compType,
                             cudnnLossNormalizationMode_t normMode,
                             cudnnNanPropagation_t gradMode,
                             int maxLabelLength);

cudnnStatus_t CUDNNWINAPI
cudnnSetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
                             cudnnDataType_t compType,
                             cudnnLossNormalizationMode_t normMode,
                             cudnnCTCGradMode_t ctcGradMode,
                             int maxLabelLength);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc,
                            cudnnDataType_t *compType,
                            cudnnLossNormalizationMode_t *normMode,
                            cudnnNanPropagation_t *gradMode);

CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI
cudnnGetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc,
                             cudnnDataType_t *compType,
                             cudnnLossNormalizationMode_t *normMode,
                             cudnnNanPropagation_t *gradMode,
                             int *maxLabelLength);

cudnnStatus_t CUDNNWINAPI
cudnnGetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc,
                             cudnnDataType_t *compType,
                             cudnnLossNormalizationMode_t *normMode,
                             cudnnCTCGradMode_t *ctcGradMode,
                             int *maxLabelLength);

cudnnStatus_t CUDNNWINAPI
cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc);

/* return the ctc costs and gradients, given the probabilities and labels */
cudnnStatus_t CUDNNWINAPI
cudnnCTCLoss(
    cudnnHandle_t handle,
    const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
                                                timing steps, N is the mini batch size, A is the alphabet size)  */
    const void *probs,                       /* probabilities after softmax, in GPU memory */
    const int hostLabels[],                  /* labels, in CPU memory */
    const int hostLabelLengths[],            /* the length of each label, in CPU memory */
    const int hostInputLengths[],            /* the lengths of timing steps in each batch, in CPU memory */
    void *costs,                             /* the returned costs of CTC, in GPU memory */
    const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
    void *gradients,         /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
    cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
    cudnnCTCLossDescriptor_t ctcLossDesc,
    void *workspace,              /* pointer to the workspace, in GPU memory */
    size_t workSpaceSizeInBytes); /* size of the workspace */

/* return the ctc costs and gradients, given the probabilities and labels */
cudnnStatus_t CUDNNWINAPI
cudnnCTCLoss_v8(
    cudnnHandle_t handle,
    cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
    cudnnCTCLossDescriptor_t ctcLossDesc,
    const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
                                                timing steps, N is the mini batch size, A is the alphabet size)  */
    const void *probs,                       /* probabilities after softmax, in GPU memory */
    const int labels[],                      /* labels, in GPU memory */
    const int labelLengths[],                /* the length of each label, in GPU memory */
    const int inputLengths[],                /* the lengths of timing steps in each batch, in GPU memory */
    void *costs,                             /* the returned costs of CTC, in GPU memory */
    const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */
    void *gradients,             /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */
    size_t workSpaceSizeInBytes, /* size of the workspace */
    void *workspace);            /* pointer to the workspace, in GPU memory */

/* return the workspace size needed for ctc */
cudnnStatus_t CUDNNWINAPI
cudnnGetCTCLossWorkspaceSize(
    cudnnHandle_t handle,
    const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
                                                timing steps, N is the mini batch size, A is the alphabet size) */
    const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
                                                    dimensions are T,N,A. To compute costs
                                                    only, set it to NULL */
    const int *labels,                           /* labels, in CPU memory */
    const int *labelLengths,                     /* the length of each label, in CPU memory */
    const int *inputLengths,                     /* the lengths of timing steps in each batch, in CPU memory */
    cudnnCTCLossAlgo_t algo,                     /* algorithm selected, supported now 0 and 1 */
    cudnnCTCLossDescriptor_t ctcLossDesc,
    size_t *sizeInBytes); /* pointer to the returned workspace size */

/* return the workspace size needed for ctc */
cudnnStatus_t CUDNNWINAPI
cudnnGetCTCLossWorkspaceSize_v8(
    cudnnHandle_t handle,
    cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */
    cudnnCTCLossDescriptor_t ctcLossDesc,
    const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the
                                                timing steps, N is the mini batch size, A is the alphabet size) */
    const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the
                                                    dimensions are T,N,A. To compute costs
                                                    only, set it to NULL */
    size_t *sizeInBytes);                        /* pointer to the returned workspace size */

#if defined(__cplusplus)
}
#endif

#endif /* CUDNN_ADV_H_ */
