18 lines
480 B
Python
18 lines
480 B
Python
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
def hinge_d_loss(logits_real, logits_fake):
|
||
|
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
||
|
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
||
|
d_loss = 0.5 * (loss_real + loss_fake)
|
||
|
return d_loss
|
||
|
|
||
|
|
||
|
def vanilla_d_loss(logits_real, logits_fake):
|
||
|
d_loss = 0.5 * (
|
||
|
torch.mean(torch.nn.functional.softplus(-logits_real))
|
||
|
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
||
|
)
|
||
|
return d_loss
|