
    i                     v   d Z ddlZddlmZ ddlmZ ddlZddlmZ ddlm	Z	 ddl
mZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddlmZ ddlmZmZmZ ddlmZmZ ddl m!Z!  ejD                  e#      Z$ e       rddl%m&Z& ndZ& G d d      Z' G d dejP                        Z) G d dejP                        Z* G d de      Z+e G d de             Z,e ed       G d  d!e                    Z-e ed"       G d# d$e                    Z.e G d% d&e,             Z/ ed'       G d( d)e,e             Z0g d*Z1y)+zPyTorch MAMBA model.    N)	dataclass)Any)nn)CrossEntropyLoss   )initialization)ACT2FN)PreTrainedConfig)GenerationMixin)lazy_load_kernel)GradientCheckpointingLayer)PreTrainedModel)ModelOutputauto_docstringlogging)is_mambapy_availableis_torchdynamo_compiling   )MambaConfig)pscanc            
           e Zd ZdZdZej                  dfdededej                  dej                  ez  dz  fdZd	ed
ej                  dej                  dej                  fdZd	edej                  fdZd Zy)
MambaCachea.  
    Cache for mamba model which does not have attention mechanism and key value states.

    Arguments:
        config (`PreTrainedConfig):
            The configuration file defining the shape-related attributes required to initialize the static cache.
        max_batch_size (`int`):
            The maximum batch size with which the model will be used. Note that a new instance must be instantiated if
            a smaller batch size is used.
        dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
            The default `dtype` to use when initializing the layer.
        device (`torch.device` or `str`, *optional*):
            The device on which the cache should be initialized. Should be the same as the layer.

    Example:

        ```python
        >>> import torch
        >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache

        >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")

        >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> cache_params = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
        >>> cache_position = torch.arange(len(inputs["input_ids"][0]), device=model.device)  # sequence length
        >>> outputs = model(**inputs, cache_params=cache_params, cache_position=cache_position, use_cache=True)
        >>> outputs.cache_params
        ```
    TNconfigmax_batch_sizedtypedevicec                    || _         || _        |j                  | _        |j                  | _        |j
                  | _        g | _        g | _        |t        j                  |      nd }t        |j                        D ]  }t        j                  | j                   | j                  | j                  || j                        }t        j                  | j                   | j                  | j                  || j                        }t        j                  j                  |       t        j                  j                  |       | j                  j!                  |       | j                  j!                  |        y )Nr   r   )r   _dtypeintermediate_size
state_sizessm_state_sizeconv_kernelconv_kernel_sizeconv_states
ssm_statestorchr   rangenum_hidden_layerszeros_dynamomark_static_addressappend)selfr   r   r   r   _
conv_state	ssm_states           r/home/obispo/Crisostomo_bridge/mision_env/lib/python3.12/site-packages/transformers/models/mamba/modeling_mamba.py__init__zMambaCache.__init__U   s/    -!'!9!9$// & 2 2/1.0)/);f%v//0 	.A',{{##&&%%kk(J ',kk##&&##kk'I MM--j9MM--i8##J/OO""9-'	.    	layer_idxnew_conv_statecache_positionreturnc                 "   | j                   |   j                  |j                  k7  r5| j                   |   j                  |j                        | j                   |<   | j                   |   }|j                  d| j                  dz
        }|j                  dd      }|j                  |j                  |j                        |d d d d |f<   | j                   |   j                          | j                   |xx   |z  cc<   | j                   |   S )Nr   r   )shiftsdimsr   )r%   r   toclampr$   rollr   zero_)r.   r5   r6   r7   r0   s        r2   update_conv_statezMambaCache.update_conv_statez   s    
 I&--1F1FF*.*:*:9*E*H*HI^I^*_DY'%%i0
'--a1F1F1JK__BR_8
+9+<+<JDUDU]g]m]m+<+n
1a'(#))+#z1#	**r4   new_ssm_statec                     | j                   |   j                          | j                   |xx   |j                  | j                   |   j                        z  cc<   | j                   |   S N)r&   r@   r=   r   )r.   r5   rB   s      r2   update_ssm_statezMambaCache.update_ssm_state   sT    	"((*	"m&6&6ty7Q7X7X&YY"y))r4   c                     t        t        | j                              D ]<  }| j                  |   j                          | j                  |   j                          > y rD   )r(   lenr%   r@   r&   )r.   r5   s     r2   resetzMambaCache.reset   sM    s4#3#345 	/IY'--/OOI&,,.	/r4   )__name__
__module____qualname____doc__is_compileabler'   float16r
   intr   r   strr3   Tensor
LongTensorrA   rE   rH    r4   r2   r   r   0   s    B N #]],0#. #. #. {{	#.
 s"T)#.J++.3ll+LQL\L\+	+"*# *ell *
/r4   r   c            
       F    e Zd ZdZdedef fdZd Z	 	 	 ddej                  de
dz  d	ej                  dz  d
ej                  dz  fdZdde
dz  d	ej                  dz  d
ej                  dz  fdZ	 	 	 dde
dz  d	ej                  dz  d
ej                  dz  fdZ xZS )
MambaMixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    r   r5   c           	         t         |           || _        |j                  | _        |j                  | _        |j                  | _        |j                  | _        t        |j                        | _
        || _        |j                  | _        t        j                  | j                  | j                  |j                  |j                  | j                  |j                  dz
        | _        |j                   | _        t$        |j                      | _        |j(                  | _        t        j*                  | j                  | j                  dz  |j,                        | _        t        j*                  | j                  | j                  | j
                  dz  z   d      | _        t        j*                  | j                  | j                  d      | _        t5        j6                  d| j
                  dz   t4        j8                        d d d f   }|j;                  | j                  d      j=                         }t        j>                  t5        j@                  |            | _!        t        j>                  t5        jD                  | j                              | _#        t        j*                  | j                  | j                  |j,                        | _$        |j,                  | _        tK        d	      a&tL         tL        jN                  tL        jP                  fnd
\  a'a(tK        d      a)tR        /tR        jT                  tR        jV                  tR        jX                  fnd\  a*a+a,| j[                          y )Nr   )in_channelsout_channelsbiaskernel_sizegroupspadding   rY   FTr   r:   zcausal-conv1d)NNz	mamba-ssmNNN).superr3   r   hidden_sizer!   r"   r#   r$   r    rO   time_step_rankr5   use_conv_biasr   Conv1dconv1d
hidden_act
activationr	   actuse_mambapyLinearuse_biasin_projx_projdt_projr'   arangefloat32expand
contiguous	ParameterlogA_logonesDout_projr   causal_conv1dcausal_conv1d_updatecausal_conv1d_fn	mamba_ssmselective_state_updateselective_scan_fnmamba_inner_fnwarn_slow_implementation)r.   r   r5   A	__class__s       r2   r3   zMambaMixer.__init__   s   !--$// & 2 2!'!9!9!&"7"78"#11ii..//%%**))&&*
 !++&++,!-- yy!1!143I3IA3MTZTcTcdii 6 68K8KdNaNadeNe8elqryy!4!4d6L6LSWX LLD//!35==I$PQ'RHHT++R0;;=\\%))A,/
ejj)?)?@A		$"8"8$:J:JQWQ`Q`a )9 ( //1O1OP 	/. %[1	 $ --y/J/JILdLde# 	B 1> 	%%'r4   c                     t        t        t        t        t        t
        f      }|sM| j                  r+t               rt        j                  d       y t        d      t        j                  d       y y )Na  The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1dzuse_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py.a  The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.)allr~   r   r|   r{   r   rj   r   loggerwarning_onceImportError)r.   is_fast_path_availables     r2   r   z#MambaMixer.warn_slow_implementation   sw    !$#%68HJ^`no"
 &')''S & Z  ##W &r4   Nhidden_statescache_paramsr7   attention_maskc                 	   | j                  |      j                  dd      }| j                  r%|"t        || j                  j
                  | j                  r| j                  j                  nd | j                  j
                  | j                  j
                  | j                  j
                  | j                  r$| j                  j                  j                         nd t        j                  | j                  j                                d d | j                   j                         | j                  j                  j                         d      }|S |j#                  dd      \  }}|||j%                  d      z  }| j                  j
                  j'                  | j                  j
                  j)                  d      | j                  j
                  j)                  d            }|m|d   dkD  ret+        |j-                  d      |j.                  | j0                     || j                  j                  | j2                        }|j%                  d      }n|Yt4        j6                  j9                  || j:                  |j<                  d   z
  df      }	|j?                  | j0                  |	|       tA        ||| j                  j                  | j2                        }|||j%                  d      z  }| j                  |j                  dd            }
