发布时间:2023-04-21 文章分类:电脑百科 投稿人:李佳 字号: 默认 | | 超大 打印

CycleGAN(基于PyTorch框架)

  • 0.论文简介
    • 0.1本文主要的工作
    • 0.2引言
    • 0.3方法
  • 1.代码结构
    • 1.1根目录中的文件
      • 1.1.1 train.py文件
      • 1.1.2 test.py文件
    • 1.2根目录中的文件夹
      • 1.2.1 docs文件夹
      • 1.2.2 .git文件夹
      • 1.2.3 data文件夹
        • 1.2.3.1 template_dataset.py
        • 1.2.3.2 __init__.py
        • 1.2.3.3 base_dataset.py
        • 1.2.3.4 image_folder.py
        • 1.2.3.5 aligned_dataset.py
        • 1.2.3.6 unaligned_dataset.py
        • 1.2.3.7 single_dataset.py
        • 1.2.3.8 colorization_dataset.py
      • 1.2.4 imgs文件夹
      • 1.2.5 models文件夹
        • 1.2.5.1 __init__.py
        • 1.2.5.2 base_model.py
        • 1.2.5.3 template_model.py
        • 1.2.5.4 network.py
        • 1.2.5.5 cycle_gan_model.py
        • 1.2.5.6 pix2pix_model.py
        • 1.2.5.7 colorazation_model.py
        • 1.2.5.8 test_model.py
      • 1.2.6 option文件夹
      • 1.2.7 scripts文件夹
      • 1.2.8 util文件夹
  • 2.复现过程
    • 2.1 准备过程
    • 2.2 训练过程
    • 2.3 测试过程

0.论文简介

CycleGAN是一款实现风格迁移的模型,其论文可以在各大平台找到。我们在aixiv上可以找到:https://arxiv.org/pdf/1703.10593.pdf。

我们复现的的代码是来自下面这个github仓库:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix。

虽然看起来很简单,不过对于刚入门的新手来说难度还真不低,下面让我们来仔细看看代码的结构。

0.1本文主要的工作

  1. 在训练集缺失的情况下,将图片从某一种风格转移到另一种风格。希望学习到一种映射规则G,使得G(X)=Y。
  2. 希望找到G的可逆变换F使得F(G(X))=X。
  3. 本文进行了很多任务方面的尝试,并和之前的方法作对比。

0.2引言

  1. 整体步骤概述:捕获风格1的图像特征,在没有训练集提示的情况下将其转化为风格2的特征。
  2. 研究背景:获取不同风格的成对数据存在一定的困难。
  3. 具体实施:虽无法获取图像级别的监督(缺乏有标注的图像对),但可获取集合级别的监督(X和Y中各有一组图像,我们不知道X中的某张图对应Y中的哪张图,但我们可以知道X和Y这两个集合是彼此对应的)。经过训练,使得y ̂=G(X)与y无法区分,也就是使得y ̂和y的分布尽可能一致。
  4. 遇到的问题:第一是无法确定哪一个才是有意义的配对(可能会有很多组映射G),第二是独立地去优化对抗损失是一件很困难的事(配对程序会将所有输入的映射图像都转换成同一个输出图像)。
  5. 解决措施:添加循环一致性损失,把F(G(X))与x、G(F(Y))与y的损失也都加到对抗网络的损失里。
    【论文笔记】CycleGAN(基于PyTorch框架)

0.3方法

  1. 整体损失项:损失函数一共是4个部分,其中有两个是对抗损失,两个是循环一致性损失。四项分别是:(1)D_Y:用于衡量y ̂=G(x)与y的损失;(2)D_X:用于衡量x ̂=F(y)与x的损失;(3)F(G(x))与x之间的损失;(4)G(F(y))与y的损失。
  2. 对抗损失:基本公式如下。
    【论文笔记】CycleGAN(基于PyTorch框架)
    优化思想是:
    【论文笔记】CycleGAN(基于PyTorch框架)
  3. 循环一致性损失。基本公式如下:
    【论文笔记】CycleGAN(基于PyTorch框架)
  4. 总损失函数
    【论文笔记】CycleGAN(基于PyTorch框架)
    优化目标是
    【论文笔记】CycleGAN(基于PyTorch框架)

1.代码结构

1.1根目录中的文件

