tiny_diff.models.vae.gan.VAEGAN

class tiny_diff.models.vae.gan.VAEGAN(vae: VAE, discriminator: Pix2PixDiscriminator, disc_weight: float = 1.0, gen_warmup: int = 60)

Bases: object

VAE + GAN.

Model that coordinates a vae and a discriminator to increase reconstruction quality.

Parameters:
  • vae – VAE model.

  • discriminator – Discriminator in the GAN pair.

  • disc_weight – discriminator weight for the discriminator loss term.

  • gen_warmup – last epoch in which the generator will be trained alone.

__init__(vae: VAE, discriminator: Pix2PixDiscriminator, disc_weight: float = 1.0, gen_warmup: int = 60)

Methods

__init__(vae, discriminator[, disc_weight, ...])

disc_loss_weight(loss, epoch)

Discriminator loss term weight.

gen_loss(x[, epoch])

Generator loss.

loss(x[, epoch, mode])

Computes the VAEGAN loss.