t        jB                  |
| jD                  | jF                  | jF                  gd      \  }}}| j                  j
                  |j                  dd      z  }t        j                  | j                  j                                }tI        | j                  d	      r$| j                  j                  j                         nd }|e|d   dkD  r]tK        |jL                  | j0                     |d
   |d
   ||d d df   |d d df   | j                   |d
   |d
      j%                  d      }nptO        ||||j                  dd      |j                  dd      | j                   j                         ||dd
      \  }}|||jQ                  | j0                  |       | j                  |j                  dd            }|S )Nr   r]   T)
delta_biasdelta_softplusdimr   r:   )rh   rY   ).r   )dt_softplus)r   return_last_state))rm   	transposetrainingr   rf   weightrd   rY   rn   ro   ry   rl   floatr'   exprv   rx   chunk	unsqueezeviewsizer{   squeezer%   r5   rh   r   
functionalpadr$   shaperA   r|   splitrc   r"   hasattrr~   r&   r   rE   )r.   r   r   r7   r   projected_statescontextualized_statesgateconv_weightsr%   ssm_parameters	time_stepBCdiscrete_time_stepr   time_proj_biasscan_outputsr1   s                      r2   cuda_kernels_forwardzMambaMixer.cuda_kernels_forward   s7     <<6@@AF==\1$2 ""$($6$6  D""##$$.2mm""((*4::++-..<<,,224#%!p %$O #3"8"8"8"BM4) -0H0H0K K  ;;--224;;3E3E3J3J13Mt{{OaOaOfOfghOijL'N1,=,A 4!))"- ,,T^^< KK$$OO! !. 7 7 ;+"$--"3"3%(=(=@S@STV@W(WYZ'[#K !224>>;P^_ 0!<1A1Adoo! ) -0H0H0K K "[[)@)@A)FGN#kk!4!4d6I6I4K^K^ _egOIq! "&!4!4y7J7J1a7P!P4::++-..A:A$,,PV:WT\\..446]aN'N1,=,A5 ++DNN;!&)&v.adGadGFFL" $  )B-  +<!&KK1%KK1%FFLLN"#'&*+'i (\-E 11$..)L %)MM,2H2HA2N$O!$$r4   c           	      X   |j                   \  }}}|j                  }| j                  |      j                  dd      }	|	j	                  dd      \  }
}||
|j                  d      z  }
||j                  | j                     j                         }|j                  |
j                        }|j                   d   | j                  k(  rt        j                  j                  |
| j                  |
j                   d   z
  df      }|j                  | j                  ||       | j!                  | j#                  |
      dd |f         }
