diff --git a/storage/redis/redis.go b/storage/redis/redis.go index 2cdadad..519369c 100644 --- a/storage/redis/redis.go +++ b/storage/redis/redis.go @@ -19,41 +19,47 @@ func (d *driver) New(conf *config.Storage) (storage.Conn, error) { return &Conn{ conf: conf, pool: &redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - var ( - conn redis.Conn - err error - ) - - if conf.ConnectTimeout != nil && - conf.ReadTimeout != nil && - conf.WriteTimeout != nil { - - conn, err = redis.DialTimeout( - conf.Network, - conf.Addr, - conf.ConnectTimeout.Duration, - conf.ReadTimeout.Duration, - conf.WriteTimeout.Duration, - ) - } else { - conn, err = redis.Dial(conf.Network, conf.Addr) - } - if err != nil { - return nil, err - } - return conn, nil - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, + MaxIdle: 3, + IdleTimeout: 240 * time.Second, + Dial: makeDialFunc(conf), + TestOnBorrow: testOnBorrow, }, }, nil } +func makeDialFunc(conf *config.Storage) func() (redis.Conn, error) { + return func() (redis.Conn, error) { + var ( + conn redis.Conn + err error + ) + + if conf.ConnectTimeout != nil && + conf.ReadTimeout != nil && + conf.WriteTimeout != nil { + + conn, err = redis.DialTimeout( + conf.Network, + conf.Addr, + conf.ConnectTimeout.Duration, + conf.ReadTimeout.Duration, + conf.WriteTimeout.Duration, + ) + } else { + conn, err = redis.Dial(conf.Network, conf.Addr) + } + if err != nil { + return nil, err + } + return conn, nil + } +} + +func testOnBorrow(c redis.Conn, t time.Time) error { + _, err := c.Do("PING") + return err +} + type Conn struct { conf *config.Storage pool *redis.Pool