(reader, field_name, field_type)
| 24 | return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) |
| 25 | |
| 26 | def get_field(reader, field_name, field_type): |
| 27 | field = reader.get_field(field_name) |
| 28 | if field is None: |
| 29 | return None |
| 30 | elif field_type == str: |
| 31 | # extra check here as this is used for checking arch string |
| 32 | if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING: |
| 33 | raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}") |
| 34 | return str(field.parts[field.data[-1]], encoding="utf-8") |
| 35 | elif field_type in [int, float, bool]: |
| 36 | return field_type(field.parts[field.data[-1]].item()) |
| 37 | else: |
| 38 | raise TypeError(f"Unknown field type {field_type}") |
| 39 | |
| 40 | def get_list_field(reader, field_name, field_type): |
| 41 | field = reader.get_field(field_name) |
no outgoing calls
no test coverage detected