n9|j                  | j                  |
|      }|j                  | j"                  j$                  j                        }t'        j(                  || j"                  j$                  d d dd d f   z  d      }
| j*                  r|
| j"                  j,                  z  }
| j!                  |
      j                  |      j                  d      }
n`t'        j.                  || j0                  | j2                  f|
j                  |      }| j!                  | j#                  |
      dd |f         }
||
|j                  d      z  }
| j5                  |
j                  dd            }t'        j6                  || j8                  | j2                  | j2                  gd      \  }}}| j;                  |      }t        j                  j=                  |      j                  dd      }t'        j>                  | j@                  jC                                }t'        j>                  |d d d d d d f   |d d d d d d d f   z        }|d d d d d d d f   |d d d d d d d f   jC                         z  }||
d d d d d d d f   jC                         z  }| jD                  r| jF                  r|tI        |j                  dd      |j                  dd            }||j                  d      z  jK                  d      j                  dd      }||
| jL                  d d d d f   z  z   }|| j!                  |      z  }ng }tO        |      D ]}  }|d d d d |d d f   |z  |d d d d |d d f   z   }t'        jP                  |j                  |      |d d |d d f   j                  d            }|jS                  |d d d d df           t'        jT                  |d      }||
| jL                  d d d d f   z  z   }|| j!                  |      z  }|(|j                  | j                     jW                  |       | jY                  |j                  dd            }|S )	Nr   r]   r   r   r:   .r   r   )-r   r   rm   r   r   r   r&   r5   cloner=   r   r$   r   r   r   rA   ri   rf   r   r'   sumrd   rY   r*   r    r"   rn   r   rc   ro   softplusr   rv   r   rj   r   r   r   rx   r(   matmulr-   stackcopy_ry   )r.   input_statesr   r7   r   
batch_sizeseq_lenr/   r   r   r   r   r1   r0   r   r   r   r   r   r   
discrete_A
discrete_BdeltaB_uhsscan_outputr   ir   s                               r2   slow_forwardzMambaMixer.slow_forwardV  s   !-!3!3
GQ""<<5??1E.44QA4>t%)N,D,DQ,GGM #$//?EEGI!]%9%9:I ##A&$*?*??]]..!**]-@-@-DDaH

 ..t~~z>Z $])CC'M)R S);;DNNM[ij
']]4;;+=+=+D+DE
 %		*t{{7I7I!QPQ'7R*RXZ [%%!T[[%5%55M $ 7 : :5 A K KB OT33T5H5HI$++5I !HHT[[%?XgX%NOM%)N,D,DQ,GGM ]%<%<Q%BC++T00$2E2EtGZGZ[ac
	1a "\\)4]]334FGQQRSUVW YYtzz'')**YYqq$!125G1aQU5VVW
'1a61dAq=9I9O9O9QQ
aAtm < B B DD ,2Fz++Aq183E3Ea3KLBB/88;EEaKK%tQ}8M(MMK%6KL7^ :&q!Qz2Y>!QPQST*AUU	#ll9<<+>!Q'
@T@TUW@XY##K1a$89:  ++l;K%a9N)NOK&$7K'''7==iH !%k.C.CAq.I J$$r4   c                 
   t        t        t        t        t        t
        f      }|rJd| j                  j                  j                  j                  v rt               s| j                  ||||      S | j                  ||||      S )Ncuda)r   r~   r   r|   r{   r   rn   r   r   typer   r   r   )r.   r   r   r7   r   r   s         r2   forwardzMambaMixer.forward  sx     "%#%68HJ^`no"
 "f0B0B0I0I0N0N&NWoWq,,]L.Zhii  nn]]r4   r`   )rI   rJ   rK   rL   r   rO   r3   r   r'   rQ   r   rR   r   r   r   __classcell__r   s   @r2   rU   rU      s   8({ 8(s 8(t4 +/2626c%||c% !4'c% ((4/	c%
 ((4/c%LO%zD7H O%^c^n^nqu^u O%  MR  M]  M]  `d  Md O%j +/2626^ !4'^ ((4/	^
 ((4/^r4   rU   c                   ,     e Zd Zd fd	Zd Zd Z xZS )MambaRMSNormc                     t         |           t        j                  t	        j
                  |            | _        || _        y)zL
        MambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
        N)ra   r3   r   rt   r'   rw   r   variance_epsilon)r.   rb   epsr   s      r2   r3   zMambaRMSNorm.__init__  s1     	ll5::k#:; #r4   c                 "   |j                   }|j                  t        j                        }|j	                  d      j                  dd      }|t        j                  || j                  z         z  }| j                  |j                  |      z  S )Nr]   r:   T)keepdim)	r   r=   r'   rq   powmeanrsqrtr   r   )r.   r   input_dtypevariances       r2   r   zMambaRMSNorm.forward  sy    #))%((7 $$Q',,R,>%Ht?T?T4T(UU{{]--k:::r4   c                 R    | j                   j                  d    d| j                   S )Nr   z, eps=)r   r   r   r.   s    r2   
extra_reprzMambaRMSNorm.extra_repr  s*    ++##A&'vd.C.C-DEEr4   )gư>)rI   rJ   rK   r3   r   r   r   r   s   @r2   r   r     s    $;Fr4   r   c                   t     e Zd Z fdZ	 	 	 ddedz  dej                  dz  dej                  dz  fdZ xZS )
MambaBlockc                     t         |           || _        || _        |j                  | _        t        |j                  |j                        | _        t        ||      | _
        y )Nr   r5   )ra   r3   r   r5   residual_in_fp32r   rb   layer_norm_epsilonnormrU   mixer)r.   r   r5   r   s      r2   r3   zMambaBlock.__init__  sR    " & 7 7 !3!39R9RS	)<
r4   Nr   r7   r   c                    |}| j                  |j                  | j                   j                  j                              }| j                  r|j                  t
        j                        }| j                  ||||      }||z   }|S )Nr_   r   r7   r   )r   r=   r   r   r   r'   rq   r   )r.   r   r   r7   r   residuals         r2   r   zMambaBlock.forward  s     !		-"2"29I9I9O9O"2"PQ  {{5==1H

^dr # 
 !=0r4   r`   )	rI   rJ   rK   r3   r   r'   rR   r   r   r   s   @r2   r   r     sQ    = +/2626 !4' ((4/	
 ((4/r4   r   c                   Z    e Zd ZU eed<   dZddgZdZdZ e	j                         d        Zy)MambaPreTrainedModelr   backboner   rU   Tc                 h	   | j                   j                  }t        |t              rt	        j
                  d|j                  dz   t        j                        dddf   }|j                  |j                  d      j                         }t        j                  |j                  t	        j                  |             t        j                  |j                          | j                   j"                  dz  | j                   j$                  z  }| j                   j&                  dk(  r+t        j(                  |j*                  j,                  |       nE| j                   j&                  dk(  r,t        j.                  |j*                  j,                  | |       t	        j0                  t	        j2                  | j                   j                        t5        j                  | j                   j6                        t5        j                  | j                   j8                        z
  z  t5        j                  | j                   j8                        z         j;                  | j                   j<                        }|t	        j                  t	        j>                  |              z   }t        j                  |j*                  j@                  |       t        jB                  |jD                  j,                  t5        jF                  d	      
       |jD                  j@                  )t        jH                  |jD                  j@                         t        jB                  |jJ                  j,                  t5        jF                  d	      
       | j                   jL                  rB|jJ                  j,                  }|t5        jF                  | j                   jN                        z  }t        |tP        jR                        rNt        jT                  |j,                  |       |j@                   t        jH                  |j@                         yyt        |tV              r t        j                  |j,                         yt        |tP        jX                        r"t        jT                  |j,                  |       yy)zInitialize the weights.r   r_   Nr:   g      constantrandom)min   )a)std)-r   initializer_range
