diff --git a/db/db.go b/db/db.go index 164a9d0..00cc9a7 100644 --- a/db/db.go +++ b/db/db.go @@ -650,7 +650,7 @@ func (db *ReadOnlyDBColumnFamily) Shutdown() { // RunDetectChanges Go routine the runs continuously while the hub is active // to keep the db readonly view up to date and handle reorgs on the // blockchain. -func (db *ReadOnlyDBColumnFamily) RunDetectChanges(notifCh chan *internal.HeightHash) { +func (db *ReadOnlyDBColumnFamily) RunDetectChanges(notifCh chan<- interface{}) { go func() { lastPrint := time.Now() for { @@ -674,7 +674,7 @@ func (db *ReadOnlyDBColumnFamily) RunDetectChanges(notifCh chan *internal.Height } // DetectChanges keep the rocksdb db in sync and handle reorgs -func (db *ReadOnlyDBColumnFamily) detectChanges(notifCh chan *internal.HeightHash) error { +func (db *ReadOnlyDBColumnFamily) detectChanges(notifCh chan<- interface{}) error { err := db.DB.TryCatchUpWithPrimary() if err != nil { return err diff --git a/db/db_get.go b/db/db_get.go index 54c3fe8..d630c90 100644 --- a/db/db_get.go +++ b/db/db_get.go @@ -3,6 +3,7 @@ package db // db_get.go contains the basic access functions to the database. import ( + "crypto/sha256" "encoding/hex" "fmt" "log" @@ -266,6 +267,72 @@ func (db *ReadOnlyDBColumnFamily) GetHistory(hashX []byte) ([]TxInfo, error) { return results, nil } +func (db *ReadOnlyDBColumnFamily) GetStatus(hashX []byte) ([]byte, error) { + // Lookup in HashXMempoolStatus first. + status, err := db.getMempoolStatus(hashX) + if err == nil && status != nil { + return status, err + } + + // No indexed mempool status. Lookup in HashXStatus second. + handle, err := db.EnsureHandle(prefixes.HashXStatus) + if err != nil { + return nil, err + } + key := &prefixes.HashXStatusKey{ + Prefix: []byte{prefixes.HashXStatus}, + HashX: hashX, + } + rawKey := key.PackKey() + slice, err := db.DB.GetCF(db.Opts, handle, rawKey) + defer slice.Free() + if err == nil && slice.Size() > 0 { + rawValue := make([]byte, len(slice.Data())) + copy(rawValue, slice.Data()) + value := prefixes.HashXStatusValue{} + value.UnpackValue(rawValue) + return value.Status, nil + } + + // No indexed status. Fall back to enumerating HashXHistory. + txs, err := db.GetHistory(hashX) + if err != nil { + return nil, err + } + hash := sha256.New() + for _, tx := range txs { + hash.Write([]byte(fmt.Sprintf("%s:%d:", tx.TxHash.String(), tx.Height))) + } + // TODO: Mempool history + return hash.Sum(nil), err +} + +func (db *ReadOnlyDBColumnFamily) getMempoolStatus(hashX []byte) ([]byte, error) { + handle, err := db.EnsureHandle(prefixes.HashXMempoolStatus) + if err != nil { + return nil, err + } + + key := &prefixes.HashXMempoolStatusKey{ + Prefix: []byte{prefixes.HashXMempoolStatus}, + HashX: hashX, + } + rawKey := key.PackKey() + slice, err := db.DB.GetCF(db.Opts, handle, rawKey) + defer slice.Free() + if err != nil { + return nil, err + } else if slice.Size() == 0 { + return nil, nil + } + + rawValue := make([]byte, len(slice.Data())) + copy(rawValue, slice.Data()) + value := prefixes.HashXMempoolStatusValue{} + value.UnpackValue(rawValue) + return value.Status, nil +} + // GetStreamsAndChannelRepostedByChannelHashes returns a map of streams and channel hashes that are reposted by the given channel hashes. func (db *ReadOnlyDBColumnFamily) GetStreamsAndChannelRepostedByChannelHashes(reposterChannelHashes [][]byte) (map[string][]byte, map[string][]byte, error) { handle, err := db.EnsureHandle(prefixes.ChannelToClaim) diff --git a/server/args.go b/server/args.go index 29a64d3..95ecce0 100644 --- a/server/args.go +++ b/server/args.go @@ -3,6 +3,7 @@ package server import ( "log" "os" + "strconv" "strings" "github.com/akamensky/argparse" @@ -27,7 +28,10 @@ type Args struct { EsPort string PrometheusPort string NotifierPort string - JSONRPCPort string + JSONRPCPort int + JSONRPCHTTPPort int + MaxSessions int + SessionTimeout int EsIndex string RefreshDelta int CacheTTL int @@ -58,7 +62,9 @@ const ( DefaultEsPort = "9200" DefaultPrometheusPort = "2112" DefaultNotifierPort = "18080" - DefaultJSONRPCPort = "50001" + DefaultJSONRPCPort = 50001 + DefaultMaxSessions = 10000 + DefaultSessionTimeout = 300 DefaultRefreshDelta = 5 DefaultCacheTTL = 5 DefaultPeerFile = "peers.txt" @@ -111,6 +117,11 @@ func ParseArgs(searchRequest *pb.SearchRequest) *Args { searchCmd := parser.NewCommand("search", "claim search") dbCmd := parser.NewCommand("db", "db testing") + validatePort := func(arg []string) error { + _, err := strconv.ParseUint(arg[0], 10, 16) + return err + } + host := parser.String("", "rpchost", &argparse.Options{Required: false, Help: "RPC host", Default: DefaultHost}) port := parser.String("", "rpcport", &argparse.Options{Required: false, Help: "RPC port", Default: DefaultPort}) dbPath := parser.String("", "db-path", &argparse.Options{Required: false, Help: "RocksDB path", Default: DefaultDBPath}) @@ -120,7 +131,10 @@ func ParseArgs(searchRequest *pb.SearchRequest) *Args { esPort := parser.String("", "esport", &argparse.Options{Required: false, Help: "elasticsearch port", Default: DefaultEsPort}) prometheusPort := parser.String("", "prometheus-port", &argparse.Options{Required: false, Help: "prometheus port", Default: DefaultPrometheusPort}) notifierPort := parser.String("", "notifier-port", &argparse.Options{Required: false, Help: "notifier port", Default: DefaultNotifierPort}) - jsonRPCPort := parser.String("", "json-rpc-port", &argparse.Options{Required: false, Help: "JSON RPC port", Default: DefaultJSONRPCPort}) + jsonRPCPort := parser.Int("", "json-rpc-port", &argparse.Options{Required: false, Help: "JSON RPC port", Validate: validatePort}) + jsonRPCHTTPPort := parser.Int("", "json-rpc-http-port", &argparse.Options{Required: false, Help: "JSON RPC over HTTP port", Validate: validatePort}) + maxSessions := parser.Int("", "max-sessions", &argparse.Options{Required: false, Help: "Maximum number of electrum clients that can be connected", Default: DefaultMaxSessions}) + sessionTimeout := parser.Int("", "session-timeout", &argparse.Options{Required: false, Help: "Session inactivity timeout (seconds)", Default: DefaultSessionTimeout}) esIndex := parser.String("", "esindex", &argparse.Options{Required: false, Help: "elasticsearch index name", Default: DefaultEsIndex}) refreshDelta := parser.Int("", "refresh-delta", &argparse.Options{Required: false, Help: "elasticsearch index refresh delta in seconds", Default: DefaultRefreshDelta}) cacheTTL := parser.Int("", "cachettl", &argparse.Options{Required: false, Help: "Cache TTL in minutes", Default: DefaultCacheTTL}) @@ -158,6 +172,11 @@ func ParseArgs(searchRequest *pb.SearchRequest) *Args { log.Fatalln(parser.Usage(err)) } + // Use default JSON RPC port only if *neither* JSON RPC arg is specified. + if *jsonRPCPort == 0 && *jsonRPCHTTPPort == 0 { + *jsonRPCPort = DefaultJSONRPCPort + } + args := &Args{ CmdType: SearchCmd, Host: *host, @@ -169,6 +188,9 @@ func ParseArgs(searchRequest *pb.SearchRequest) *Args { PrometheusPort: *prometheusPort, NotifierPort: *notifierPort, JSONRPCPort: *jsonRPCPort, + JSONRPCHTTPPort: *jsonRPCHTTPPort, + MaxSessions: *maxSessions, + SessionTimeout: *sessionTimeout, EsIndex: *esIndex, RefreshDelta: *refreshDelta, CacheTTL: *cacheTTL, diff --git a/server/jsonrpc_blockchain.go b/server/jsonrpc_blockchain.go index 41b6809..de5c457 100644 --- a/server/jsonrpc_blockchain.go +++ b/server/jsonrpc_blockchain.go @@ -9,10 +9,7 @@ import ( "encoding/hex" "errors" "fmt" - "net/http" - "strings" - "github.com/gorilla/rpc" "github.com/lbryio/herald.go/db" "github.com/lbryio/herald.go/internal" "github.com/lbryio/lbcd/chaincfg" @@ -23,56 +20,37 @@ import ( "golang.org/x/exp/constraints" ) -type BlockchainCodec struct { - rpc.Codec -} - -func (c *BlockchainCodec) NewRequest(r *http.Request) rpc.CodecRequest { - return &BlockchainCodecRequest{c.Codec.NewRequest(r)} -} - -// BlockchainCodecRequest provides ability to rewrite the incoming -// request "method" field. For example: -// blockchain.block.get_header -> blockchain_block.Get_header -// blockchain.address.listunspent -> blockchain_address.Listunspent -// This makes the "method" string compatible with Gorilla/RPC -// requirements. -type BlockchainCodecRequest struct { - rpc.CodecRequest -} - -func (cr *BlockchainCodecRequest) Method() (string, error) { - rawMethod, err := cr.CodecRequest.Method() - if err != nil { - return rawMethod, err - } - parts := strings.Split(rawMethod, ".") - if len(parts) < 2 { - return rawMethod, fmt.Errorf("blockchain rpc: service/method ill-formed: %q", rawMethod) - } - service := strings.Join(parts[0:len(parts)-1], "_") - method := parts[len(parts)-1] - if len(method) < 1 { - return rawMethod, fmt.Errorf("blockchain rpc: method ill-formed: %q", method) - } - method = strings.ToUpper(string(method[0])) + string(method[1:]) - return service + "." + method, err -} - -// BlockchainService methods handle "blockchain.block.*" RPCs -type BlockchainService struct { +// BlockchainBlockService methods handle "blockchain.block.*" RPCs +type BlockchainBlockService struct { DB *db.ReadOnlyDBColumnFamily Chain *chaincfg.Params } +// BlockchainBlockService methods handle "blockchain.headers.*" RPCs +type BlockchainHeadersService struct { + DB *db.ReadOnlyDBColumnFamily + Chain *chaincfg.Params + // needed for subscribe/unsubscribe + sessionMgr *sessionManager + session *session +} + // BlockchainAddressService methods handle "blockchain.address.*" RPCs type BlockchainAddressService struct { - BlockchainService + DB *db.ReadOnlyDBColumnFamily + Chain *chaincfg.Params + // needed for subscribe/unsubscribe + sessionMgr *sessionManager + session *session } // BlockchainScripthashService methods handle "blockchain.scripthash.*" RPCs type BlockchainScripthashService struct { - BlockchainService + DB *db.ReadOnlyDBColumnFamily + Chain *chaincfg.Params + // needed for subscribe/unsubscribe + sessionMgr *sessionManager + session *session } const CHUNK_SIZE = 96 @@ -87,10 +65,45 @@ func min[Ord constraints.Ordered](x, y Ord) Ord { return y } +func max[Ord constraints.Ordered](x, y Ord) Ord { + if x > y { + return x + } + return y +} + +type BlockHeaderElectrum struct { + Version uint32 `json:"version"` + PrevBlockHash string `json:"prev_block_hash"` + MerkleRoot string `json:"merkle_root"` + ClaimTrieRoot string `json:"claim_trie_root"` + Timestamp uint32 `json:"timestamp"` + Bits uint32 `json:"bits"` + Nonce uint32 `json:"nonce"` + BlockHeight uint32 `json:"block_height"` +} + +func newBlockHeaderElectrum(header *[HEADER_SIZE]byte, height uint32) *BlockHeaderElectrum { + var h1, h2, h3 chainhash.Hash + h1.SetBytes(header[4:36]) + h2.SetBytes(header[36:68]) + h3.SetBytes(header[68:100]) + return &BlockHeaderElectrum{ + Version: binary.LittleEndian.Uint32(header[0:]), + PrevBlockHash: h1.String(), + MerkleRoot: h2.String(), + ClaimTrieRoot: h3.String(), + Timestamp: binary.LittleEndian.Uint32(header[100:]), + Bits: binary.LittleEndian.Uint32(header[104:]), + Nonce: binary.LittleEndian.Uint32(header[108:]), + BlockHeight: height, + } +} + type BlockGetServerHeightReq struct{} type BlockGetServerHeightResp uint32 -func (s *BlockchainService) Get_server_height(r *http.Request, req *BlockGetServerHeightReq, resp **BlockGetServerHeightResp) error { +func (s *BlockchainBlockService) Get_server_height(req *BlockGetServerHeightReq, resp **BlockGetServerHeightResp) error { if s.DB == nil || s.DB.LastState == nil { return fmt.Errorf("unknown height") } @@ -103,7 +116,7 @@ type BlockGetChunkReq uint32 type BlockGetChunkResp string // 'blockchain.block.get_chunk' -func (s *BlockchainService) Get_chunk(r *http.Request, req *BlockGetChunkReq, resp **BlockGetChunkResp) error { +func (s *BlockchainBlockService) Get_chunk(req *BlockGetChunkReq, resp **BlockGetChunkResp) error { index := uint32(*req) db_headers, err := s.DB.GetHeaders(index*CHUNK_SIZE, CHUNK_SIZE) if err != nil { @@ -120,18 +133,11 @@ func (s *BlockchainService) Get_chunk(r *http.Request, req *BlockGetChunkReq, re type BlockGetHeaderReq uint32 type BlockGetHeaderResp struct { - Version uint32 `json:"version"` - PrevBlockHash string `json:"prev_block_hash"` - MerkleRoot string `json:"merkle_root"` - ClaimTrieRoot string `json:"claim_trie_root"` - Timestamp uint32 `json:"timestamp"` - Bits uint32 `json:"bits"` - Nonce uint32 `json:"nonce"` - BlockHeight uint32 `json:"block_height"` + BlockHeaderElectrum } // 'blockchain.block.get_header' -func (s *BlockchainService) Get_header(r *http.Request, req *BlockGetHeaderReq, resp **BlockGetHeaderResp) error { +func (s *BlockchainBlockService) Get_header(req *BlockGetHeaderReq, resp **BlockGetHeaderResp) error { height := uint32(*req) headers, err := s.DB.GetHeaders(height, 1) if err != nil { @@ -140,23 +146,7 @@ func (s *BlockchainService) Get_header(r *http.Request, req *BlockGetHeaderReq, if len(headers) < 1 { return errors.New("not found") } - decode := func(header *[HEADER_SIZE]byte, height uint32) *BlockGetHeaderResp { - var h1, h2, h3 chainhash.Hash - h1.SetBytes(header[4:36]) - h2.SetBytes(header[36:68]) - h3.SetBytes(header[68:100]) - return &BlockGetHeaderResp{ - Version: binary.LittleEndian.Uint32(header[0:]), - PrevBlockHash: h1.String(), - MerkleRoot: h2.String(), - ClaimTrieRoot: h3.String(), - Timestamp: binary.LittleEndian.Uint32(header[100:]), - Bits: binary.LittleEndian.Uint32(header[104:]), - Nonce: binary.LittleEndian.Uint32(header[108:]), - BlockHeight: height, - } - } - *resp = decode(&headers[0], height) + *resp = &BlockGetHeaderResp{*newBlockHeaderElectrum(&headers[0], height)} return err } @@ -177,7 +167,7 @@ type BlockHeadersResp struct { } // 'blockchain.block.headers' -func (s *BlockchainService) Headers(r *http.Request, req *BlockHeadersReq, resp **BlockHeadersResp) error { +func (s *BlockchainBlockService) Headers(req *BlockHeadersReq, resp **BlockHeadersResp) error { count := min(req.Count, MAX_CHUNK_SIZE) db_headers, err := s.DB.GetHeaders(req.StartHeight, count) if err != nil { @@ -209,6 +199,47 @@ func (s *BlockchainService) Headers(r *http.Request, req *BlockHeadersReq, resp return err } +type HeadersSubscribeReq struct { + Raw bool `json:"raw"` +} + +type HeadersSubscribeResp struct { + BlockHeaderElectrum +} +type HeadersSubscribeRawResp struct { + Hex string `json:"hex"` + Height uint32 `json:"height"` +} + +// 'blockchain.headers.subscribe' +func (s *BlockchainHeadersService) Subscribe(req *HeadersSubscribeReq, resp *interface{}) error { + if s.sessionMgr == nil || s.session == nil { + return errors.New("no session, rpc not supported") + } + s.sessionMgr.headersSubscribe(s.session, req.Raw, true /*subscribe*/) + height := s.DB.Height + if s.DB.LastState != nil { + height = s.DB.LastState.Height + } + headers, err := s.DB.GetHeaders(height, 1) + if err != nil { + s.sessionMgr.headersSubscribe(s.session, req.Raw, false /*subscribe*/) + return err + } + if len(headers) < 1 { + return errors.New("not found") + } + if req.Raw { + *resp = &HeadersSubscribeRawResp{ + Hex: hex.EncodeToString(headers[0][:]), + Height: height, + } + } else { + *resp = &HeadersSubscribeResp{*newBlockHeaderElectrum(&headers[0], height)} + } + return err +} + func decodeScriptHash(scripthash string) ([]byte, error) { sh, err := hex.DecodeString(scripthash) if err != nil { @@ -249,7 +280,7 @@ type AddressGetBalanceResp struct { } // 'blockchain.address.get_balance' -func (s *BlockchainAddressService) Get_balance(r *http.Request, req *AddressGetBalanceReq, resp **AddressGetBalanceResp) error { +func (s *BlockchainAddressService) Get_balance(req *AddressGetBalanceReq, resp **AddressGetBalanceResp) error { address, err := lbcutil.DecodeAddress(req.Address, s.Chain) if err != nil { return err @@ -276,7 +307,7 @@ type ScripthashGetBalanceResp struct { } // 'blockchain.scripthash.get_balance' -func (s *BlockchainScripthashService) Get_balance(r *http.Request, req *scripthashGetBalanceReq, resp **ScripthashGetBalanceResp) error { +func (s *BlockchainScripthashService) Get_balance(req *scripthashGetBalanceReq, resp **ScripthashGetBalanceResp) error { scripthash, err := decodeScriptHash(req.ScriptHash) if err != nil { return err @@ -307,7 +338,7 @@ type AddressGetHistoryResp struct { } // 'blockchain.address.get_history' -func (s *BlockchainAddressService) Get_history(r *http.Request, req *AddressGetHistoryReq, resp **AddressGetHistoryResp) error { +func (s *BlockchainAddressService) Get_history(req *AddressGetHistoryReq, resp **AddressGetHistoryResp) error { address, err := lbcutil.DecodeAddress(req.Address, s.Chain) if err != nil { return err @@ -346,7 +377,7 @@ type ScripthashGetHistoryResp struct { } // 'blockchain.scripthash.get_history' -func (s *BlockchainScripthashService) Get_history(r *http.Request, req *ScripthashGetHistoryReq, resp **ScripthashGetHistoryResp) error { +func (s *BlockchainScripthashService) Get_history(req *ScripthashGetHistoryReq, resp **ScripthashGetHistoryResp) error { scripthash, err := decodeScriptHash(req.ScriptHash) if err != nil { return err @@ -378,7 +409,7 @@ type AddressGetMempoolReq struct { type AddressGetMempoolResp []TxInfoFee // 'blockchain.address.get_mempool' -func (s *BlockchainAddressService) Get_mempool(r *http.Request, req *AddressGetMempoolReq, resp **AddressGetMempoolResp) error { +func (s *BlockchainAddressService) Get_mempool(req *AddressGetMempoolReq, resp **AddressGetMempoolResp) error { address, err := lbcutil.DecodeAddress(req.Address, s.Chain) if err != nil { return err @@ -402,7 +433,7 @@ type ScripthashGetMempoolReq struct { type ScripthashGetMempoolResp []TxInfoFee // 'blockchain.scripthash.get_mempool' -func (s *BlockchainScripthashService) Get_mempool(r *http.Request, req *ScripthashGetMempoolReq, resp **ScripthashGetMempoolResp) error { +func (s *BlockchainScripthashService) Get_mempool(req *ScripthashGetMempoolReq, resp **ScripthashGetMempoolResp) error { scripthash, err := decodeScriptHash(req.ScriptHash) if err != nil { return err @@ -428,7 +459,7 @@ type TXOInfo struct { type AddressListUnspentResp []TXOInfo // 'blockchain.address.listunspent' -func (s *BlockchainAddressService) Listunspent(r *http.Request, req *AddressListUnspentReq, resp **AddressListUnspentResp) error { +func (s *BlockchainAddressService) Listunspent(req *AddressListUnspentReq, resp **AddressListUnspentResp) error { address, err := lbcutil.DecodeAddress(req.Address, s.Chain) if err != nil { return err @@ -460,7 +491,7 @@ type ScripthashListUnspentReq struct { type ScripthashListUnspentResp []TXOInfo // 'blockchain.scripthash.listunspent' -func (s *BlockchainScripthashService) Listunspent(r *http.Request, req *ScripthashListUnspentReq, resp **ScripthashListUnspentResp) error { +func (s *BlockchainScripthashService) Listunspent(req *ScripthashListUnspentReq, resp **ScripthashListUnspentResp) error { scripthash, err := decodeScriptHash(req.ScriptHash) if err != nil { return err @@ -481,3 +512,94 @@ func (s *BlockchainScripthashService) Listunspent(r *http.Request, req *Scriptha *resp = &result return err } + +type AddressSubscribeReq []string +type AddressSubscribeResp []string + +// 'blockchain.address.subscribe' +func (s *BlockchainAddressService) Subscribe(req *AddressSubscribeReq, resp **AddressSubscribeResp) error { + if s.sessionMgr == nil || s.session == nil { + return errors.New("no session, rpc not supported") + } + result := make([]string, 0, len(*req)) + for _, addr := range *req { + address, err := lbcutil.DecodeAddress(addr, s.Chain) + if err != nil { + return err + } + script, err := txscript.PayToAddrScript(address) + if err != nil { + return err + } + hashX := hashXScript(script, s.Chain) + s.sessionMgr.hashXSubscribe(s.session, hashX, addr, true /*subscribe*/) + status, err := s.DB.GetStatus(hashX) + if err != nil { + return err + } + result = append(result, hex.EncodeToString(status)) + } + *resp = (*AddressSubscribeResp)(&result) + return nil +} + +// 'blockchain.address.unsubscribe' +func (s *BlockchainAddressService) Unsubscribe(req *AddressSubscribeReq, resp **AddressSubscribeResp) error { + if s.sessionMgr == nil || s.session == nil { + return errors.New("no session, rpc not supported") + } + for _, addr := range *req { + address, err := lbcutil.DecodeAddress(addr, s.Chain) + if err != nil { + return err + } + script, err := txscript.PayToAddrScript(address) + if err != nil { + return err + } + hashX := hashXScript(script, s.Chain) + s.sessionMgr.hashXSubscribe(s.session, hashX, addr, false /*subscribe*/) + } + *resp = (*AddressSubscribeResp)(nil) + return nil +} + +type ScripthashSubscribeReq string +type ScripthashSubscribeResp string + +// 'blockchain.scripthash.subscribe' +func (s *BlockchainScripthashService) Subscribe(req *ScripthashSubscribeReq, resp **ScripthashSubscribeResp) error { + if s.sessionMgr == nil || s.session == nil { + return errors.New("no session, rpc not supported") + } + var result string + scripthash, err := decodeScriptHash(string(*req)) + if err != nil { + return err + } + hashX := hashX(scripthash) + s.sessionMgr.hashXSubscribe(s.session, hashX, string(*req), true /*subscribe*/) + + status, err := s.DB.GetStatus(hashX) + if err != nil { + return err + } + result = hex.EncodeToString(status) + *resp = (*ScripthashSubscribeResp)(&result) + return nil +} + +// 'blockchain.scripthash.unsubscribe' +func (s *BlockchainScripthashService) Unsubscribe(req *ScripthashSubscribeReq, resp **ScripthashSubscribeResp) error { + if s.sessionMgr == nil || s.session == nil { + return errors.New("no session, rpc not supported") + } + scripthash, err := decodeScriptHash(string(*req)) + if err != nil { + return err + } + hashX := hashX(scripthash) + s.sessionMgr.hashXSubscribe(s.session, hashX, string(*req), false /*subscribe*/) + *resp = (*ScripthashSubscribeResp)(nil) + return nil +} diff --git a/server/jsonrpc_blockchain_test.go b/server/jsonrpc_blockchain_test.go index a57bc35..2c94ebb 100644 --- a/server/jsonrpc_blockchain_test.go +++ b/server/jsonrpc_blockchain_test.go @@ -1,12 +1,18 @@ package server import ( + "encoding/hex" "encoding/json" + "net" "strconv" + "sync" "testing" "github.com/lbryio/herald.go/db" + "github.com/lbryio/herald.go/internal" "github.com/lbryio/lbcd/chaincfg" + "github.com/lbryio/lbcd/txscript" + "github.com/lbryio/lbcutil" ) // Source: test_variety_of_transactions_and_longish_history (lbry-sdk/tests/integration/transactions) @@ -58,14 +64,14 @@ func TestServerGetHeight(t *testing.T) { return } - s := &BlockchainService{ + s := &BlockchainBlockService{ DB: db, Chain: &chaincfg.RegressionNetParams, } req := BlockGetServerHeightReq{} var resp *BlockGetServerHeightResp - err = s.Get_server_height(nil, &req, &resp) + err = s.Get_server_height(&req, &resp) if err != nil { t.Errorf("handler err: %v", err) } @@ -88,7 +94,7 @@ func TestGetChunk(t *testing.T) { return } - s := &BlockchainService{ + s := &BlockchainBlockService{ DB: db, Chain: &chaincfg.RegressionNetParams, } @@ -96,7 +102,7 @@ func TestGetChunk(t *testing.T) { for index := 0; index < 10; index++ { req := BlockGetChunkReq(index) var resp *BlockGetChunkResp - err := s.Get_chunk(nil, &req, &resp) + err := s.Get_chunk(&req, &resp) if err != nil { t.Errorf("index: %v handler err: %v", index, err) } @@ -131,7 +137,7 @@ func TestGetHeader(t *testing.T) { return } - s := &BlockchainService{ + s := &BlockchainBlockService{ DB: db, Chain: &chaincfg.RegressionNetParams, } @@ -139,7 +145,7 @@ func TestGetHeader(t *testing.T) { for height := 0; height < 700; height += 100 { req := BlockGetHeaderReq(height) var resp *BlockGetHeaderResp - err := s.Get_header(nil, &req, &resp) + err := s.Get_header(&req, &resp) if err != nil && height <= 500 { t.Errorf("height: %v handler err: %v", height, err) } @@ -151,6 +157,128 @@ func TestGetHeader(t *testing.T) { } } +func TestHeaders(t *testing.T) { + secondaryPath := "asdf" + db, toDefer, err := db.GetProdDB(regTestDBPath, secondaryPath) + defer toDefer() + if err != nil { + t.Error(err) + return + } + + s := &BlockchainBlockService{ + DB: db, + Chain: &chaincfg.RegressionNetParams, + } + + for height := uint32(0); height < 700; height += 100 { + req := BlockHeadersReq{ + StartHeight: height, + Count: 1, + CpHeight: 0, + B64: false, + } + var resp *BlockHeadersResp + err := s.Headers(&req, &resp) + marshalled, err := json.MarshalIndent(resp, "", " ") + if err != nil { + t.Errorf("height: %v unmarshal err: %v", height, err) + } + t.Logf("height: %v resp: %v", height, string(marshalled)) + } +} + +func TestHeadersSubscribe(t *testing.T) { + secondaryPath := "asdf" + db, toDefer, err := db.GetProdDB(regTestDBPath, secondaryPath) + defer toDefer() + if err != nil { + t.Error(err) + return + } + + sm := newSessionManager(db, &chaincfg.RegressionNetParams, DefaultMaxSessions, DefaultSessionTimeout) + sm.start() + defer sm.stop() + + client1, server1 := net.Pipe() + sess1 := sm.addSession(server1) + client2, server2 := net.Pipe() + sess2 := sm.addSession(server2) + + // Set up logic to read a notification. + var received sync.WaitGroup + recv := func(client net.Conn) { + buf := make([]byte, 1024) + len, err := client.Read(buf) + if err != nil { + t.Errorf("read err: %v", err) + } + t.Logf("len: %v notification: %v", len, string(buf)) + received.Done() + } + received.Add(2) + go recv(client1) + go recv(client2) + + s1 := &BlockchainHeadersService{ + DB: db, + Chain: &chaincfg.RegressionNetParams, + sessionMgr: sm, + session: sess1, + } + s2 := &BlockchainHeadersService{ + DB: db, + Chain: &chaincfg.RegressionNetParams, + sessionMgr: sm, + session: sess2, + } + + // Subscribe with Raw: false. + req1 := HeadersSubscribeReq{Raw: false} + var r any + err = s1.Subscribe(&req1, &r) + if err != nil { + t.Errorf("handler err: %v", err) + } + resp1 := r.(*HeadersSubscribeResp) + marshalled1, err := json.MarshalIndent(resp1, "", " ") + if err != nil { + t.Errorf("unmarshal err: %v", err) + } + // Subscribe with Raw: true. + t.Logf("resp: %v", string(marshalled1)) + req2 := HeadersSubscribeReq{Raw: true} + err = s2.Subscribe(&req2, &r) + if err != nil { + t.Errorf("handler err: %v", err) + } + resp2 := r.(*HeadersSubscribeRawResp) + marshalled2, err := json.MarshalIndent(resp2, "", " ") + if err != nil { + t.Errorf("unmarshal err: %v", err) + } + t.Logf("resp: %v", string(marshalled2)) + + // Now send a notification. + header500, err := hex.DecodeString("00000020e9537f98ae80a3aa0936dd424439b2b9305e5e9d9d5c7aa571b4422c447741e739b3109304ed4f0330d6854271db17da221559a46b68db4ceecfebd9f0c75dbe0100000000000000000000000000000000000000000000000000000000000000b3e02063ffff7f2001000000") + if err != nil { + t.Errorf("decode err: %v", err) + } + note1 := headerNotification{ + HeightHash: internal.HeightHash{Height: 500}, + blockHeader: [112]byte{}, + blockHeaderElectrum: nil, + blockHeaderStr: "", + } + copy(note1.blockHeader[:], header500) + t.Logf("sending notification") + sm.doNotify(note1) + + t.Logf("waiting to receive notification(s)...") + received.Wait() +} + func TestGetBalance(t *testing.T) { secondaryPath := "asdf" db, toDefer, err := db.GetProdDB(regTestDBPath, secondaryPath) @@ -161,16 +289,14 @@ func TestGetBalance(t *testing.T) { } s := &BlockchainAddressService{ - BlockchainService{ - DB: db, - Chain: &chaincfg.RegressionNetParams, - }, + DB: db, + Chain: &chaincfg.RegressionNetParams, } for _, addr := range regTestAddrs { req := AddressGetBalanceReq{addr} var resp *AddressGetBalanceResp - err := s.Get_balance(nil, &req, &resp) + err := s.Get_balance(&req, &resp) if err != nil { t.Errorf("address: %v handler err: %v", addr, err) } @@ -192,16 +318,14 @@ func TestGetHistory(t *testing.T) { } s := &BlockchainAddressService{ - BlockchainService{ - DB: db, - Chain: &chaincfg.RegressionNetParams, - }, + DB: db, + Chain: &chaincfg.RegressionNetParams, } for _, addr := range regTestAddrs { req := AddressGetHistoryReq{addr} var resp *AddressGetHistoryResp - err := s.Get_history(nil, &req, &resp) + err := s.Get_history(&req, &resp) if err != nil { t.Errorf("address: %v handler err: %v", addr, err) } @@ -223,16 +347,14 @@ func TestListUnspent(t *testing.T) { } s := &BlockchainAddressService{ - BlockchainService{ - DB: db, - Chain: &chaincfg.RegressionNetParams, - }, + DB: db, + Chain: &chaincfg.RegressionNetParams, } for _, addr := range regTestAddrs { req := AddressListUnspentReq{addr} var resp *AddressListUnspentResp - err := s.Listunspent(nil, &req, &resp) + err := s.Listunspent(&req, &resp) if err != nil { t.Errorf("address: %v handler err: %v", addr, err) } @@ -243,3 +365,92 @@ func TestListUnspent(t *testing.T) { t.Logf("address: %v resp: %v", addr, string(marshalled)) } } + +func TestAddressSubscribe(t *testing.T) { + secondaryPath := "asdf" + db, toDefer, err := db.GetProdDB(regTestDBPath, secondaryPath) + defer toDefer() + if err != nil { + t.Error(err) + return + } + + sm := newSessionManager(db, &chaincfg.RegressionNetParams, DefaultMaxSessions, DefaultSessionTimeout) + sm.start() + defer sm.stop() + + client1, server1 := net.Pipe() + sess1 := sm.addSession(server1) + client2, server2 := net.Pipe() + sess2 := sm.addSession(server2) + + // Set up logic to read a notification. + var received sync.WaitGroup + recv := func(client net.Conn) { + buf := make([]byte, 1024) + len, err := client.Read(buf) + if err != nil { + t.Errorf("read err: %v", err) + } + t.Logf("len: %v notification: %v", len, string(buf)) + received.Done() + } + received.Add(2) + go recv(client1) + go recv(client2) + + s1 := &BlockchainAddressService{ + DB: db, + Chain: &chaincfg.RegressionNetParams, + sessionMgr: sm, + session: sess1, + } + s2 := &BlockchainAddressService{ + DB: db, + Chain: &chaincfg.RegressionNetParams, + sessionMgr: sm, + session: sess2, + } + + addr1, addr2 := regTestAddrs[1], regTestAddrs[2] + // Subscribe to addr1 and addr2. + req1 := AddressSubscribeReq{addr1, addr2} + var resp1 *AddressSubscribeResp + err = s1.Subscribe(&req1, &resp1) + if err != nil { + t.Errorf("handler err: %v", err) + } + marshalled1, err := json.MarshalIndent(resp1, "", " ") + if err != nil { + t.Errorf("unmarshal err: %v", err) + } + // Subscribe to addr2 only. + t.Logf("resp: %v", string(marshalled1)) + req2 := AddressSubscribeReq{addr2} + var resp2 *AddressSubscribeResp + err = s2.Subscribe(&req2, &resp2) + if err != nil { + t.Errorf("handler err: %v", err) + } + marshalled2, err := json.MarshalIndent(resp2, "", " ") + if err != nil { + t.Errorf("unmarshal err: %v", err) + } + t.Logf("resp: %v", string(marshalled2)) + + // Now send a notification for addr2. + address, _ := lbcutil.DecodeAddress(addr2, sm.chain) + script, _ := txscript.PayToAddrScript(address) + note := hashXNotification{} + copy(note.hashX[:], hashXScript(script, sm.chain)) + status, err := hex.DecodeString((*resp1)[1]) + if err != nil { + t.Errorf("decode err: %v", err) + } + note.status = append(note.status, []byte(status)...) + t.Logf("sending notification") + sm.doNotify(note) + + t.Logf("waiting to receive notification(s)...") + received.Wait() +} diff --git a/server/jsonrpc_claimtrie.go b/server/jsonrpc_claimtrie.go new file mode 100644 index 0000000..a8ef42b --- /dev/null +++ b/server/jsonrpc_claimtrie.go @@ -0,0 +1,27 @@ +package server + +import ( + "github.com/lbryio/herald.go/db" + pb "github.com/lbryio/herald.go/protobuf/go" + log "github.com/sirupsen/logrus" +) + +type ClaimtrieService struct { + DB *db.ReadOnlyDBColumnFamily +} + +type ResolveData struct { + Data []string `json:"data"` +} + +type Result struct { + Data string `json:"data"` +} + +// Resolve is the json rpc endpoint for 'blockchain.claimtrie.resolve'. +func (t *ClaimtrieService) Resolve(args *ResolveData, result **pb.Outputs) error { + log.Println("Resolve") + res, err := InternalResolve(args.Data, t.DB) + *result = res + return err +} diff --git a/server/jsonrpc_service.go b/server/jsonrpc_service.go index 130d2d3..df12f3d 100644 --- a/server/jsonrpc_service.go +++ b/server/jsonrpc_service.go @@ -1,69 +1,133 @@ package server import ( + "fmt" + "net" "net/http" + "strconv" + "strings" - "github.com/gorilla/mux" - "github.com/gorilla/rpc" - "github.com/gorilla/rpc/json" - "github.com/lbryio/herald.go/db" - pb "github.com/lbryio/herald.go/protobuf/go" + gorilla_mux "github.com/gorilla/mux" + gorilla_rpc "github.com/gorilla/rpc" + gorilla_json "github.com/gorilla/rpc/json" log "github.com/sirupsen/logrus" + "golang.org/x/net/netutil" ) -type ClaimtrieService struct { - DB *db.ReadOnlyDBColumnFamily +type gorillaRpcCodec struct { + gorilla_rpc.Codec } -type ResolveData struct { - Data []string `json:"data"` +func (c *gorillaRpcCodec) NewRequest(r *http.Request) gorilla_rpc.CodecRequest { + return &gorillaRpcCodecRequest{c.Codec.NewRequest(r)} } -type Result struct { - Data string `json:"data"` +// gorillaRpcCodecRequest provides ability to rewrite the incoming +// request "method" field. For example: +// blockchain.block.get_header -> blockchain_block.Get_header +// blockchain.address.listunspent -> blockchain_address.Listunspent +// This makes the "method" string compatible with Gorilla/RPC +// requirements. +type gorillaRpcCodecRequest struct { + gorilla_rpc.CodecRequest } -// Resolve is the json rpc endpoint for 'blockchain.claimtrie.resolve'. -func (t *ClaimtrieService) Resolve(r *http.Request, args *ResolveData, result **pb.Outputs) error { - log.Println("Resolve") - res, err := InternalResolve(args.Data, t.DB) - *result = res - return err +func (cr *gorillaRpcCodecRequest) Method() (string, error) { + rawMethod, err := cr.CodecRequest.Method() + if err != nil { + return rawMethod, err + } + parts := strings.Split(rawMethod, ".") + if len(parts) < 2 { + return rawMethod, fmt.Errorf("blockchain rpc: service/method ill-formed: %q", rawMethod) + } + service := strings.Join(parts[0:len(parts)-1], "_") + method := parts[len(parts)-1] + if len(method) < 1 { + return rawMethod, fmt.Errorf("blockchain rpc: method ill-formed: %q", method) + } + method = strings.ToUpper(string(method[0])) + string(method[1:]) + return service + "." + method, err } // StartJsonRPC starts the json rpc server and registers the endpoints. func (s *Server) StartJsonRPC() error { - port := ":" + s.Args.JSONRPCPort + s.sessionManager.start() + defer s.sessionManager.stop() - s1 := rpc.NewServer() // Create a new RPC server - // Register the type of data requested as JSON, with custom codec. - s1.RegisterCodec(&BlockchainCodec{json.NewCodec()}, "application/json") - - // Register "blockchain.claimtrie.*"" handlers. - claimtrieSvc := &ClaimtrieService{s.DB} - err := s1.RegisterService(claimtrieSvc, "blockchain_claimtrie") - if err != nil { - log.Errorf("RegisterService: %v\n", err) + // Set up the pure JSONRPC server with persistent connections/sessions. + if s.Args.JSONRPCPort != 0 { + port := ":" + strconv.FormatUint(uint64(s.Args.JSONRPCPort), 10) + laddr, err := net.ResolveTCPAddr("tcp", port) + if err != nil { + log.Errorf("ResoveIPAddr: %v\n", err) + goto fail1 + } + listener, err := net.ListenTCP("tcp", laddr) + if err != nil { + log.Errorf("ListenTCP: %v\n", err) + goto fail1 + } + log.Infof("JSONRPC server listening on %s", listener.Addr().String()) + acceptConnections := func(listener net.Listener) { + for { + conn, err := listener.Accept() + if err != nil { + log.Errorf("Accept: %v\n", err) + break + } + log.Infof("Accepted: %v", conn.RemoteAddr()) + s.sessionManager.addSession(conn) + } + } + go acceptConnections(netutil.LimitListener(listener, s.sessionManager.sessionsMax)) } - // Register other "blockchain.{block,address,scripthash}.*" handlers. - blockchainSvc := &BlockchainService{s.DB, s.Chain} - err = s1.RegisterService(blockchainSvc, "blockchain_block") - if err != nil { - log.Errorf("RegisterService: %v\n", err) - } - err = s1.RegisterService(&BlockchainAddressService{*blockchainSvc}, "blockchain_address") - if err != nil { - log.Errorf("RegisterService: %v\n", err) - } - err = s1.RegisterService(&BlockchainScripthashService{*blockchainSvc}, "blockchain_scripthash") - if err != nil { - log.Errorf("RegisterService: %v\n", err) +fail1: + // Set up the JSONRPC over HTTP server. + if s.Args.JSONRPCHTTPPort != 0 { + s1 := gorilla_rpc.NewServer() // Create a new RPC server + // Register the type of data requested as JSON, with custom codec. + s1.RegisterCodec(&gorillaRpcCodec{gorilla_json.NewCodec()}, "application/json") + + // Register "blockchain.claimtrie.*"" handlers. + claimtrieSvc := &ClaimtrieService{s.DB} + err := s1.RegisterTCPService(claimtrieSvc, "blockchain_claimtrie") + if err != nil { + log.Errorf("RegisterTCPService: %v\n", err) + goto fail2 + } + + // Register other "blockchain.{block,address,scripthash}.*" handlers. + blockchainSvc := &BlockchainBlockService{s.DB, s.Chain} + err = s1.RegisterTCPService(blockchainSvc, "blockchain_block") + if err != nil { + log.Errorf("RegisterTCPService: %v\n", err) + goto fail2 + } + err = s1.RegisterTCPService(&BlockchainHeadersService{s.DB, s.Chain, nil, nil}, "blockchain_headers") + if err != nil { + log.Errorf("RegisterTCPService: %v\n", err) + goto fail2 + } + err = s1.RegisterTCPService(&BlockchainAddressService{s.DB, s.Chain, nil, nil}, "blockchain_address") + if err != nil { + log.Errorf("RegisterTCPService: %v\n", err) + goto fail2 + } + err = s1.RegisterTCPService(&BlockchainScripthashService{s.DB, s.Chain, nil, nil}, "blockchain_scripthash") + if err != nil { + log.Errorf("RegisterTCPService: %v\n", err) + goto fail2 + } + + r := gorilla_mux.NewRouter() + r.Handle("/rpc", s1) + port := ":" + strconv.FormatUint(uint64(s.Args.JSONRPCHTTPPort), 10) + log.Infof("HTTP JSONRPC server listening on %s", port) + log.Fatal(http.ListenAndServe(port, r)) } - r := mux.NewRouter() - r.Handle("/rpc", s1) - log.Fatal(http.ListenAndServe(port, r)) - +fail2: return nil } diff --git a/server/notifier.go b/server/notifier.go index 8d8a367..1f4a847 100644 --- a/server/notifier.go +++ b/server/notifier.go @@ -52,8 +52,13 @@ func (s *Server) DoNotify(heightHash *internal.HeightHash) error { // RunNotifier Runs the notfying action forever func (s *Server) RunNotifier() error { - for heightHash := range s.NotifierChan { - s.DoNotify(heightHash) + for notification := range s.NotifierChan { + switch notification.(type) { + case internal.HeightHash: + heightHash, _ := notification.(internal.HeightHash) + s.DoNotify(&heightHash) + } + s.sessionManager.doNotify(notification) } return nil } diff --git a/server/notifier_test.go b/server/notifier_test.go index e642e62..3e820cc 100644 --- a/server/notifier_test.go +++ b/server/notifier_test.go @@ -80,7 +80,7 @@ func TestNotifierServer(t *testing.T) { hash, _ := hex.DecodeString("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") logrus.Warn("sending hash") - hub.NotifierChan <- &internal.HeightHash{Height: 1, BlockHash: hash} + hub.NotifierChan <- internal.HeightHash{Height: 1, BlockHash: hash} res := <-resCh logrus.Info(string(res)) diff --git a/server/server.go b/server/server.go index 821d5c6..4bb1bb6 100644 --- a/server/server.go +++ b/server/server.go @@ -18,7 +18,6 @@ import ( "github.com/ReneKroon/ttlcache/v2" "github.com/lbryio/herald.go/db" - "github.com/lbryio/herald.go/internal" "github.com/lbryio/herald.go/internal/metrics" "github.com/lbryio/herald.go/meta" pb "github.com/lbryio/herald.go/protobuf/go" @@ -53,7 +52,8 @@ type Server struct { ExternalIP net.IP HeightSubs map[net.Addr]net.Conn HeightSubsMut sync.RWMutex - NotifierChan chan *internal.HeightHash + NotifierChan chan interface{} + sessionManager *sessionManager pb.UnimplementedHubServer } @@ -332,7 +332,8 @@ func MakeHubServer(ctx context.Context, args *Args) *Server { ExternalIP: net.IPv4(127, 0, 0, 1), HeightSubs: make(map[net.Addr]net.Conn), HeightSubsMut: sync.RWMutex{}, - NotifierChan: make(chan *internal.HeightHash), + NotifierChan: make(chan interface{}), + sessionManager: newSessionManager(myDB, &chain, args.MaxSessions, args.SessionTimeout), } // Start up our background services diff --git a/server/session.go b/server/session.go new file mode 100644 index 0000000..9b9fbb7 --- /dev/null +++ b/server/session.go @@ -0,0 +1,383 @@ +package server + +import ( + "encoding/hex" + "fmt" + "net" + "net/rpc" + "net/rpc/jsonrpc" + "strings" + "sync" + "time" + "unsafe" + + "github.com/lbryio/herald.go/db" + "github.com/lbryio/herald.go/internal" + "github.com/lbryio/lbcd/chaincfg" + log "github.com/sirupsen/logrus" +) + +type headerNotification struct { + internal.HeightHash + blockHeader [HEADER_SIZE]byte + blockHeaderElectrum *BlockHeaderElectrum + blockHeaderStr string +} + +type hashXNotification struct { + hashX [HASHX_LEN]byte + status []byte + statusStr string +} + +type session struct { + id uintptr + addr net.Addr + conn net.Conn + // hashXSubs maps hashX to the original subscription key (address or scripthash) + hashXSubs map[[HASHX_LEN]byte]string + // headersSub indicates header subscription + headersSub bool + // headersSubRaw indicates the header subscription mode + headersSubRaw bool + // client provides the ability to send notifications + client rpc.ClientCodec + clientSeq uint64 + // lastRecv records time of last incoming data + lastRecv time.Time + // lastSend records time of last outgoing data + lastSend time.Time +} + +func (s *session) doNotify(notification interface{}) { + var method string + var params interface{} + switch notification.(type) { + case headerNotification: + if !s.headersSub { + return + } + note, _ := notification.(headerNotification) + heightHash := note.HeightHash + method = "blockchain.headers.subscribe" + if s.headersSubRaw { + header := note.blockHeaderStr + if len(header) == 0 { + header = hex.EncodeToString(note.blockHeader[:]) + } + params = &HeadersSubscribeRawResp{ + Hex: header, + Height: uint32(heightHash.Height), + } + } else { + header := note.blockHeaderElectrum + if header == nil { // not initialized + header = newBlockHeaderElectrum(¬e.blockHeader, uint32(heightHash.Height)) + } + params = header + } + case hashXNotification: + note, _ := notification.(hashXNotification) + orig, ok := s.hashXSubs[note.hashX] + if !ok { + return + } + if len(orig) == 64 { + method = "blockchain.scripthash.subscribe" + } else { + method = "blockchain.address.subscribe" + } + status := note.statusStr + if len(status) == 0 { + status = hex.EncodeToString(note.status) + } + params = []string{orig, status} + default: + log.Warnf("unknown notification type: %v", notification) + return + } + // Send the notification. + s.clientSeq += 1 + req := &rpc.Request{ + ServiceMethod: method, + Seq: s.clientSeq, + } + err := s.client.WriteRequest(req, params) + if err != nil { + log.Warnf("error: %v", err) + } + // Bump last send time. + s.lastSend = time.Now() +} + +type sessionMap map[uintptr]*session + +type sessionManager struct { + // sessionsMut protects sessions, headerSubs, hashXSubs state + sessionsMut sync.RWMutex + sessions sessionMap + sessionsWait sync.WaitGroup + sessionsMax int + sessionTimeout time.Duration + manageTicker *time.Ticker + db *db.ReadOnlyDBColumnFamily + chain *chaincfg.Params + // headerSubs are sessions subscribed via 'blockchain.headers.subscribe' + headerSubs sessionMap + // hashXSubs are sessions subscribed via 'blockchain.{address,scripthash}.subscribe' + hashXSubs map[[HASHX_LEN]byte]sessionMap +} + +func newSessionManager(db *db.ReadOnlyDBColumnFamily, chain *chaincfg.Params, sessionsMax, sessionTimeout int) *sessionManager { + return &sessionManager{ + sessions: make(sessionMap), + sessionsMax: sessionsMax, + sessionTimeout: time.Duration(sessionTimeout) * time.Second, + manageTicker: time.NewTicker(time.Duration(max(5, sessionTimeout/20)) * time.Second), + db: db, + chain: chain, + headerSubs: make(sessionMap), + hashXSubs: make(map[[HASHX_LEN]byte]sessionMap), + } +} + +func (sm *sessionManager) start() { + go sm.manage() +} + +func (sm *sessionManager) stop() { + sm.sessionsMut.Lock() + defer sm.sessionsMut.Unlock() + sm.headerSubs = make(sessionMap) + sm.hashXSubs = make(map[[HASHX_LEN]byte]sessionMap) + for _, sess := range sm.sessions { + sess.client.Close() + sess.conn.Close() + } + sm.sessions = make(sessionMap) +} + +func (sm *sessionManager) manage() { + for { + sm.sessionsMut.Lock() + for _, sess := range sm.sessions { + if time.Since(sess.lastRecv) > sm.sessionTimeout { + sm.removeSessionLocked(sess) + log.Infof("session %v timed out", sess.addr.String()) + } + } + sm.sessionsMut.Unlock() + // Wait for next management clock tick. + <-sm.manageTicker.C + } +} + +func (sm *sessionManager) addSession(conn net.Conn) *session { + sm.sessionsMut.Lock() + sess := &session{ + addr: conn.RemoteAddr(), + conn: conn, + hashXSubs: make(map[[11]byte]string), + client: jsonrpc.NewClientCodec(conn), + lastRecv: time.Now(), + } + sess.id = uintptr(unsafe.Pointer(sess)) + sm.sessions[sess.id] = sess + sm.sessionsMut.Unlock() + + // Create a new RPC server. These services are linked to the + // session, which allows RPC handlers to know the session for + // each request and update subscriptions. + s1 := rpc.NewServer() + + // Register "blockchain.claimtrie.*"" handlers. + claimtrieSvc := &ClaimtrieService{sm.db} + err := s1.RegisterName("blockchain.claimtrie", claimtrieSvc) + if err != nil { + log.Errorf("RegisterService: %v\n", err) + } + + // Register other "blockchain.{block,address,scripthash}.*" handlers. + blockchainSvc := &BlockchainBlockService{sm.db, sm.chain} + err = s1.RegisterName("blockchain.block", blockchainSvc) + if err != nil { + log.Errorf("RegisterName: %v\n", err) + goto fail + } + err = s1.RegisterName("blockchain.headers", &BlockchainHeadersService{sm.db, sm.chain, sm, sess}) + if err != nil { + log.Errorf("RegisterName: %v\n", err) + goto fail + } + err = s1.RegisterName("blockchain.address", &BlockchainAddressService{sm.db, sm.chain, sm, sess}) + if err != nil { + log.Errorf("RegisterName: %v\n", err) + goto fail + } + err = s1.RegisterName("blockchain.scripthash", &BlockchainScripthashService{sm.db, sm.chain, sm, sess}) + if err != nil { + log.Errorf("RegisterName: %v\n", err) + goto fail + } + + sm.sessionsWait.Add(1) + go func() { + s1.ServeCodec(&SessionServerCodec{jsonrpc.NewServerCodec(conn), sess}) + log.Infof("session %v goroutine exit", sess.addr.String()) + sm.sessionsWait.Done() + }() + return sess + +fail: + sm.removeSession(sess) + return nil +} + +func (sm *sessionManager) removeSession(sess *session) { + sm.sessionsMut.Lock() + defer sm.sessionsMut.Unlock() + sm.removeSessionLocked(sess) +} + +func (sm *sessionManager) removeSessionLocked(sess *session) { + if sess.headersSub { + delete(sm.headerSubs, sess.id) + } + for hashX := range sess.hashXSubs { + subs, ok := sm.hashXSubs[hashX] + if !ok { + continue + } + delete(subs, sess.id) + } + delete(sm.sessions, sess.id) + sess.client.Close() + sess.conn.Close() +} + +func (sm *sessionManager) headersSubscribe(sess *session, raw bool, subscribe bool) { + sm.sessionsMut.Lock() + defer sm.sessionsMut.Unlock() + if subscribe { + sm.headerSubs[sess.id] = sess + sess.headersSub = true + sess.headersSubRaw = raw + return + } + delete(sm.headerSubs, sess.id) + sess.headersSub = false + sess.headersSubRaw = false +} + +func (sm *sessionManager) hashXSubscribe(sess *session, hashX []byte, original string, subscribe bool) { + sm.sessionsMut.Lock() + defer sm.sessionsMut.Unlock() + var key [HASHX_LEN]byte + copy(key[:], hashX) + subs, ok := sm.hashXSubs[key] + if subscribe { + if !ok { + subs = make(sessionMap) + sm.hashXSubs[key] = subs + } + subs[sess.id] = sess + sess.hashXSubs[key] = original + return + } + if ok { + delete(subs, sess.id) + if len(subs) == 0 { + delete(sm.hashXSubs, key) + } + } + delete(sess.hashXSubs, key) +} + +func (sm *sessionManager) doNotify(notification interface{}) { + sm.sessionsMut.RLock() + var subsCopy sessionMap + switch notification.(type) { + case headerNotification: + note, _ := notification.(headerNotification) + subsCopy = sm.headerSubs + if len(subsCopy) > 0 { + note.blockHeaderElectrum = newBlockHeaderElectrum(¬e.blockHeader, uint32(note.Height)) + note.blockHeaderStr = hex.EncodeToString(note.blockHeader[:]) + } + case hashXNotification: + note, _ := notification.(hashXNotification) + hashXSubs, ok := sm.hashXSubs[note.hashX] + if ok { + subsCopy = hashXSubs + } + if len(subsCopy) > 0 { + note.statusStr = hex.EncodeToString(note.status) + } + default: + log.Warnf("unknown notification type: %v", notification) + } + sm.sessionsMut.RUnlock() + + // Deliver notification to relevant sessions. + for _, sess := range subsCopy { + sess.doNotify(notification) + } +} + +type SessionServerCodec struct { + rpc.ServerCodec + sess *session +} + +// ReadRequestHeader provides ability to rewrite the incoming +// request "method" field. For example: +// blockchain.block.get_header -> blockchain.block.Get_header +// blockchain.address.listunspent -> blockchain.address.Listunspent +// This makes the "method" string compatible with rpc.Server +// requirements. +func (c *SessionServerCodec) ReadRequestHeader(req *rpc.Request) error { + log.Infof("receive header from %v", c.sess.addr.String()) + err := c.ServerCodec.ReadRequestHeader(req) + if err != nil { + log.Warnf("error: %v", err) + return err + } + rawMethod := req.ServiceMethod + parts := strings.Split(rawMethod, ".") + if len(parts) < 2 { + return fmt.Errorf("blockchain rpc: service/method ill-formed: %q", rawMethod) + } + service := strings.Join(parts[0:len(parts)-1], ".") + method := parts[len(parts)-1] + if len(method) < 1 { + return fmt.Errorf("blockchain rpc: method ill-formed: %q", method) + } + method = strings.ToUpper(string(method[0])) + string(method[1:]) + req.ServiceMethod = service + "." + method + return err +} + +// ReadRequestBody wraps the regular implementation, but updates session stats too. +func (c *SessionServerCodec) ReadRequestBody(params any) error { + err := c.ServerCodec.ReadRequestBody(params) + if err != nil { + log.Warnf("error: %v", err) + return err + } + log.Infof("receive body from %v", c.sess.addr.String()) + // Bump last receive time. + c.sess.lastRecv = time.Now() + return err +} + +// WriteResponse wraps the regular implementation, but updates session stats too. +func (c *SessionServerCodec) WriteResponse(resp *rpc.Response, reply any) error { + log.Infof("respond to %v", c.sess.addr.String()) + err := c.ServerCodec.WriteResponse(resp, reply) + if err != nil { + return err + } + // Bump last send time. + c.sess.lastSend = time.Now() + return err +}