Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results

半监督3D医学图像分割(一):Mean Teacher

The pipeline of the mean-teacher framework for classification

研究背景

半监督3D医学图像分割(一):Mean Teacher

自监督学习先使用大量无标签的数据集,通过对比学习和图像重建等方式构建损失函数,进行预训练,然后在特定任务上使用有标签的数据集进行微调。

半监督学习则是将少量有标注的数据和大量无标注的数据直接输入到网络中,构建一致性损失或者多任务学习,达到比单独用有标注数据集更好的结果。

网络结构

下面是我参考Mean Teacher论文里的方法,结合图像分割画的网络图。

半监督3D医学图像分割(一):Mean Teacher

The pipeline of the mean-teacher framework for semi-supervised segmentation

网络分为两部分,学生网络和教师网络,教师网络的参数重是冻结的,通过指数滑动平均从学生网络迁移更新。

同时输入有标签的图像和无标签的图像,同一张图像加上独立的随机噪声分别输入到学生网络和教师网络中。

损失由两部分组成,有标签的数据做分割损失,无标签的图像做一致性损失(有标签的也可以做一致性损失)。

个人认为,Mean Teacher网络的训练是一个求同存异的过程,输入的图像略有差异,网络参数略有差异,我们假设网络训练好后完全收敛,此时学生网络和教师网络的参数应该是非常接近的,也具备良好的去噪能力,那么一致性损失就会很小;反之,如果网络没有收敛,一致性损失也不会收敛。

指数滑动平均

Exponential moving average (EMA ):

θ
t

=
α
θ
t

1

+
(
1

α
)
θ
t
\theta_t' = \alpha \theta'_{t-1} + (1-\alpha)\theta_t
θt=αθt1+(1α)θt

损失函数

θ

=
a
r
g
m
i
n
θ

i
=
1
N
L
s
e
g
(
f
(
x
i
;
θ
)
,
y
i
)
+
λ

i
=
N
+
1
N
+
M
L
c
o
n
(
f
(
x
i
;
θ
,
η
s
)
,
f
(
x
i
;
θ

,
η
t
)
)
\theta^* = argmin_{\theta} \prod_{i=1}^{N}L_{seg}(f(x_i;\theta),y_i) + \lambda\prod_{i=N+1}^{N+M}L_{con}(f(x_i;\theta,\eta^s),f(x_i;\theta',\eta^t))
θ=argminθi=1NLseg(f(xi;θ),yi)+λi=N+1N+MLcon(f(xi;θ,ηs),f(xi;θ,ηt))

代码解读

LASeg: 2018 Left Atrium Segmentation (MRI) (github.com)

运行:

python train_mean_teacher.py

对比只使用有标签部分的数据:

python train_sup.py

使用的数据集仍然是Left Atrium (LA) MR dataset ,是在上一篇博文LAHeart2018左心房分割实战的基础上实现的,参考https://github.com/yulequan/UA-MT

1.TwoStreamBatchSampler

肯定很多人想问,如何从dataset中采样,才能在每个 batch size 中包含有标签的数据和无标签的数据

import itertools
import numpy as np
from torch.utils.data.sampler import Sampler
class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices
    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        # 有标签的索引
        self.primary_indices = primary_indices
        # 无标签的索引
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size
        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0
    def __iter__(self):
        # 随机打乱索引顺序
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )
    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size
def iterate_once(iterable):
    # print('shuffle labeled_idxs')
    return np.random.permutation(iterable)
def iterate_eternally(indices):
    # print('shuffle unlabeled_idxs')
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())
def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    return zip(*args)
if __name__ == '__main__':
    labeled_idxs = list(range(12))
    unlabeled_idxs = list(range(12,60))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, 4, 2)
    for _ in range(2):
        i = 0
        for x in batch_sampler:
            i += 1
            print('%02d' % i, '\t', x)
shuffle labeled_idxs
shuffle unlabeled_idxs
01 	 (2, 7, 46, 12)
02 	 (9, 3, 25, 50)
03 	 (8, 0, 15, 49)
04 	 (6, 11, 14, 41)
05 	 (1, 10, 37, 19)
06 	 (5, 4, 34, 35)
shuffle labeled_idxs
shuffle unlabeled_idxs
01 	 (0, 1, 22, 17)
02 	 (10, 7, 55, 19)
03 	 (6, 11, 53, 21)
04 	 (2, 4, 49, 27)
05 	 (3, 8, 41, 36)
06 	 (9, 5, 48, 44)