isinstancerU   r'   rp   r"   rq   rr   r    rs   initr   rv   ru   ones_rx   rc   time_step_scaletime_step_init_scheme	constant_ro   r   uniform_r   randmathtime_step_maxtime_step_minr>   time_step_floorexpm1rY   kaiming_uniform_rf   sqrtzeros_ry   rescale_prenorm_residualr)   r   rk   normal_r   	Embedding)r.   moduler   r   dt_init_stddtinv_dtps           r2   _init_weightsz"MambaPreTrainedModel._init_weights  s    kk++fj) Q 5 5 9OPTVWPWXA1126AACAJJv||UYYq\2JJvxx ++44d:T[[=X=XXK{{00J>v~~44kB22h>fnn33k\;O

4;;88988DKK556$++B[B[9\\^((4;;4456 e33e4	  %))U[["%5$566FJJv~~**F3!!&--"6"6$))A,G}}!!-FMM../!!&//"8"8DIIaLI{{33 OO**TYYt{{<<==fbii(LLC0{{&FKK( '-JJv}}%-LLC0 .r4   N)rI   rJ   rK   r   __annotations__base_model_prefix_no_split_modulessupports_gradient_checkpointing_is_statefulr'   no_gradr  rS   r4   r2   r   r     s>    "%|4&*#LU]]_41 41r4   r   z,
    Class for the MAMBA model outputs.
    )custom_introc                   |    e Zd ZU dZdZej                  dz  ed<   dZe	dz  ed<   dZ
eej                     dz  ed<   y)MambaOutputa9  
    cache_params (`MambaCache`):
        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
        avoid providing the old `input_ids`.

        Includes both the State space model state matrices after the selective scan, and the Convolutional states
    Nlast_hidden_stater   r   )rI   rJ   rK   rL   r  r'   FloatTensorr  r   r   r   tuplerS   r4   r2   r  r  '  sH     37u((4/6&*L*t#*59M5**+d29r4   r  zK
    Base class for causal language model (or autoregressive) outputs.
    c                       e Zd ZU dZdZej                  dz  ed<   dZej                  dz  ed<   dZ	e
dz  ed<   dZeej                     dz  ed<   y)MambaCausalLMOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    cache_params (`MambaCache`):
        The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
        avoid providing the old `input_ids`.

        Includes both the State space model state matrices after the selective scan, and the Convolutional states
    Nlosslogitsr   r   )rI   rJ   rK   rL   r  r'   r  r  r  r   r   r   r  rS   r4   r2   r  r  ;  s\    
 &*D%

d
")'+FE$+&*L*t#*59M5**+d29r4   r  c                        e Zd Z fdZd Zd Zd Ze	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  de
dz  d	edz  d
edz  dedz  dej                  dz  dej                  dz  deez  fd       Z xZS )
MambaModelc           	         t         |   |       t        j                  |j                  |j
                        | _        t        j                  t        |j                        D cg c]  }t        ||       c}      | _        d| _        t        |j
                  |j                        | _        | j!                  | j"                         | j%                          y c c}w )Nr   Fr   )ra   r3   r   r   
