package mhayaMongo import ( "context" "crypto/tls" "fmt" "time" cfacade "github.com/mhaya/facade" clog "github.com/mhaya/logger" cprofile "github.com/mhaya/profile" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" ) const ( Name = "mongo_component" ) type ( Component struct { cfacade.Component dbMap map[string]map[string]*mongo.Database } // HashDb hash by group id HashDb func(dbMaps map[string]*mongo.Database) string ) func NewComponent() *Component { return &Component{ dbMap: make(map[string]map[string]*mongo.Database), } } func (*Component) Name() string { return Name } func (s *Component) Init() { // load only the database contained in the `db_id_list` mongoIdList := s.App().Settings().Get("db_id_list") if mongoIdList.LastError() != nil || mongoIdList.Size() < 1 { clog.Warnf("[nodeId = %s] `mongo_id_list` property not exists.", s.App().NodeId()) return } mongoConfig := cprofile.GetConfig("mongo") if mongoConfig.LastError() != nil { panic("`mongo` property not exists in profile file.") } for _, groupId := range mongoConfig.Keys() { s.dbMap[groupId] = make(map[string]*mongo.Database) dbGroup := mongoConfig.GetConfig(groupId) for i := 0; i < dbGroup.Size(); i++ { item := dbGroup.GetConfig(i) var ( enable = item.GetBool("enable", true) id = item.GetString("db_id") dbName = item.GetString("db_name") uri = item.GetString("uri") timeout = time.Duration(item.GetInt64("timeout", 10)) * time.Second tlsEnable = item.GetInt("tls") maxPoolSize = item.GetInt("maxPoolSize") minPoolSize = item.GetInt("minPoolSize") maxConnIdleTime = item.GetInt("maxConnIdleTime") connectTimeout = item.GetInt("connectTimeout") socketTimeout = item.GetInt("socketTimeout") setReplicaSet = item.GetString("setReplicaSet") ) for _, key := range mongoIdList.Keys() { dbId := mongoIdList.Get(key).ToString() if id != dbId { continue } if !enable { panic(fmt.Sprintf("[dbName = %s] is disabled!", dbName)) } db, err := CreateDatabase(uri, setReplicaSet, dbName, tlsEnable, uint64(maxPoolSize), uint64(minPoolSize), maxConnIdleTime, connectTimeout, socketTimeout, timeout) if err != nil { panic(fmt.Sprintf("[dbName = %s] create mongodb fail. error = %s", dbName, err)) } s.dbMap[groupId][id] = db clog.Infof("[dbGroup =%s, dbName = %s] is connected.", groupId, id) } } } } func CreateDatabase(uri, setReplicaSet, dbName string, tlsEnable int, maxPoolSize uint64, minPoolSize uint64, maxConnIdleTime int, connectTimeout, socketTimeout int, timeout ...time.Duration) (*mongo.Database, error) { tt := 5 * time.Second if len(timeout) > 0 && timeout[0].Seconds() > 3 { tt = timeout[0] } var o *options.ClientOptions if tlsEnable == 1 { tlsConfig := &tls.Config{ //MinVersion: tls.VersionTLS12, //PreferServerCipherSuites: true, InsecureSkipVerify: true, } o = options.Client().ApplyURI(uri).SetReplicaSet(setReplicaSet).SetMaxPoolSize(maxPoolSize). //最大连接 SetMinPoolSize(minPoolSize). //最小连接 SetMaxConnIdleTime(time.Duration(maxConnIdleTime) * time.Second). //连接空闲时间 SetConnectTimeout(time.Duration(connectTimeout) * time.Second). //连接超时时间 SetSocketTimeout(time.Duration(socketTimeout) * time.Second).SetTLSConfig(tlsConfig) //套接字超时时间 } else { o = options.Client().ApplyURI(uri).SetMaxPoolSize(maxPoolSize). //最大连接 SetMinPoolSize(minPoolSize). //最小连接 SetMaxConnIdleTime(time.Duration(maxConnIdleTime) * time.Second). //连接空闲时间 SetConnectTimeout(time.Duration(connectTimeout) * time.Second). //连接超时时间 SetSocketTimeout(time.Duration(socketTimeout) * time.Second) //套接字超时时间 } if err := o.Validate(); err != nil { return nil, err } ctx, cancel := context.WithTimeout(context.Background(), tt) defer cancel() client, err := mongo.Connect(ctx, o) if err != nil { return nil, err } err = client.Ping(context.Background(), readpref.Primary()) if err != nil { return nil, err } clog.Infof("ping database [uri = %s] is ok", uri) return client.Database(dbName), nil } func (s *Component) GetDb(id string) *mongo.Database { for _, group := range s.dbMap { for k, v := range group { if k == id { return v } } } return nil } func (s *Component) GetHashDb(groupId string, hashFn HashDb) (*mongo.Database, bool) { dbGroup, found := s.GetDbMap(groupId) if !found { clog.Warnf("groupId = %s not found.", groupId) return nil, false } dbId := hashFn(dbGroup) db, found := dbGroup[dbId] return db, found } func (s *Component) GetDbMap(groupId string) (map[string]*mongo.Database, bool) { dbGroup, found := s.dbMap[groupId] return dbGroup, found }