mgocrud/mgo.go

218 lines
4.1 KiB
Go

package mgocrud
import (
"errors"
"runtime"
mgo "gopkg.in/mgo.v2"
)
type MgoConnection struct {
connection *mgo.Session
closed bool
}
func NewMgoConnection(dial string) (Connection, error) {
connection, err := mgo.Dial(dial)
if err != nil {
return nil, err
}
connection.SetMode(mgo.Monotonic, true)
c := &MgoConnection{connection: connection}
runtime.SetFinalizer(c, func(c *MgoConnection) {
if !c.closed {
c.Close()
}
})
return c, nil
}
func (c *MgoConnection) Close() {
if !c.closed {
c.connection.Close()
c.closed = true
}
runtime.SetFinalizer(c, nil)
}
type MgoSession struct {
session *mgo.Session
closed bool
}
func (s *MgoSession) Close() {
if !s.closed {
s.session.Close()
s.closed = true
}
runtime.SetFinalizer(s, nil)
}
func (c *MgoConnection) NewSession() Session {
s := &MgoSession{session: c.connection.Copy()}
runtime.SetFinalizer(s, func(s *MgoSession) {
if !s.closed {
s.Close()
}
})
return s
}
type MgoDatabase struct {
database *mgo.Database
session Session
}
func (s *MgoSession) DB(name string) Database {
return &MgoDatabase{database: s.session.DB(name), session: s}
}
type MgoCollection struct {
collection *mgo.Collection
}
func (db *MgoDatabase) Session() Session {
return db.session
}
func (db *MgoDatabase) C(name string) Collection {
return &MgoCollection{collection: db.database.C(name)}
}
func (db *MgoDatabase) Name() string {
return db.database.Name
}
func (c *MgoCollection) Insert(docs ...interface{}) error {
return c.collection.Insert(docs...)
}
func (c *MgoCollection) UpdateId(id interface{}, update interface{}) error {
return c.collection.UpdateId(id, update)
}
func (c *MgoCollection) RemoveId(id interface{}) error {
return c.collection.RemoveId(id)
}
type MgoChangeInfo struct {
changeInfo *mgo.ChangeInfo
}
func (ci *MgoChangeInfo) Matched() int {
return ci.changeInfo.Matched
}
func (ci *MgoChangeInfo) Removed() int {
return ci.changeInfo.Removed
}
func (ci *MgoChangeInfo) Updated() int {
return ci.changeInfo.Updated
}
func (c *MgoCollection) Upsert(selector interface{}, update interface{}) (ChangeInfo, error) {
ci, err := c.collection.Upsert(selector, update)
if err != nil {
return nil, err
}
return &MgoChangeInfo{changeInfo: ci}, nil
}
func (c *MgoCollection) RemoveAll(filter interface{}) (ChangeInfo, error) {
ci, err := c.collection.RemoveAll(filter)
if err != nil {
return nil, err
}
return &MgoChangeInfo{changeInfo: ci}, nil
}
type MgoQuery struct {
query *mgo.Query
}
func (c *MgoCollection) FindId(id interface{}) Query {
return &MgoQuery{query: c.collection.FindId(id)}
}
func (c *MgoCollection) Find(query interface{}) Query {
return &MgoQuery{query: c.collection.Find(query)}
}
func (q *MgoQuery) Select(selector interface{}) Query {
q.query = q.query.Select(selector)
return q
}
func (q *MgoQuery) One(result interface{}) error {
err := q.query.One(result)
if err == mgo.ErrNotFound {
err = ErrNotFound
}
return err
}
func (q *MgoQuery) Sort(fields ...string) Query {
q.query = q.query.Sort(fields...)
return q
}
func (q *MgoQuery) Skip(n int) Query {
q.query = q.query.Skip(n)
return q
}
func (q *MgoQuery) Limit(n int) Query {
q.query = q.query.Limit(n)
return q
}
func (q *MgoQuery) All(result interface{}) error {
err := q.query.All(result)
if err == mgo.ErrNotFound {
err = ErrNotFound
}
return err
}
func (q *MgoQuery) Count() (int, error) {
c, err := q.query.Count()
if err == mgo.ErrNotFound {
err = ErrNotFound
}
return c, err
}
type MgoIndex struct {
index *mgo.Index
}
func NewMgoIndex(index mgo.Index) Index {
return &MgoIndex{index: &index}
}
func (c *MgoCollection) EnsureIndex(index Index) error {
if i, ok := index.(*MgoIndex); ok {
if i != nil && i.index != nil {
return c.collection.EnsureIndex(*i.index)
}
}
return errors.New("index parameter not initialized with mgo.Index")
}
type MgoPipe struct {
pipe *mgo.Pipe
}
func (p *MgoPipe) All(result interface{}) error {
err := p.pipe.All(result)
if err == mgo.ErrNotFound {
err = ErrNotFound
}
return err
}
func (c *MgoCollection) Pipe(pipeline interface{}) Pipe {
return &MgoPipe{pipe: c.collection.Pipe(pipeline).AllowDiskUse()}
}