insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities.
(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder)
| 1950 | |
| 1951 | // insertLastIDs invokes the batch insert query on the transaction and returns the LastInsertID of all entities. |
| 1952 | func (c *batchCreator) insertLastIDs(ctx context.Context, tx dialect.ExecQuerier, insert *sql.InsertBuilder) error { |
| 1953 | query, args, err := insert.QueryErr() |
| 1954 | if err != nil { |
| 1955 | return err |
| 1956 | } |
| 1957 | // MySQL does not support the "RETURNING" clause. |
| 1958 | if insert.Dialect() != dialect.MySQL { |
| 1959 | rows := &sql.Rows{} |
| 1960 | if err := tx.Query(ctx, query, args, rows); err != nil { |
| 1961 | return err |
| 1962 | } |
| 1963 | defer rows.Close() |
| 1964 | for i := 0; rows.Next(); i++ { |
| 1965 | node := c.Nodes[i] |
| 1966 | switch _, ok := node.ID.Value.(field.ValueScanner); { |
| 1967 | case ok: |
| 1968 | // If the ID implements the sql.Scanner |
| 1969 | // interface it should be a pointer type. |
| 1970 | if err := rows.Scan(node.ID.Value); err != nil { |
| 1971 | return err |
| 1972 | } |
| 1973 | case node.ID.Type.Numeric(): |
| 1974 | // Normalize the type to int64 to make it looks |
| 1975 | // like LastInsertId. |
| 1976 | var id int64 |
| 1977 | if err := rows.Scan(&id); err != nil { |
| 1978 | return err |
| 1979 | } |
| 1980 | node.ID.Value = id |
| 1981 | default: |
| 1982 | if err := rows.Scan(&node.ID.Value); err != nil { |
| 1983 | return err |
| 1984 | } |
| 1985 | } |
| 1986 | } |
| 1987 | return rows.Err() |
| 1988 | } |
| 1989 | // MySQL. |
| 1990 | var res sql.Result |
| 1991 | if err := tx.Exec(ctx, query, args, &res); err != nil { |
| 1992 | return err |
| 1993 | } |
| 1994 | // If the ID field is not numeric (e.g. string), |
| 1995 | // there is no way to scan the LAST_INSERT_ID. |
| 1996 | if len(c.Nodes) > 0 && c.Nodes[0].ID.Type.Numeric() { |
| 1997 | id, err := res.LastInsertId() |
| 1998 | if err != nil { |
| 1999 | return err |
| 2000 | } |
| 2001 | affected, err := res.RowsAffected() |
| 2002 | if err != nil { |
| 2003 | return err |
| 2004 | } |
| 2005 | // Assume the ID field is AUTO_INCREMENT |
| 2006 | // if its type is numeric. |
| 2007 | for i := 0; int64(i) < affected && i < len(c.Nodes); i++ { |
| 2008 | c.Nodes[i].ID.Value = id + int64(i) |
| 2009 | } |
no test coverage detected