torch和paddlepaddle的部分API对比

paddlepaddle代码和torch的模型代码相互转换, 其实只需要关注这些api的不同,进行相应替换即可

paddlepaddle 封装了类似torch, huggingface transformers 和datasets的接口形式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
torch包和paddle包的对比, 只列出不同的地方
PyTorch PaddlePaddle 说明
torch.nn paddle.nn 包括了神经网络相关的大部分函数
nn.Module nn.Layer 搭建网络时集成的父类,包含了初始化等基本功能
torch.optim paddle.optimizer 训练优化器
torch.optim.AdamW paddle.optimizer.AdamW 参数也不一样
torchvision.transforms paddle.vision.transforms 数据预处理、图片处理
torchvision.datasets paddle.vision.datasets 数据集的加载与处理
nn.Conv2d nn.Conv2D 2维卷积层
nn.BatchNorm2d nn.BatchNorm2D Batch Normalization 归一化
nn.MaxPool2d nn.MaxPool2D 二维最大池化层
nn.AdaptiveAvgPool2d nn.AdaptiveAvgPool2D 自适应二维平均池化(只用给定输出形状即可)
torch.flatten paddle.flatten 展平处理
torch.softmax paddle.softmax softmax层
datasets.ImageFolder datasets.DatasetFolder 指定数据集文件夹
torch.utils.data.DataLoader paddle.io.DataLoader 加载数据集, 参数也不一样
(optimizer).no_grad (optimizer).zero_grad 梯度清零
torch.save paddle.jit.save 说实话,这两个还是有点区别的,使用请看官方文档
torch.device paddle.set_device 指定设备
torch.utils.data.Dataset paddle.io.Dataset 数据集
torch.utils.data.RandomSampler paddle.io.RandomSampler
torch.utils.data.BatchSampler paddle.io.BatchSampler 参数也不一样
torch.utils.data.DataLoader paddle.io.DataLoader 数据集加载
tensor.type(torch.float32) paddle.cast(mask, 'float32') 数据类型变更
tensor.cpu().item() tensor.numpy().item() # 取出数据
transformers.get_linear_schedule_with_warmup paddlenlp.transformers.LinearDecayWithWarmup # warmup 函数
也不一样

torch和paddlepaddle的部分API对比
https://johnson7788.github.io/2022/02/23/torch-paddle/
作者
Johnson
发布于
2022年2月23日
许可协议