expandInsertStmt expands * in an INSERT statement's RETURNING clause
(ctx context.Context, stmt *ast.InsertStmt)
| 206 | |
| 207 | // expandInsertStmt expands * in an INSERT statement's RETURNING clause |
| 208 | func (e *Expander) expandInsertStmt(ctx context.Context, stmt *ast.InsertStmt) error { |
| 209 | // Expand CTEs first |
| 210 | if stmt.WithClause != nil && stmt.WithClause.Ctes != nil { |
| 211 | for _, cte := range stmt.WithClause.Ctes.Items { |
| 212 | if err := e.expandNode(ctx, cte); err != nil { |
| 213 | return err |
| 214 | } |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | // Expand the SELECT part if present |
| 219 | if stmt.SelectStmt != nil { |
| 220 | if err := e.expandNode(ctx, stmt.SelectStmt); err != nil { |
| 221 | return err |
| 222 | } |
| 223 | } |
| 224 | |
| 225 | // Expand RETURNING clause |
| 226 | if hasStarInList(stmt.ReturningList) { |
| 227 | tempRaw := &ast.RawStmt{Stmt: stmt} |
| 228 | tempQuery := ast.Format(tempRaw, e.dialect) |
| 229 | columns, err := e.getColumnNames(ctx, tempQuery) |
| 230 | if err != nil { |
| 231 | return fmt.Errorf("failed to get column names: %w", err) |
| 232 | } |
| 233 | stmt.ReturningList = rewriteTargetList(stmt.ReturningList, columns) |
| 234 | } |
| 235 | |
| 236 | return nil |
| 237 | } |
| 238 | |
| 239 | // expandUpdateStmt expands * in an UPDATE statement's RETURNING clause |
| 240 | func (e *Expander) expandUpdateStmt(ctx context.Context, stmt *ast.UpdateStmt) error { |
no test coverage detected