2026/3/4 7:16:27
网站建设
项目流程
做原型的素材网站,做电子画册的网站,天津建设银行网站,dede多个网站怎么做#x1f368; 本文为#x1f517;365天深度学习训练营 中的学习记录博客 #x1f356; 原作者#xff1a;K同学啊 GAN就是让两个AI互相斗智#xff1a;一个想造假#xff0c;一个想识假。通过不断斗#xff0c;造假… 本文为365天深度学习训练营 中的学习记录博客 原作者K同学啊GAN就是让两个AI互相斗智一个想造假一个想识假。通过不断斗造假的越来越像真的识假的也越来越厉害最后达到一个平衡点造假的几乎能以假乱真。 这就像一个学生和老师的关系学生努力学习老师不断出难题学生通过老师的反馈不断进步老师也通过学生的进步而更了解教学难点。最终学生能解答几乎所有问题老师能出最难的题。 这就是GAN——两个AI互相斗出来的神奇结果# 代码功能说明 # 这是一个生成对抗网络GAN的完整实现用来学习生成手写数字图片MNIST数据集# 生成器Generator把随机噪声变成手写数字图片# 判别器Discriminator判断图片是真实的还是生成的# 两者互相“斗法”直到生成器能造出以假乱真的图片# 1. 准备工作 importargparse# 用来接收命令行参数比如训练轮数importos# 操作系统命令创建文件夹importnumpyasnp# 科学计算库处理数字importtorchvision.transformsastransforms# 图像处理工具fromtorchvision.utilsimportsave_image# 保存图片fromtorch.utils.dataimportDataLoader# 加载数据集fromtorchvisionimportdatasets# MNIST数据集fromtorch.autogradimportVariable# 为Tensor添加梯度计算功能importtorch.nnasnn# 神经网络核心模块importtorch# PyTorch深度学习框架importssl# 解决HTTPS证书问题防止下载数据集时出错ssl._create_default_https_contextssl._create_unverified_context# 关闭SSL证书验证# 创建三个文件夹# - images/保存训练中生成的图片看效果用# - save/保存最终训练好的模型以后能直接用# - datasets/mnist/存放下载的MNIST手写数字数据集os.makedirs(./images/,exist_okTrue)os.makedirs(./save/,exist_okTrue)os.makedirs(./datasets/mnist,exist_okTrue)# 2. 设置训练参数 n_epochs50# 训练50轮每轮遍历所有数据batch_size64# 每次训练用64张图片lr0.0002# 学习率控制模型更新速度b10.5# Adam优化器参数控制梯度衰减b20.999# Adam优化器参数n_cpu2# 使用2个CPU核心加速latent_dim100# 随机噪声的维度100个随机数img_size28# 图片尺寸28x28像素channels1# 图片通道黑白图1通道sample_interval500# 每训练500次保存一次生成的图片# 图片形状(1, 28, 28) → 总像素数784img_shape(channels,img_size,img_size)img_areanp.prod(img_shape)# 1*28*28784# 检查是否能用GPU速度更快cudaTrueiftorch.cuda.is_available()elseFalseprint(是否使用GPU:,cuda)# 打印结果True/False# 3. 下载并处理数据 # 从MNIST下载手写数字数据集28x28黑白图mnistdatasets.MNIST(root./datasets/,# 保存位置trainTrue,# 下载训练集downloadTrue,# 自动下载transformtransforms.Compose([transforms.Resize(img_size),# 缩放到28x28transforms.ToTensor(),# 转成PyTorch张量transforms.Normalize([0.5],[0.5])# 归一化到[-1,1]]),)# 创建数据加载器每次给64张图片dataloaderDataLoader(mnist,batch_sizebatch_size,shuffleTrue,# 打乱顺序防止模型记住顺序)# 4. 构建判别器判断真假 classDiscriminator(nn.Module):# 判别器类def__init__(self):super(Discriminator,self).__init__()# 一个简单的全连接神经网络# 输入784个像素 → 512个神经元 → 256个神经元 → 1个输出0~1概率self.modelnn.Sequential(nn.Linear(img_area,512),# 784→512nn.LeakyReLU(0.2,inplaceTrue),# 激活函数解决梯度消失nn.Linear(512,256),# 512→256nn.LeakyReLU(0.2,inplaceTrue),nn.Linear(256,1),# 256→1输出概率nn.Sigmoid()# 0~1概率1真图0假图)defforward(self,img):# 把图片拉成一维向量64,784img_flatimg.view(img.size(0),-1)# 通过网络得到真假概率validityself.model(img_flat)returnvalidity# 5. 构建生成器生成假图 classGenerator(nn.Module):# 生成器类def__init__(self):super(Generator,self).__init__()# 辅助函数创建一个带正则化的神经网络层defblock(in_feat,out_feat,normalizeTrue):layers[nn.Linear(in_feat,out_feat)]# 线性变换ifnormalize:layers.append(nn.BatchNorm1d(out_feat,0.8))# 正则化加速训练layers.append(nn.LeakyReLU(0.2,inplaceTrue))# 激活函数returnlayers# 生成器网络结构# 100维噪声 → 128 → 256 → 512 → 1024 → 784输出# 最后用Tanh让输出在[-1,1]之间符合归一化后的数据范围self.modelnn.Sequential(*block(latent_dim,128,normalizeFalse),# 100→128不用正则化*block(128,256),# 128→256*block(256,512),# 256→512*block(512,1024),# 512→1024nn.Linear(1024,img_area),# 1024→784nn.Tanh()# 输出归一化到[-1,1])defforward(self,z):# z是100维随机噪声64个样本imgsself.model(z)# 生成图片784维向量# 重塑成(64,1,28,28)PyTorch需要的图片格式imgsimgs.view(imgs.size(0),*img_shape)returnimgs# 6. 初始化模型 generatorGenerator()# 创建生成器discriminatorDiscriminator()# 创建判别器# 损失函数衡量真假判断的准确性二分类交叉熵criteriontorch.nn.BCELoss()# 二分类交叉熵# 优化器Adam优化器比普通梯度下降更快更好optimizer_Gtorch.optim.Adam(generator.parameters(),lrlr,betas(b1,b2))optimizer_Dtorch.optim.Adam(discriminator.parameters(),lrlr,betas(b1,b2))# 如果有GPU把模型搬到GPU上加速ifcuda:generatorgenerator.cuda()discriminatordiscriminator.cuda()criterioncriterion.cuda()# 7. 训练循环 forepochinrange(n_epochs):# 训练50轮fori,(imgs,_)inenumerate(dataloader):# 遍历数据集# 步骤1训练判别器 # 把图片拉成一维64,784imgsimgs.view(imgs.size(0),-1)# 转成可计算张量GPU上real_imgVariable(imgs).cuda()# 真实图片的标签全1表示“这是真图”real_labelVariable(torch.ones(imgs.size(0),1)).cuda()# 假图片的标签全0表示“这是假图”fake_labelVariable(torch.zeros(imgs.size(0),1)).cuda()# 判别器的损失 真图被判断为真 假图被判断为假# 真图输入判别器 → 得到概率 → 计算和标签的差距real_outdiscriminator(real_img)loss_real_Dcriterion(real_out,real_label)# 假图生成器生成假图 → 判别器判断 → 计算差距zVariable(torch.randn(imgs.size(0),latent_dim)).cuda()# 生成随机噪声fake_imggenerator(z).detach()# 生成假图detach不更新生成器参数fake_outdiscriminator(fake_img)loss_fake_Dcriterion(fake_out,fake_label)loss_Dloss_real_Dloss_fake_D# 总损失# 优化判别器反向传播 更新参数optimizer_D.zero_grad()loss_D.backward()optimizer_D.step()# 步骤2训练生成器 # 生成器的目标让判别器把假图判断成真图zVariable(torch.randn(imgs.size(0),latent_dim)).cuda()# 新随机噪声fake_imggenerator(z)# 生成假图outputdiscriminator(fake_img)# 判别器判断假图# 生成器损失希望判别器输出1真图loss_Gcriterion(output,real_label)# 优化生成器反向传播 更新参数optimizer_G.zero_grad()loss_G.backward()optimizer_G.step()# 打印训练进度 if(i1)%3000:# 每300次打印一次print(f[Epoch{epoch}/{n_epochs}] [Batch{i}/{len(dataloader)}] f[D loss:{loss_D.item():.4f}] [G loss:{loss_G.item():.4f}] f[D real:{real_out.data.mean():.4f}] [D fake:{fake_out.data.mean():.4f}])# 保存生成的图片每500次保存一次 batches_doneepoch*len(dataloader)iifbatches_done%sample_interval0:# 保存前25张生成的图片5x5网格save_image(fake_img.data[:25],f./images/{batches_done}.png,nrow5,normalizeTrue)# 8. 保存最终模型 torch.save(generator.state_dict(),./save/generator.pth)# 保存生成器torch.save(discriminator.state_dict(),./save/discriminator.pth)# 保存判别器[Epoch0/50][Batch299/938][D loss:1.108700][G loss:1.494937][D real:0.765423][D fake:0.563390][Epoch0/50][Batch599/938][D loss:0.981047][G loss:2.200819][D real:0.859328][D fake:0.555203][Epoch0/50][Batch899/938][D loss:1.012156][G loss:1.935689][D real:0.728062][D fake:0.476248][Epoch1/50][Batch299/938][D loss:1.188978][G loss:0.676110][D real:0.426300][D fake:0.200765][Epoch1/50][Batch599/938][D loss:1.007571][G loss:1.044460][D real:0.562748][D fake:0.284159][Epoch1/50][Batch899/938][D loss:1.071741][G loss:1.711364][D real:0.720821][D fake:0.483612][Epoch2/50][Batch299/938][D loss:0.910406][G loss:2.151794][D real:0.764064][D fake:0.448280][Epoch2/50][Batch599/938][D loss:0.800963][G loss:1.313761][D real:0.613358][D fake:0.188154][Epoch2/50][Batch899/938][D loss:1.093633][G loss:1.053562][D real:0.531550][D fake:0.230020][Epoch3/50][Batch299/938][D loss:0.963498][G loss:2.506877][D real:0.811666][D fake:0.497298][Epoch3/50][Batch599/938][D loss:1.083450][G loss:0.882004][D real:0.465563][D fake:0.117864][Epoch3/50][Batch899/938][D loss:0.973209][G loss:2.698256][D real:0.809422][D fake:0.502016][Epoch4/50][Batch299/938][D loss:0.817019][G loss:1.351617][D real:0.666476][D fake:0.273635].......