tiny_diff.models.vae.gan.VAEGAN
- class tiny_diff.models.vae.gan.VAEGAN(vae: VAE, discriminator: Pix2PixDiscriminator, disc_weight: float = 1.0, gen_warmup: int = 60)
Bases:
objectVAE + GAN.
Model that coordinates a vae and a discriminator to increase reconstruction quality.
- Parameters:
vae – VAE model.
discriminator – Discriminator in the GAN pair.
disc_weight – discriminator weight for the discriminator loss term.
gen_warmup – last epoch in which the generator will be trained alone.
- __init__(vae: VAE, discriminator: Pix2PixDiscriminator, disc_weight: float = 1.0, gen_warmup: int = 60)
Methods
__init__(vae, discriminator[, disc_weight, ...])disc_loss_weight(loss, epoch)Discriminator loss term weight.
gen_loss(x[, epoch])Generator loss.
loss(x[, epoch, mode])Computes the VAEGAN loss.