| 11 | ) |
| 12 | |
| 13 | func TestApply(t *testing.T) { |
| 14 | p := NewParser() |
| 15 | |
| 16 | input, err := p.Parse(strings.NewReader("SELECT sqlc.arg(name)")) |
| 17 | if err != nil { |
| 18 | t.Fatal(err) |
| 19 | } |
| 20 | output, err := p.Parse(strings.NewReader("SELECT $1")) |
| 21 | if err != nil { |
| 22 | t.Fatal(err) |
| 23 | } |
| 24 | |
| 25 | expect := &output[0] |
| 26 | actual := astutils.Apply(&input[0], func(cr *astutils.Cursor) bool { |
| 27 | fun, ok := cr.Node().(*ast.FuncCall) |
| 28 | if !ok { |
| 29 | return true |
| 30 | } |
| 31 | if astutils.Join(fun.Funcname, ".") == "sqlc.arg" { |
| 32 | cr.Replace(&ast.ParamRef{ |
| 33 | Dollar: true, |
| 34 | Number: 1, |
| 35 | Location: fun.Location, |
| 36 | }) |
| 37 | return false |
| 38 | } |
| 39 | return true |
| 40 | }, nil) |
| 41 | |
| 42 | if diff := cmp.Diff(expect, actual); diff != "" { |
| 43 | t.Errorf("rewrite mismatch:\n%s", diff) |
| 44 | } |
| 45 | } |