torchvision分类介绍
Torchvision高版本支持各种SOTA的图像分类模型,同时还支持不同数据集分类模型的预训练模型的切换。使用起来十分方便快捷,Pytroch中支持两种迁移学习方式,分别是:
- Finetune模式 基于预训练模型,全链路调优参数 - 冻结特征层模式 这种方式只修改输出层的参数,CNN部分的参数冻结上述两种迁移方式,分别适合大量数据跟少量数据,前一种方式计算跟训练时间会比第二种方式要长点,但是针对大量自定义分类数据效果会比较好。
自定义分类模型修改与训练
加载模型之后,feature_extracting 为true表示冻结模式,否则为finetune模式,相关的代码如下:
def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False以resnet18为例,修改之后的自定义训练代码如下:
model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features # Here the size of each output sample is set to 5. # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)). model_ft.fc = nn.Linear(num_ftrs, 5) model_ft = model_ft.to(device) criterion = nn.CrossEntropyLoss() # Observe that all parameters are being optimized optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) # Decay LR by a factor of 0.1 every 7 epochs exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)
数据集是flowers-dataset,有五个分类分别是:
daisy dandelion roses sunflowers tulips
全链路调优,迁移学习训练CNN部分的权重参数
Epoch 0/24 ---------- train Loss: 1.3993 Acc: 0.5597 valid Loss: 1.8571 Acc: 0.7073 Epoch 1/24 ---------- train Loss: 1.0903 Acc: 0.6580 valid Loss: 0.6150 Acc: 0.7805 Epoch 2/24 ---------- train Loss: 0.9095 Acc: 0.6991 valid Loss: 0.4386 Acc: 0.8049 Epoch 3/24 ---------- train Loss: 0.7628 Acc: 0.7349 valid Loss: 0.9111 Acc: 0.7317 Epoch 4/24 ---------- train Loss: 0.7107 Acc: 0.7669 valid Loss: 0.4854 Acc: 0.8049 Epoch 5/24 ---------- train Loss: 0.6231 Acc: 0.7793 valid Loss: 0.6822 Acc: 0.8049 Epoch 6/24 ---------- train Loss: 0.5768 Acc: 0.8033 valid Loss: 0.2748 Acc: 0.8780 Epoch 7/24 ---------- train Loss: 0.5448 Acc: 0.8110 valid Loss: 0.4440 Acc: 0.7561 Epoch 8/24 ---------- train Loss: 0.5037 Acc: 0.8170 valid Loss: 0.2900 Acc: 0.9268 Epoch 9/24 ---------- train Loss: 0.4836 Acc: 0.8360 valid Loss: 0.7108 Acc: 0.7805 Epoch 10/24 ---------- train Loss: 0.4663 Acc: 0.8369 valid Loss: 0.5868 Acc: 0.8049 Epoch 11/24 ---------- train Loss: 0.4276 Acc: 0.8504 valid Loss: 0.6998 Acc: 0.8293 Epoch 12/24 ---------- train Loss: 0.4299 Acc: 0.8529 valid Loss: 0.6449 Acc: 0.8049 Epoch 13/24 ---------- train Loss: 0.4256 Acc: 0.8567 valid Loss: 0.7897 Acc: 0.7805 Epoch 14/24 ---------- train Loss: 0.4062 Acc: 0.8559 valid Loss: 0.5855 Acc: 0.8293 Epoch 15/24 ---------- train Loss: 0.4030 Acc: 0.8545 valid Loss: 0.7336 Acc: 0.7805 Epoch 16/24 ---------- train Loss: 0.3786 Acc: 0.8730 valid Loss: 1.0429 Acc: 0.7561 Epoch 17/24 ---------- train Loss: 0.3699 Acc: 0.8763 valid Loss: 0.4549 Acc: 0.8293 Epoch 18/24 ---------- train Loss: 0.3394 Acc: 0.8788 valid Loss: 0.2828 Acc: 0.9024 Epoch 19/24 ---------- train Loss: 0.3300 Acc: 0.8834 valid Loss: 0.6766 Acc: 0.8537 Epoch 20/24 ---------- train Loss: 0.3136 Acc: 0.8906 valid Loss: 0.5893 Acc: 0.8537 Epoch 21/24 ---------- train Loss: 0.3110 Acc: 0.8901 valid Loss: 0.4909 Acc: 0.8537 Epoch 22/24 ---------- train Loss: 0.3141 Acc: 0.8931 valid Loss: 0.3930 Acc: 0.9024 Epoch 23/24 ---------- train Loss: 0.3106 Acc: 0.8887 valid Loss: 0.3079 Acc: 0.9024 Epoch 24/24 ---------- train Loss: 0.3143 Acc: 0.8923 valid Loss: 0.5122 Acc: 0.8049 Training complete in 25m 34s Best val Acc: 0.926829
冻结CNN部分,只训练全连接分类权重
Params to learn: fc.weight fc.bias Epoch 0/24 ---------- train Loss: 1.0217 Acc: 0.6465 valid Loss: 1.5317 Acc: 0.8049 Epoch 1/24 ---------- train Loss: 0.9569 Acc: 0.6947 valid Loss: 1.2450 Acc: 0.6829 Epoch 2/24 ---------- train Loss: 1.0280 Acc: 0.6999 valid Loss: 1.5677 Acc: 0.7805 Epoch 3/24 ---------- train Loss: 0.8344 Acc: 0.7426 valid Loss: 1.1053 Acc: 0.7317 Epoch 4/24 ---------- train Loss: 0.9110 Acc: 0.7250 valid Loss: 1.1148 Acc: 0.7561 Epoch 5/24 ---------- train Loss: 0.9049 Acc: 0.7346 valid Loss: 1.1541 Acc: 0.6341 Epoch 6/24 ---------- train Loss: 0.8538 Acc: 0.7465 valid Loss: 1.4098 Acc: 0.8293 Epoch 7/24 ---------- train Loss: 0.9041 Acc: 0.7349 valid Loss: 0.9604 Acc: 0.7561 Epoch 8/24 ---------- train Loss: 0.8885 Acc: 0.7468 valid Loss: 1.2603 Acc: 0.7561 Epoch 9/24 ---------- train Loss: 0.9257 Acc: 0.7333 valid Loss: 1.0751 Acc: 0.7561 Epoch 10/24 ---------- train Loss: 0.8637 Acc: 0.7492 valid Loss: 0.9748 Acc: 0.7317 Epoch 11/24 ---------- train Loss: 0.8686 Acc: 0.7517 valid Loss: 1.0194 Acc: 0.8049 Epoch 12/24 ---------- train Loss: 0.8492 Acc: 0.7572 valid Loss: 1.0378 Acc: 0.7317 Epoch 13/24 ---------- train Loss: 0.8773 Acc: 0.7432 valid Loss: 0.7224 Acc: 0.8049 Epoch 14/24 ---------- train Loss: 0.8919 Acc: 0.7473 valid Loss: 1.3564 Acc: 0.7805 Epoch 15/24 ---------- train Loss: 0.8634 Acc: 0.7490 valid Loss: 0.7822 Acc: 0.7805 Epoch 16/24 ---------- train Loss: 0.8069 Acc: 0.7644 valid Loss: 1.4132 Acc: 0.7561 Epoch 17/24 ---------- train Loss: 0.8589 Acc: 0.7492 valid Loss: 0.9812 Acc: 0.8049 Epoch 18/24 ---------- train Loss: 0.7677 Acc: 0.7688 valid Loss: 0.7176 Acc: 0.8293 Epoch 19/24 ---------- train Loss: 0.8044 Acc: 0.7514 valid Loss: 1.4486 Acc: 0.7561 Epoch 20/24 ---------- train Loss: 0.7916 Acc: 0.7564 valid Loss: 1.0575 Acc: 0.8049 Epoch 21/24 ---------- train Loss: 0.7922 Acc: 0.7647 valid Loss: 1.0406 Acc: 0.7805 Epoch 22/24 ---------- train Loss: 0.8187 Acc: 0.7647 valid Loss: 1.0965 Acc: 0.7561 Epoch 23/24 ---------- train Loss: 0.8443 Acc: 0.7503 valid Loss: 1.6163 Acc: 0.7317 Epoch 24/24 ---------- train Loss: 0.8165 Acc: 0.7583 valid Loss: 1.1680 Acc: 0.7561 Training complete in 20m 7s Best val Acc: 0.829268
测试结果:
零代码训练演示
我已经完成torchvision中分类模型自定义数据集迁移学习的代码封装与开发,支持基于收集到的数据集,零代码训练,生成模型。图示如下:
全部0条评论
快来发表一下你的评论吧 !