Generative adversarial networks (GAN) are widely used in machine learning to generate text to images. This network structure consists of a generator that converts random noise into images, and a discriminator that works to distinguish between real images and images generated by the generator. Through continuous adversarial training, the generator is able to gradually generate realistic images that are difficult to distinguish from the discriminator. This technology has broad application prospects in image generation, image enhancement and other fields.
A simple example is using GAN to generate images of handwritten digits. The following is sample code in PyTorch:
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torchvision.utils import save_image from torch.autograd import Variable # 定义生成器网络 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.fc = nn.Linear(100, 256) self.main = nn.Sequential( nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2), nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1), nn.Tanh() ) def forward(self, x): x = self.fc(x) x = x.view(-1, 256, 1, 1) x = self.main(x) return x # 定义判别器网络 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Conv2d(1, 64, 4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, 4, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 1, 4, stride=1, padding=0), nn.Sigmoid() ) def forward(self, x): x = self.main(x) return x.view(-1, 1) # 定义训练函数 def train(generator, discriminator, dataloader, optimizer_G, optimizer_D, device): criterion = nn.BCELoss() real_label = 1 fake_label = 0 for epoch in range(200): for i, (data, _) in enumerate(dataloader): # 训练判别器 discriminator.zero_grad() real_data = data.to(device) batch_size = real_data.size(0) label = torch.full((batch_size,), real_label, device=device) output = discriminator(real_data).view(-1) errD_real = criterion(output, label) errD_real.backward() D_x = output.mean().item() noise = torch.randn(batch_size, 100, device=device) fake_data = generator(noise) label.fill_(fake_label) output = discriminator(fake_data.detach()).view(-1) errD_fake = criterion(output, label) errD_fake.backward() D_G_z1 = output.mean().item() errD = errD_real + errD_fake optimizer_D.step() # 训练生成器 generator.zero_grad() label.fill_(real_label) output = discriminator(fake_data).view(-1) errG = criterion(output, label) errG.backward() D_G_z2 = output.mean().item() optimizer_G.step() if i % 100 == 0: print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch+1, 200, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) # 保存生成的图像 fake = generator(fixed_noise) save_image(fake.detach(), 'generated_images_%03d.png' % epoch, normalize=True) # 加载MNIST数据集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = datasets.MNIST(root='./数据集', train=True, transform=transform, download=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) # 定义设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 初始化生成器和判别器 generator = Generator().to(device) discriminator = Discriminator().to(device) # 定义优化器 optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) # 定义固定噪声用于保存生成的图像 fixed_noise = torch.randn(64, 100, device=device) # 开始训练 train(generator, discriminator, dataloader, optimizer_G, optimizer_D, device)
Running this code will train a GAN model to generate images of handwritten digits and save the generated images.
The above is the detailed content of How to implement machine learning to convert text into images with sample code?. For more information, please follow other related articles on the PHP Chinese website!