package mgocrud import ( "errors" "runtime" mgo "gopkg.in/mgo.v2" ) type MgoConnection struct { connection *mgo.Session closed bool } func NewMgoConnection(dial string) (Connection, error) { connection, err := mgo.Dial(dial) if err != nil { return nil, err } connection.SetMode(mgo.Monotonic, true) c := &MgoConnection{connection: connection} runtime.SetFinalizer(c, func(c *MgoConnection) { if !c.closed { c.Close() } }) return c, nil } func (c *MgoConnection) Close() { if !c.closed { c.connection.Close() c.closed = true } runtime.SetFinalizer(c, nil) } type MgoSession struct { connection *MgoConnection session *mgo.Session closed bool } func (s *MgoSession) Connection() Connection { return s.connection } func (s *MgoSession) Close() { if !s.closed { s.session.Close() s.closed = true } runtime.SetFinalizer(s, nil) } func (c *MgoConnection) NewSession() Session { s := &MgoSession{ connection: c, session: c.connection.Copy(), } runtime.SetFinalizer(s, func(s *MgoSession) { if !s.closed { s.Close() } }) return s } type MgoDatabase struct { database *mgo.Database session Session } func (s *MgoSession) DB(name string) Database { return &MgoDatabase{database: s.session.DB(name), session: s} } type MgoCollection struct { database *MgoDatabase collection *mgo.Collection } func (c *MgoCollection) DB() Database { return c.database } func (db *MgoDatabase) Session() Session { return db.session } func (db *MgoDatabase) C(name string) Collection { return &MgoCollection{ database: db, collection: db.database.C(name), } } func (db *MgoDatabase) Name() string { return db.database.Name } func (c *MgoCollection) Insert(docs ...interface{}) error { return c.collection.Insert(docs...) } func (c *MgoCollection) UpdateId(id interface{}, update interface{}) error { return c.collection.UpdateId(id, update) } func (c *MgoCollection) RemoveId(id interface{}) error { return c.collection.RemoveId(id) } type MgoChangeInfo struct { changeInfo *mgo.ChangeInfo } func (ci *MgoChangeInfo) Matched() int { return ci.changeInfo.Matched } func (ci *MgoChangeInfo) Removed() int { return ci.changeInfo.Removed } func (ci *MgoChangeInfo) Updated() int { return ci.changeInfo.Updated } func (c *MgoCollection) Upsert(selector interface{}, update interface{}) (ChangeInfo, error) { ci, err := c.collection.Upsert(selector, update) if err != nil { return nil, err } return &MgoChangeInfo{changeInfo: ci}, nil } func (c *MgoCollection) RemoveAll(filter interface{}) (ChangeInfo, error) { ci, err := c.collection.RemoveAll(filter) if err != nil { return nil, err } return &MgoChangeInfo{changeInfo: ci}, nil } type MgoQuery struct { query *mgo.Query } func (c *MgoCollection) FindId(id interface{}) Query { return &MgoQuery{query: c.collection.FindId(id)} } func (c *MgoCollection) Find(query interface{}) Query { return &MgoQuery{query: c.collection.Find(query)} } func (q *MgoQuery) Select(selector interface{}) Query { q.query = q.query.Select(selector) return q } func (q *MgoQuery) One(result interface{}) error { err := q.query.One(result) if err == mgo.ErrNotFound { err = ErrNotFound } return err } func (q *MgoQuery) Sort(fields ...string) Query { q.query = q.query.Sort(fields...) return q } func (q *MgoQuery) Skip(n int) Query { q.query = q.query.Skip(n) return q } func (q *MgoQuery) Limit(n int) Query { q.query = q.query.Limit(n) return q } func (q *MgoQuery) All(result interface{}) error { err := q.query.All(result) if err == mgo.ErrNotFound { err = ErrNotFound } return err } func (q *MgoQuery) Count() (int, error) { c, err := q.query.Count() if err == mgo.ErrNotFound { err = ErrNotFound } return c, err } type MgoIndex struct { index *mgo.Index } func NewMgoIndex(index mgo.Index) Index { return &MgoIndex{index: &index} } func (c *MgoCollection) EnsureIndex(index Index) error { if i, ok := index.(*MgoIndex); ok { if i != nil && i.index != nil { return c.collection.EnsureIndex(*i.index) } } return errors.New("index parameter not initialized with mgo.Index") } type MgoPipe struct { pipe *mgo.Pipe } func (p *MgoPipe) All(result interface{}) error { err := p.pipe.All(result) if err == mgo.ErrNotFound { err = ErrNotFound } return err } func (c *MgoCollection) Pipe(pipeline interface{}) Pipe { return &MgoPipe{pipe: c.collection.Pipe(pipeline).AllowDiskUse()} }