vocab_sizerb   
embeddings
ModuleListr(   r)   r   layersgradient_checkpointingr   r   norm_f"_register_load_state_dict_pre_hook	load_hook	post_init)r.   r   idxr   s      r2   r3   zMambaModel.__init__V  s     ,,v'8'8&:L:LMmmRWX^XpXpRq$r3Z#%F$rs&+#"6#5#56;T;TU//? %ss   &Cc                 f    |D ],  }d|v s|j                  |      ||j                  dd      <    y  y )Nz
embedding.zembeddings.)popreplace)r.   
state_dictprefixargsks        r2   r  zMambaModel.load_hookb  s;     	Aq EO^^TUEV
199\=AB	r4   c                     | j                   S rD   r  r   s    r2   get_input_embeddingszMambaModel.get_input_embeddingsh  s    r4   c                     || _         y rD   r*  r.   new_embeddingss     r2   set_input_embeddingszMambaModel.set_input_embeddingsk  s	    (r4   N	input_idsinputs_embedsr   	use_cacheoutput_hidden_statesreturn_dictr7   r   r8   c	                 8   ||n| j                   j                  }||n#| j                  s| j                   j                  nd}||n| j                   j                  }|du |duz  rt        d      || j                  |      }| j                  r| j                  r|rd}|r|st        | j                   |j                  d      |j                  |j                        }t        j                  d| j                   j                  |j                        }n|t        d      d}|}
|rdnd}| j                  D ]  } ||
|||	      }
|s||
fz   } | j!                  |
      }