我们将目光聚集到根目录这个位置。
【论文笔记】CycleGAN(基于PyTorch框架)
这里面,我们先来看根目录下的文件:

  1. README.md就是说明书。
  2. requirements.txt是说明这样的一个仓库所需要的各种包的版本。
  3. .gitignore文件是那些上传的时候要忽略的东西(不是所有的数据都需要被commit到仓库里,有时候可能只需要交源码)。
  4. LICENSE文件是许可证文件,会告知我们有什么样的权限(比如这个项目的代码我们可以下载并在本地中修改,但是不能修改远程仓库的内容)。
  5. .replit文件提供了所使用的信息,便于在浏览器中运行代码,这样一来就无需在本地配置环境。这是在使用云编辑器repl.it的时候可能会用到的设置。
  6. environment.yml这个文件相当于Python+requirement.txt,我们可以直接使用conda env create -f environment.yml来创建一个environment.yml文件里指定的环境(里头什么样的包、环境名是什么、Python版本是多少)。当然,如果你现在就有这样的一个环境,想要导出这个环境所对应的environment.yml,只需要下面的命令就可以conda env export | grep -v "^prefix: " > environment.yml
  7. train.py是用于训练的主脚本,可以指定使用什么样的数据集和模型。我们使用–model选项可以指定使用什么样的模型(例如:pix2pix, cyclegan, colorization),通过–dataset_mode指定数据模式(例如:aligned, unaligned, single, colorization),通过–dataroot指定数据集路径,通过–name指定实验名称。这里可以列举一个命令供参考python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
  8. test.py是用于测试的主脚本,通过–checkpoints_dir可以设置模型读取的路径,通过–results_dir可以设置结果的保存路径,通过–dataroot可以设置数据集的路径,通过–name可以设置任务名称,通过–model设置所采用的模型。对于CycleGAN双向检验,可以利用命令python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan来实现,其中的–model cycle_gan会将数据导入的模式变为双向。对于CycleGAN的单项检验,可以利用命令python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout来实现。–no_dropout指的是不需要dropout;–model test指的是单向验证CycleGAN模型,这将使得–dataset_mode自动变成single,也就是导入单一集合的数据。
  9. CycleGAN.ipynb和pix2pix.ipynb里头是两个模型的运行教程(在jupyter notebook)。

1.1.1 train.py文件

21-25行不说了,都是导入一些基本的类。

第27行的意思是,如果这个脚本作为主脚本使用,那么就运行下方的东西。28行是先把TrainOption实例化成对象,然后用parse进行解析,这样形成一个结果,赋给opt,也就是说,opt解析出来的结果。29行是根据这个结果去创建数据集。30行获取数据集中样本的数量。31行不说了。

33行是创建模型。34行是根据opt创建合适的学习率调整策略、导入网络并打印。第35行是根据opt创建可视化实例。36行是训练迭代次数。

第38行是迭代过程的开始,opt.epoch_count是从哪个epoch开始,opt.n_epochs_decay是持续多少epoch。39-40行不说了,获取这个epoch开始的时间和本轮epoch导入数据的时间。41行是在本轮epoch当中的第几次迭代。42行是可视化机器的重置,保证在每个epoch里它至少有一次保存图片。43行是在每次epoch之前率先更新一下学习率。

第44行就是每个epoch内部的循环了,enumerate函数的作用是同时列出数据和下标,这个无需多说,注意这里的i是batch的编号,而data也不是一张图,而是一个batch的图。45行是本次iteration开始的时间。46-47行是说如果总迭代次数total_iters到了opt.print_freq的整倍数,就计算t_data,也就是本轮iteration开始的时刻到本轮epoch导入数据的时间已经过去了多久。49-50行指的是,一共多少个数据参与了训练以及本轮epoch里有多少数据参与了迭代。51行是把每一个数据解包,52行是参数优化,这些都是在model当中的basemodel.py当中定义的。

54-57行是数据可视化的部分,如果本轮epoch当中已经参与迭代的样本总数是opt.display_freq的整数倍,那么执行55-57行的操作。55行是返回一个叫save_result的布尔值,用于判定是否需要存出结果到html文件里。第56行是只有在着色任务中才有用,是展示图片的命令,其他的模型中compute_visuals函数只有一个命令,那就是pass。第57行则是存储到html文件里的命令,其中save_result就决定本行是否执行,可以参见util文件夹里的visualizer.py。

