返回

PyTorch MNIST生成噪声?告别踩坑用GAN生成数字

Ai

生成模型只输出噪声?MNIST数字生成踩坑指南

写了个生成网络,想让它学着 MNIST 数据集画数字,结果不管怎么调,输出的都是一堆乱七八糟的噪声,根本看不出数字的形状。试了换损失函数、换优化器,都没啥用。感觉是对这类网络的工作原理理解不到位,但具体错在哪儿了呢?

看看下面这段出问题的代码,是不是你也觉得眼熟?

# (代码片段如问题中所示,为简洁起见此处省略,
# 主要包含Net定义、Data/DataSet函数、训练循环)
import numpy as np
import torch.utils
import torch.utils.data
import torch.utils.data.dataloader
from tqdm import tqdm
import torch
import torch.nn as nn #OOP
import torch.nn.functional as F #Functions (not oop based)
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib import style
import torchvision
from torchvision import transforms, datasets

device = torch.device("cuda:0") #GPU

class Net(nn.Module):
    def __init__(self):
        super().__init__() #Super = nn.Module, inherit the methods and modules from __init__
        self.fc1 = nn.Linear(10, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, 28*28)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        # 问题点:使用了 Softmax
        return (F.softmax(x, dim=1))

net = Net().to(device)

#Data
# 1 hot vector, index represents the number i want it to generate
# Generate a batch of 10 images
# Pass these images through the Number_Recogniser
# Compare the outputs using mean square loss
# With the loss update the gradients
def Data():
    # 生成随机数字标签 (0-9)
    return(np.random.randint(0, 10, size=64))

#Data sets
def DataSet(Data):
    One_Hot = []
    for data in Data:
        Ones = [0] * 10
        Ones[data] = 1
        # 问题点:每次只处理一个样本,效率低
        Ones_tensor = torch.Tensor(Ones).view(1, 10) # 改为 (1, 10) 形状
        Ones_tensor = Ones_tensor.to(device) # 确保移到 GPU
        One_Hot.append(Ones_tensor)
    return(One_Hot) # 返回 Tensor 列表

EPOCHS = 3000 # 300 for 30000 images

# 加载预训练的数字识别器
try:
    Number_Recogniser = torch.jit.load("model.pth")
    Number_Recogniser.eval()
    Number_Recogniser.to(device) # 确保识别器也在 GPU 上
except FileNotFoundError:
    print("错误:找不到 'model.pth'。请确保预训练模型存在。")
    exit() # 如果模型不存在,无法继续

# 问题点:SGD 可能不是最优选择
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)

# 问题点:损失函数和训练方式可能不适合生成任务
loss_fn = nn.CrossEntropyLoss()

print("开始训练...")
for epoch in tqdm(range(EPOCHS)):
    rand_labels = Data()
    dataset = DataSet(rand_labels) # 获取 one-hot 向量列表

    batch_loss = 0.0
    num_items = 0

    # 问题点:一次处理一个样本,效率低下且不稳定
    for data in dataset: # data 现在是 shape 为 [1, 10] 的 Tensor
        y = data # y shape is [1, 10]

        # 生成器网络
        net.zero_grad() # 清除旧梯度
        X = net(y) # 生成图像数据, X shape [1, 784]
        X_img = X.view(-1, 1, 28, 28) # 调整为识别器期望的形状 [batch, channels, H, W]

        # 使用识别器评估生成图像
        # 确保 X_img 和 Number_Recogniser 在同一设备
        # 注意:识别器可能期望不同范围的输入(例如 [0, 1] 或 [-1, 1])
        # 而 Net 的输出经过 softmax 后是 [0, 1] 分布,但可能与识别器训练时的分布不匹配
        try:
            with torch.no_grad(): # 通常评估识别器时不需要计算梯度
                pred = Number_Recogniser(X_img) # pred shape [1, 10]
        except Exception as e:
             print(f"\n识别器前向传播出错: {e}")
             print(f"输入形状: {X_img.shape}, 输入设备: {X_img.device}")
             # 可能需要检查 Number_Recogniser 的期望输入
             continue # 跳过这个样本

        # 计算损失
        # CrossEntropyLoss 期望的 pred 是 [batch_size, num_classes]
        # 期望的 target 是 [batch_size] 且包含类别索引 (0-9),或者 [batch_size, num_classes] 的 one-hot (但需要转换)
        # 当前 y 是 one-hot [1, 10], 需要转换为类别索引
        target_label = torch.argmax(y, dim=1) # 从 one-hot 获取类别索引, shape [1]

        loss = loss_fn(pred, target_label)

        # 反向传播和优化仅针对生成器 net
        loss.backward()
        optimizer.step()
        # 移到循环开始处 net.zero_grad() 更常见

        batch_loss += loss.item()
        num_items += 1

    if num_items > 0 and epoch % 100 == 0: # 每 100 个 epoch 打印一次损失
        print(f"\nEpoch {epoch}, Average Loss: {batch_loss / num_items:.4f}")