|r||
fz   }|st#        d
 |
||fD              S t%        |
|r||      S d|      S )a  
        cache_params (`MambaCache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        NFz:You must specify exactly one of input_ids or inputs_embedsr   r   r   zYou have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will be initialized for you automaticallyrS   r   c              3   &   K   | ]	  }||  y wrD   rS   ).0vs     r2   	<genexpr>z%MambaModel.forward.<locals>.<genexpr>  s     fqXYXefs   )r  r   r   )r   r3  r   r2  use_return_dict
ValueErrorr  r  r   r   r   r   r'   rp   r#   r  r  r  r  )r.   r0  r1  r   r2  r3  r4  r7   r   kwargsr   all_hidden_statesmixer_blocks                r2   r   zMambaModel.forwardn  s   * %9$D $++JjJj 	 "+!6IZ^ZgZgT[[=R=Rmr	%0%<k$++B]B]-t";<YZZ  OOI6M&&4==YI#)KK!3!3A!6}?S?S[h[n[n  "'a1H1HQ^QeQe!f' !;   L%"6BD;; 		IK')--	M $$58H$H!		I M2 1]4D Df]LBS$Tfff+)2+
 	
8<+
 	
r4   )NNNNNNNN)rI   rJ   rK   r3   r  r+  r/  r   r'   rR   r   boolr  r  r   r   r   s   @r2   r  r  T  s    
)  .215*.!%,0#'2626M
##d*M
 ''$.M
 !4'	M

 $;M
 #TkM
 D[M
 ((4/M
 ((4/M
 
	M
 M
r4   r  z
    The MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    c                       e Zd ZddiZ fdZd Zd Z	 ddedee	e
f   ded	ee	e
f   fd
Z	 	 	 	 	 	 ddedz  dej                  dz  dej                  dz  dedz  f fdZe	 	 	 	 	 	 	 	 	 	 ddej                  dz  dej                  dz  dej&                  dz  dedz  dej                  dz  dedz  dedz  dedz  dej(                  dz  deej(                  z  d	eez  fd       Z xZS )MambaForCausalLMzlm_head.weightzbackbone.embeddings.weightc                     t         |   |       t        |      | _        t	        j
                  |j                  |j                  d      | _        | j                          y )NFr^   )
ra   r3   r  r   r   rk   rb   r  lm_headr   )r.   r   r   s     r2   r3   zMambaForCausalLM.__init__  sF     "6*yy!3!3V5F5FUSr4   c                 6    | j                   j                         S rD   )r   r+  r   s    r2   r+  z%MambaForCausalLM.get_input_embeddings  s    }}1133r4   c                 8    | j                   j                  |      S rD   )r   r/  r-  s     r2   r/  z%MambaForCausalLM.set_input_embeddings  s    }}11.AAr4   outputsmodel_kwargsnum_new_tokensr8   c                    |j                  dd       |d<   |j                  dd      rd|v r|d   |d   dd  |z   |d<   d|v r?|d   }t        j                  ||j                  |j                  d   df      gd	      |d<   |S )
Nr   r2  Tr7   r:   r   r   r   r   )getr'   catnew_onesr   )r.   rG  rH  rI  r=  r   s         r2   #_update_model_kwargs_for_generationz4MambaForCausalLM._update_model_kwargs_for_generation  s     (/{{>4'H^$[$/ L0-.:-9:J-KBC-PSa-aL)*|+)*:;N-2YY!8!8.:N:Nq:QST9U!VW]_.L)* r4   Nr   r7   r   is_first_iterationc           
         t        |   |f||||||d|}	|r|t        j                  d| j                  j
                  j                  |j                        |	d<   ||j                  d      }
n|j                  d      }
t        | j                  j
                  |
| j                  | j                        |	d<   |	S |r|d   dkD  rd |	d<   |	S )N)r1  r2  r   r7   r   rO  r   r6  r7   r   r   r   )ra   prepare_inputs_for_generationr'   rp   r   r   r#   r   r   r   r   )r.   r0  r1  r2  r   r7   r   rO  r=  model_inputsr   r   s              r2   rQ  z.MambaForCausalLM.prepare_inputs_for_generation  s     w<	
'%))1	
 	
 -
 .3\\!T]]=Q=Q=]=]fofvfv-wL)*(!.!3!3A!6!*!2+5$$nT[[PTPZPZ,L(  >!,q0-1L)*r4   r0  r1  labelsr3  r4  r2  logits_to_keepc           
         ||n| j                   j                  }| j                  |||||||	|      }|d   }t        |
t              rt        |
 d      n|
}| j                  |dd|ddf   j                  | j                  j                  j                              j                         }d}||j                  |j                        }|dddddf   j                         }|dddf   j                         }t               } ||j                  d|j                  d            |j                  d            }|s|f|dd z   }||f|z   S |S t!        |||j"                  |j$                        S )aS  
        cache_params (`MambaCache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        N)r   r1  r3  r4  r2  r7   r   r   .r:   r   )r  r  r   r   )r   r;  r   r   rO   slicerD  r=   r   r   r   r   rs   r   r   r   r  r   r   )r.   r0  r   r1  r   rS  r3  r4  r2  r7   rT  r=  mamba_outputsr   slice_indicesr  r  shift_logitsshift_labelsloss_fctoutputs                        r2   r   zMambaForCausalLM.forward  s   4 &1%<k$++B]B]%'!5#)) & 	
 &a(8B>SV8W~ot4]kmA}a,?@CCDLLDWDWD]D]^_eegYYv}}-F!#ssA+.99;L!#qr'?557L')HL--b,2C2CB2GH,J[J[\^J_`DYqr!22F)-)9TGf$EvE"&33'55	
 	
