生成语言模型temperature_topp_topk解析

在大型生成语言模型中,我们经常用到这些超参数temperature,topk,topp,下面根据源码进行分析

在huggingface transformers中,这3个对分数的处理时串联的,先进行temperature,然后topk,最后topp,当然,如果用户自己定义了分数过滤器,用户的过滤器优先。

temperature

越大,随机性越高,因为scores越接近

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class TemperatureLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution).

Args:
temperature (`float`):
The value used to module the logits distribution.
"""

def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")

self.temperature = temperature

def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
scores = scores / self.temperature
return scores

topk

越大随机性越高,小于topk的都设置成很小的值了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class TopKLogitsWarper(LogitsWarper):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.

Args:
top_k (`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""

def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(self.top_k, scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
class FilterTokens:
def __init__(self, top_k: int, filter_value: float = -float('Inf')):
self.top_k = top_k # 3
# 过滤值通常被设置为一个很小的负数,以避免对后续计算产生影响
self.filter_value = filter_value # -1e9或者-float('Inf')

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(self.top_k, scores.size(-1)) # 防止top_k大于scores的最后一维
# torch.topk(scores, top_k)返回两个tensor,第一个是top_k个最大值,第二个是对应的索引, 获取第0个tensor,即top_k个最大值
# [..., -1, None]表示在倒数第二个维度上取最后一个值,即取最后一个维度的值,作为阈值
# scores < xxxx: 表示scores中小于阈值的值,都设置为False
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
# 将scores中小于阈值的值,都设置为filter_value
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores

# create a FilterTokens object
filter_tokens = FilterTokens(top_k=3, filter_value=-1e9)

# create some input ids and scores
input_ids = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
scores = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [0.5, 0.4, 0.3, 0.2, 0.1]])

# apply the filter
filtered_scores = filter_tokens(input_ids, scores)

# print the result
print(filtered_scores)

topp

可以保证生成的token概率值总和小于等于给定的阈值top_p,越大随机性越高

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class TopPLogitsWarper(LogitsWarper):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.

Args:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""

def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")

self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# 分数从大到小排序
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
# 累积下概率值
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

# 计算哪些scores需要被移除,这里的移除就是设置很小的值
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
# 如果设置了最小保存的token数量,那么就按用户设置的走
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

#判断那些需要移除的token的位置
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
# 填充成最小值
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores

测试代码:
top_p=0.4时的结果是
tensor([[ -inf, -inf, -inf, 0.4000, 0.5000],
[ -inf, -inf, -inf, 0.9000, 1.0000]])
top_p=0.99时的结果是:
tensor([[0.1000, 0.2000, 0.3000, 0.4000, 0.5000],
[0.6000, 0.7000, 0.8000, 0.9000, 1.0000]])

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch

class TopPLogitsWarper:
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")

self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores

# Test
top_p = 0.4
filter_value = -float("Inf")
min_tokens_to_keep = 1
input_ids = torch.LongTensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
scores = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]])
warper = TopPLogitsWarper(top_p, filter_value, min_tokens_to_keep)
filtered_scores = warper(input_ids, scores)
print(filtered_scores)

生成语言模型temperature_topp_topk解析
https://johnson7788.github.io/2023/04/28/%E7%94%9F%E6%88%90%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8Btemperature-topp-topk%E8%A7%A3%E6%9E%90/
作者
Johnson
发布于
2023年4月28日
许可协议