×

PyTorch教程20.2之深度卷积生成对抗网络

消耗积分:2 | 格式:pdf | 大小:0.29 MB | 2023-06-05

fansz

分享资料个

20.1 节中,我们介绍了 GAN 工作原理背后的基本思想。我们展示了他们可以从一些简单的、易于采样的分布中抽取样本,比如均匀分布或正态分布,并将它们转换成看起来与某些数据集的分布相匹配的样本。虽然我们匹配 2D 高斯分布的示例说明了要点,但它并不是特别令人兴奋。

在本节中,我们将演示如何使用 GAN 生成逼真的图像。我们的模型将基于 Radford等人介绍的深度卷积 GAN (DCGAN)。2015 年我们将借用已经证明在判别计算机视觉问题上非常成功的卷积架构,并展示如何通过 GAN 来利用它们来生成逼真的图像。

import warnings
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
from mxnet import gluon, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()
import tensorflow as tf
from d2l import tensorflow as d2l

20.2.1。口袋妖怪数据集

我们将使用的数据集是从pokemondb获得的 Pokemon 精灵的集合 首先下载、提取和加载此数据集。

#@save
d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
              'c065c0e2593b8b161a2d7873e42418bf6a21106c')

data_dir = d2l.download_extract('pokemon')
pokemon = torchvision.datasets.ImageFolder(data_dir)
Downloading ../data/pokemon.zip from http://d2l-data.s3-accelerate.amazonaws.com/pokemon.zip...
#@save
d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
              'c065c0e2593b8b161a2d7873e42418bf6a21106c')

data_dir = d2l.download_extract('pokemon')
pokemon = gluon.data.vision.datasets.ImageFolderDataset(data_dir)
Downloading ../data/pokemon.zip from http://d2l-data.s3-accelerate.amazonaws.com/pokemon.zip...
#@save
d2l.DATA_HUB['pokemon'] = (d2l.DATA_URL + 'pokemon.zip',
              'c065c0e2593b8b161a2d7873e42418bf6a21106c')

data_dir = d2l.download_extract('pokemon')
batch_size = 256
pokemon = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir, batch_size=batch_size, image_size=(64, 64))
Downloading ../data/pokemon.zip from http://d2l-data.s3-accelerate.amazonaws.com/pokemon.zip...
Found 40597 files belonging to 721 classes.

我们将每个图像调整为64×64. 变换ToTensor 会将像素值投影到[0,1],而我们的生成器将使用 tanh 函数获取输出 [−1,1]. 因此我们用0.5意味着和0.5标准偏差以匹配值范围。

batch_size = 256
transformer = torchvision.transforms.Compose([
  torchvision.transforms.Resize((64, 64)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize(0.5, 0.5)
])
pokemon.transform = transformer
data_iter = torch.utils.data.DataLoader(
  pokemon, batch_size=batch_size,
  shuffle=True, num_workers=d2l.get_dataloader_workers())
batch_size = 256
transformer = gluon.data.vision.transforms.Compose([
  gluon.data.vision.transforms.Resize(64),
  gluon.data.vision.transforms.ToTensor(),
  gluon.data.vision.transforms.Normalize(0.5, 0.5)
])
data_iter = gluon.data.DataLoader(
  pokemon.transform_first(transformer), batch_size=batch_size,
  shuffle=True, num_workers=d2l.get_dataloader_workers())
def transform_func(X):
  X = X / 255.
  X = (X - 0.5) / (0.5)
  return X

# For TF>=2.4 use `num_parallel_calls = tf.data.AUTOTUNE`
data_iter = pokemon.map(lambda x, y: (transform_func(x), y),
            num_parallel_calls=tf.data.experimental.AUTOTUNE)
data_iter = data_iter.cache().shuffle(buffer_size=1000).prefetch(
  buffer_size=tf.data.experimental.AUTOTUNE)
WARNING:tensorflow:From /home/d2l-worker/miniconda3/envs/d2l-en-release-1/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

让我们想象一下前 20 张图像。

warnings.filterwarnings('ignore')
d2l.set_figsize((4, 4))
for X, y in data_iter:
  imgs = X[:20,:,:,:].permute(0, 2, 3, 1)/2+0.5
  d2l.show_images(imgs, num_rows=4, num_cols=5)
  break
 
d2l.set_figsize((4, 4))
for X, y in data_iter:
  imgs = X[:20,:,:,:].transpose(0, 2, 3, 1)/2+0.5
  d2l.show_images(imgs, num_rows=4, num_cols=5)
  break
 
d2l.set_figsize(figsize=(4, 4))
for X, y in data_iter.take(1):
  imgs = X[:20, :, :, :] / 2 + 0.5
  d2l.show_images(imgs, num_rows=4, num_cols=5)
 

20.2.2。发电机

生成器需要映射噪声变量 z∈Rd, 长度-d向量,到具有宽度和高度的 RGB 图像64×64. 14.11 节中我们介绍了全卷积网络,它使用转置卷积层(参考 14.10 节


声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !