トリプレットロスのReidPytorch実装



Reid Pytorch Implementation Triplet Loss



トリプレットロスがリードの実現に追加されました。プロジェクトのディレクトリ構造を次の図に示します。



loss.pyコードは次のとおりです。

import torch import torch.nn as nn import torch.nn.functional as F def euclidean_dist(x,y): m,n = x.size(0),y.size(0) xx = torch.pow(x,2).sum(1,keepdim=True).expand(m,n) yy = torch.pow(y,2).sum(dim=1,keepdim=True).expand(n,m).t() dist = xx + yy dist.addmm_(1,-2,x,y.t()) dist = dist.clamp(min=1e-12).sqrt() return dist def cosine_dist(x,y): bs1, bs2 = x.size(0),y.size(0) frac_up = torch.matmul(x,y.transpose(0,1)) frac_down = (torch.sqrt( torch.pow(x,2).sum(dim=1) ).view(bs1,1).repeat(1,bs2)) * (torch.sqrt( torch.pow(y,2).sum(dim=1).view(1,bs2).repeat(bs1,1) ) ) cosine = frac_up/frac_down cos_d = 1 - cosine return cos_d def _batch_hard(mat_distance,mat_similarity,indice=False): sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-100000.0)*(1 - mat_similarity),dim=1, descending=True) hard_p = sorted_mat_distance[:,0] hard_p_indice = positive_indices[:,0] sorted_mat_distance, negative_indices = torch.sort( mat_distance + 100000.0 * mat_similarity,dim = 1,descending=False ) hard_n = sorted_mat_distance[:,0] hard_n_indice = negative_indices[:,0] if(indice): return hard_p, hard_n, hard_p_indice, hard_n_indice return hard_p, hard_n class TripletLoss(nn.Module): def __init__(self, margin=0.5, normalize_feature = True): super(TripletLoss, self).__init__() self.margin = margin self.normalize_feature = normalize_feature self.margin_loss = nn.MarginRankingLoss(margin = margin) def forward(self, emb, label): if self.normalize_feature: emb = F.normalize(emb) #print('emb') #print(emb) mat_dist = euclidean_dist(emb, emb) #print('mat_dist') assert mat_dist.size(0) == mat_dist.size(1) N = mat_dist.size(0) mat_sim = label.expand(N,N).eq(label.expand(N,N).t()).float() #print(mat_dist) #print(mat_sim) dist_ap, dist_an = _batch_hard(mat_dist, mat_sim) assert dist_an.size(0) == dist_ap.size(0) y = torch.ones_like(dist_ap) loss = self.margin_loss(dist_an, dist_ap, y) prec = (dist_an.data > dist_ap.data).sum() * 1.0 / y.size(0) return loss, prec # loss = nn.CrossEntropyLoss() # an = torch.randn(4,3) # y = torch.ones(4).long() # print(an) # print(y) # l = loss(an,y) # print(l) # l.backward() # print(an.grad)

トリプレットロスはその中で定義されています。これは非常に簡単に言うことができますが、実装で慎重に検討する必要のある領域がまだいくつかあります。コードに対する態度は、コードを見るだけでなく、パーカッションに対してもお勧めします。パーカッション中、思考中、トリックを学びながらです。以前にこのトリプレット損失を調べたとき、データセットがpytorchの形式である場合、誰がアンカーで、誰がポジティブで、誰がネガティブであるかを判断する方法を常に考えていました。これは、データローダーからのものである限り、コードをタップすることですべて理解できます。読み取られたデータのラベルは、誰がポジティブで誰がネガティブであるかを知るために変換できます。



model.pyのコードは次のとおりです。

