| 335 | |
| 336 | |
| 337 | class ConvUNetVAE(nn.Module): |
| 338 | def __init__(self) -> None: |
| 339 | super().__init__() |
| 340 | self.embed_image = ImageEmbedding() |
| 341 | self.embed_time = TimestepEmbedding_() |
| 342 | |
| 343 | down_0 = nn.ModuleList( |
| 344 | [ |
| 345 | ConvResblock(320, 320), |
| 346 | ConvResblock(320, 320), |
| 347 | ConvResblock(320, 320), |
| 348 | Downsample(320), |
| 349 | ] |
| 350 | ) |
| 351 | down_1 = nn.ModuleList( |
| 352 | [ |
| 353 | ConvResblock(320, 640), |
| 354 | ConvResblock(640, 640), |
| 355 | ConvResblock(640, 640), |
| 356 | Downsample(640), |
| 357 | ] |
| 358 | ) |
| 359 | down_2 = nn.ModuleList( |
| 360 | [ |
| 361 | ConvResblock(640, 1024), |
| 362 | ConvResblock(1024, 1024), |
| 363 | ConvResblock(1024, 1024), |
| 364 | Downsample(1024), |
| 365 | ] |
| 366 | ) |
| 367 | down_3 = nn.ModuleList( |
| 368 | [ |
| 369 | ConvResblock(1024, 1024), |
| 370 | ConvResblock(1024, 1024), |
| 371 | ConvResblock(1024, 1024), |
| 372 | ] |
| 373 | ) |
| 374 | self.down = nn.ModuleList( |
| 375 | [ |
| 376 | down_0, |
| 377 | down_1, |
| 378 | down_2, |
| 379 | down_3, |
| 380 | ] |
| 381 | ) |
| 382 | |
| 383 | self.mid = nn.ModuleList( |
| 384 | [ |
| 385 | ConvResblock(1024, 1024), |
| 386 | ConvResblock(1024, 1024), |
| 387 | ] |
| 388 | ) |
| 389 | |
| 390 | up_3 = nn.ModuleList( |
| 391 | [ |
| 392 | ConvResblock(1024 * 2, 1024), |
| 393 | ConvResblock(1024 * 2, 1024), |
| 394 | ConvResblock(1024 * 2, 1024), |
no outgoing calls
no test coverage detected
searching dependent graphs…