mgoapi/filter.go

256 lines
6.0 KiB
Go

package mgoapi
import (
"fmt"
"reflect"
"strings"
"time"
"gopkg.in/mgo.v2/bson"
)
func structFieldFromJSON(t reflect.Type, fieldname string) (sfield string, err error) {
if t.Kind() == reflect.Slice {
t = t.Elem()
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return "", fmt.Errorf("cannot find field from json name %s, %+v is no struct", fieldname, t)
}
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Anonymous {
// embedded struct
embedSField, _ := structFieldFromJSON(f.Type, fieldname)
if embedSField != "" {
return embedSField, nil
}
}
jsonTag := strings.Split(f.Tag.Get("json"), ",")
if jsonTag[0] == "-" {
continue
}
if jsonTag[0] == fieldname {
return f.Name, nil
}
if jsonTag[0] == "" && f.Name == fieldname {
return f.Name, nil
}
}
return "", fmt.Errorf("json field %s not found in %+v", fieldname, t)
}
func bsonFieldFromStruct(t reflect.Type, fieldname string) (bfield string, err error) {
if t.Kind() == reflect.Slice {
t = t.Elem()
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return "", fmt.Errorf("cannot find struct field from bson name %s, %+v is no struct", fieldname, t)
}
if f, found := t.FieldByName(fieldname); found {
bsonTag := strings.Split(f.Tag.Get("bson"), ",")
if bsonTag[0] == "-" {
return "", fmt.Errorf("struct field %s from %+v is no bson field", fieldname, t)
}
if bsonTag[0] != "" {
return bsonTag[0], nil
}
return strings.ToLower(fieldname), nil
}
return "", fmt.Errorf("struct field %s not found in %+v", fieldname, t)
}
func typeOfStructFieldByName(t reflect.Type, fieldname string) (reflect.Type, error) {
if t.Kind() == reflect.Slice {
t = t.Elem()
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("cannot get field type of %s, %+v is no struct", fieldname, t)
}
if f, found := t.FieldByName(fieldname); found {
fT := f.Type
if fT.Kind() == reflect.Ptr {
return fT.Elem(), nil
}
return fT, nil
}
return nil, fmt.Errorf("struct field %s not found in %+v", fieldname, t)
}
func _checkValue(fieldT reflect.Type, val interface{}) (newVal interface{}, err error) {
fieldTE := fieldT
if fieldTE.Kind() == reflect.Ptr || fieldTE.Kind() == reflect.Slice {
fieldTE = fieldTE.Elem()
}
if fieldTE.Kind() == reflect.Ptr {
fieldTE = fieldTE.Elem()
}
switch value := val.(type) {
case string:
switch fieldTE {
case reflect.TypeOf(bson.ObjectId("")): // struct field is objectid
newVal = bson.ObjectIdHex(value)
case reflect.TypeOf(time.Now()): // struct field is time
newVal, err = time.Parse(time.RFC3339, value)
if err != nil {
return nil, fmt.Errorf("date string from filter is not compatible with type %s (date must be RFC3339)", fieldT)
}
}
case float64:
switch fieldTE {
case reflect.TypeOf(int(0)): // json is always float64 for numbers
newVal = int(value)
}
case bson.RegEx:
switch fieldTE {
case reflect.TypeOf(""):
// leave as is
newVal = value
}
}
if newVal == nil {
valT := reflect.TypeOf(val)
if fieldTE == valT {
newVal = val
} else if valT.Kind() == reflect.Map {
// go into sub object and validate
newVal, err = _validateFilter(fieldTE, val.(map[string]interface{}), true)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("type %s from filter is not compatible with type %s", valT, fieldT)
}
}
return newVal, err
}
func _validateFilter(structT reflect.Type, f bson.M, checkValues bool) (bson.M, error) {
newF := make(bson.M)
for key, val := range f {
if key == "" {
return nil, fmt.Errorf("empty key in filter for model %+v", structT)
}
newKey := ""
var newVal interface{}
if key == "$text" {
// dont modify or check search
newKey = key
newVal = val
} else if strings.HasPrefix(key, "$") {
// is filter function and no fieldname
newKey = key
valT := reflect.TypeOf(val)
if valT.Kind() == reflect.Slice {
l := len(val.([]interface{}))
newVal = make([]interface{}, l, l)
for i := 0; i < l; i++ {
var err error
switch valF := val.([]interface{})[i].(type) {
case map[string]interface{}:
// array element is object
newVal.([]interface{})[i], err = _validateFilter(structT, valF, checkValues)
if err != nil {
return nil, err
}
default:
// array element is value
newVal.([]interface{})[i], err = _checkValue(structT, valF)
if err != nil {
return nil, fmt.Errorf("%s from %+v for field %s", err.Error(), structT, newKey)
}
}
}
} else {
var err error
switch valF := val.(type) {
case map[string]interface{}:
newVal, err = _validateFilter(structT, valF, checkValues)
if err != nil {
return nil, err
}
default:
newVal, err = _checkValue(structT, val)
if err != nil {
return nil, fmt.Errorf("%s from %+v for field %s", err.Error(), structT, newKey)
}
}
}
} else {
keys := strings.Split(key, ".")
var _getBSONField func(k []string, t reflect.Type) error
_getBSONField = func(k []string, t reflect.Type) error {
structField, err := structFieldFromJSON(t, k[0])
if err != nil {
return err
}
bsonField, err := bsonFieldFromStruct(t, structField)
if err != nil {
return err
}
newKey += bsonField
fieldT, err := typeOfStructFieldByName(t, structField)
if err != nil {
return err
}
if len(k) > 1 {
newKey += "."
return _getBSONField(k[1:], fieldT)
}
// last element -> check value
if checkValues {
newVal, err = _checkValue(fieldT, val)
if err != nil {
return fmt.Errorf("%s from %+v for field %s", err.Error(), t, newKey)
}
} else {
newVal = val
}
return nil
}
err := _getBSONField(keys, structT)
if err != nil {
return nil, err
}
}
newF[newKey] = newVal
// spew.Dump(key)
// spew.Dump(val)
// spew.Dump(newKey)
}
return newF, nil
}