transformers的llama2模型的一个bug

transformers的4.31.0版本的关于llama2模型的一个bug

在使用llama2模型进行lora微调的时候,会报一个类似RuntimeError: expected scalar type Float but found Half这样的错误。本文是针对这个错误的解决。

模型加载

加载lora模型的代码,注意我们使用了量化qlora的方式,nf4格式,量化到了4bit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def load_lora_model(ckpt_path):
model_name_or_path = ”xxx/llama2“ #本地的llama2模型
print("加载模型:", model_name_or_path)
print("加载微调后的权重:", ckpt_path)
assert os.path.exists(ckpt_path), "微调后的权重不存在!"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)

tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
quantization_config=bnb_config,
trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained(model_name_or_path)
model = PeftModel.from_pretrained(model, ckpt_path)
return tokenizer, model

配置解析

bnb_4bit_compute_dtype=torch.float16表示在真正计算的时候使用的是float16进行的计算把它转换成float16进行计算,详细在peft的Linear4bit的forward函数中,x = x.to(self.compute_dtype)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype
self.compute_type_is_set = False

def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
# the input is in a dtype that is safe to compute in, we switch
# to this type for speed and stability
self.compute_dtype = x.dtype
elif x.dtype == torch.float16:
# we take the compoute dtype passed into the layer
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.')
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')






def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True

inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)

bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)

out = out.to(inp_dtype)

return out

完整的报错栈

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
  File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/peft/peft_model.py", line 977, in generate
outputs = self.base_model.generate(**kwargs)
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/transformers/generation/utils.py", line 1588, in generate
return self.sample(
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/transformers/generation/utils.py", line 2642, in sample
outputs = self(
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 824, in forward
logits = self.lm_head(hidden_states)
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
File "/home/wac/johnson/anaconda3/envs/qlora/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Float but found Half

分析

由上面的logits = self.lm_head(hidden_states)可知,在llama2模型的head层时出现了tensor的类型问题,导致类型不一致。

查看transformers/models/llama/modeling_llama.py源码,发现hidden_states是LlamaModel的forward函数,遍历每一层for idx, decoder_layer in enumerate(self.layers):时,hidden_states是float32位时是没有关系的,但是在最后一次正则化时,hidden_states = self.norm(hidden_states),让hidden_states和self.weights的类型一致即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

Args:
config: LlamaConfig
"""

def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

seq_length_with_past = seq_length
past_key_values_length = 0

if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

hidden_states = inputs_embeds

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

定位和修改LlamaRMSNorm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(self.weight.dtype)

修改forward函数的最后一行,即可

1
2
3
4
5
6
7
8
9
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states

transformers的llama2模型的一个bug
https://johnson7788.github.io/2023/08/15/transformers%E7%9A%84llama2%E6%A8%A1%E5%9E%8B%E7%9A%84%E4%B8%80%E4%B8%AAbug/
作者
Johnson
发布于
2023年8月15日
许可协议