if tensor is one dim, shapeSize is nil
(pr *pb.PredictRequest, tensorName string, dataType framework.DataType, tensor interface{},
shapeSize []int64, shapeName []string)
| 23 | |
| 24 | // if tensor is one dim, shapeSize is nil |
| 25 | func addInput(pr *pb.PredictRequest, tensorName string, dataType framework.DataType, tensor interface{}, |
| 26 | shapeSize []int64, shapeName []string) (err error) { |
| 27 | v := reflect.ValueOf(tensor) |
| 28 | if v.Kind() != reflect.Slice { |
| 29 | return errors.New("tensor must be slice") |
| 30 | } |
| 31 | size := v.Len() |
| 32 | tp := &framework.TensorProto{ |
| 33 | Dtype: dataType, |
| 34 | } |
| 35 | |
| 36 | var ok bool |
| 37 | switch dataType { |
| 38 | case framework.DataType_DT_HALF: |
| 39 | tp.HalfVal, ok = tensor.([]int32) |
| 40 | case framework.DataType_DT_FLOAT: |
| 41 | tp.FloatVal, ok = tensor.([]float32) |
| 42 | case framework.DataType_DT_DOUBLE: |
| 43 | tp.DoubleVal, ok = tensor.([]float64) |
| 44 | case framework.DataType_DT_INT16, framework.DataType_DT_INT32, |
| 45 | framework.DataType_DT_INT8, framework.DataType_DT_UINT8: |
| 46 | tp.IntVal, ok = tensor.([]int32) |
| 47 | case framework.DataType_DT_STRING: |
| 48 | tp.StringVal, ok = tensor.([][]byte) |
| 49 | case framework.DataType_DT_COMPLEX64: |
| 50 | tp.ScomplexVal, ok = tensor.([]float32) |
| 51 | case framework.DataType_DT_INT64: |
| 52 | tp.Int64Val, ok = tensor.([]int64) |
| 53 | case framework.DataType_DT_BOOL: |
| 54 | tp.BoolVal, ok = tensor.([]bool) |
| 55 | case framework.DataType_DT_COMPLEX128: |
| 56 | tp.DcomplexVal, ok = tensor.([]float64) |
| 57 | case framework.DataType_DT_RESOURCE: |
| 58 | tp.ResourceHandleVal, ok = tensor.([]*framework.ResourceHandle) |
| 59 | default: |
| 60 | err = errors.New("Unknown data type") |
| 61 | } |
| 62 | |
| 63 | if !ok { |
| 64 | if err != nil { |
| 65 | err = errors.New("Type switch failed") |
| 66 | } |
| 67 | return |
| 68 | } |
| 69 | |
| 70 | if shapeSize == nil { |
| 71 | name := "" |
| 72 | if len(shapeName) != 0 { |
| 73 | name = shapeName[0] |
| 74 | } |
| 75 | tp.TensorShape = &framework.TensorShapeProto{ |
| 76 | Dim: []*framework.TensorShapeProto_Dim{ |
| 77 | &framework.TensorShapeProto_Dim{ |
| 78 | Size: int64(size), |
| 79 | Name: name, |
| 80 | }, |
| 81 | }, |
| 82 | } |
no outgoing calls
no test coverage detected