1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
   | class LlamaModel(LlamaPreTrainedModel):     """     Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
      Args:         config: LlamaConfig     """
      def __init__(self, config: LlamaConfig):         super().__init__(config)         self.padding_idx = config.pad_token_id         self.vocab_size = config.vocab_size
          self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)         self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])         self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
          self.gradient_checkpointing = False                  self.post_init()     @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)     def forward(         self,         input_ids: torch.LongTensor = None,         attention_mask: Optional[torch.Tensor] = None,         position_ids: Optional[torch.LongTensor] = None,         past_key_values: Optional[List[torch.FloatTensor]] = None,         inputs_embeds: Optional[torch.FloatTensor] = None,         use_cache: Optional[bool] = None,         output_attentions: Optional[bool] = None,         output_hidden_states: Optional[bool] = None,         return_dict: Optional[bool] = None,     ) -> Union[Tuple, BaseModelOutputWithPast]:         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions         output_hidden_states = (             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states         )         use_cache = use_cache if use_cache is not None else self.config.use_cache
          return_dict = return_dict if return_dict is not None else self.config.use_return_dict
                   if input_ids is not None and inputs_embeds is not None:             raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")         elif input_ids is not None:             batch_size, seq_length = input_ids.shape         elif inputs_embeds is not None:             batch_size, seq_length, _ = inputs_embeds.shape         else:             raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
          seq_length_with_past = seq_length         past_key_values_length = 0
          if past_key_values is not None:             past_key_values_length = past_key_values[0][0].shape[2]             seq_length_with_past = seq_length_with_past + past_key_values_length
          if position_ids is None:             device = input_ids.device if input_ids is not None else inputs_embeds.device             position_ids = torch.arange(                 past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device             )             position_ids = position_ids.unsqueeze(0).view(-1, seq_length)         else:             position_ids = position_ids.view(-1, seq_length).long()
          if inputs_embeds is None:             inputs_embeds = self.embed_tokens(input_ids)                  if attention_mask is None:             attention_mask = torch.ones(                 (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device             )         attention_mask = self._prepare_decoder_attention_mask(             attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length         )
          hidden_states = inputs_embeds
          if self.gradient_checkpointing and self.training:             if use_cache:                 logger.warning_once(                     "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."                 )                 use_cache = False
                   all_hidden_states = () if output_hidden_states else None         all_self_attns = () if output_attentions else None         next_decoder_cache = () if use_cache else None
          for idx, decoder_layer in enumerate(self.layers):             if output_hidden_states:                 all_hidden_states += (hidden_states,)
              past_key_value = past_key_values[idx] if past_key_values is not None else None
              if self.gradient_checkpointing and self.training:
                  def create_custom_forward(module):                     def custom_forward(*inputs):                                                  return module(*inputs, output_attentions, None)
                      return custom_forward
                  layer_outputs = torch.utils.checkpoint.checkpoint(                     create_custom_forward(decoder_layer),                     hidden_states,                     attention_mask,                     position_ids,                     None,                 )             else:                 layer_outputs = decoder_layer(                     hidden_states,                     attention_mask=attention_mask,                     position_ids=position_ids,                     past_key_value=past_key_value,                     output_attentions=output_attentions,                     use_cache=use_cache,                 )
              hidden_states = layer_outputs[0]
              if use_cache:                 next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
              if output_attentions:                 all_self_attns += (layer_outputs[1],)
          hidden_states = self.norm(hidden_states)
                   if output_hidden_states:             all_hidden_states += (hidden_states,)
          next_cache = next_decoder_cache if use_cache else None         if not return_dict:             return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)         return BaseModelOutputWithPast(             last_hidden_state=hidden_states,             past_key_values=next_cache,             hidden_states=all_hidden_states,             attentions=all_self_attns,         )
 
  |