Training
You can start with cloning my repo and downloading dataset.
git clone https://github.com/elephantmipt/cycle-gan-distillation.git
bash bin/download_dataset.sh monet2photoAfter that you can install requirements
pip install -r requierements/requirements.txtPreprocess data
python scripts/preprocess_dataset.py --path ./datasets/monet2photoThen you can change some configurations in train.py and start training with
python train.pyor you can try to use config API (I'm not sure that it works properly as I remove some core callbacks to create multiple phase pipeline). So the config file can be found in configs/train.yml. You can change any properties you want. And then start training with
catalyst-dl run -C configs/train.yml --verboseCloser look into Notebook API
from itertools import chain
import torch
from catalyst import dl
from src.callbacks.gan import (
CycleGANLoss,
GANLoss,
IdenticalGANLoss,
PrepareGeneratorPhase,
GeneratorOptimizerCallback,
PrepareDiscriminatorPhase,
DiscriminatorLoss,
DiscriminatorOptimizerCallback
)
from src.callbacks.visualization import LogImageCallback
from src.dataset import UnpairedDataset
from src.modules.generator import Generator
from src.modules.discriminator import NLayerDiscriminator, PixelDiscriminator
from src.runner import CycleGANRunner
from src.modules.loss import LSGanLoss
from torchvision import transforms as T
from PIL import Image
# defining dataset
train_ds = UnpairedDataset(
"./datasets/monet2photo/trainA_preprocessed",
"./datasets/monet2photo/trainB_preprocessed",
transforms=T.Compose([
T.Resize((300,300)),
T.RandomCrop((256, 256)),
T.RandomHorizontalFlip(),
T.ToTensor(),
])
)
train_dl = torch.utils.data.DataLoader(
train_ds,
batch_size=1,
shuffle=True
)
tr = T.Compose([
T.Resize((256,256)),
T.ToTensor(),
])
# loading images for validation
mipt_photo = tr(Image.open("./datasets/mipt.jpg"))
zinger_photo = tr(Image.open("./datasets/vk.jpg"))
# defining model arch
model = {
"generator_ab": Generator(
inp_channel_dim=3,
out_channel_dim=3,
n_blocks=9
),
"generator_ba": Generator(
inp_channel_dim=3,
out_channel_dim=3,
n_blocks=9
),
"discriminator_a": PixelDiscriminator(input_channels_dim=3),
"discriminator_b": PixelDiscriminator(input_channels_dim=3),
}
# initializing optimizers
optimizer = {
"generator": torch.optim.Adam(
chain(
model["generator_ab"].parameters(),
model["generator_ba"].parameters()
),
lr=0.0002
),
"discriminator": torch.optim.Adam(
chain(
model["discriminator_a"].parameters(),
model["discriminator_b"].parameters()
),
lr=0.0002
)
}
callbacks = [
##############################################
PrepareGeneratorPhase(),
GANLoss(),
CycleGANLoss(),
IdenticalGANLoss(),
GeneratorOptimizerCallback(
weights=[1, 10, 5], # weights for losses
),
##############################################
PrepareDiscriminatorPhase(),
DiscriminatorLoss(),
DiscriminatorOptimizerCallback(),
##############################################
LogImageCallback(), # valid images
LogImageCallback(key="mipt", img=mipt_photo),
LogImageCallback(key="vk", img=zinger_photo),
]
# criterions for losses
criterion = {
"gan": LSGanLoss(),
"cycle": torch.nn.L1Loss(),
"identical": torch.nn.L1Loss(),
}
runner = CycleGANRunner(buffer_size=50)
# start training
runner.train(
model=model,
optimizer=optimizer,
loaders={"train": train_dl},
callbacks=callbacks,
criterion=criterion,
num_epochs=100,
verbose=True,
logdir="naive_train",
main_metric="identical_loss"
)Last updated