2 Commits

Author SHA1 Message Date
98801edbc0 db.Session() 2022-02-09 18:55:58 +01:00
7dd238342c encapsulate ErrNotFound 2022-02-09 15:53:37 +01:00

View File

@@ -6,6 +6,10 @@ import (
mgo "gopkg.in/mgo.v2" mgo "gopkg.in/mgo.v2"
) )
var (
ErrNotFound = errors.New("not found")
)
type Session struct { type Session struct {
session *mgo.Session session *mgo.Session
} }
@@ -29,16 +33,21 @@ func NewSession(dial string) (*Session, error) {
type Database struct { type Database struct {
database *mgo.Database database *mgo.Database
session *Session
} }
func (s *Session) DB(name string) *Database { func (s *Session) DB(name string) *Database {
return &Database{database: s.session.DB(name)} return &Database{database: s.session.DB(name), session: s}
} }
type Collection struct { type Collection struct {
collection *mgo.Collection collection *mgo.Collection
} }
func (db *Database) Session() *Session {
return db.session
}
func (db *Database) C(name string) *Collection { func (db *Database) C(name string) *Collection {
return &Collection{collection: db.database.C(name)} return &Collection{collection: db.database.C(name)}
} }
@@ -111,7 +120,7 @@ func (q *Query) Select(selector interface{}) *Query {
func (q *Query) One(result interface{}) error { func (q *Query) One(result interface{}) error {
err := q.query.One(result) err := q.query.One(result)
if err == mgo.ErrNotFound { if err == mgo.ErrNotFound {
err = &ErrNotFound{msg: err.Error()} err = ErrNotFound
} }
return err return err
} }
@@ -134,7 +143,7 @@ func (q *Query) Limit(n int) *Query {
func (q *Query) All(result interface{}) error { func (q *Query) All(result interface{}) error {
err := q.query.All(result) err := q.query.All(result)
if err == mgo.ErrNotFound { if err == mgo.ErrNotFound {
err = &ErrNotFound{msg: err.Error()} err = ErrNotFound
} }
return err return err
} }
@@ -142,7 +151,7 @@ func (q *Query) All(result interface{}) error {
func (q *Query) Count() (int, error) { func (q *Query) Count() (int, error) {
c, err := q.query.Count() c, err := q.query.Count()
if err == mgo.ErrNotFound { if err == mgo.ErrNotFound {
err = &ErrNotFound{msg: err.Error()} err = ErrNotFound
} }
return c, err return c, err
} }
@@ -169,7 +178,7 @@ type Pipe struct {
func (p *Pipe) All(result interface{}) error { func (p *Pipe) All(result interface{}) error {
err := p.pipe.All(result) err := p.pipe.All(result)
if err == mgo.ErrNotFound { if err == mgo.ErrNotFound {
err = &ErrNotFound{msg: err.Error()} err = ErrNotFound
} }
return err return err
} }
@@ -177,14 +186,3 @@ func (p *Pipe) All(result interface{}) error {
func (c *Collection) Pipe(pipeline interface{}) *Pipe { func (c *Collection) Pipe(pipeline interface{}) *Pipe {
return &Pipe{pipe: c.collection.Pipe(pipeline).AllowDiskUse()} return &Pipe{pipe: c.collection.Pipe(pipeline).AllowDiskUse()}
} }
type ErrNotFound struct {
msg string
}
func (e *ErrNotFound) Error() string {
if e.msg == "" {
return "not found"
}
return e.msg
}