# 训练后生成并显示一个样本
print("训练完成,生成图像...")
net.eval() # 设置为评估模式
with torch.no_grad():
    # 创建一个测试用的 one-hot 向量,比如生成数字 '7'
    test_label = torch.zeros(1, 10, device=device)
    test_label[0, 7] = 1

    generated_output = net(test_label)
    image = generated_output.view(1, 28, 28) # Reshape to [batch, H, W]
    image = image.squeeze().cpu().numpy() # 移除 batch 维, 移到 CPU, 转为 numpy

# 可视化
plt.imshow(image, cmap="gray")
plt.title("Generated Image Sample")
plt.show()

别急,这个问题挺常见的。咱们来捋一捋,看看问题到底出在哪儿,以及怎么把它“掰”回来。

为啥我的模型“画”不好?原因分析

生成图像,特别是像 MNIST 这种有结构的数据,和做分类任务不太一样。直接用简单的全连接网络(就是代码里的 Net)从一个类别标签(one-hot 向量)硬生生“算出” 784 个像素点,难度相当大。这相当于让网络凭空想象出数字的笔画结构、空间关系,只给它一个类别提示。这种结构的网络天生缺乏捕捉图像空间信息的能力。

几个关键的可能“病灶”:

  1. 网络结构不给力 : 全连接层处理图像这种二维结构数据时,丢失了像素间的空间关联性。它把 28x28 的图像看成了一个扁平的 784 维向量,哪个像素挨着哪个像素,它并不关心。这使得学习生成连贯的形状变得异常困难。
  2. 输出层激活函数用错了 : 代码里最后一层用了 softmaxSoftmax 通常用在分类任务的输出层,它会把所有输出值归一化,并且让它们的和等于 1。这对于表示概率分布很合适,但对于图像像素值来说,完全不对路!图像的像素值通常是 0 到 255 的整数,或者归一化到 [0, 1] 或 [-1, 1] 的浮点数。Softmax 会强制让所有 784 个像素值的和为 1,结果就是每个像素值都非常小,看起来就像一片噪声或者灰蒙蒙的啥也不是。
  3. 训练策略“非主流” : 使用一个预训练好的、固定的 Number_Recogniser 来计算损失,想法是让生成的图片能被识别器认出来。这种方法有时被称为“感知损失”或特征匹配。理论上可行,但实践中可能很困难。首先,识别器的梯度可能不够“友好”,难以指导生成器学习。其次,识别器关注的是“能不能认出来”,不一定能保证生成图像的“好看”或“像真的一样”。它可能满足于生成一些刚好能触发分类神经元的奇怪图案。
  4. 数据处理效率低 : 代码里的 DataSet 函数一次只处理一个标签,把它变成 one-hot 向量。然后在训练循环里,也是一个一个样本地喂给网络、计算损失、更新梯度。这没法利用 GPU 的并行计算优势,训练速度慢得像蜗牛。更重要的是,每次只用一个样本计算梯度,更新方向会非常不稳定,模型很难收敛到好的状态,容易在原地“抖动”,表现出来就是输出没什么改进。
  5. 优化器可能拖后腿SGD (随机梯度下降) 是个基础的优化器。虽然也能用,但在处理复杂任务,尤其是像生成模型这种可能有复杂损失曲面的情况时,它的收敛速度可能较慢,也更容易卡在局部最优解。

对症下药:搞定噪声,生成图像

知道了问题在哪,咱们就能“对症下药”了。下面提供几个方向,你可以根据情况尝试或者组合使用。

方案一:换个更合适的模型架构(试试 GAN?)

生成图像这活儿,现在最常用的“工具”之一是 生成对抗网络 (Generative Adversarial Networks, GANs) 。简单说,GANs 有两个主要部分:

  • 生成器 (Generator) : 它的任务就是“造假”,接收一个随机噪声向量(通常是从正态分布或均匀分布采样),或者像你这里一样,可以加入类别信息作为条件,然后输出一张图像。目标是让输出的图像越来越逼真,能骗过判别器。你的 Net 可以改造一下作为生成器。
  • 判别器 (Discriminator) : 它的任务是“打假”,接收一张图像(可能是真实图像,也可能是生成器造的假图像),然后判断这张图像是真实的还是假的。它本身通常是个标准的图像分类网络。

