MCPcopy
hub / github.com/yerfor/GeneFacePlusPlus / __init__

Method __init__

modules/commons/vqvae_taming.py:120–148  ·  view source on GitHub ↗
(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
                 kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
                 remap=None, unknown_index="random")

Source from the content-addressed store, hash-verified

118 """
119
120 def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
121 kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
122 remap=None, unknown_index="random"):
123 super().__init__()
124
125 self.embedding_dim = embedding_dim
126 self.n_embed = n_embed
127
128 self.straight_through = straight_through
129 self.temperature = temp_init
130 self.kl_weight = kl_weight
131
132 self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
133 self.embed = nn.Embedding(n_embed, embedding_dim)
134
135 self.use_vqinterface = use_vqinterface
136
137 self.remap = remap
138 if self.remap is not None:
139 self.register_buffer("used", torch.tensor(np.load(self.remap)))
140 self.re_embed = self.used.shape[0]
141 self.unknown_index = unknown_index # "random" or "extra" or integer
142 if self.unknown_index == "extra":
143 self.unknown_index = self.re_embed
144 self.re_embed = self.re_embed + 1
145 print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
146 f"Using {self.unknown_index} for unknown indices.")
147 else:
148 self.re_embed = n_embed
149
150 def remap_to_used(self, inds):
151 ishape = inds.shape

Callers 3

__init__Method · 0.45
__init__Method · 0.45
__init__Method · 0.45

Calls 1

loadMethod · 0.80

Tested by

no test coverage detected