Extract classes, functions, and imports from a .py file via tree-sitter AST.
(path: Path)
| 15 | |
| 16 | |
| 17 | def extract_python(path: Path) -> dict: |
| 18 | """Extract classes, functions, and imports from a .py file via tree-sitter AST.""" |
| 19 | try: |
| 20 | import tree_sitter_python as tspython |
| 21 | from tree_sitter import Language, Parser |
| 22 | except ImportError: |
| 23 | return {"nodes": [], "edges": [], "error": "tree-sitter-python not installed"} |
| 24 | |
| 25 | try: |
| 26 | language = Language(tspython.language()) |
| 27 | parser = Parser(language) |
| 28 | source = path.read_bytes() |
| 29 | tree = parser.parse(source) |
| 30 | root = tree.root_node |
| 31 | except Exception as e: |
| 32 | return {"nodes": [], "edges": [], "error": str(e)} |
| 33 | |
| 34 | stem = path.stem |
| 35 | str_path = str(path) |
| 36 | nodes: list[dict] = [] |
| 37 | edges: list[dict] = [] |
| 38 | seen_ids: set[str] = set() |
| 39 | |
| 40 | def add_node(nid: str, label: str, line: int) -> None: |
| 41 | if nid not in seen_ids: |
| 42 | seen_ids.add(nid) |
| 43 | nodes.append({ |
| 44 | "id": nid, |
| 45 | "label": label, |
| 46 | "file_type": "code", |
| 47 | "source_file": str_path, |
| 48 | "source_location": f"L{line}", |
| 49 | }) |
| 50 | |
| 51 | def add_edge(src: str, tgt: str, relation: str, line: int) -> None: |
| 52 | # Only add edge if both endpoints exist or src is the file node |
| 53 | edges.append({ |
| 54 | "source": src, |
| 55 | "target": tgt, |
| 56 | "relation": relation, |
| 57 | "confidence": "EXTRACTED", |
| 58 | "source_file": str_path, |
| 59 | "source_location": f"L{line}", |
| 60 | "weight": 1.0, |
| 61 | }) |
| 62 | |
| 63 | # File-level node - stable ID based on stem only |
| 64 | file_nid = _make_id(stem) |
| 65 | add_node(file_nid, path.name, 1) |
| 66 | |
| 67 | def walk(node, parent_class_nid: str | None = None) -> None: |
| 68 | t = node.type |
| 69 | |
| 70 | if t == "import_statement": |
| 71 | for child in node.children: |
| 72 | if child.type in ("dotted_name", "aliased_import"): |
| 73 | raw = source[child.start_byte:child.end_byte].decode("utf-8", errors="replace") |
| 74 | module_name = raw.split(" as ")[0].strip().lstrip(".") |