多标签分类的2种简单实现

模型的损失函数

torch.nn.BCEWithLogitsLoss
先进行了sigmoid,然后进行了二分类交叉熵损失函数

评估metric

hamming score
sklearn.metrics.multilabel_confusion_matrix #多标签混淆矩阵

#hamming score
示例

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
truth = [0, 1, 1, 0, 0] # 真实值
prediction = [1, 1, 1, 0, 0] #预测值
truth = np.array(truth)
prediction = np.array(prediction)
num_classes = len(truth)
num_samples = 1
numerator = float(sum(truth & prediction))
denominator = float(sum(truth | prediction))
hamming_score = numerator / denominator
------>
0.667
如果是按照准确率计算accuracy_score(truth, prediction),那么是0.8

我们的输出是一个批次的,所以是二维的,计算的方法是: score=((predicts & labels).sum(axis=1) / (predicts | labels).sum(axis=1)).mean()

实现方式1: 使用CLS进行分类

假设我们的示例是商品的购买意向,模型的基本输入是:CLS+句子1+SEP+商品+SEP
对CLS进行求sigmoid和二分类交叉熵
模型实现逻辑:

  1. Bert模型编码, last_hidden_state, all_hidden_states = self.encode(input_ids, token_type_ids, attention_mask)
  2. 取最后一个隐藏层的CLS向量, first_token_tensor = hidden_states[:, 0]
  3. 进行dropout, self.dropout(first_token_tensor)
  4. 加个全连接层和激活
  5. 线性层映射到标签个数,得到logits, nn.Linear(hidden_size, num_labels)
  6. 计算损失, 反向传播,更新参数

实现方式2: 每个标签类别向量进行分类

假设我们的示例是商品的购买意向, 模型的基本输入是: CLS+句子1+SEP+商品+每个标签的id+SEP
取出每个标签的向量
对每个标签向量进行二分类交叉熵损失
模型实现逻辑:

  1. 注意我们在输入的末尾添加了每个标签的id, 所以需要用特殊的token表示这些id, 所以需要增加词表,tokenizer.add_special_tokens({‘additional_special_tokens’:’opinion1’}) 我们用opinion1代表第一个标签,opinion2代表第二个标签,分别都加入到vocab中
  2. 注意在处理数据时,我们还有生成label_mask参数,告知我们关注的label的位置在哪里
  3. Bert模型编码, last_hidden_state, all_hidden_states = self.encode(input_ids, token_type_ids, attention_mask)
  4. hidden_states 形状 [batch_size, seq_len, last_hidden_size] –> [batch_size, labels_num, last_hidden_size] 加一个维度,然后扩充到hidden_states形状,方便后面取出label_mask需要的维度数据 label_mask_expand = label_mask.unsqueeze(-1).expand(hidden_states.size())
    labels_token_tensor_1d = torch.masked_select(hidden_states, (label_mask_expand == 1))
    labels_token_tensor = labels_token_tensor_1d.view(batch_size, -1, last_hidden_size)
  5. 加个droput, 全连接层和激活
  6. 分类层, 最后变成1维度,nn.Linear(hidden_size, 1)
  7. logits = logits.squeeze(-1), 去掉最后一次,得到[batch_size, num_labels] 形状
  8. 计算损失, 反向传播,更新参数

多标签分类的2种简单实现
https://johnson7788.github.io/2022/02/28/multilabel/
作者
Johnson
发布于
2022年2月28日
许可协议