2.随机噪声

代码只在教师网络的输入加了随机噪声,学生网络的输入没有加噪声

noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)

其实学生网络和教师网络分别加上随机噪声,跟只给一边网络加噪声的效果是差不多的,都是为了制造一点差异性。

3.指数滑动平均(EMA)

student network 和 teacher network 结构相同,teacher network的参数冻结,不参与反向传播

    def create_model(ema=False):
        # Network definition
        net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
        model = net.cuda()
        if ema:
            for param in model.parameters():
                param.detach_()  # 切断反向传播
        return model
    model = create_model()
    ema_model = create_model(ema=True)

权重迁移

θ
t

=
α
θ
t

1

+
(
1

α
)
θ
t
\theta_t' = \alpha \theta'_{t-1} + (1-\alpha)\theta_t
θt=αθt1+(1α)θt

def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

4.损失函数

分割损失

L
t
o
t
a
l
=
L
d
i
c
e
+
L
C
E
L_{total} = L_{dice} + L_{CE}
Ltotal=Ldice+LCE

loss_seg = F.cross_entropy(outputs[:labeled_bs], label_batch[:labeled_bs])
outputs_soft = F.softmax(outputs, dim=1)
loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1)
supervised_loss = 0.5 * (loss_seg + loss_seg_dice)

一致性损失

L
c
o
n
=


f
(
x
i
;
θ
,
η
s
)
,
f
(
x
i
;
θ

,
η
t
)


2
L_{con} = ||f(x_i;\theta,\eta^s),f(x_i;\theta',\eta^t)||^2
Lcon=∣∣f(xi;θ,ηs),f(xi;θ,ηt)2

损失权重

半监督3D医学图像分割(一):Mean Teacher


λ
(
t
)
=
ω
m
a
x

e

5
(
1

t
t
m
a
x
)
2
\lambda(t) = \omega_{max} \cdot e^{-5(1-\frac{t}{t_{max}})^2}
λ(t)=ωmaxe5(1tmaxt)2

# 每150个iteration更新一次损失权重
consistency_weight = get_current_consistency_weight(iter_num // 150)
consistency_dist = consistency_criterion(outputs[labeled_bs:], ema_output)
consistency_loss = consistency_weight * consistency_dist

一致性损失的权重随着训练周期逐渐增加,防止网络训练前期被无意义的一致性目标影响。

def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)

ramps.sigmoid_rampup

def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

完整训练代码:

import os
import sys
from tqdm import tqdm
from tensorboardX import SummaryWriter
import argparse
import logging
import time
import random
import torch
import torch.optim as optim
from torchvision import transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from networks.vnet import VNet
from utils import ramps, losses
from dataloaders.la_heart import *
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='LA', help='dataset_name')
parser.add_argument('--root_path', type=str, default='/***/data_set/LASet/data',
                    help='Name of Experiment')
parser.add_argument('--exp', type=str, default='vnet', help='model_name')
parser.add_argument('--model', type=str, default='MT', help='model_name')
parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')
parser.add_argument('--labeled_bs', type=int, default=2, help='labeled_batch_size per gpu')
parser.add_argument('--labelnum', type=int, default=25, help='trained samples')
parser.add_argument('--max_samples', type=int, default=123, help='all samples')
parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train')
parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training')
parser.add_argument('--seed', type=int, default=1337, help='random seed')
parser.add_argument('--gpu', type=str, default='1', help='GPU to use')
### costs
parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay')
parser.add_argument('--consistency_type', type=str, default="mse", help='consistency_type')
parser.add_argument('--consistency', type=float, default=0.1, help='consistency')
parser.add_argument('--consistency_rampup', type=float, default=40.0, help='consistency_rampup')
args = parser.parse_args()
patch_size = (112, 112, 80)
snapshot_path = "model/{}_{}_{}_labeled/{}".format(args.dataset_name, args.exp, args.labelnum, args.model)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
batch_size = args.batch_size * len(args.gpu.split(','))
max_iterations = args.max_iterations
base_lr = args.base_lr
labeled_bs = args.labeled_bs
if args.deterministic:
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
num_classes = 2
patch_size = (112, 112, 80)
def cal_dice(output, target, eps=1e-3):
    output = torch.argmax(output,dim=1)
    inter = torch.sum(output * target) + eps
    union = torch.sum(output) + torch.sum(target) + eps * 2
    dice = 2 * inter / union
    return dice
