Extract functions, structs, enums, traits, impl methods, and use declarations from a .rs file.
(path: Path)
| 586 | |
| 587 | |
| 588 | def extract_rust(path: Path) -> dict: |
| 589 | """Extract functions, structs, enums, traits, impl methods, and use declarations from a .rs file.""" |
| 590 | try: |
| 591 | import tree_sitter_rust as tsrust |
| 592 | from tree_sitter import Language, Parser |
| 593 | except ImportError: |
| 594 | return {"nodes": [], "edges": [], "error": "tree-sitter-rust not installed"} |
| 595 | |
| 596 | try: |
| 597 | language = Language(tsrust.language()) |
| 598 | parser = Parser(language) |
| 599 | source = path.read_bytes() |
| 600 | tree = parser.parse(source) |
| 601 | root = tree.root_node |
| 602 | except Exception as e: |
| 603 | return {"nodes": [], "edges": [], "error": str(e)} |
| 604 | |
| 605 | stem = path.stem |
| 606 | str_path = str(path) |
| 607 | nodes: list[dict] = [] |
| 608 | edges: list[dict] = [] |
| 609 | seen_ids: set[str] = set() |
| 610 | |
| 611 | def add_node(nid: str, label: str, line: int) -> None: |
| 612 | if nid not in seen_ids: |
| 613 | seen_ids.add(nid) |
| 614 | nodes.append({ |
| 615 | "id": nid, |
| 616 | "label": label, |
| 617 | "file_type": "code", |
| 618 | "source_file": str_path, |
| 619 | "source_location": f"L{line}", |
| 620 | }) |
| 621 | |
| 622 | def add_edge(src: str, tgt: str, relation: str, line: int, confidence: str = "EXTRACTED", weight: float = 1.0) -> None: |
| 623 | edges.append({ |
| 624 | "source": src, |
| 625 | "target": tgt, |
| 626 | "relation": relation, |
| 627 | "confidence": confidence, |
| 628 | "source_file": str_path, |
| 629 | "source_location": f"L{line}", |
| 630 | "weight": weight, |
| 631 | }) |
| 632 | |
| 633 | file_nid = _make_id(stem) |
| 634 | add_node(file_nid, path.name, 1) |
| 635 | |
| 636 | function_bodies: list[tuple[str, object]] = [] |
| 637 | |
| 638 | def walk(node, parent_impl_nid: str | None = None) -> None: |
| 639 | t = node.type |
| 640 | |
| 641 | if t == "function_item": |
| 642 | name_node = node.child_by_field_name("name") |
| 643 | if name_node: |
| 644 | func_name = source[name_node.start_byte:name_node.end_byte].decode("utf-8", errors="replace") |
| 645 | line = node.start_point[0] + 1 |