您的当前位置:首页正文

队服撞衫?如何让AI区分相似球衣?

来源:华拓网

问题

笔者最近在学人工智能领域的深度学习技术,在 fast.ai的课程里,我学会了:复用简洁的示例代码来训练卷积神经网络CNN。它能够精准地让机器对图片进行分类。(课程里的例子是识别猫狗图片的二元分类任务,在2000张测试图片中,它的识别准确率达到了99%。)

我认为,真正的学习,是带着好奇心,自己来解决实际问题,探究输入-输出之间的关系,从而构建自己的知识体系。

因为自己是多年的足球迷,从兴趣出发,我“凿开”了一个脑洞:能否训练一个卷积神经网络训练,识别区分两件相似足球队衣图片?

其中,我发现,大名鼎鼎的巴塞罗那队服是红蓝箭条衫,而瑞士巴赛尔队的队服也是红蓝箭条衫,两队的主场战袍相似度很高。因此,我打算用 fast.ai的模块来以及预训练的 CNN 卷积神经网络(比如,resnet34)来进行迁移学习,看看我们的 AI 图片分类器,能否较好地识别两个队的队服。

下面是两队的队服示例:

巴塞罗那队服 巴塞尔队服

任务定义:区分图片是「巴塞罗那队队服」还是「巴赛尔队队服」?

数据集

  • 通过开源的 Python 脚本 ,根据关键词来批量谷歌图片里的图像。每队批量下载了140张图片,训练集每队约83张图片,验证集每队约60张图片。
  • 巴塞罗那红蓝箭条衫队服
巴萨球衣图片下载过程 巴萨球衣图片概述
  • 巴赛尔红蓝箭条衫队服
巴赛尔球衣图片下载过程 巴赛尔球衣图片概述

训练

预训练模型

定义数据路径及图片大小(324*324)

PATH = "./data/basel_or_barcelona/"
sz=324

设置预训练模型为resnet34,然后生成模型,学习率设为0.01,并训练10次

arch=resnet34
data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz)) 
learn = ConvLearner.pretrained(arch, data, precompute=True) 
learn.fit(0.01, 10)

100% 10/10 [00:01<00:00, 5.81it/s]
epoch      trn_loss   val_loss   accuracy                
    0      0.773914   0.711084   0.586777  
    1      0.694298   0.670647   0.644628                
    2      0.597124   0.62849    0.710744                
    3      0.505373   0.600698   0.735537                
    4      0.450973   0.583853   0.760331        
    5      0.42152    0.595698   0.768595        
    6      0.373794   0.610405   0.785124        
    7      0.33496    0.624578   0.77686                 
    8      0.304601   0.634755   0.77686         
    9      0.277069   0.641988   0.77686     

采用预训练的模型,120多张的测试数据,准确率达到了77.6%,还不错。这是最简单的办法,fast.ai课程还讲授了其他的提示准确率的办法,对于这个小型的足球图片数据集,我来实验一下,看看准确率能否有明显提升。

寻找最佳学习率

使用了一个寻找最佳学习率的函数,但是数据图里是空的,没有曲线,暂时无解,先跳过,我决定沿用0.01的学习率。

数据扩充(Data Argumentation)

tfms = tfms_from_model(resnet34, sz, aug_tfms=transforms_side_on, max_zoom=1.5)

随机对图片进行水平旋转,并放大1.5倍

image.png
data = ImageClassifierData.from_paths(PATH, tfms=tfms)
learn = ConvLearner.pretrained(arch, data, precompute=True)
learn.fit(1e-2, 1)

用了 数据扩充(Data Argumentation) 之后,第一次训练,准确率较低,只有52.8%。

learn.precompute=False
learn.fit(1e-2, 3, cycle_len=1)
epoch      trn_loss   val_loss   accuracy                
    0      0.711067   0.799083   0.561983  
    1      0.673241   0.670506   0.652893                
    2      0.632643   0.606632   0.644628

解冻

之前训练的是最后一层,通过 Unfreeze函数,“解冻”所有神经层,进一步做 Fine-Tuning 参数微调,并且,不同深浅的神经层,采用不同的学习率。

learn.unfreeze()
lr=np.array([1e-4,1e-3,1e-2])
learn.fit(lr, 3, cycle_len=1, cycle_mult=2)
epoch      trn_loss   val_loss   accuracy                
    0      0.577406   0.585321   0.652893  
    1      0.574191   0.495774   0.727273                
    2      0.505914   0.456579   0.752066                 
    3      0.456115   0.393354   0.801653                
    4      0.408209   0.363914   0.818182                
    5      0.370777   0.353467   0.834711                
    6      0.342305   0.350941   0.842975 

