11 lines
No EOL
305 B
Python
11 lines
No EOL
305 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
class SharedEmbedding(nn.Embedding):
|
|
|
|
def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
|
|
if unembed:
|
|
return F.linear(input, self.weight)
|
|
return super().forward(input) |