Training

You can start with cloning my repo and downloading dataset.

git clone https://github.com/elephantmipt/cycle-gan-distillation.git
bash bin/download_dataset.sh monet2photo

After that you can install requirements

pip install -r requierements/requirements.txt

Preprocess data

python scripts/preprocess_dataset.py --path ./datasets/monet2photo

Then you can change some configurations in train.py and start training with

python train.py

or you can try to use config API (I'm not sure that it works properly as I remove some core callbacks to create multiple phase pipeline). So the config file can be found in configs/train.yml. You can change any properties you want. And then start training with

catalyst-dl run -C configs/train.yml --verbose

Closer look into Notebook API

from itertools import chain

import torch

from catalyst import dl

from src.callbacks.gan import (
    CycleGANLoss,
    GANLoss,
    IdenticalGANLoss,
    PrepareGeneratorPhase,
    GeneratorOptimizerCallback,
    PrepareDiscriminatorPhase,
    DiscriminatorLoss,
    DiscriminatorOptimizerCallback
)
from src.callbacks.visualization import LogImageCallback
from src.dataset import UnpairedDataset
from src.modules.generator import Generator
from src.modules.discriminator import NLayerDiscriminator, PixelDiscriminator
from src.runner import CycleGANRunner
from src.modules.loss import LSGanLoss

from torchvision import transforms as T
from PIL import Image

# defining dataset

train_ds = UnpairedDataset(
    "./datasets/monet2photo/trainA_preprocessed",
    "./datasets/monet2photo/trainB_preprocessed",
    transforms=T.Compose([
        T.Resize((300,300)),
        T.RandomCrop((256, 256)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        
    ])
    
)
train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size=1,
    shuffle=True
)

tr = T.Compose([
    T.Resize((256,256)),
    T.ToTensor(),
])

# loading images for validation
mipt_photo = tr(Image.open("./datasets/mipt.jpg"))
zinger_photo = tr(Image.open("./datasets/vk.jpg"))
# defining model arch
model = {
    "generator_ab": Generator(
         inp_channel_dim=3, 
         out_channel_dim=3,
         n_blocks=9
     ),
    "generator_ba": Generator(
         inp_channel_dim=3, 
         out_channel_dim=3,
         n_blocks=9
     ),
    "discriminator_a": PixelDiscriminator(input_channels_dim=3),
    "discriminator_b": PixelDiscriminator(input_channels_dim=3),
}

# initializing optimizers

optimizer = {
    "generator": torch.optim.Adam(
        chain(
            model["generator_ab"].parameters(),
            model["generator_ba"].parameters()
        ),
        lr=0.0002
    ),
    "discriminator": torch.optim.Adam(
        chain(
            model["discriminator_a"].parameters(),
            model["discriminator_b"].parameters()
        ),
        lr=0.0002
    )
}
callbacks = [
    ##############################################
    PrepareGeneratorPhase(),
    GANLoss(),
    CycleGANLoss(),
    IdenticalGANLoss(),
    GeneratorOptimizerCallback(
        weights=[1, 10, 5],  # weights for losses
    ),
    ##############################################
    PrepareDiscriminatorPhase(),
    DiscriminatorLoss(),
    DiscriminatorOptimizerCallback(),
    ##############################################
    LogImageCallback(),  # valid images
    LogImageCallback(key="mipt", img=mipt_photo),
    LogImageCallback(key="vk", img=zinger_photo),
]

# criterions for losses

criterion = {
    "gan": LSGanLoss(),
    "cycle": torch.nn.L1Loss(),
    "identical": torch.nn.L1Loss(),
}

runner = CycleGANRunner(buffer_size=50)
# start training
runner.train(
    model=model,
    optimizer=optimizer,
    loaders={"train": train_dl},
    callbacks=callbacks,
    criterion=criterion,
    num_epochs=100,
    verbose=True,
    logdir="naive_train",
    main_metric="identical_loss"
)

Last updated