From b0af52178cac1f4df28629ce1ad1e9ad5a6b2c9d Mon Sep 17 00:00:00 2001 From: Sebastian Frank <sebastian@webmakers.de> Date: Wed, 9 Feb 2022 15:11:15 +0100 Subject: [PATCH] db via interfaces --- crud.go | 22 ++++--- lookup.go | 5 +- session.go | 159 +++++++++++++++++++++++++++++++++++++++++++++++++++ setup.go | 14 ++--- validator.go | 5 +- 5 files changed, 180 insertions(+), 25 deletions(-) create mode 100644 session.go diff --git a/crud.go b/crud.go index 1848df4..a7f7ac4 100644 --- a/crud.go +++ b/crud.go @@ -7,7 +7,6 @@ import ( "strings" "time" - mgo "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" ) @@ -28,7 +27,7 @@ func (e *ErrorWithStack) Stack() string { } // CreateDocument creates a document from specified model -func CreateDocument(db *mgo.Database, m ModelInterface) error { +func CreateDocument(db *Database, m ModelInterface) error { m.PrepareInsert() c := db.C(GetCollectionName(m)) @@ -38,7 +37,7 @@ func CreateDocument(db *mgo.Database, m ModelInterface) error { } // ReadDocument gets one document via its id -func ReadDocument(db *mgo.Database, m ModelInterface, selector bson.M) error { +func ReadDocument(db *Database, m ModelInterface, selector bson.M) error { c := db.C(GetCollectionName(m)) q := c.FindId(m.GetID()) @@ -137,7 +136,7 @@ func idToObjectID(filter interface{}) { } // ReadCollection gets the filtered collection of the model -func ReadCollection(db *mgo.Database, results interface{}, filter bson.M, selector bson.M, offset int, limit int, sort []string, pipelineModifier PipelineModifierFunction) (err error) { +func ReadCollection(db *Database, results interface{}, filter bson.M, selector bson.M, offset int, limit int, sort []string, pipelineModifier PipelineModifierFunction) (err error) { defer func() { if r := recover(); r != nil { err = &ErrorWithStack{Message: fmt.Sprintf("%v", r), StackTrace: string(debug.Stack())} @@ -212,8 +211,7 @@ func ReadCollection(db *mgo.Database, results interface{}, filter bson.M, select pipeline = pipelineModifier(pipeline) } - q := c.Pipe(pipeline).AllowDiskUse().Iter() - _err = q.All(results) + _err = c.Pipe(pipeline).All(results) } else { // search without pipe is faster @@ -245,7 +243,7 @@ func ReadCollection(db *mgo.Database, results interface{}, filter bson.M, select } // ReadCollectionCount gets the count of elements in filtered collection -func ReadCollectionCount(db *mgo.Database, m ModelInterface, filter bson.M) (count int, err error) { +func ReadCollectionCount(db *Database, m ModelInterface, filter bson.M) (count int, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("%v", r) @@ -257,7 +255,7 @@ func ReadCollectionCount(db *mgo.Database, m ModelInterface, filter bson.M) (cou } // UpdateDocument updates a document from specified model -func UpdateDocument(db *mgo.Database, m ModelInterface, changes bson.M) error { +func UpdateDocument(db *Database, m ModelInterface, changes bson.M) error { m.PrepareUpdate() changes["updateTime"] = time.Now() @@ -268,7 +266,7 @@ func UpdateDocument(db *mgo.Database, m ModelInterface, changes bson.M) error { } // UpsertDocument updates a document from specified model or inserts it, of not found -func UpsertDocument(db *mgo.Database, m ModelInterface, changes bson.M) error { +func UpsertDocument(db *Database, m ModelInterface, changes bson.M) error { m.PrepareUpdate() changes["updateTime"] = time.Now() @@ -279,7 +277,7 @@ func UpsertDocument(db *mgo.Database, m ModelInterface, changes bson.M) error { } // DeleteDocument deletes one document via its id -func DeleteDocument(db *mgo.Database, m ModelInterface) error { +func DeleteDocument(db *Database, m ModelInterface) error { c := db.C(GetCollectionName(m)) err := c.RemoveId(m.GetID()) @@ -288,12 +286,12 @@ func DeleteDocument(db *mgo.Database, m ModelInterface) error { } // DeleteDocuments deletes documents found by filter -func DeleteDocuments(db *mgo.Database, m ModelInterface, filter bson.M) (removed int, err error) { +func DeleteDocuments(db *Database, m ModelInterface, filter bson.M) (removed int, err error) { c := db.C(GetCollectionName(m)) info, err := c.RemoveAll(filter) if info != nil { - removed = info.Removed + removed = info.Removed() } return removed, err diff --git a/lookup.go b/lookup.go index dfc5f4d..a118993 100644 --- a/lookup.go +++ b/lookup.go @@ -4,12 +4,11 @@ import ( "fmt" "reflect" - mgo "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" ) // Lookup extends results with data for inline structs -func Lookup(db *mgo.Database, structField string, results interface{}, selector bson.M) error { +func Lookup(db *Database, structField string, results interface{}, selector bson.M) error { t := reflect.TypeOf(results) v := reflect.ValueOf(results) @@ -94,7 +93,7 @@ func Lookup(db *mgo.Database, structField string, results interface{}, selector // no entries to map return nil } - sArr := make([]bson.M, lArr, lArr) + sArr := make([]bson.M, lArr) aI := 0 for sID := range objectIDs { sArr[aI] = bson.M{ diff --git a/session.go b/session.go new file mode 100644 index 0000000..cf268a1 --- /dev/null +++ b/session.go @@ -0,0 +1,159 @@ +package mgocrud + +import ( + "errors" + + mgo "gopkg.in/mgo.v2" +) + +type Session struct { + session *mgo.Session +} + +func (s *Session) Close() { + s.session.Close() +} + +func NewSession(dial string) (*Session, error) { + session, err := mgo.Dial(dial) + if err != nil { + return nil, err + } + session.SetMode(mgo.Monotonic, true) + return &Session{session: session}, nil +} + +type Database struct { + database *mgo.Database +} + +func (s *Session) DB(name string) *Database { + return &Database{database: s.session.DB(name)} +} + +type Collection struct { + collection *mgo.Collection +} + +func (db *Database) C(name string) *Collection { + return &Collection{collection: db.database.C(name)} +} + +func (db *Database) Name() string { + return db.database.Name +} + +func (c *Collection) Insert(docs ...interface{}) error { + return c.collection.Insert(docs...) +} + +func (c *Collection) UpdateId(id interface{}, update interface{}) error { + return c.collection.UpdateId(id, update) +} + +func (c *Collection) RemoveId(id interface{}) error { + return c.collection.RemoveId(id) +} + +type ChangeInfo struct { + changeInfo *mgo.ChangeInfo +} + +func (ci *ChangeInfo) Matched() int { + return ci.changeInfo.Matched +} + +func (ci *ChangeInfo) Removed() int { + return ci.changeInfo.Removed +} + +func (ci *ChangeInfo) Updated() int { + return ci.changeInfo.Updated +} + +func (c *Collection) Upsert(selector interface{}, update interface{}) (*ChangeInfo, error) { + ci, err := c.collection.Upsert(selector, update) + if err != nil { + return nil, err + } + return &ChangeInfo{changeInfo: ci}, nil +} + +func (c *Collection) RemoveAll(filter interface{}) (*ChangeInfo, error) { + ci, err := c.collection.RemoveAll(filter) + if err != nil { + return nil, err + } + return &ChangeInfo{changeInfo: ci}, nil +} + +type Query struct { + query *mgo.Query +} + +func (c *Collection) FindId(id interface{}) *Query { + return &Query{query: c.collection.FindId(id)} +} + +func (c *Collection) Find(query interface{}) *Query { + return &Query{query: c.collection.Find(query)} +} + +func (q *Query) Select(selector interface{}) *Query { + q.query = q.query.Select(selector) + return q +} + +func (q *Query) One(result interface{}) error { + return q.query.One(result) +} + +func (q *Query) Sort(fields ...string) *Query { + q.query = q.query.Sort(fields...) + return q +} + +func (q *Query) Skip(n int) *Query { + q.query = q.query.Skip(n) + return q +} + +func (q *Query) Limit(n int) *Query { + q.query = q.query.Limit(n) + return q +} + +func (q *Query) All(result interface{}) error { + return q.query.All(result) +} + +func (q *Query) Count() (int, error) { + return q.query.Count() +} + +type Index struct { + index *mgo.Index +} + +func NewMgoIndex(index mgo.Index) *Index { + return &Index{index: &index} +} + +func (c *Collection) EnsureIndex(index *Index) error { + if index != nil && index.index != nil { + return c.collection.EnsureIndex(*index.index) + } + return errors.New("index parameter not initialized with mgo.Index") +} + +type Pipe struct { + pipe *mgo.Pipe +} + +func (p *Pipe) All(result interface{}) error { + return p.pipe.All(result) +} + +func (c *Collection) Pipe(pipeline interface{}) *Pipe { + return &Pipe{pipe: c.collection.Pipe(pipeline).AllowDiskUse()} +} diff --git a/setup.go b/setup.go index cb6f634..878c8c9 100644 --- a/setup.go +++ b/setup.go @@ -9,7 +9,7 @@ import ( ) // EnsureIndex ensured mongodb index reflecting model struct index tag -func EnsureIndex(db *mgo.Database, m ModelInterface) error { +func EnsureIndex(db *Database, m ModelInterface) error { colName := GetCollectionName(m) col := db.C(colName) @@ -57,14 +57,14 @@ func EnsureIndex(db *mgo.Database, m ModelInterface) error { case indexEl == "text": textFields = append(textFields, "$text:"+fieldbase+bsonField) default: - return fmt.Errorf("invalid index tag on collection %s.%s for field %s%s in model %+v", db.Name, colName, fieldbase, bsonField, t) + return fmt.Errorf("invalid index tag on collection %s.%s for field %s%s in model %+v", db.Name(), colName, fieldbase, bsonField, t) } } if len(index.Key) > 0 { // fmt.Println(bsonField, index) - fmt.Printf("ensure index on collection %s.%s for field %s%s\n", db.Name, colName, fieldbase, bsonField) - err := col.EnsureIndex(index) + fmt.Printf("ensure index on collection %s.%s for field %s%s\n", db.Name(), colName, fieldbase, bsonField) + err := col.EnsureIndex(NewMgoIndex(index)) if err != nil { return err } @@ -93,13 +93,13 @@ func EnsureIndex(db *mgo.Database, m ModelInterface) error { if len(textFields) > 0 { // fmt.Println("$text", textFields) - fmt.Printf("ensure text index on collection %s.%s for fields %v\n", db.Name, GetCollectionName(m), textFields) - err := col.EnsureIndex(mgo.Index{ + fmt.Printf("ensure text index on collection %s.%s for fields %v\n", db.Name(), GetCollectionName(m), textFields) + err := col.EnsureIndex(NewMgoIndex(mgo.Index{ Name: "textindex", Key: textFields, DefaultLanguage: "german", Background: false, - }) + })) if err != nil { return err } diff --git a/validator.go b/validator.go index 174cc7c..e9f321f 100644 --- a/validator.go +++ b/validator.go @@ -2,12 +2,11 @@ package mgocrud import ( validator "gopkg.in/go-playground/validator.v8" - mgo "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" ) // ValidateObject validates object via validator tag and custom method -func ValidateObject(db *mgo.Database, m ModelInterface, changes bson.M) error { +func ValidateObject(db *Database, m ModelInterface, changes bson.M) error { // first validate via struct tag validator := validator.New(&validator.Config{ TagName: "validator", @@ -19,7 +18,7 @@ func ValidateObject(db *mgo.Database, m ModelInterface, changes bson.M) error { // next execute custom model validator if exists if i, ok := m.(interface { - Validate(db *mgo.Database, changes bson.M) error + Validate(db *Database, changes bson.M) error }); ok { return i.Validate(db, changes) }