@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes. @return: a list contains values which can be fed into the self.forward()
(self, max_batch_size, **kwargs)
| 1278 | return output |
| 1279 | |
| 1280 | def prepare_inputs(self, max_batch_size, **kwargs): |
| 1281 | '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the |
| 1282 | ranges of the dimensions of when using TRT dynamic shapes. |
| 1283 | |
| 1284 | @return: a list contains values which can be fed into the self.forward() |
| 1285 | ''' |
| 1286 | |
| 1287 | mapping = self.config.mapping |
| 1288 | if mapping.tp_size > 1: |
| 1289 | current_all_reduce_helper().set_workspace_tensor(mapping, 1) |
| 1290 | |
| 1291 | def stdit_default_batch_range(max_batch_size): |
| 1292 | return [max_batch_size, max_batch_size, max_batch_size] |
| 1293 | |
| 1294 | default_range = stdit_default_batch_range |
| 1295 | # [NOTE] For now only static batch size is supported, so we run the model with max_batch_size. |
| 1296 | batch_size = max_batch_size |
| 1297 | |
| 1298 | x = Tensor(name='x', |
| 1299 | dtype=self.dtype, |
| 1300 | shape=[batch_size, self.in_channels, *self.latent_size], |
| 1301 | dim_range=OrderedDict([ |
| 1302 | ('batch_size', [default_range(max_batch_size)]), |
| 1303 | ('in_channels', [[self.in_channels] * 3]), |
| 1304 | ('latent_frames', [[self.latent_size[0]] * 3]), |
| 1305 | ('latent_height', [[self.latent_size[1]] * 3]), |
| 1306 | ('latent_width', [[self.latent_size[2]] * 3]), |
| 1307 | ])) |
| 1308 | timestep = Tensor(name='timestep', |
| 1309 | dtype=self.dtype, |
| 1310 | shape=[batch_size], |
| 1311 | dim_range=OrderedDict([ |
| 1312 | ('batch_size', [default_range(max_batch_size)]), |
| 1313 | ])) |
| 1314 | y = Tensor( |
| 1315 | name='y', |
| 1316 | dtype=self.dtype, |
| 1317 | shape=[batch_size, 1, self.model_max_length, self.caption_channels], |
| 1318 | dim_range=OrderedDict([ |
| 1319 | ('batch_size', [default_range(max_batch_size)]), |
| 1320 | ('mask_batch_size', [[1, 1, 1]]), |
| 1321 | ('num_tokens', [[self.model_max_length] * 3]), |
| 1322 | ('caption_channels', [[self.caption_channels] * 3]), |
| 1323 | ])) |
| 1324 | mask = Tensor(name='mask', |
| 1325 | dtype=trt.int32, |
| 1326 | shape=[1, self.model_max_length], |
| 1327 | dim_range=OrderedDict([ |
| 1328 | ('mask_batch_size', [[1, 1, 1]]), |
| 1329 | ('num_tokens', [[self.model_max_length] * 3]), |
| 1330 | ])) |
| 1331 | x_mask = Tensor(name='x_mask', |
| 1332 | dtype=tensorrt_llm.str_dtype_to_trt('bool'), |
| 1333 | shape=[batch_size, self.latent_size[0]], |
| 1334 | dim_range=OrderedDict([ |
| 1335 | ('batch_size', [default_range(max_batch_size)]), |
| 1336 | ('latent_frames', [[self.latent_size[0]] * 3]), |
| 1337 | ])) |
nothing calls this directly
no test coverage detected