how to save all the generated image in a folder in pytorch

Question:

I am trying to use data augmentation with pytorch. I want to save all the generated images in a folder (target_dir) with different numbering based on the batch index.

Here is my code. I am using epoch=100 and batch_size=128.

import os



for batch_idx in range(BATCH_SIZE):
    
torchvision.utils.save_image(img_grid_fake, f"C:/UserspythonProjectgenerated_image/Fake_image%{batch_idx}d.png", global_step=step)

but i am only getting last 128 generated images, previous generated image are get deleted when next epoch run.

Asked By: sid1994s

||

Answers:

You need to save the images with f"Fake_image-{epoch}-{batch-idx}.png" so that both epoch and batch_idx are used in naming the files.

import os
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

target_dir = r"C:/Users/PycharmProjects/pythonProject/generated/generated_image/"

EPOCHS = 10
BATCH_SIZE = 64
GRID_SIZE = 9 # 9 images in each grid
NUM_ROWS = 3 # sqrt(GRID_SIZE)

# if you want all the images in a batch to make the image-grid, 
# set GRID_SIZE = BATCH_SIZE

train_dataset = YourFakeImageDataset()
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                              shuffle=True, transform=ToTensor())

for epoch in range(EPOCHS):
    for batch_idx, (X, y) in enumerate(train_dataloader):
        # assume X is the fake-image returned by the dataloader
        # and y is some target value for the X, also returned by the dataloader

        # ... do something with your images here
        # B, C, H, W = X.shape
        img_grid_fake = torchvision.utils.make_grid(X[:GRID_SIZE, ...], nrow=NUM_ROWS)
        filepath = os.path.join(target_dir, f"Fake_image-{epoch}-{batch_idx}.png")
        torchvision.utils.save_image(img_grid_fake, filepath)

NOTE: I cannot answer you properly, as your question does not specify a lot of details clearly (some of them are asked by others in the comments).

If you are making a fake-image-grid, how are you doing that? With torchvision.utils.make_grid()?

References

Answered By: CypherX