If max_seq_len is not specified, set it to max_position_embeddings * rotary_factor Additional checks to ensure max_seq_len, max_input_len, and max_num_tokens have valid values.
(model_config, build_config)
| 809 | |
| 810 | |
| 811 | def _init_max_seq_len(model_config, build_config): |
| 812 | """ |
| 813 | If max_seq_len is not specified, set it to max_position_embeddings * rotary_factor |
| 814 | Additional checks to ensure max_seq_len, max_input_len, and max_num_tokens have valid values. |
| 815 | """ |
| 816 | # Extract rotary scaling which will be used for checks and default value of max_seq_len |
| 817 | rotary_scaling = getattr(model_config, "rotary_scaling", None) |
| 818 | if rotary_scaling is not None: |
| 819 | rotary_type = rotary_scaling.get('type', |
| 820 | rotary_scaling.get('rope_type')) |
| 821 | rotary_factor = rotary_scaling.get( |
| 822 | 'factor', 1.0) if rotary_type not in ("su", "longrope", |
| 823 | "llama3") else 1 |
| 824 | else: |
| 825 | rotary_factor = 1 |
| 826 | |
| 827 | if model_config.architecture == "EncoderModel": |
| 828 | if build_config.max_seq_len is None: |
| 829 | build_config.max_seq_len = build_config.max_input_len |
| 830 | logger.info( |
| 831 | f'max_seq_len is not specified for EncoderModel, using --max_input_len.' |
| 832 | ) |
| 833 | assert build_config.max_input_len == build_config.max_seq_len, f"EncoderModel should have same --max_input_len ({build_config.max_input_len}) and --max_seq_len ({build_config.max_seq_len})." |
| 834 | |
| 835 | if build_config.max_seq_len is None: |
| 836 | # Step 1: Find the upper bound of max_seq_len |
| 837 | deduced_max_seq_len = 2048 |
| 838 | if model_config.max_position_embeddings is not None: |
| 839 | deduced_max_seq_len = model_config.max_position_embeddings |
| 840 | |
| 841 | # Step 2: Scale max_seq_len with rotary scaling |
| 842 | if rotary_factor != 1: |
| 843 | deduced_max_seq_len = math.ceil(deduced_max_seq_len * rotary_factor) |
| 844 | logger.warning( |
| 845 | f'max_seq_len is scaled to {deduced_max_seq_len} by rotary scaling {rotary_factor}' |
| 846 | ) |
| 847 | |
| 848 | # Step 3: Assign the new max_seq_len |
| 849 | build_config.max_seq_len = int(deduced_max_seq_len) |
| 850 | logger.info( |
| 851 | f'max_seq_len is not specified, using deduced value {deduced_max_seq_len}' |
| 852 | ) |
| 853 | else: |
| 854 | if not build_config.plugin_config.streamingllm and model_config.max_position_embeddings is not None \ |
| 855 | and model_config.position_embedding_type != PositionEmbeddingType.relative: |
| 856 | if build_config.max_seq_len > model_config.max_position_embeddings * rotary_factor: |
| 857 | logger.warning( |
| 858 | f'max_seq_len {build_config.max_seq_len} is larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}, ' |
| 859 | 'the model accuracy might be affected') |
| 860 | |
| 861 | if build_config.max_input_len > build_config.max_seq_len: |
| 862 | logger.warning( |
| 863 | f'max_input_len is {build_config.max_input_len} is larger than max_seq_len {build_config.max_seq_len}, clipping it to max_seq_len' |
| 864 | ) |
| 865 | build_config.max_input_len = build_config.max_seq_len |
| 866 | |
| 867 | # Check and may modify max_num_tokens and opt_num_tokens (need to happen after max_seq_len is deduced) |
| 868 | max_num_tokens, opt_num_tokens = check_max_num_tokens( |
no test coverage detected