From 813fd4590a97cd4804932a0222b8bf446ea88c21 Mon Sep 17 00:00:00 2001 From: Jonathan Moody <103143855+moodyjon@users.noreply.github.com> Date: Wed, 28 Sep 2022 14:49:32 -0500 Subject: [PATCH] Add --max-sessions, --session-timeout args. Enforce max sessions. --- server/args.go | 8 +++++++ server/jsonrpc_service.go | 5 +++-- server/server.go | 2 +- server/session.go | 44 ++++++++++++++++++++++++++------------- 4 files changed, 41 insertions(+), 18 deletions(-) diff --git a/server/args.go b/server/args.go index 27c740a..95ecce0 100644 --- a/server/args.go +++ b/server/args.go @@ -30,6 +30,8 @@ type Args struct { NotifierPort string JSONRPCPort int JSONRPCHTTPPort int + MaxSessions int + SessionTimeout int EsIndex string RefreshDelta int CacheTTL int @@ -61,6 +63,8 @@ const ( DefaultPrometheusPort = "2112" DefaultNotifierPort = "18080" DefaultJSONRPCPort = 50001 + DefaultMaxSessions = 10000 + DefaultSessionTimeout = 300 DefaultRefreshDelta = 5 DefaultCacheTTL = 5 DefaultPeerFile = "peers.txt" @@ -129,6 +133,8 @@ func ParseArgs(searchRequest *pb.SearchRequest) *Args { notifierPort := parser.String("", "notifier-port", &argparse.Options{Required: false, Help: "notifier port", Default: DefaultNotifierPort}) jsonRPCPort := parser.Int("", "json-rpc-port", &argparse.Options{Required: false, Help: "JSON RPC port", Validate: validatePort}) jsonRPCHTTPPort := parser.Int("", "json-rpc-http-port", &argparse.Options{Required: false, Help: "JSON RPC over HTTP port", Validate: validatePort}) + maxSessions := parser.Int("", "max-sessions", &argparse.Options{Required: false, Help: "Maximum number of electrum clients that can be connected", Default: DefaultMaxSessions}) + sessionTimeout := parser.Int("", "session-timeout", &argparse.Options{Required: false, Help: "Session inactivity timeout (seconds)", Default: DefaultSessionTimeout}) esIndex := parser.String("", "esindex", &argparse.Options{Required: false, Help: "elasticsearch index name", Default: DefaultEsIndex}) refreshDelta := parser.Int("", "refresh-delta", &argparse.Options{Required: false, Help: "elasticsearch index refresh delta in seconds", Default: DefaultRefreshDelta}) cacheTTL := parser.Int("", "cachettl", &argparse.Options{Required: false, Help: "Cache TTL in minutes", Default: DefaultCacheTTL}) @@ -183,6 +189,8 @@ func ParseArgs(searchRequest *pb.SearchRequest) *Args { NotifierPort: *notifierPort, JSONRPCPort: *jsonRPCPort, JSONRPCHTTPPort: *jsonRPCHTTPPort, + MaxSessions: *maxSessions, + SessionTimeout: *sessionTimeout, EsIndex: *esIndex, RefreshDelta: *refreshDelta, CacheTTL: *cacheTTL, diff --git a/server/jsonrpc_service.go b/server/jsonrpc_service.go index ee73de7..df12f3d 100644 --- a/server/jsonrpc_service.go +++ b/server/jsonrpc_service.go @@ -11,6 +11,7 @@ import ( gorilla_rpc "github.com/gorilla/rpc" gorilla_json "github.com/gorilla/rpc/json" log "github.com/sirupsen/logrus" + "golang.org/x/net/netutil" ) type gorillaRpcCodec struct { @@ -79,7 +80,7 @@ func (s *Server) StartJsonRPC() error { s.sessionManager.addSession(conn) } } - go acceptConnections(listener) + go acceptConnections(netutil.LimitListener(listener, s.sessionManager.sessionsMax)) } fail1: @@ -109,7 +110,7 @@ fail1: log.Errorf("RegisterTCPService: %v\n", err) goto fail2 } - err = s1.RegisterTCPService(&BlockchainAddressService{s.DB, s.Chain, nil, nil}, "blockchain_address") + err = s1.RegisterTCPService(&BlockchainAddressService{s.DB, s.Chain, nil, nil}, "blockchain_address") if err != nil { log.Errorf("RegisterTCPService: %v\n", err) goto fail2 diff --git a/server/server.go b/server/server.go index 0fa8057..4bb1bb6 100644 --- a/server/server.go +++ b/server/server.go @@ -333,7 +333,7 @@ func MakeHubServer(ctx context.Context, args *Args) *Server { HeightSubs: make(map[net.Addr]net.Conn), HeightSubsMut: sync.RWMutex{}, NotifierChan: make(chan interface{}), - sessionManager: newSessionManager(myDB, &chain), + sessionManager: newSessionManager(myDB, &chain, args.MaxSessions, args.SessionTimeout), } // Start up our background services diff --git a/server/session.go b/server/session.go index 180c45c..9c8f9ea 100644 --- a/server/session.go +++ b/server/session.go @@ -18,8 +18,6 @@ import ( log "github.com/sirupsen/logrus" ) -var SESSION_INACTIVE_TIMEOUT = 2 * time.Minute - type headerNotification struct { internal.HeightHash blockHeader [HEADER_SIZE]byte @@ -116,23 +114,28 @@ type sessionMap map[net.Addr]*session type sessionManager struct { // sessionsMut protects sessions, headerSubs, hashXSubs state - sessionsMut sync.RWMutex - sessions sessionMap - db *db.ReadOnlyDBColumnFamily - chain *chaincfg.Params + sessionsMut sync.RWMutex + sessions sessionMap + sessionsWait sync.WaitGroup + sessionsMax int + sessionTimeout time.Duration + db *db.ReadOnlyDBColumnFamily + chain *chaincfg.Params // headerSubs are sessions subscribed via 'blockchain.headers.subscribe' headerSubs sessionMap // hashXSubs are sessions subscribed via 'blockchain.{address,scripthash}.subscribe' hashXSubs map[[HASHX_LEN]byte]sessionMap } -func newSessionManager(db *db.ReadOnlyDBColumnFamily, chain *chaincfg.Params) *sessionManager { +func newSessionManager(db *db.ReadOnlyDBColumnFamily, chain *chaincfg.Params, sessionsMax, sessionTimeout int) *sessionManager { return &sessionManager{ - sessions: make(sessionMap), - db: db, - chain: chain, - headerSubs: make(sessionMap), - hashXSubs: make(map[[HASHX_LEN]byte]sessionMap), + sessions: make(sessionMap), + sessionsMax: sessionsMax, + sessionTimeout: time.Duration(sessionTimeout) * time.Second, + db: db, + chain: chain, + headerSubs: make(sessionMap), + hashXSubs: make(map[[HASHX_LEN]byte]sessionMap), } } @@ -153,12 +156,14 @@ func (sm *sessionManager) stop() { } func (sm *sessionManager) manage() { + sm.sessionsMut.Lock() for _, sess := range sm.sessions { - if time.Since(sess.lastRecv) > SESSION_INACTIVE_TIMEOUT { - sm.removeSession(sess) + if time.Since(sess.lastRecv) > sm.sessionTimeout { + sm.removeSessionLocked(sess) log.Infof("session %v timed out", sess.addr.String()) } } + sm.sessionsMut.Unlock() // TEMPORARY TESTING: Send fake notification for specific address. address, _ := lbcutil.DecodeAddress("bNe63fYgYNA85ZQ56p7MwBtuCL7MXPRfrm", sm.chain) @@ -220,7 +225,12 @@ func (sm *sessionManager) addSession(conn net.Conn) { goto fail } - go s1.ServeCodec(&SessionServerCodec{jsonrpc.NewServerCodec(conn), sess}) + sm.sessionsWait.Add(1) + go func() { + s1.ServeCodec(&SessionServerCodec{jsonrpc.NewServerCodec(conn), sess}) + log.Infof("session %v goroutine exit", sess.addr.String()) + sm.sessionsWait.Done() + }() return fail: @@ -230,6 +240,10 @@ fail: func (sm *sessionManager) removeSession(sess *session) { sm.sessionsMut.Lock() defer sm.sessionsMut.Unlock() + sm.removeSessionLocked(sess) +} + +func (sm *sessionManager) removeSessionLocked(sess *session) { if sess.headersSub { delete(sm.headerSubs, sess.addr) }