59-64行是打印的部分。如果本轮epoch当中已经参与迭代的样本总数是opt.print_freq的整数倍,那么执行60-64行的操作。第60行是获取当前的损失函数。第61行是计算每个图片所用的时间。第62行是输出当前的损失值,后面的参数含义大家可以点击util文件夹下的visuallizer.py去查看相应的函数。63-64行是损失值可视化的部分,如果window id of the web display这个值大于0,那么就利用plot_current_losses函数输出,参数的含义可以点击util文件夹下的visuallizer.py去查看相应的函数。

66-69行是保存权重文件的部分,如果本轮epoch当中已经参与迭代的样本总数是opt.save_latest_freq的整数倍,那么执行67-69行的操作。67行不说了,68行是设置保存后缀,69行是保存模型。

71行是重新获取时间。72-75行也是在保存模型,不过这次是在每个epoch结束的时候。

77行是输出,无需多言。

1.1.2 test.py文件

它将从’–checkpoints_dir’加载保存的模型,并将结果保存到’–results_dir’。它首先在给定opt选项的情况下创建模型和数据集。它将硬编码一些参数。然后,它对“–num_test”图像运行推断,并将结果保存到HTML文件中。

29-34行不用多说了,导入包。

36-39行是导入wandb包,可以帮我们记录超参数指标。

42-43行不说了,和上面的train.py有异曲同工之妙。45行,测试模式仅能使用单线程,至于哪一个线程,你可以去自己指定;46行,batch_size只能为1;47行是确定数据需不需要打乱;48行则是是否翻转;49行是放弃展示图片;50-52行不说了,上面的train.py里解释过是什么含义。

54-57行,没太看懂。不太熟悉wandb这个包的含义。

59-行,是在创建网站。60行是在确定地址。61-62行是要根据本轮迭代来确定网址的域名(整体)。63行不说了,就是打印一下结果。64行是确定网页的地址和标题(这块可能得在util文件夹里头找html.py)。68-69行是评估模式开启。

70-72行比较简单,不再重复。73-74行也比较容易理解,分别是解包数据、测试。75-78行可以参考注释,获取图像结果、获取图像路径,每隔5个图片打印一次。79行是保存图片到html中,参数的含义可以点击util文件夹下的visuallizer.py去查看相应的函数。最后80行,保存html。

1.2根目录中的文件夹

之后,我们再来看看各个文件夹都是怎么回事。

1.2.1 docs文件夹

docs文件夹不多说了,里头是各种说明文档。

1.2.2 .git文件夹

.git文件夹也不多说了,这是用于分布式版本管理的工具,具体什么是git请自行百度。在我【教程搬运】的专栏下也有专门介绍git的博文。

1.2.3 data文件夹

data文件夹,里头是各种和数据加载、处理的模块。里头的__init__.py是一个接口文件,basedataset.py是一个基础文件(包含一些常见的转换功能,有点相当于公用的“基类”,不知道怎么描述了),template_dataset.py这是一个模板,相当于示例文件。其它的都是具体的数据集对应的文件了。

1.2.3.1 template_dataset.py

首先让我们聚焦一个模板文件,也就是template_dataset.py,在这里我们仅仅给出一些说明,读完之后觉得抽象也没关系,我们后面还有例子(1.2.3.2节之后),慢慢体会,慢慢读就可以了。

这个文件主要起到一个模板的作用,是一个参考,具体说明如下:

好了,刚刚我们已经说明了这个模板函数的作用,下面让我们详细地说一下要实现的是个具体功能:

