getCTEColumnNames gets the column names for a CTE by constructing a query with proper context
(ctx context.Context, stmt *ast.SelectStmt, targetCTE *ast.CommonTableExpr)
| 157 | |
| 158 | // getCTEColumnNames gets the column names for a CTE by constructing a query with proper context |
| 159 | func (e *Expander) getCTEColumnNames(ctx context.Context, stmt *ast.SelectStmt, targetCTE *ast.CommonTableExpr) ([]string, error) { |
| 160 | // Build a temporary query: WITH <all CTEs up to and including target> SELECT * FROM <targetCTE> |
| 161 | var ctesToInclude []ast.Node |
| 162 | for _, cteNode := range stmt.WithClause.Ctes.Items { |
| 163 | ctesToInclude = append(ctesToInclude, cteNode) |
| 164 | cte, ok := cteNode.(*ast.CommonTableExpr) |
| 165 | if ok && cte.Ctename != nil && targetCTE.Ctename != nil && *cte.Ctename == *targetCTE.Ctename { |
| 166 | break |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | // Create a SELECT * FROM <ctename> with the relevant CTEs |
| 171 | cteName := "" |
| 172 | if targetCTE.Ctename != nil { |
| 173 | cteName = *targetCTE.Ctename |
| 174 | } |
| 175 | |
| 176 | tempStmt := &ast.SelectStmt{ |
| 177 | WithClause: &ast.WithClause{ |
| 178 | Ctes: &ast.List{Items: ctesToInclude}, |
| 179 | Recursive: stmt.WithClause.Recursive, |
| 180 | }, |
| 181 | TargetList: &ast.List{ |
| 182 | Items: []ast.Node{ |
| 183 | &ast.ResTarget{ |
| 184 | Val: &ast.ColumnRef{ |
| 185 | Fields: &ast.List{ |
| 186 | Items: []ast.Node{&ast.A_Star{}}, |
| 187 | }, |
| 188 | }, |
| 189 | }, |
| 190 | }, |
| 191 | }, |
| 192 | FromClause: &ast.List{ |
| 193 | Items: []ast.Node{ |
| 194 | &ast.RangeVar{ |
| 195 | Relname: &cteName, |
| 196 | }, |
| 197 | }, |
| 198 | }, |
| 199 | } |
| 200 | |
| 201 | tempRaw := &ast.RawStmt{Stmt: tempStmt} |
| 202 | tempQuery := ast.Format(tempRaw, e.dialect) |
| 203 | |
| 204 | return e.getColumnNames(ctx, tempQuery) |
| 205 | } |
| 206 | |
| 207 | // expandInsertStmt expands * in an INSERT statement's RETURNING clause |
| 208 | func (e *Expander) expandInsertStmt(ctx context.Context, stmt *ast.InsertStmt) error { |
no test coverage detected