def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
if __name__ == "__main__":
    # make logger file
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
    logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))
    def create_model(ema=False):
        # Network definition
        net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
        model = net.cuda()
        if ema:
            for param in model.parameters():
                param.detach_()
        return model
    model = create_model()
    ema_model = create_model(ema=True)
    db_train = LAHeart(base_dir=args.root_path,
                       split='train',
                       transform=transforms.Compose([
                           RandomRotFlip(),
                           RandomCrop(patch_size),
                           ToTensor(),
                       ]))
    db_test = LAHeart(base_dir=args.root_path,
                      split='test',
                      transform=transforms.Compose([
                          CenterCrop(patch_size),
                          ToTensor()
                      ]))
    labeled_idxs = list(range(args.labelnum))
    unlabeled_idxs = list(range(args.labelnum, args.max_samples))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size - labeled_bs)
    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)
    train_loader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True,
                              worker_init_fn=worker_init_fn)
    test_loader = DataLoader(db_test, batch_size=1,shuffle=False, num_workers=4, pin_memory=True)
    model.train()
    ema_model.train()
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
    if args.consistency_type == 'mse':
        consistency_criterion = losses.softmax_mse_loss
    elif args.consistency_type == 'kl':
        consistency_criterion = losses.softmax_kl_loss
    else:
        assert False, args.consistency_type
    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} itertations per epoch".format(len(train_loader)))
    iter_num = 0
    best_dice = 0
    max_epoch = max_iterations // len(train_loader) + 1
    lr_ = base_lr
    model.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, sampled_batch in enumerate(train_loader):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            unlabeled_volume_batch = volume_batch[labeled_bs:]
            noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
            ema_inputs = unlabeled_volume_batch + noise
            outputs = model(volume_batch)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)
            # calculate the loss
            loss_seg = F.cross_entropy(outputs[:labeled_bs], label_batch[:labeled_bs])
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1)
            supervised_loss = 0.5 * (loss_seg + loss_seg_dice)
            consistency_weight = get_current_consistency_weight(iter_num // 150)
            consistency_dist = consistency_criterion(outputs[labeled_bs:], ema_output) # (batch, 2, 112,112,80)
            consistency_loss = consistency_weight * consistency_dist
            loss = supervised_loss + consistency_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)
            iter_num = iter_num + 1
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('train/consistency_weight', consistency_weight, iter_num)
            writer.add_scalar('train/consistency_dist', consistency_dist, iter_num)
            logging.info('iteration %d : loss : %f cons_dist: %f, loss_weight: %f' %
                         (iter_num, loss.item(), consistency_dist.item(), consistency_weight))
            if iter_num >= 800 and iter_num % 200 == 0:
                model.eval()
                with torch.no_grad():
                    dice_sample = 0
                    for sampled_batch in test_loader:
                        img, lbl = sampled_batch['image'].cuda(), sampled_batch['label'].cuda()
                        outputs = model(img)
                        dice_once = cal_dice(outputs,lbl)
                        dice_sample += dice_once
                    dice_sample = dice_sample / len(test_loader)
                    print('Average center dice:{:.3f}'.format(dice_sample))
                if dice_sample > best_dice:
                    best_dice = dice_sample
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, best_dice))
                    save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best_path)
                    logging.info("save best model to {}".format(save_mode_path))
                writer.add_scalar('Var_dice/Dice', dice_sample, iter_num)
                writer.add_scalar('Var_dice/Best_dice', best_dice, iter_num)
                model.train()
            if iter_num >= max_iterations:
                break
            time1 = time.time()
        if iter_num >= max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth')
    torch.save(model.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()

需要注意的是,训练过程中记录的dice并不准确,真实指标需要运行inference.py中滑动窗口法进行推理。

实验结果

分割结果重建图:蓝色是金标签,红色是模型预测结果

半监督3D医学图像分割(一):Mean Teacher

半监督3D医学图像分割(一):Mean Teacher

半监督3D医学图像分割(一):Mean Teacher

不管是评价指标,还是可视化结果,在使用同样数量有标签的数据集的情况下,半监督训练结果相比有监督结果都有显著提升。

参考资料:

Tarvainen A, Valpola H. Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results[J]. Advances in neural information processing systems, 2017, 30.

项目地址:

LASeg: 2018 Left Atrium Segmentation (MRI)

发表回复