为了便于大家理解__init__函数,我列举了single_dataset.py这个脚本里的内容进行举例。

    def __init__(self, opt):
        """Initialize this dataset class.
        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        # 调用BaseDataset.__init__方法,将创建好的对象self和你在训练命令里的opt传入。
        BaseDataset.__init__(self, opt)
        # 用opt.dataroot解析出数据路径,opt.max_dataset_size解析出最大允许数据集大小。
        # make_dataset函数是用来制作数据集,返回值是一个图片组成的列表。
        # 最后使用sorted函数对图片进行一下排序。
        self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
        # 这是对输入进行处理的部分。
        input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
        self.transform = get_transform(opt, grayscale=(input_nc == 1))

读完上面我写的,你可能一头雾水,没事,我们马上就利用示例来分析。

1.2.3.2 init.py

这个脚本主要是提供接口。里头分成两部分,第一部分有三个函数,第二部分是一个类,里头也有几个函数。

让我们先进入第一部分。

来让我们进入第二部分,CustomDatasetDataLoader这个类。

总而言之,这个脚本就是围绕着create_dataset这样一个函数展开的,目的就是根据opt制作数据集,只是一个接口而已。

1.2.3.3 base_dataset.py

这个文件实现了一些基础的数据读取、转换功能。在BaseDataset类中,所有的函数基本都留空,这个是要根据具体的数据集来确定的,所以base_dataset.py里的这块是空的,可以在template_dataset.py查到每个函数对应的用法示例,而实际的应用就是single_dataset.py、colorizationo_dataset.py等文件里的用法。

我们来看看在BaseDataset类之后有什么函数。

1.2.3.4 image_folder.py

这是一个和图片读取有关的脚本。1-16行不说了,很容易读懂,我们从后面说起。

1.2.3.5 aligned_dataset.py

先插一句,为什么要准备对齐的数据,这是因为pix2pix模型要用这个。

aligned_dataset.py包含一个可以加载图像对的数据集类。它设置好了一个图像目录/path/to/data/train,其中包含 {A,B} 形式的图像对。在测试期间,您需要准备一个目录/path/to/data/test作为测试数据。那么如何准备对齐的数据集呢,方法在这里。您也可以参阅/pytorch-CycleGAN-and-pix2pix/datasets/combine_A_and_B.py这个脚本,它就是我们在准备对齐数据的时候需要执行的脚本。我们在这里权且先不提,后面还会再说。

让我们来逐个分析在Aligned_dataset类里面的3个函数。

1.2.3.6 unaligned_dataset.py

unaligned_dataset.py包含一个可以加载未对齐/未配对数据集的数据集类。我们可以使用数据集标志训练模型--dataroot /path/to/data

依然是同样的三个函数,让我们来一一解读。

1.2.3.7 single_dataset.py

single_dataset.py包含一个数据集类,可以加载由path指定的一组单个图像–dataroot /path/to/data。它只能用于使用模型选项为一侧生成CycleGAN结果-model test。

里面的三个函数完全就是前面1.2.3.6的翻版,我就不再一一赘述了。

1.2.3.8 colorization_dataset.py

colorization_dataset.py实现了一个数据集类,可以加载一组 RGB 的自然图像,并将 RGB 格式转换为Lab颜色空间中的 (L, ab) 对。基于 pix2pix 的着色模型 ( --model colorization) 需要它。

1.2.4 imgs文件夹

这里是两个示例图片,也可以被用来存放效果图。

1.2.5 models文件夹

模型目录包含与目标函数、优化和网络架构相关的模块。如果要添加一个名为的自定义模型类dummy,那么必须要添加一个名为的文件dummy_model.py并定义一个DummyModel类,这个类继承父类BaseModel。您需要实现四个功能:__init__初始化类(您需要先调用BaseModel.init(self, opt))、set_input(从数据集中解包数据并应用预处理)、forward(生成中间结果)、optimize_parameters(计算损失、梯度和更新网络权重),以及可选的modify_commandline_options(添加特定于模型的选项并设置默认选项)。现在您可以通过指定 flag 来使用模型类–model dummy。有关示例,请参见我们的模板模型类。下面我们详细解释每个文件。

1.2.5.1 init.py

_init_.py 实现了这个包与训练和测试脚本之间的接口。 train.py并在给定选项的情况下test.py调用from models import create_modelandmodel = create_model(opt)创建模型opt。您还需要调用model.setup(opt)以正确初始化模型。

1.2.5.2 base_model.py

base_model.py为模型实现了一个抽象基类 ( ABC )。它还包括常用的辅助函数(例如 , setup, test, update_learning_rate, ) save_networks,load_networks以后可以在子类中使用。

我们来看看BaseModel这个类里包含了什么东西。

1.2.5.3 template_model.py

这个脚本主要是用来做模板之用。该模块为用户提供了一个模板来实现自定义模型,可以指定“–model template”来使用此模型,类名应与文件名及其模型选项一致。文件名应该是_dataset.py,类名应该是Dataset.py。它实现了一个简单的基于回归损失的图像到图像的转换baseline。给定输入输出对(data_A,data_B),它学习可以最小化以下L1损失的网络netG,使得:

min_<netG> ||netG(data_A) - data_B||_1

1.2.5.4 network.py

  1. 第1-5行是导入一些基础的包。前两行比较简单,第三行的作用是让我们可以使用torch.nn.init进行初始化参数,第四行这个包是一个针对函数进行操作的函数,第五行是导入优化器。

  2. 下面介绍Identity这个类。

  1. 下面介绍get_norm_layer这个函数。
  1. 下面介绍get_scheduler这个函数。
  1. 下面介绍init_weights这个函数。
  1. 下面介绍init_net函数。
  1. 下面介绍define_G函数。
  1. 下面介绍define_D函数。
  1. GANLoss类

这个类是用来创建不同的GAN对象的,按照注释内所说的,他把“创建与输入大小相同的目标标签张量”这一任务抽象化了。换言之这是用来创建标签用的。

  1. cal_gradient_penalty函数。
  1. 类ResnetGenerator:
    这个类是用来定义残差结构的。
  1. 类ResnetBlock:
    这是用来定义残差块结构的。
  1. 类UnetGenerator:
    这个是用来定义UNet生成器结构的。
  1. 类UnetSkipConnectionBlock:
    这个类用于定义UNet里面的小结构。
  1. 类NLayerDiscriminator
    这个类是用来创建判别器的。
  1. 类PixelDiscriminator:

至此,network.py长达616行的脚本解读完毕。

1.2.5.5 cycle_gan_model.py

第2行的itertool是一个Python提供的迭代工具箱,第3行image_pool.py实现了一个存储先前生成的图像的图像缓冲区。这个缓冲区使我们能够使用生成图像的历史而不是最新生成器生成的图像来更新鉴别器。后面就是CycleGANModel这个类,第12-15行有非常明确的概述,dataset模式要使用“非对齐、未配对”,它会使用带有9个残差块的生成器网络结构,并使用PatchGAN这样的判别器结构,以及一个最小方根的GANs对象(就是说损失函数使用LSGAN,平方损失),我们来逐个分析其中的函数。

1.2.5.6 pix2pix_model.py

该任务和我们的CycleGAN任务暂时无关,是pix2pix里的任务,我们暂时不谈。

1.2.5.7 colorazation_model.py

该任务和我们的CycleGAN任务暂时无关,是pix2pix里的任务,我们暂时不谈。

1.2.5.8 test_model.py

将于后续补充。

1.2.6 option文件夹

这个文件夹里的4个文件都是用来规定训练命令里的选项之用。init.py是一个接口类型的文件,没有什么用。base_option.py是一些基础性的选项,它还实现了一些辅助功能,例如解析、打印和保存选项。它还收集modify_commandline_options数据集类和模型类的函数中定义的附加选项;而剩下两个文件则分别对应着训练时、测试时的一些选项,原脚本已经把里面的参数解释的明明白白了。

这一部分不是重点,建议大家不要过分拘泥在这里。

1.2.7 scripts文件夹

这里存放了一些.sh脚本,关于shell的教程可以参见这里。

这一部分不是重点,建议大家不要过分拘泥在这里。

1.2.8 util文件夹

这个文件夹内是一些辅助功能。

这一部分不是重点,建议大家不要过分拘泥在这里。

2.复现过程

2.1 准备过程

  1. 先要克隆远程仓库git clone https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 并将当前目录调整到项目根目录下:cd pytorch-CycleGAN-and-pix2pix
  2. 安装所需要的包pip install -r requirements.txt
  3. 下载数据,这里要用到的命令是bash ./datasets/download_cyclegan_dataset.sh maps,当然你可以将maps换成其它数据集的名字,也可以自己制作数据集。下载之后的数据集被自动存储在了datasets文件夹下面相应的文件夹里,例如:map。然后,这个文件夹中有train、test、val、trainA、trainB、testA、testB、valA、valB九个文件夹,A、B分别象征着风格A和风格B,而不带AB的则是两个风格合并到一起的图片。

2.2 训练过程

  1. 下面转入训练部分,你可以用python脚本训练,也可以用.sh脚本训练,原github里提供了两个训练命令。这里可以看看一个训练博客。这里我们选用的命令是python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

训练结果是这样的:
【论文笔记】CycleGAN(基于PyTorch框架)
里面的含义是这样的:第几个epoch、第几次迭代、总用时、训练一个data的用时,后面就是各种损失。

训练后的结果(包括模型和示例图片)都保存在根目录下一个叫checkpoint的文件夹里。大家可以查看,image文件夹下是图片,其它.pth文件都是模型权重文件。如下图所示:
【论文笔记】CycleGAN(基于PyTorch框架)

2.3 测试过程

  1. 测试的命令是python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan。测试结束后,结果可以在下面标蓝的文件夹里找到。
    【论文笔记】CycleGAN(基于PyTorch框架)
    参考:https://blog.csdn.net/Joe9800/article/details/103224383