import torch import torch.nn as nn from torchvision import models from torch.nn import functional as F class resnet_model(nn.Module): def __init__(self,cut_at_pooling=False, num_features=0, norm=False, dropout=0, num_classes=0 ): super(resnet_model,self).__init__() self.cut_at_pooling = cut_at_pooling resnet = models.resnet50(pretrained=False) resnet.load_state_dict(torch.load('./pretrain_model/resnet50.pth')) resnet.layer4[0].conv2.stride = (1,1) resnet.layer4[0].downsample[0].stride = (1,1) self.base = nn.Sequential( resnet.conv1, resnet.bn1, resnet.maxpool, resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4, ) self.gap = nn.AdaptiveAvgPool2d(1) if not self.cut_at_pooling: self.num_features = num_features self.norm = norm self.dropout = dropout self.has_embedding = num_features > 0 self.num_classes = num_classes out_planes = resnet.fc.in_features if self.has_embedding: self.feat = nn.Linear(out_planes, self.num_features) self.feat_bn = nn.BatchNorm1d(self.num_features) nn.init.kaiming_normal_(self.feat.weight,mode='fan_out') nn.init.constant_(self.feat.bias,0) else: self.num_features = out_planes self.feat_bn = nn.BatchNorm1d(self.num_features) self.feat_bn.bias.requires_grad_(False) if self.dropout > 0: self.drop = nn.Dropout(self.dropout) if self.num_classes > 0: self.classifier = nn.Linear(self.num_features,self.num_classes, bias=False) nn.init.normal_(self.classifier.weight, std=0.001) nn.init.constant_(self.feat_bn.weight, 1) nn.init.constant_(self.feat_bn.bias, 0) def forward(self,x,feature_withbn = False): x = self.base(x) x = self.gap(x) x = x.view(x.size(0), -1) if self.cut_at_pooling: return x if self.has_embedding: bn_x = self.feat_bn(self.feat(x)) else: bn_x = self.feat_bn(x) if self.training is False: bn_x = F.normalize(bn_x) return bn_x if self.norm: bn_x = F.normalize(bn_x) elif self.has_embedding: bn_x = F.relu(bn_x) if self.dropout > 0: bn_x = self.drop(bn_x) if self.num_classes > 0: prob = self.classifier( bn_x ) else: return x, bn_x if feature_withbn: return bn_x, prob return x, prob

Model.pyは、resnet50をバックボーンとして使用し、線形分類器を追加することで、言うことはあまりないようです。

reid.pyのコードは次のとおりです。



import torch import torch.nn as nn from torchvision import datasets, transforms from model import resnet_model from torch.optim import lr_scheduler import loss transform_list = [ transforms.Resize((256,128), interpolation=3), transforms.RandomHorizontalFlip(p=0.5), transforms.Pad(10), transforms.RandomCrop((256,128)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) ] transform_compose = transforms.Compose(transform_list) try_dataset1 = datasets.ImageFolder('./try_data1',transform_compose) try_dataloader1 = torch.utils.data.DataLoader( try_dataset1, batch_size=32,shuffle=True ) try_data1_len = len(try_dataset1) try_data1_class_name = try_dataset1.classes net = resnet_model(num_classes=try_data1_len) net.cuda() params = [] for key, value in net.named_parameters(): if not value.requires_grad: continue params += [ { 'params': [value], 'lr': 0.00035, 'weight_decay':5e-4} ] optimizer = torch.optim.Adam( params ) exp_lr_scheduler = lr_scheduler.StepLR( optimizer, step_size=10, gamma=0.1 ) criterion_ce = nn.CrossEntropyLoss() criterion_triple = loss.TripletLoss() triplet_loss_list = [] pre_loss_list = [] loss_list = [] acc_list = [] for epoch in range(30): print('epoch: {} / 30' .format(epoch + 1)) for data in try_dataloader1: input, labels = data input = input.cuda() labels = labels.cuda() features, pres = net(input) tri_loss, _ = criterion_triple( features, labels ) ce_loss = criterion_ce(features, labels) loss = tri_loss + ce_loss triplet_loss_list.append(tri_loss.item()) pre_loss_list.append(ce_loss.item()) loss_list.append(loss.item()) _, pid = torch.max(pres.data, dim = 1) acc = torch.sum( pid==labels.data )/pid.size(0) acc_list.append(acc.item()) optimizer.zero_grad() loss.backward() optimizer.step() exp_lr_scheduler.step() all_acc = sum(acc_list)/(len(acc_list)) all_triple_loss = sum(triplet_loss_list)/(len(triplet_loss_list)) all_pre_loss = sum(pre_loss_list)/(len(pre_loss_list)) all_loss = sum(loss_list)/(len(loss_list)) print('accuracy: {:.4f}'.format(all_acc)) print('triplet loss: {:.4f}:'.format(all_triple_loss)) print('predict loss: {:.4f}'.format(all_pre_loss)) print('loss : {:.4f}'.format(all_loss))

それを記録し、質問があればメッセージを残してください。最もシンプルで理解しやすいが、ネットワーク全体で最も完全なコード表示になるように努めてください。