17 lines
480 B
Python
17 lines
480 B
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def hinge_d_loss(logits_real, logits_fake):
|
|
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
|
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
|
d_loss = 0.5 * (loss_real + loss_fake)
|
|
return d_loss
|
|
|
|
|
|
def vanilla_d_loss(logits_real, logits_fake):
|
|
d_loss = 0.5 * (
|
|
torch.mean(torch.nn.functional.softplus(-logits_real))
|
|
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
|
)
|
|
return d_loss
|