r4   )r   )NNNNNF)
NNNNNNNNNr   )rI   rJ   rK   _tied_weights_keysr3   r+  r/  r   dictrP   r   rO   rN  r   r'   rR   r@  rQ  r   r  rQ   r  r  r   r   r   s   @r2   rB  rB    s    +,HI4B YZ"26sCx.RU	c3h, *.2626*/'
 !4'' ((4/' ((4/' !4K'R  .22626*.*.,0#'!%.2-.?
##d*?
 ((4/?
 ((4/	?

 !4'?
   4'?
 #Tk?
 D[?
 $;?
 t+?
 ell*?
 
$	$?
 ?
r4   rB  )rB  r  r   r   )2rL   r   dataclassesr   typingr   r'   r   torch.nnr    r   r   activationsr	   configuration_utilsr
   
generationr   integrationsr   modeling_layersr   modeling_utilsr   utilsr   r   r   utils.import_utilsr   r   configuration_mambar   
get_loggerrI   r   mambapy.pscanr   r   ModulerU   r   r   r   r  r  r  rB  __all__rS   r4   r2   <module>rp     sq     !    % & ! 3 ) , 9 - 
 Q , 
		H	%#Ed/ d/N]^ ]^@	F299 F(+ 8 <1? <1 <1~ 
:+ : : 
:+ : :& g
% g
 g
T L
+_ L
L
^ Sr4   