(self, max_batch_size=16)
| 2123 | return x |
| 2124 | |
| 2125 | def prepare_inputs(self, max_batch_size=16): |
| 2126 | |
| 2127 | bs_range = [1, (max_batch_size + 1) // 2, max_batch_size] |
| 2128 | min_feat_len, optimal_feat_len = 10, 1000 # 100ms, 10s |
| 2129 | inlen_range = [ |
| 2130 | min_feat_len, optimal_feat_len, self.max_audio_feature_seq_len |
| 2131 | ] |
| 2132 | inlen_range_after_downsample = [ |
| 2133 | min_feat_len // self.downsample_factor, |
| 2134 | optimal_feat_len // self.downsample_factor, |
| 2135 | self.max_audio_feature_seq_len // self.downsample_factor |
| 2136 | ] |
| 2137 | if not default_net().plugin_config.remove_input_padding: |
| 2138 | x = Tensor(name="input_features", |
| 2139 | dtype=self._dtype, |
| 2140 | shape=[-1, self.config.n_mels, -1], |
| 2141 | dim_range=OrderedDict([ |
| 2142 | ("batch_size", [bs_range]), |
| 2143 | ("feature_dim", [self.config.n_mels]), |
| 2144 | ("feature_len_range", [inlen_range]), |
| 2145 | ])) |
| 2146 | position_ids = Tensor( |
| 2147 | name='position_ids', |
| 2148 | dtype=trt.int32, |
| 2149 | shape=[-1, -1], |
| 2150 | dim_range=OrderedDict([('batch_size', [bs_range]), |
| 2151 | ('feature_len_downsample_range', |
| 2152 | [inlen_range_after_downsample])]), |
| 2153 | ) |
| 2154 | else: |
| 2155 | batch_seqlen_range = [ |
| 2156 | 1, |
| 2157 | (self.max_audio_feature_seq_len * max_batch_size + 1) // 2, |
| 2158 | self.max_audio_feature_seq_len * max_batch_size, |
| 2159 | ] |
| 2160 | batch_seqlen_downsample_range = [ |
| 2161 | 1, |
| 2162 | (self.max_audio_feature_seq_len // self.downsample_factor * |
| 2163 | max_batch_size + 1) // 2, |
| 2164 | self.max_audio_feature_seq_len // self.downsample_factor * |
| 2165 | max_batch_size, |
| 2166 | ] |
| 2167 | x = Tensor(name="input_features", |
| 2168 | dtype=self._dtype, |
| 2169 | shape=[-1, self.config.n_mels], |
| 2170 | dim_range=OrderedDict([ |
| 2171 | ("batch_seqlen_range", [batch_seqlen_range]), |
| 2172 | ("feature_dim", [self.config.n_mels]), |
| 2173 | ])) |
| 2174 | position_ids = Tensor( |
| 2175 | name='position_ids', |
| 2176 | dtype=trt.int32, |
| 2177 | shape=[-1], |
| 2178 | dim_range=OrderedDict([('batch_seqlen_downsample_range', |
| 2179 | [batch_seqlen_downsample_range])]), |
| 2180 | ) |
| 2181 | input_lengths = Tensor( |
| 2182 | name="input_lengths", |
nothing calls this directly
no test coverage detected