validateQueryTables checks if the query accesses only allowed tables.
(db *sqlx.DB, query string, allowedTables map[string]struct{}, args ...any)
| 602 | |
| 603 | // validateQueryTables checks if the query accesses only allowed tables. |
| 604 | func validateQueryTables(db *sqlx.DB, query string, allowedTables map[string]struct{}, args ...any) error { |
| 605 | // Get the EXPLAIN (FORMAT JSON) output. |
| 606 | tx, err := db.BeginTxx(context.Background(), &sql.TxOptions{ReadOnly: true}) |
| 607 | if err != nil { |
| 608 | return err |
| 609 | } |
| 610 | defer tx.Rollback() |
| 611 | |
| 612 | var plan string |
| 613 | if err = tx.QueryRow("EXPLAIN (FORMAT JSON) "+query, args...).Scan(&plan); err != nil { |
| 614 | return err |
| 615 | } |
| 616 | |
| 617 | // Extract all relation names from the JSON plan. |
| 618 | tables, err := getTablesFromQueryPlan(plan) |
| 619 | if err != nil { |
| 620 | return fmt.Errorf("error getting tables from query: %v", err) |
| 621 | } |
| 622 | |
| 623 | // Validate against allowed tables. |
| 624 | for _, table := range tables { |
| 625 | if _, ok := allowedTables[table]; !ok { |
| 626 | return fmt.Errorf("table '%s' is not allowed", table) |
| 627 | } |
| 628 | } |
| 629 | |
| 630 | return nil |
| 631 | } |
| 632 | |
| 633 | // getTablesFromQueryPlan parses the EXPLAIN JSON to find all "Relation Name" entries. |
| 634 | func getTablesFromQueryPlan(explainJSON string) ([]string, error) { |
no test coverage detected