文档段,这个是给大模型上下文的最小单位
| 47 | |
| 48 | @dataclass |
| 49 | class MarkdownBlock: |
| 50 | """文档段,这个是给大模型上下文的最小单位""" |
| 51 | |
| 52 | # 文件名 |
| 53 | file_name: str |
| 54 | # 文件标题 |
| 55 | title: str = "" |
| 56 | # 二级或三级标题 |
| 57 | header: str = "" |
| 58 | # 内容,可能是文本或代码段 |
| 59 | content: list[ContentType] = field(default_factory=list) |
| 60 | |
| 61 | def gen_text(self, max_length: int = 500, include_code=True) -> str: |
| 62 | """"输出文本""" |
| 63 | current_length = 0 |
| 64 | output = self.header + "\n\n" if self.header else "" |
| 65 | for para in self.content: |
| 66 | content = para.text |
| 67 | # 超过长度限制了就中断,这里其实没考虑代码段 ``` 多出来的 10 个字符 |
| 68 | if current_length + len(content) > max_length: |
| 69 | break |
| 70 | if para.type == ContentType.Code and include_code: |
| 71 | output += f"\n```\n{content}\n```\n" |
| 72 | else: |
| 73 | output += content + "\n" |
| 74 | current_length += len(content) |
| 75 | |
| 76 | return output |
| 77 | |
| 78 | def get_text_blocks(self) -> list[str]: |
| 79 | """获取用于生成嵌入的文本段落列表""" |
| 80 | blocks: list[str] = [] |
| 81 | header = self.header.replace("#", "") if self.header else "" |
| 82 | if header != "": |
| 83 | if len(header) < 4: |
| 84 | blocks.append(self.title + header) |
| 85 | else: |
| 86 | blocks.append(header) |
| 87 | all_text = "" |
| 88 | for para in self.content: |
| 89 | if para.type == ContentType.Text: |
| 90 | # 去掉各种样式及图片避免影响 |
| 91 | text = unmark(para.text) |
| 92 | all_text += text |
| 93 | blocks.append(self.title + header + text) |
| 94 | blocks.append(text) |
| 95 | # 对于太长的段落,拆分一下 |
| 96 | if len(text) > LONG_CONTENT_LENGTH: |
| 97 | for line in text.split(","): |
| 98 | blocks.append(line) |
| 99 | |
| 100 | if len(all_text) < LONG_CONTENT_LENGTH: |
| 101 | blocks.append(header + all_text) |
| 102 | |
| 103 | # 删掉重复的和避免空字符 |
| 104 | output_blocks = set() |
| 105 | for block in blocks: |
| 106 | block = block.strip() |