learn.fit(lr, 6, cycle_len=1, cycle_mult=2)
epoch      trn_loss   val_loss   accuracy                
    0      0.224649   0.351165   0.842975  
    1      0.201187   0.343334   0.867769                
    2      0.192907   0.339028   0.867769                
    3      0.187865   0.32258    0.867769                
    4      0.167363   0.309566   0.867769                
    5      0.165586   0.302056   0.859504                
    6      0.162697   0.303541   0.859504                
    7      0.160678   0.301258   0.867769                
    8      0.150494   0.315228   0.876033                
    9      0.149012   0.333124   0.876033                
    10     0.140085   0.341198   0.884298                
    11     0.133135   0.343363   0.884298                
    12     0.125423   0.339807   0.884298                
    13     0.1168     0.33534    0.884298                
    14     0.109273   0.33899    0.884298                
    15     0.10801    0.32313    0.884298                
    16     0.101856   0.313022   0.892562                
    17     0.099126   0.29976    0.884298                 
    18     0.094342   0.293067   0.884298                 
    19     0.090559   0.291236   0.884298                 
    20     0.089409   0.294657   0.884298                 
    21     0.085061   0.292772   0.884298                 
    22     0.080943   0.294916   0.884298                 
    23     0.077687   0.289591   0.884298                 
    24     0.073291   0.290864   0.884298                  
    25     0.070136   0.289896   0.884298                  
    26     0.071041   0.291556   0.876033                 
    27     0.067812   0.287182   0.884298                  
    28     0.06408    0.287058   0.884298                 
    29     0.062913   0.288546   0.884298                  
    30     0.060431   0.286025   0.884298                 
    31     0.060556   0.284598   0.884298                 
    32     0.058281   0.290405   0.884298                 
    33     0.056919   0.295286   0.892562                  
    34     0.054588   0.29396    0.900826                 
    35     0.052521   0.292013   0.892562                 
    36     0.051742   0.281408   0.892562                  
    37     0.050915   0.274899   0.892562                  
    38     0.048206   0.265904   0.892562                 
    39     0.048089   0.2706     0.892562                 
    40     0.045976   0.286523   0.892562                 
    41     0.046062   0.297521   0.892562                  
    42     0.043593   0.297748   0.900826                 
    43     0.041344   0.293669   0.900826                 
    44     0.040934   0.297795   0.900826                 
    45     0.039792   0.300095   0.900826                 
    46     0.038014   0.300615   0.900826                 
    47     0.037288   0.298393   0.892562                 
    48     0.037229   0.294151   0.900826                 
    49     0.035728   0.290464   0.884298                 
    50     0.034198   0.286805   0.884298                 
    51     0.033855   0.286428   0.876033                 
    52     0.032238   0.282124   0.884298                 
    53     0.031438   0.276031   0.892562                 
    54     0.031037   0.282506   0.876033                 
    55     0.02959    0.285489   0.884298                 
    56     0.028276   0.282334   0.892562                 
    57     0.027054   0.278635   0.892562                 
    58     0.025802   0.281353   0.892562                 
    59     0.024861   0.282644   0.892562                 
    60     0.023592   0.281519   0.892562                 
    61     0.023124   0.283864   0.884298                 
    62     0.022198   0.284992   0.892562
log_preds,y = learn.TTA()
probs = np.mean(np.exp(log_preds),0)
accuracy_np(probs, y)
0.8925619834710744

采用了 TTA(测试集也使用 Data Argumentation) 之后,最后的准确率达到了89.2%。

结果

混淆矩阵(Confusion Matrix)

画出一个混淆矩阵——查看不同分类,识别错误的图片都有哪些?

Confusion Matrix

看图可知

巴塞罗那队服:54张正确,6张错误。
巴塞尔队服:54张正确,7张错误。

下面我们看看具体哪些图片识别错了:

第一行几张图片,判断错误还是有点奇怪的,因为图片本身有着巴塞罗那队的标志,标志人物梅西和队徽。也许是训练集还太小的缘故吧。

第二行判断错误的巴赛尔队服,第二张是异常图片,呈黑白色。其他三张与训练集里典型的巴赛尔队服还是有点区别的。这样也许解释得过去。

参考资料: