diff --git a/crud.go b/crud.go index 67fb175..6946e89 100644 --- a/crud.go +++ b/crud.go @@ -27,7 +27,7 @@ func (e *ErrorWithStack) Stack() string { } // CreateDocument creates a document from specified model -func (db *Database) CreateDocument(m ModelInterface) error { +func (db *MgoDatabase) CreateDocument(m ModelInterface) error { m.PrepareInsert() c := db.C(GetCollectionName(m)) @@ -37,7 +37,7 @@ func (db *Database) CreateDocument(m ModelInterface) error { } // ReadDocument gets one document via its id -func (db *Database) ReadDocument(m ModelInterface, selector bson.M) error { +func (db *MgoDatabase) ReadDocument(m ModelInterface, selector bson.M) error { c := db.C(GetCollectionName(m)) q := c.FindId(m.GetID()) @@ -136,7 +136,7 @@ func idToObjectID(filter interface{}) { } // ReadCollection gets the filtered collection of the model -func (db *Database) ReadCollection(results interface{}, filter bson.M, selector bson.M, offset int, limit int, sort []string, pipelineModifier PipelineModifierFunction) (err error) { +func (db *MgoDatabase) ReadCollection(results interface{}, filter bson.M, selector bson.M, offset int, limit int, sort []string, pipelineModifier PipelineModifierFunction) (err error) { defer func() { if r := recover(); r != nil { err = &ErrorWithStack{Message: fmt.Sprintf("%v", r), StackTrace: string(debug.Stack())} @@ -243,7 +243,7 @@ func (db *Database) ReadCollection(results interface{}, filter bson.M, selector } // ReadCollectionCount gets the count of elements in filtered collection -func (db *Database) ReadCollectionCount(m ModelInterface, filter bson.M) (count int, err error) { +func (db *MgoDatabase) ReadCollectionCount(m ModelInterface, filter bson.M) (count int, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("%v", r) @@ -255,7 +255,7 @@ func (db *Database) ReadCollectionCount(m ModelInterface, filter bson.M) (count } // UpdateDocument updates a document from specified model -func (db *Database) UpdateDocument(m ModelInterface, changes bson.M) error { +func (db *MgoDatabase) UpdateDocument(m ModelInterface, changes bson.M) error { m.PrepareUpdate() changes["updateTime"] = time.Now() @@ -266,7 +266,7 @@ func (db *Database) UpdateDocument(m ModelInterface, changes bson.M) error { } // UpsertDocument updates a document from specified model or inserts it, of not found -func (db *Database) UpsertDocument(m ModelInterface, changes bson.M) error { +func (db *MgoDatabase) UpsertDocument(m ModelInterface, changes bson.M) error { m.PrepareUpdate() changes["updateTime"] = time.Now() @@ -277,7 +277,7 @@ func (db *Database) UpsertDocument(m ModelInterface, changes bson.M) error { } // DeleteDocument deletes one document via its id -func (db *Database) DeleteDocument(m ModelInterface) error { +func (db *MgoDatabase) DeleteDocument(m ModelInterface) error { c := db.C(GetCollectionName(m)) err := c.RemoveId(m.GetID()) @@ -286,7 +286,7 @@ func (db *Database) DeleteDocument(m ModelInterface) error { } // DeleteDocuments deletes documents found by filter -func (db *Database) DeleteDocuments(m ModelInterface, filter bson.M) (removed int, err error) { +func (db *MgoDatabase) DeleteDocuments(m ModelInterface, filter bson.M) (removed int, err error) { c := db.C(GetCollectionName(m)) info, err := c.RemoveAll(filter) diff --git a/lookup.go b/lookup.go index e1fed59..38bf785 100644 --- a/lookup.go +++ b/lookup.go @@ -8,7 +8,7 @@ import ( ) // Lookup extends results with data for inline structs -func (db *Database) Lookup(structField string, results interface{}, selector bson.M) error { +func (db *MgoDatabase) Lookup(structField string, results interface{}, selector bson.M) error { t := reflect.TypeOf(results) v := reflect.ValueOf(results) diff --git a/mgo.go b/mgo.go new file mode 100644 index 0000000..7a4d703 --- /dev/null +++ b/mgo.go @@ -0,0 +1,194 @@ +package mgocrud + +import ( + "errors" + + mgo "gopkg.in/mgo.v2" +) + +type MgoConnection struct { + connection *mgo.Session +} + +func NewMgoConnection(dial string) (Connection, error) { + connection, err := mgo.Dial(dial) + if err != nil { + return nil, err + } + connection.SetMode(mgo.Monotonic, true) + return &MgoConnection{connection: connection}, nil +} + +func (c *MgoConnection) Close() { + c.connection.Close() +} + +type MgoSession struct { + session *mgo.Session +} + +func (s *MgoSession) Close() { + s.session.Close() +} + +func (c *MgoConnection) NewSession() Session { + return &MgoSession{session: c.connection.Copy()} +} + +type MgoDatabase struct { + database *mgo.Database + session Session +} + +func (s *MgoSession) DB(name string) Database { + return &MgoDatabase{database: s.session.DB(name), session: s} +} + +type MgoCollection struct { + collection *mgo.Collection +} + +func (db *MgoDatabase) Session() Session { + return db.session +} + +func (db *MgoDatabase) C(name string) Collection { + return &MgoCollection{collection: db.database.C(name)} +} + +func (db *MgoDatabase) Name() string { + return db.database.Name +} + +func (c *MgoCollection) Insert(docs ...interface{}) error { + return c.collection.Insert(docs...) +} + +func (c *MgoCollection) UpdateId(id interface{}, update interface{}) error { + return c.collection.UpdateId(id, update) +} + +func (c *MgoCollection) RemoveId(id interface{}) error { + return c.collection.RemoveId(id) +} + +type MgoChangeInfo struct { + changeInfo *mgo.ChangeInfo +} + +func (ci *MgoChangeInfo) Matched() int { + return ci.changeInfo.Matched +} + +func (ci *MgoChangeInfo) Removed() int { + return ci.changeInfo.Removed +} + +func (ci *MgoChangeInfo) Updated() int { + return ci.changeInfo.Updated +} + +func (c *MgoCollection) Upsert(selector interface{}, update interface{}) (ChangeInfo, error) { + ci, err := c.collection.Upsert(selector, update) + if err != nil { + return nil, err + } + return &MgoChangeInfo{changeInfo: ci}, nil +} + +func (c *MgoCollection) RemoveAll(filter interface{}) (ChangeInfo, error) { + ci, err := c.collection.RemoveAll(filter) + if err != nil { + return nil, err + } + return &MgoChangeInfo{changeInfo: ci}, nil +} + +type MgoQuery struct { + query *mgo.Query +} + +func (c *MgoCollection) FindId(id interface{}) Query { + return &MgoQuery{query: c.collection.FindId(id)} +} + +func (c *MgoCollection) Find(query interface{}) Query { + return &MgoQuery{query: c.collection.Find(query)} +} + +func (q *MgoQuery) Select(selector interface{}) Query { + q.query = q.query.Select(selector) + return q +} + +func (q *MgoQuery) One(result interface{}) error { + err := q.query.One(result) + if err == mgo.ErrNotFound { + err = ErrNotFound + } + return err +} + +func (q *MgoQuery) Sort(fields ...string) Query { + q.query = q.query.Sort(fields...) + return q +} + +func (q *MgoQuery) Skip(n int) Query { + q.query = q.query.Skip(n) + return q +} + +func (q *MgoQuery) Limit(n int) Query { + q.query = q.query.Limit(n) + return q +} + +func (q *MgoQuery) All(result interface{}) error { + err := q.query.All(result) + if err == mgo.ErrNotFound { + err = ErrNotFound + } + return err +} + +func (q *MgoQuery) Count() (int, error) { + c, err := q.query.Count() + if err == mgo.ErrNotFound { + err = ErrNotFound + } + return c, err +} + +type MgoIndex struct { + index *mgo.Index +} + +func NewMgoIndex(index mgo.Index) Index { + return &MgoIndex{index: &index} +} + +func (c *MgoCollection) EnsureIndex(index Index) error { + if i, ok := index.(*MgoIndex); ok { + if i != nil && i.index != nil { + return c.collection.EnsureIndex(*i.index) + } + } + return errors.New("index parameter not initialized with mgo.Index") +} + +type MgoPipe struct { + pipe *mgo.Pipe +} + +func (p *MgoPipe) All(result interface{}) error { + err := p.pipe.All(result) + if err == mgo.ErrNotFound { + err = ErrNotFound + } + return err +} + +func (c *MgoCollection) Pipe(pipeline interface{}) Pipe { + return &MgoPipe{pipe: c.collection.Pipe(pipeline).AllowDiskUse()} +} diff --git a/session.go b/session.go deleted file mode 100644 index 3675db2..0000000 --- a/session.go +++ /dev/null @@ -1,196 +0,0 @@ -package mgocrud - -import ( - "errors" - - mgo "gopkg.in/mgo.v2" -) - -var ( - ErrNotFound = errors.New("not found") -) - -type Connection struct { - connection *mgo.Session -} - -func NewConnection(dial string) (*Connection, error) { - connection, err := mgo.Dial(dial) - if err != nil { - return nil, err - } - connection.SetMode(mgo.Monotonic, true) - return &Connection{connection: connection}, nil -} - -func (c *Connection) Close() { - c.connection.Close() -} - -type Session struct { - session *mgo.Session -} - -func (s *Session) Close() { - s.session.Close() -} - -func (c *Connection) NewSession() (*Session, error) { - return &Session{session: c.connection.Copy()}, nil -} - -type Database struct { - database *mgo.Database - session *Session -} - -func (s *Session) DB(name string) *Database { - return &Database{database: s.session.DB(name), session: s} -} - -type Collection struct { - collection *mgo.Collection -} - -func (db *Database) Session() *Session { - return db.session -} - -func (db *Database) C(name string) *Collection { - return &Collection{collection: db.database.C(name)} -} - -func (db *Database) Name() string { - return db.database.Name -} - -func (c *Collection) Insert(docs ...interface{}) error { - return c.collection.Insert(docs...) -} - -func (c *Collection) UpdateId(id interface{}, update interface{}) error { - return c.collection.UpdateId(id, update) -} - -func (c *Collection) RemoveId(id interface{}) error { - return c.collection.RemoveId(id) -} - -type ChangeInfo struct { - changeInfo *mgo.ChangeInfo -} - -func (ci *ChangeInfo) Matched() int { - return ci.changeInfo.Matched -} - -func (ci *ChangeInfo) Removed() int { - return ci.changeInfo.Removed -} - -func (ci *ChangeInfo) Updated() int { - return ci.changeInfo.Updated -} - -func (c *Collection) Upsert(selector interface{}, update interface{}) (*ChangeInfo, error) { - ci, err := c.collection.Upsert(selector, update) - if err != nil { - return nil, err - } - return &ChangeInfo{changeInfo: ci}, nil -} - -func (c *Collection) RemoveAll(filter interface{}) (*ChangeInfo, error) { - ci, err := c.collection.RemoveAll(filter) - if err != nil { - return nil, err - } - return &ChangeInfo{changeInfo: ci}, nil -} - -type Query struct { - query *mgo.Query -} - -func (c *Collection) FindId(id interface{}) *Query { - return &Query{query: c.collection.FindId(id)} -} - -func (c *Collection) Find(query interface{}) *Query { - return &Query{query: c.collection.Find(query)} -} - -func (q *Query) Select(selector interface{}) *Query { - q.query = q.query.Select(selector) - return q -} - -func (q *Query) One(result interface{}) error { - err := q.query.One(result) - if err == mgo.ErrNotFound { - err = ErrNotFound - } - return err -} - -func (q *Query) Sort(fields ...string) *Query { - q.query = q.query.Sort(fields...) - return q -} - -func (q *Query) Skip(n int) *Query { - q.query = q.query.Skip(n) - return q -} - -func (q *Query) Limit(n int) *Query { - q.query = q.query.Limit(n) - return q -} - -func (q *Query) All(result interface{}) error { - err := q.query.All(result) - if err == mgo.ErrNotFound { - err = ErrNotFound - } - return err -} - -func (q *Query) Count() (int, error) { - c, err := q.query.Count() - if err == mgo.ErrNotFound { - err = ErrNotFound - } - return c, err -} - -type Index struct { - index *mgo.Index -} - -func NewMgoIndex(index mgo.Index) *Index { - return &Index{index: &index} -} - -func (c *Collection) EnsureIndex(index *Index) error { - if index != nil && index.index != nil { - return c.collection.EnsureIndex(*index.index) - } - return errors.New("index parameter not initialized with mgo.Index") -} - -type Pipe struct { - pipe *mgo.Pipe -} - -func (p *Pipe) All(result interface{}) error { - err := p.pipe.All(result) - if err == mgo.ErrNotFound { - err = ErrNotFound - } - return err -} - -func (c *Collection) Pipe(pipeline interface{}) *Pipe { - return &Pipe{pipe: c.collection.Pipe(pipeline).AllowDiskUse()} -} diff --git a/setup.go b/setup.go index 1df161a..662fa1a 100644 --- a/setup.go +++ b/setup.go @@ -9,7 +9,7 @@ import ( ) // EnsureIndex ensured mongodb index reflecting model struct index tag -func (db *Database) EnsureIndex(m ModelInterface) error { +func (db *MgoDatabase) EnsureIndex(m ModelInterface) error { colName := GetCollectionName(m) col := db.C(colName) diff --git a/types.go b/types.go new file mode 100644 index 0000000..7b6a8b2 --- /dev/null +++ b/types.go @@ -0,0 +1,71 @@ +package mgocrud + +import ( + "errors" + + "gopkg.in/mgo.v2/bson" +) + +var ( + ErrNotFound = errors.New("not found") +) + +type Connection interface { + Close() + NewSession() Session +} + +type Session interface { + Close() + DB(name string) Database +} + +type Database interface { + Session() Session + C(name string) Collection + Name() string + EnsureIndex(m ModelInterface) error + ValidateObject(m ModelInterface, changes bson.M) error + ReadDocument(m ModelInterface, selector bson.M) error + CreateDocument(m ModelInterface) error + ReadCollection(results interface{}, filter bson.M, selector bson.M, offset int, limit int, sort []string, pipelineModifier PipelineModifierFunction) error + ReadCollectionCount(m ModelInterface, filter bson.M) (count int, err error) + UpdateDocument(m ModelInterface, changes bson.M) error + UpsertDocument(m ModelInterface, changes bson.M) error + DeleteDocument(m ModelInterface) error + DeleteDocuments(m ModelInterface, filter bson.M) (removed int, err error) +} + +type Collection interface { + Insert(docs ...interface{}) error + UpdateId(id interface{}, update interface{}) error + RemoveId(id interface{}) error + Upsert(selector interface{}, update interface{}) (ChangeInfo, error) + RemoveAll(filter interface{}) (ChangeInfo, error) + FindId(id interface{}) Query + Find(query interface{}) Query + EnsureIndex(index Index) error + Pipe(pipeline interface{}) Pipe +} + +type ChangeInfo interface { + Matched() int + Removed() int + Updated() int +} + +type Query interface { + Select(selector interface{}) Query + One(result interface{}) error + Sort(fields ...string) Query + Skip(n int) Query + Limit(n int) Query + All(result interface{}) error + Count() (int, error) +} + +type Index interface{} + +type Pipe interface { + All(result interface{}) error +} diff --git a/validator.go b/validator.go index 1ac7e5d..5b13502 100644 --- a/validator.go +++ b/validator.go @@ -6,7 +6,7 @@ import ( ) // ValidateObject validates object via validator tag and custom method -func (db *Database) ValidateObject(m ModelInterface, changes bson.M) error { +func (db *MgoDatabase) ValidateObject(m ModelInterface, changes bson.M) error { // first validate via struct tag validator := validator.New(&validator.Config{ TagName: "validator", @@ -18,7 +18,7 @@ func (db *Database) ValidateObject(m ModelInterface, changes bson.M) error { // next execute custom model validator if exists if i, ok := m.(interface { - Validate(db *Database, changes bson.M) error + Validate(db Database, changes bson.M) error }); ok { return i.Validate(db, changes) }