lavis和transformers的冲突

salesforce的LAVIS包,https://github.com/salesforce/LAVIS是一个多模态的库,其中的BLIP2模块部分和huggingface的trasformers==4.27.1不兼容,不兼容部分是,原因是query_embeds被repeat了维度0两次,所以造成torch.cat拼接时冲突。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
lavis/models/blip2_models/blip2_opt.py
的第217218
# else:
# query_embeds = inputs_opt.repeat_interleave(num_beams, dim=0)
和trasformers==4.27.1679683冲突
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand

造成 lavis/models/blip2_models/modeling_opt.py 的703705行 cat拼接时维度不一致
if query_embeds is not None:
inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
input_shape = inputs_embeds.size()[:-1]

repeat_interleave()是PyTorch中的一个函数,用于将张量中的元素沿某一维度复制n次,即复制后的张量沿该维度.
这个函数有两个参数,第一个参数是重复的次数,第二个参数是重复的维度 (pytorch.org).
例如,如果你有一个形状为(3, 4)的张量,你可以使用repeat_interleave()函数将其中的每个元素沿着第0维重复2次,如下所示:

1
2
3
4
5
6
7
8
9
import torch

x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])

y = torch.repeat_interleave(x, repeats=2, dim=0)

print(y)

输出结果为:

tensor([[ 1, 2, 3, 4],
[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[ 9, 10, 11, 12]])
在这个例子中,我们将x沿着第0维重复了2次,因此输出结果中每个元素都被重复了2次


lavis和transformers的冲突
https://johnson7788.github.io/2023/03/27/lavis%E5%92%8Ctransformers%E7%9A%84%E5%86%B2%E7%AA%81/
作者
Johnson
发布于
2023年3月27日
许可协议