1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| import torch
class TopPLogitsWarper: def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): top_p = float(top_p) if top_p < 0 or top_p > 1.0: raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: sorted_logits, sorted_indices = torch.sort(scores, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores
# Test top_p = 0.4 filter_value = -float("Inf") min_tokens_to_keep = 1 input_ids = torch.LongTensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) scores = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0]]) warper = TopPLogitsWarper(top_p, filter_value, min_tokens_to_keep) filtered_scores = warper(input_ids, scores) print(filtered_scores)
|