它俩互相“博弈”:生成器努力提高造假水平,判别器努力提高打假能力。最终达到一个平衡时,生成器就能造出非常逼真的图像了。

怎么改?

  1. 改造你的 Net 作为生成器 (Generator) :

    • 输入 :通常 GAN 的生成器输入是一个随机噪声向量 z(比如 100 维)。如果你想生成特定数字,可以把 one-hot 标签 yz 连接起来,或者用更高级的条件 GAN 技术(比如 Embedding 层、Conditional BatchNorm)。单纯只用 10 维的 one-hot 标签作为输入,信息量可能太少了,网络很难“无中生有”。
    • 结构 :与其用全连接层直接怼出 784 个像素,不如用 转置卷积 (Transposed Convolution / Deconvolution) 。这玩意儿可以把低维特征图逐步放大,同时学习空间结构,更适合生成图像。
    import torch.nn as nn
    import torch.nn.functional as F
    
    class Generator(nn.Module):
        def __init__(self, noise_dim=100, label_dim=10, img_channels=1):
            super().__init__()
            self.label_emb = nn.Embedding(label_dim, label_dim) # 把标签转成嵌入向量
            self.init_size = 28 // 4 # 初始尺寸,取决于后续上采样层数
    
            # 将噪声和标签嵌入向量组合后的维度
            combined_dim = noise_dim + label_dim
    
            self.l1 = nn.Sequential(
                nn.Linear(combined_dim, 128 * self.init_size ** 2)
            )
    
            self.conv_blocks = nn.Sequential(
                nn.BatchNorm2d(128),
                nn.Upsample(scale_factor=2), # 放大一倍
                nn.Conv2d(128, 128, 3, stride=1, padding=1),
                nn.BatchNorm2d(128, 0.8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Upsample(scale_factor=2), # 再放大一倍
                nn.Conv2d(128, 64, 3, stride=1, padding=1),
                nn.BatchNorm2d(64, 0.8),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(64, img_channels, 3, stride=1, padding=1),
                # 输出层激活函数用 tanh,对应像素值范围 [-1, 1]
                # 如果想用 [0, 1],这里换成 nn.Sigmoid()
                nn.Tanh()
            )
    
        def forward(self, noise, labels):
            # noise: [batch_size, noise_dim]
            # labels: [batch_size] (数字 0-9)
            label_embedding = self.label_emb(labels) # [batch_size, label_dim]
            gen_input = torch.cat((label_embedding, noise), -1) # [batch_size, combined_dim]
    
            out = self.l1(gen_input)
            out = out.view(out.shape[0], 128, self.init_size, self.init_size)
            img = self.conv_blocks(out)
            return img # 输出 [batch_size, img_channels, 28, 28]
    
  2. 创建一个判别器 (Discriminator) :

    • 它就是一个普通的卷积神经网络,用来做二分类(判断图像真假)。输入是 28x28 的图像,输出是一个概率值。
    class Discriminator(nn.Module):
        def __init__(self, img_channels=1):
            super().__init__()
    
            def discriminator_block(in_filters, out_filters, bn=True):
                block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                         nn.LeakyReLU(0.2, inplace=True),
                         nn.Dropout2d(0.25)]
                if bn:
                    block.append(nn.BatchNorm2d(out_filters, 0.8))
                return block
    
            self.model = nn.Sequential(
                *discriminator_block(img_channels, 16, bn=False),
                *discriminator_block(16, 32),
                *discriminator_block(32, 64),
                *discriminator_block(64, 128),
            )
    
            # 计算经过卷积后的特征图尺寸
            ds_size = 28 // (2**4) # 28x28经过4次 stride=2 的卷积
            self.adv_layer = nn.Sequential(
                nn.Linear(128 * ds_size ** 2, 1),
                # 输出一个 logit,后续可以用 Sigmoid 转为概率
                # 但通常配合 BCEWithLogitsLoss,这里就不加 Sigmoid
            )
    
        def forward(self, img):
            # img: [batch_size, img_channels, 28, 28]
            out = self.model(img)
            out = out.view(out.shape[0], -1) # 展平
            validity = self.adv_layer(out) # [batch_size, 1]
            return validity
    

