package mgoapi import ( "fmt" "reflect" "strings" "time" "gopkg.in/mgo.v2/bson" ) func structFieldFromJSON(t reflect.Type, fieldname string) (sfield string, err error) { if t.Kind() == reflect.Slice { t = t.Elem() } if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return "", fmt.Errorf("cannot find field from json name %s, %+v is no struct", fieldname, t) } for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Anonymous { // embedded struct embedSField, _ := structFieldFromJSON(f.Type, fieldname) if embedSField != "" { return embedSField, nil } } jsonTag := strings.Split(f.Tag.Get("json"), ",") if jsonTag[0] == "-" { continue } if jsonTag[0] == fieldname { return f.Name, nil } if jsonTag[0] == "" && f.Name == fieldname { return f.Name, nil } } return "", fmt.Errorf("json field %s not found in %+v", fieldname, t) } func bsonFieldFromStruct(t reflect.Type, fieldname string) (bfield string, err error) { if t.Kind() == reflect.Slice { t = t.Elem() } if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return "", fmt.Errorf("cannot find struct field from bson name %s, %+v is no struct", fieldname, t) } if f, found := t.FieldByName(fieldname); found { bsonTag := strings.Split(f.Tag.Get("bson"), ",") if bsonTag[0] == "-" { return "", fmt.Errorf("struct field %s from %+v is no bson field", fieldname, t) } if bsonTag[0] != "" { return bsonTag[0], nil } return strings.ToLower(fieldname), nil } return "", fmt.Errorf("struct field %s not found in %+v", fieldname, t) } func typeOfStructFieldByName(t reflect.Type, fieldname string) (reflect.Type, error) { if t.Kind() == reflect.Slice { t = t.Elem() } if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return nil, fmt.Errorf("cannot get field type of %s, %+v is no struct", fieldname, t) } if f, found := t.FieldByName(fieldname); found { fT := f.Type if fT.Kind() == reflect.Ptr { return fT.Elem(), nil } return fT, nil } return nil, fmt.Errorf("struct field %s not found in %+v", fieldname, t) } func _checkValue(fieldT reflect.Type, val interface{}) (newVal interface{}, err error) { fieldTE := fieldT if fieldTE.Kind() == reflect.Ptr || fieldTE.Kind() == reflect.Slice { fieldTE = fieldTE.Elem() } if fieldTE.Kind() == reflect.Ptr { fieldTE = fieldTE.Elem() } switch value := val.(type) { case string: switch fieldTE { case reflect.TypeOf(bson.ObjectId("")): // struct field is objectid newVal = bson.ObjectIdHex(value) case reflect.TypeOf(time.Now()): // struct field is time newVal, err = time.Parse(time.RFC3339, value) if err != nil { return nil, fmt.Errorf("date string from filter is not compatible with type %s (date must be RFC3339)", fieldT) } } case float64: switch fieldTE { case reflect.TypeOf(int(0)): // json is always float64 for numbers newVal = int(value) } case bson.RegEx: switch fieldTE { case reflect.TypeOf(""): // leave as is newVal = value } } if newVal == nil { valT := reflect.TypeOf(val) if fieldTE == valT { newVal = val } else if valT.Kind() == reflect.Map { // go into sub object and validate newVal, err = _validateFilter(fieldTE, val.(map[string]interface{}), true) if err != nil { return nil, err } } else { return nil, fmt.Errorf("type %s from filter is not compatible with type %s", valT, fieldT) } } return newVal, err } func _validateFilter(structT reflect.Type, f bson.M, checkValues bool) (bson.M, error) { newF := make(bson.M) for key, val := range f { if key == "" { return nil, fmt.Errorf("empty key in filter for model %+v", structT) } newKey := "" var newVal interface{} if key == "$text" { // dont modify or check search newKey = key newVal = val } else if strings.HasPrefix(key, "$") { // is filter function and no fieldname newKey = key valT := reflect.TypeOf(val) if valT.Kind() == reflect.Slice { l := len(val.([]interface{})) newVal = make([]interface{}, l) for i := 0; i < l; i++ { var err error switch valF := val.([]interface{})[i].(type) { case map[string]interface{}: // array element is object newVal.([]interface{})[i], err = _validateFilter(structT, valF, checkValues) if err != nil { return nil, err } default: // array element is value newVal.([]interface{})[i], err = _checkValue(structT, valF) if err != nil { return nil, fmt.Errorf("%s from %+v for field %s", err.Error(), structT, newKey) } } } } else { var err error switch valF := val.(type) { case map[string]interface{}: newVal, err = _validateFilter(structT, valF, checkValues) if err != nil { return nil, err } default: newVal, err = _checkValue(structT, val) if err != nil { return nil, fmt.Errorf("%s from %+v for field %s", err.Error(), structT, newKey) } } } } else { keys := strings.Split(key, ".") var _getBSONField func(k []string, t reflect.Type) error _getBSONField = func(k []string, t reflect.Type) error { structField, err := structFieldFromJSON(t, k[0]) if err != nil { return err } bsonField, err := bsonFieldFromStruct(t, structField) if err != nil { return err } newKey += bsonField fieldT, err := typeOfStructFieldByName(t, structField) if err != nil { return err } if len(k) > 1 { newKey += "." return _getBSONField(k[1:], fieldT) } // last element -> check value if checkValues { newVal, err = _checkValue(fieldT, val) if err != nil { return fmt.Errorf("%s from %+v for field %s", err.Error(), t, newKey) } } else { newVal = val } return nil } err := _getBSONField(keys, structT) if err != nil { return nil, err } } newF[newKey] = newVal // spew.Dump(key) // spew.Dump(val) // spew.Dump(newKey) } return newF, nil }