package mgocrud import ( "fmt" "reflect" "runtime/debug" "strings" "time" mgo "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" ) // ErrorWithStack is error which stored its stack trace type ErrorWithStack struct { Message string StackTrace string } // Error returns error message func (e *ErrorWithStack) Error() string { return e.Message } // Stack returns strack trace to error origin func (e *ErrorWithStack) Stack() string { return e.StackTrace } // CreateDocument creates a document from specified model func CreateDocument(db *mgo.Database, m ModelInterface) error { m.PrepareInsert() c := db.C(GetCollectionName(m)) err := c.Insert(m) return err } // ReadDocument gets one document via its id func ReadDocument(db *mgo.Database, m ModelInterface, selector bson.M) error { c := db.C(GetCollectionName(m)) q := c.FindId(m.GetID()) if selector != nil { q = q.Select(selector) } err := q.One(m) return err } // PipelineModifierFunction is a function to modify mongodb query type PipelineModifierFunction func(pipeline []bson.M) []bson.M func convertIDValue(v reflect.Value) reflect.Value { v = reflect.ValueOf(v.Interface()) vKind := v.Kind() // mapVal.Kind() does not work, no idea why ;( //spew.Dump(v.Interface()) //spew.Dump(vKind) switch vKind { case reflect.String: if hex, ok := v.Interface().(string); ok { return reflect.ValueOf(bson.ObjectIdHex(hex)) } case reflect.Map: for _, key := range v.MapKeys() { v.SetMapIndex(key, convertIDValue(v.MapIndex(key))) } case reflect.Slice: for i := 0; i < v.Len(); i++ { v.Index(i).Set(convertIDValue(v.Index(i))) } } return v } func idToObjectID(filter interface{}) { //spew.Dump(filter) val := reflect.ValueOf(filter) switch reflect.TypeOf(filter).Kind() { case reflect.Slice: for i := 0; i < val.Len(); i++ { idToObjectID(val.Index(i).Interface()) } case reflect.Ptr: idToObjectID(reflect.Indirect(val).Interface()) case reflect.Map: for _, key := range val.MapKeys() { if k, ok := key.Interface().(string); ok { mapVal := val.MapIndex(key) if mapVal.Type().Kind() == reflect.Ptr { mapVal = reflect.Indirect(mapVal) } if k == "_id" { val.SetMapIndex(key, convertIDValue(mapVal)) } else { idToObjectID(mapVal.Interface()) } } } } /* for key, data := range filter { if key == "_id" { if hex, ok := data.(string); ok { filter[key] = bson.ObjectIdHex(hex) } } else { switch d := data.(type) { case map[string]interface{}: idToObjectID(d) case bson.M: idToObjectID(d) case []map[string]interface{}: for _, s := range d { idToObjectID(s) } case []bson.M: for _, s := range d { idToObjectID(s) } case []interface{}: for _, s := range d { isToObjectID(s) } } } } */ } // ReadCollection gets the filtered collection of the model func ReadCollection(db *mgo.Database, results interface{}, filter bson.M, selector bson.M, offset int, limit int, sort []string, pipelineModifier PipelineModifierFunction) (err error) { defer func() { if r := recover(); r != nil { err = &ErrorWithStack{Message: fmt.Sprintf("%v", r), StackTrace: string(debug.Stack())} } }() mSlice := reflect.Indirect(reflect.ValueOf(results)) // results is *[]ModelInterface, so we need to indirect var m ModelInterface if mSlice.Len() > 0 { // use existing first element m = mSlice.Index(0).Interface().(ModelInterface) } else { // create new element to get collection name m = reflect.New(reflect.TypeOf(mSlice.Interface()).Elem()).Interface().(ModelInterface) } /* // get pointer to model (element of slice in results) to get collection name m := reflect.New(reflect.TypeOf( reflect.Indirect(reflect.ValueOf(results)).Interface(), // get indirection of slice pointer ).Elem()).Interface().(ModelInterface) // it must be a ModelInterface here */ c := db.C(GetCollectionName(m)) var _err error if pipelineModifier != nil { // search via pipeline pipeline := []bson.M{} if filter != nil { pipeline = append(pipeline, bson.M{ "$match": filter, }) } if len(sort) > 0 { sortM := bson.M{} for _, s := range sort { if strings.HasPrefix(s, "-") { s = s[1:] sortM[s] = -1 } else { sortM[s] = 1 } } // spew.Dump(sortM) pipeline = append(pipeline, bson.M{ "$sort": sortM, }) } if offset > 0 { pipeline = append(pipeline, bson.M{ "$skip": offset, }) } if limit > 0 { pipeline = append(pipeline, bson.M{ "$limit": limit, }) } if selector != nil { pipeline = append(pipeline, bson.M{ "$project": selector, }) } if pipelineModifier != nil { pipeline = pipelineModifier(pipeline) } q := c.Pipe(pipeline).AllowDiskUse().Iter() _err = q.All(results) } else { // search without pipe is faster idToObjectID(filter) // spew.Dump(filter) q := c.Find(filter) if selector != nil { q = q.Select(selector) } if len(sort) > 0 { q = q.Sort(sort...) } if offset > 0 { q = q.Skip(offset) } if limit > 0 { q = q.Limit(limit) } _err = q.All(results) } if _err != nil { return &ErrorWithStack{Message: _err.Error(), StackTrace: string(debug.Stack())} } return nil } // ReadCollectionCount gets the count of elements in filtered collection func ReadCollectionCount(db *mgo.Database, m ModelInterface, filter bson.M) (count int, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("%v", r) } }() c := db.C(GetCollectionName(m)) return c.Find(filter).Count() } // UpdateDocument updates a document from specified model func UpdateDocument(db *mgo.Database, m ModelInterface, changes bson.M) error { m.PrepareUpdate() changes["updateTime"] = time.Now() c := db.C(GetCollectionName(m)) err := c.UpdateId(m.GetID(), bson.M{"$set": changes}) return err } // UpsertDocument updates a document from specified model or inserts it, of not found func UpsertDocument(db *mgo.Database, m ModelInterface, changes bson.M) error { m.PrepareUpdate() changes["updateTime"] = time.Now() c := db.C(GetCollectionName(m)) _, err := c.Upsert(m, bson.M{"$set": changes}) return err } // DeleteDocument deletes one document via its id func DeleteDocument(db *mgo.Database, m ModelInterface) error { c := db.C(GetCollectionName(m)) err := c.RemoveId(m.GetID()) return err } // DeleteDocuments deletes documents found by filter func DeleteDocuments(db *mgo.Database, m ModelInterface, filter bson.M) (removed int, err error) { c := db.C(GetCollectionName(m)) info, err := c.RemoveAll(filter) if info != nil { removed = info.Removed } return removed, err }