进阶使用技巧 :条件 GAN (Conditional GAN, cGAN) 有很多变种,比如可以将标签信息通过 Embedding 后与特征图在通道维度拼接,或者使用 Conditional Batch Normalization 等方式,更有效地利用标签信息指导生成过程。

方案二:修正输出层激活函数

这是个必须改的点!把 forward 函数最后那行的 F.softmax(x, dim=1) 换掉。

  • 如果你希望生成的图像像素值在 [0, 1] 范围内(对应归一化后的 MNIST 数据),用 torch.sigmoid()

    # 在 Net 的 forward 方法里
    # ... 前面的层 ...
    x = self.fc4(x)
    # return F.softmax(x, dim=1) # 注释掉或者删掉这行
    return torch.sigmoid(x) # 输出像素值范围 [0, 1]
    

    或者作为模块使用 nn.Sigmoid()

  • 如果你希望生成的图像像素值在 [-1, 1] 范围内(另一种常见的归一化方式),用 torch.tanh()

    # 在 Net 的 forward 方法里
    # ... 前面的层 ...
    x = self.fc4(x)
    # return F.softmax(x, dim=1) # 注释掉或者删掉这行
    return torch.tanh(x) # 输出像素值范围 [-1, 1]
    

    或者作为模块使用 nn.Tanh()

重要提醒 :选择哪种激活函数,要和你后续使用的 判别器 (无论是预训练的识别器还是 GAN 的判别器) 所期望的输入范围一致 。如果判别器是在 [0, 1] 范围的 MNIST 图像上训练的,那生成器最好也输出 [0, 1]。同理对应 [-1, 1]。训练 GAN 时,真实图像也要做同样的归一化处理。

方案三:优化训练流程与损失函数

如果你决定尝试 GAN,训练流程也要跟着变。不再是用固定的识别器算损失,而是要交替训练生成器和判别器。

  1. 损失函数 :常用的是 二元交叉熵损失 (Binary Cross Entropy Loss) 。PyTorch 提供了 nn.BCEWithLogitsLoss,这个函数更稳定,因为它结合了 Sigmoid 和 BCE Loss,能处理原始的 logit 输出。

    adversarial_loss = nn.BCEWithLogitsLoss()
    
  2. 训练循环 :大致思路是这样(这里是伪代码,具体实现需要 PyTorch Tensor 操作):

    # 假设 optimizer_G 是生成器的优化器, optimizer_D 是判别器的优化器
    # generator 和 discriminator 是实例化的 Generator 和 Discriminator 类
    
    for epoch in range(EPOCHS):
        for i, (real_imgs, labels) in enumerate(dataloader): # 从 DataLoader 获取真实图像和标签
    
            batch_size = real_imgs.size(0)
            real_imgs = real_imgs.to(device) # 真实图像
            labels = labels.to(device) # 对应的标签
    
            # 创建真假标签 (用于判别器损失)
            valid = torch.ones(batch_size, 1, device=device, dtype=torch.float)
            fake = torch.zeros(batch_size, 1, device=device, dtype=torch.float)
    
            # --- 训练判别器 ---
            optimizer_D.zero_grad()
    
            # 判别器对真实图像的损失
            real_pred = discriminator(real_imgs)
            d_real_loss = adversarial_loss(real_pred, valid)
    
            # 生成假图像
            z = torch.randn(batch_size, noise_dim, device=device) # 随机噪声
            gen_labels = torch.randint(0, 10, (batch_size,), device=device) # 随机生成标签
            gen_imgs = generator(z, gen_labels) # 生成假图像
    
            # 判别器对假图像的损失 (用 .detach() 阻止梯度流向生成器)
            fake_pred = discriminator(gen_imgs.detach())
            d_fake_loss = adversarial_loss(fake_pred, fake)
    
            # 总判别器损失并反向传播
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()
    
            # --- 训练生成器 ---
            optimizer_G.zero_grad()
    
            # 生成假图像 (重新生成或使用刚才的 gen_imgs)
            # 如果想用条件,这里的 gen_labels 需要和你希望生成器生成的数字对应
            # z = torch.randn(batch_size, noise_dim, device=device)
            # gen_labels = torch.randint(0, 10, (batch_size,), device=device)
            # gen_imgs = generator(z, gen_labels)
    
            # 生成器希望判别器将其误判为真
            fake_pred_for_G = discriminator(gen_imgs)
            g_loss = adversarial_loss(fake_pred_for_G, valid) # 注意这里目标是 'valid'
    
            # 生成器损失反向传播
            g_loss.backward()
            optimizer_G.step()
    
        # --- 打印损失,保存模型,生成样本图像等 ---
    

