()
| 110 | # ------------------------------------------------------------------ # |
| 111 | |
| 112 | def main() -> None: |
| 113 | manifest = json.loads((EXT_DIR / "manifest.json").read_text(encoding="utf-8")) |
| 114 | model_id = manifest["id"] |
| 115 | |
| 116 | try: |
| 117 | GenClass = load_generator(manifest) |
| 118 | except Exception: |
| 119 | send({"type": "error", "id": None, |
| 120 | "message": "Failed to load generator class", |
| 121 | "traceback": traceback.format_exc()}) |
| 122 | return |
| 123 | |
| 124 | # Support both flat manifest (legacy) and nodes[] format. |
| 125 | # Use MODEL_DIR to find the correct node for multi-node extensions: |
| 126 | # MODEL_DIR is set by ExtensionProcess to MODELS_DIR/ext_id/node_id, |
| 127 | # so its last component matches the node id. |
| 128 | node = _select_node(manifest, _MODEL_DIR_OVERRIDE) |
| 129 | |
| 130 | # Announce readiness and send params_schema so ExtensionProcess |
| 131 | # can serve it without needing to query the subprocess later. |
| 132 | # We try to get it from the generator class (may be a classmethod), |
| 133 | # falling back to the selected node, then to the top-level manifest. |
| 134 | send({"type": "ready", "params_schema": _resolve_ready_schema(GenClass, node, manifest)}) |
| 135 | |
| 136 | # Use MODEL_DIR env var (set by ExtensionProcess) when available so the |
| 137 | # generator uses the exact same path that is_downloaded() checks against. |
| 138 | # Falls back to MODELS_DIR/manifest_id for legacy / standalone use. |
| 139 | model_dir = Path(_MODEL_DIR_OVERRIDE) if _MODEL_DIR_OVERRIDE else MODELS_DIR / model_id |
| 140 | gen = GenClass(model_dir, WORKSPACE_DIR) |
| 141 | _apply_manifest_metadata(gen, manifest, node) |
| 142 | |
| 143 | # Active cancel events keyed by request id |
| 144 | _cancel: dict[str, threading.Event] = {} |
| 145 | |
| 146 | for msg in recv(): |
| 147 | action = msg.get("action") |
| 148 | rid = msg.get("id") |
| 149 | |
| 150 | try: |
| 151 | # ---- load ------------------------------------------------ |
| 152 | if action == "load": |
| 153 | gen.load() |
| 154 | send({"type": "loaded"}) |
| 155 | |
| 156 | # ---- generate -------------------------------------------- |
| 157 | elif action == "generate": |
| 158 | cancel_evt = threading.Event() |
| 159 | _cancel[rid] = cancel_evt |
| 160 | image_bytes = base64.b64decode(msg["image_b64"]) |
| 161 | params = msg.get("params", {}) |
| 162 | if msg.get("outputs_dir"): |
| 163 | gen.outputs_dir = Path(msg["outputs_dir"]) |
| 164 | gen.outputs_dir.mkdir(parents=True, exist_ok=True) |
| 165 | |
| 166 | def progress_cb(pct: int, step: str = "") -> None: |
| 167 | send({"type": "progress", "id": rid, "pct": pct, "step": step}) |
| 168 | |
| 169 | try: |
no test coverage detected