mgocrud/crud.go

299 lines
6.5 KiB
Go

package mgocrud
import (
"fmt"
"reflect"
"runtime/debug"
"strings"
"time"
"gopkg.in/mgo.v2/bson"
)
// ErrorWithStack is error which stored its stack trace
type ErrorWithStack struct {
Message string
StackTrace string
}
// Error returns error message
func (e *ErrorWithStack) Error() string {
return e.Message
}
// Stack returns strack trace to error origin
func (e *ErrorWithStack) Stack() string {
return e.StackTrace
}
// CreateDocument creates a document from specified model
func CreateDocument(db *Database, m ModelInterface) error {
m.PrepareInsert()
c := db.C(GetCollectionName(m))
err := c.Insert(m)
return err
}
// ReadDocument gets one document via its id
func ReadDocument(db *Database, m ModelInterface, selector bson.M) error {
c := db.C(GetCollectionName(m))
q := c.FindId(m.GetID())
if selector != nil {
q = q.Select(selector)
}
err := q.One(m)
return err
}
// PipelineModifierFunction is a function to modify mongodb query
type PipelineModifierFunction func(pipeline []bson.M) []bson.M
func convertIDValue(v reflect.Value) reflect.Value {
v = reflect.ValueOf(v.Interface())
vKind := v.Kind() // mapVal.Kind() does not work, no idea why ;(
//spew.Dump(v.Interface())
//spew.Dump(vKind)
switch vKind {
case reflect.String:
if hex, ok := v.Interface().(string); ok {
return reflect.ValueOf(bson.ObjectIdHex(hex))
}
case reflect.Map:
for _, key := range v.MapKeys() {
v.SetMapIndex(key, convertIDValue(v.MapIndex(key)))
}
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
v.Index(i).Set(convertIDValue(v.Index(i)))
}
}
return v
}
func idToObjectID(filter interface{}) {
if filter == nil {
return
}
//spew.Dump(filter)
val := reflect.ValueOf(filter)
switch reflect.TypeOf(filter).Kind() {
case reflect.Slice:
for i := 0; i < val.Len(); i++ {
idToObjectID(val.Index(i).Interface())
}
case reflect.Ptr:
idToObjectID(reflect.Indirect(val).Interface())
case reflect.Map:
for _, key := range val.MapKeys() {
if k, ok := key.Interface().(string); ok {
mapVal := val.MapIndex(key)
if mapVal.Type().Kind() == reflect.Ptr {
mapVal = reflect.Indirect(mapVal)
}
if k == "_id" {
val.SetMapIndex(key, convertIDValue(mapVal))
} else {
idToObjectID(mapVal.Interface())
}
}
}
}
/*
for key, data := range filter {
if key == "_id" {
if hex, ok := data.(string); ok {
filter[key] = bson.ObjectIdHex(hex)
}
} else {
switch d := data.(type) {
case map[string]interface{}:
idToObjectID(d)
case bson.M:
idToObjectID(d)
case []map[string]interface{}:
for _, s := range d {
idToObjectID(s)
}
case []bson.M:
for _, s := range d {
idToObjectID(s)
}
case []interface{}:
for _, s := range d {
isToObjectID(s)
}
}
}
}
*/
}
// ReadCollection gets the filtered collection of the model
func ReadCollection(db *Database, results interface{}, filter bson.M, selector bson.M, offset int, limit int, sort []string, pipelineModifier PipelineModifierFunction) (err error) {
defer func() {
if r := recover(); r != nil {
err = &ErrorWithStack{Message: fmt.Sprintf("%v", r), StackTrace: string(debug.Stack())}
}
}()
mSlice := reflect.Indirect(reflect.ValueOf(results)) // results is *[]ModelInterface, so we need to indirect
var m ModelInterface
if mSlice.Len() > 0 {
// use existing first element
m = mSlice.Index(0).Interface().(ModelInterface)
} else {
// create new element to get collection name
m = reflect.New(reflect.TypeOf(mSlice.Interface()).Elem()).Interface().(ModelInterface)
}
/*
// get pointer to model (element of slice in results) to get collection name
m := reflect.New(reflect.TypeOf(
reflect.Indirect(reflect.ValueOf(results)).Interface(), // get indirection of slice pointer
).Elem()).Interface().(ModelInterface) // it must be a ModelInterface here
*/
c := db.C(GetCollectionName(m))
var _err error
if pipelineModifier != nil {
// search via pipeline
pipeline := []bson.M{}
if filter != nil {
pipeline = append(pipeline, bson.M{
"$match": filter,
})
}
if len(sort) > 0 {
sortM := bson.M{}
for _, s := range sort {
if strings.HasPrefix(s, "-") {
s = s[1:]
sortM[s] = -1
} else {
sortM[s] = 1
}
}
// spew.Dump(sortM)
pipeline = append(pipeline, bson.M{
"$sort": sortM,
})
}
if offset > 0 {
pipeline = append(pipeline, bson.M{
"$skip": offset,
})
}
if limit > 0 {
pipeline = append(pipeline, bson.M{
"$limit": limit,
})
}
if selector != nil {
pipeline = append(pipeline, bson.M{
"$project": selector,
})
}
if pipelineModifier != nil {
pipeline = pipelineModifier(pipeline)
}
_err = c.Pipe(pipeline).All(results)
} else {
// search without pipe is faster
idToObjectID(filter)
// spew.Dump(filter)
q := c.Find(filter)
if selector != nil {
q = q.Select(selector)
}
if len(sort) > 0 {
q = q.Sort(sort...)
}
if offset > 0 {
q = q.Skip(offset)
}
if limit > 0 {
q = q.Limit(limit)
}
_err = q.All(results)
}
if _err != nil {
return &ErrorWithStack{Message: _err.Error(), StackTrace: string(debug.Stack())}
}
return nil
}
// ReadCollectionCount gets the count of elements in filtered collection
func ReadCollectionCount(db *Database, m ModelInterface, filter bson.M) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
}
}()
c := db.C(GetCollectionName(m))
return c.Find(filter).Count()
}
// UpdateDocument updates a document from specified model
func UpdateDocument(db *Database, m ModelInterface, changes bson.M) error {
m.PrepareUpdate()
changes["updateTime"] = time.Now()
c := db.C(GetCollectionName(m))
err := c.UpdateId(m.GetID(), bson.M{"$set": changes})
return err
}
// UpsertDocument updates a document from specified model or inserts it, of not found
func UpsertDocument(db *Database, m ModelInterface, changes bson.M) error {
m.PrepareUpdate()
changes["updateTime"] = time.Now()
c := db.C(GetCollectionName(m))
_, err := c.Upsert(m, bson.M{"$set": changes})
return err
}
// DeleteDocument deletes one document via its id
func DeleteDocument(db *Database, m ModelInterface) error {
c := db.C(GetCollectionName(m))
err := c.RemoveId(m.GetID())
return err
}
// DeleteDocuments deletes documents found by filter
func DeleteDocuments(db *Database, m ModelInterface, filter bson.M) (removed int, err error) {
c := db.C(GetCollectionName(m))
info, err := c.RemoveAll(filter)
if info != nil {
removed = info.Removed()
}
return removed, err
}