安全建议 :虽然 MNIST 生成相对无害,但对于更复杂的生成模型(比如人脸、文本),要意识到可能的滥用风险。生成的内容易于被误认为是真实的,需要负责任地使用。

方案四:高效数据处理:拥抱批处理

别再一个样本一个样本地处理了!用 PyTorch 的 DataLoader 来组织数据,实现批处理 (Batch Processing)。

  1. 准备数据集 :你需要一个包含真实 MNIST 图像和对应标签的数据集。torchvision.datasets.MNIST 可以直接用。

    from torchvision import transforms, datasets
    from torch.utils.data import DataLoader
    
    # 定义数据预处理:转换为 Tensor,并归一化到 [-1, 1] 或 [0, 1]
    # 这里以 [-1, 1] 为例,对应上面生成器用 tanh
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) # Mean=0.5, Std=0.5 -> [-1, 1]
        # 如果用 sigmoid,归一化到 [0, 1] 保持默认即可,或者用 transforms.Normalize((0.1307,), (0.3081,)) MNIST 标准值
    ])
    
    # 下载/加载 MNIST 数据集
    mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    
    # 创建 DataLoader
    batch_size = 64 # 设置你想要的批大小
    dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
  2. 修改训练循环 :你的训练循环就需要改成从 dataloader 里迭代获取 (real_imgs, labels) 批次数据,就像上面 GAN 训练伪代码里那样。

好处

  • 快! GPU 并行计算能力被充分利用。
  • 稳! 基于一批样本计算的梯度更加稳定,有助于模型收敛。

方案五:选用更强大的优化器

torch.optim.SGD 换成 torch.optim.Adam 或者 torch.optim.AdamW。它们通常能更快、更好地找到解决方案,特别是在 GAN 训练中。

# Adam 优化器,lr 是学习率,betas 是 Adam 的超参数,通常用默认值
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

学习率(lr) 很关键,需要调整。0.0002 是 GAN 训练中一个常见的起始值,但不一定是最优的。betas=(0.5, 0.999) 也是 GAN 论文中常用的设置。

方案六(进阶技巧):改进输入 - 随机噪声 + 条件

正如方案一提到的,只给一个 10 维的 one-hot 标签,让网络生成图像,信息量太少了。标准的 GAN 通常从一个更高维度的 随机噪声向量 z 开始(比如 100 维,从标准正态分布采样 torch.randn)。

如果你想控制生成的数字类别(条件生成),可以这样做:

  1. 标签嵌入 (Label Embedding) :把数字标签 (0-9) 通过一个 nn.Embedding 层转换成向量。
  2. 拼接 (Concatenation) :把噪声向量 z 和标签嵌入向量在某个维度上拼接起来,形成最终的输入向量,再喂给生成器的后续层。

看看上面方案一 Generator 示例代码里的 forward 方法,就是用了这种方式。这给了生成器更多可以“发挥”的基础,同时也受到了类别标签的引导。

一些额外的建议

  • 监控损失 :训练 GAN 时,要同时观察生成器损失 (G Loss) 和判别器损失 (D Loss)。理想情况下,它们会达到某种平衡。如果 D Loss 很快降到 0,说明判别器太强了,生成器学不到东西;如果 G Loss 降得很快但 D Loss 很高,说明生成器随便生成点啥都能骗过判别器。这俩的平衡是个技术活,可能需要调整学习率、网络结构等。
  • 可视化 :训练过程中,定时从固定的噪声向量和标签生成一些样本图像看看。光看损失数字不够直观,亲眼看到图像从噪声逐渐变得清晰(或者没有变化),能帮你判断训练进展。
  • 超参数调整 :学习率、批大小、噪声维度、网络层数和宽度...这些都可能影响结果。没有万能药,需要耐心尝试和调整。
  • 小心模式崩溃 (Mode Collapse) :有时候 GAN 会陷入一个状态,无论输入什么噪声,生成器都只输出少数几种甚至一种图像。这是 GAN 训练中的常见难题,有多种技术尝试缓解它(比如 WGAN-GP, mini-batch discrimination 等)。

把这些思路和代码片段组合起来,应该能帮你解决生成模型只输出噪声的问题,让它开始“画”出像样的 MNIST 数字来。动手试试看吧!