diff --git a/rpcserver.go b/rpcserver.go index 731f4cd3..fc874992 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -166,7 +166,7 @@ func (s *rpcServer) Start() { ReadTimeout: time.Second * rpcAuthTimeoutSeconds, } rpcServeMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - if err := s.checkAuth(r); err != nil { + if _, err := s.checkAuth(r, true); err != nil { jsonAuthFail(w, r, s) return } @@ -175,12 +175,15 @@ func (s *rpcServer) Start() { }) rpcServeMux.HandleFunc("/wallet", func(w http.ResponseWriter, r *http.Request) { - if err := s.checkAuth(r); err != nil { + authenticated, err := s.checkAuth(r, false) + if err != nil { http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) return } wsServer := websocket.Server{ - Handler: websocket.Handler(s.walletReqsNotifications), + Handler: websocket.Handler(func(ws *websocket.Conn) { + s.walletReqsNotifications(ws, authenticated) + }), } wsServer.ServeHTTP(w, r) }) @@ -202,20 +205,24 @@ func (s *rpcServer) Start() { // returned. // // This check is time-constant. -func (s *rpcServer) checkAuth(r *http.Request) error { +func (s *rpcServer) checkAuth(r *http.Request, require bool) (bool, error) { authhdr := r.Header["Authorization"] if len(authhdr) <= 0 { - rpcsLog.Warnf("Auth failure.") - return errors.New("auth failure") + if require { + rpcsLog.Warnf("Auth failure.") + return false, errors.New("auth failure") + } + + return false, nil } authsha := sha256.Sum256([]byte(authhdr[0])) cmp := subtle.ConstantTimeCompare(authsha[:], s.authsha[:]) if cmp != 1 { rpcsLog.Warnf("Auth failure.") - return errors.New("auth failure") + return false, errors.New("auth failure") } - return nil + return true, nil } // Stop is used by server.go to stop the rpc listener. diff --git a/rpcwebsocket.go b/rpcwebsocket.go index 81dd404a..36692c52 100644 --- a/rpcwebsocket.go +++ b/rpcwebsocket.go @@ -8,8 +8,12 @@ import ( "bytes" "code.google.com/p/go.net/websocket" "container/list" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "github.com/conformal/btcdb" "github.com/conformal/btcjson" @@ -23,6 +27,8 @@ import ( var timeZeroVal time.Time +var ErrBadAuth = errors.New("invalid credentials") + type ntfnChan chan btcjson.Cmd type handlerChans struct { @@ -230,6 +236,15 @@ func newWebsocketContext() *wsContext { // requestContexts holds all requests for a single wallet connection. type requestContexts struct { + // disconnecting indicates the websocket is in the process of + // disconnecting. This is used to prevent trying to handle any more + // commands in the interim. + disconnecting bool + + // authenticated specifies whether a client has been authenticated + // and therefore is allowed to communicated over the websocket. + authenticated bool + // blockUpdates specifies whether a client has requested notifications // for whenever blocks are connected or disconnected from the main // chain. @@ -512,15 +527,41 @@ func handleWalletSendRawTransaction(s *rpcServer, icmd btcjson.Cmd, c handlerCha return result, nil } +// websocketAuthenticate checks the authenticate command for valid credentials. +// An error is returned if the credentials are invalid or if the connection is +// already authenticated. +// +// This function MUST be called with the websocket lock held. +func websocketAuthenticate(icmd btcjson.Cmd, rc *requestContexts, authSha []byte) error { + cmd, ok := icmd.(*btcws.AuthenticateCmd) + if !ok { + return fmt.Errorf("%s", btcjson.ErrInternal.Message) + } + + // Already authenticated? + if rc.authenticated { + rpcsLog.Warnf("Already authenticated") + return ErrBadAuth + } + + // Check credentials. + login := cmd.Username + ":" + cmd.Passphrase + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(login)) + calcualtedAuthSha := sha256.Sum256([]byte(auth)) + cmp := subtle.ConstantTimeCompare(calcualtedAuthSha[:], authSha) + if cmp != 1 { + rpcsLog.Warnf("Auth failure.") + return ErrBadAuth + } + + rc.authenticated = true + return nil +} + // AddWalletListener adds a channel to listen for new messages from a // wallet. -func (s *rpcServer) AddWalletListener(n ntfnChan) { +func (s *rpcServer) AddWalletListener(n ntfnChan, rc *requestContexts) { s.ws.Lock() - rc := &requestContexts{ - txRequests: make(map[string]struct{}), - spentRequests: make(map[btcwire.OutPoint]struct{}), - minedTxRequests: make(map[btcwire.ShaHash]struct{}), - } s.ws.connections[n] = rc s.ws.Unlock() } @@ -547,7 +588,7 @@ func (s *rpcServer) RemoveWalletListener(n ntfnChan) { // walletReqsNotifications is the handler function for websocket // connections from a btcwallet instance. It reads messages from wallet and // sends back replies, as well as notififying wallets of chain updates. -func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) { +func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn, authenticated bool) { // Clear the read deadline that was set before the websocket hijacked // the connection. ws.SetReadDeadline(timeZeroVal) @@ -555,7 +596,13 @@ func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) { // Add wallet notification channel so this handler receives btcd chain // notifications. n := make(ntfnChan) - s.AddWalletListener(n) + rc := &requestContexts{ + authenticated: authenticated, + txRequests: make(map[string]struct{}), + spentRequests: make(map[btcwire.OutPoint]struct{}), + minedTxRequests: make(map[btcwire.ShaHash]struct{}), + } + s.AddWalletListener(n, rc) defer s.RemoveWalletListener(n) // Channel for responses. @@ -622,10 +669,18 @@ func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) { } case m := <-msgs: - // Spawn new goroutine to handle request. Responses and - // notifications are read by channels in this for-select - // loop. - go s.websocketJSONHandler(r, hc, m) + // This function internally spawns a new goroutine to + // the handle request after validating authentication. + // Responses and notifications are read by channels in + // this for-select loop. + if !rc.disconnecting { + err := s.websocketJSONHandler(r, hc, m) + if err == ErrBadAuth { + rc.disconnecting = true + close(disconnected) + ws.Close() + } + } case response := <-r: // Marshal and send response. @@ -670,10 +725,7 @@ func (s *rpcServer) walletReqsNotifications(ws *websocket.Conn) { // websocketJSONHandler parses and handles a marshalled json message, // sending the marshalled reply to a wallet notification channel. -func (s *rpcServer) websocketJSONHandler(r chan *btcjson.Reply, c handlerChans, msg string) { - s.wg.Add(1) - defer s.wg.Done() - +func (s *rpcServer) websocketJSONHandler(r chan *btcjson.Reply, c handlerChans, msg string) error { var resp *btcjson.Reply cmd, jsonErr := parseCmd([]byte(msg)) @@ -685,24 +737,57 @@ func (s *rpcServer) websocketJSONHandler(r chan *btcjson.Reply, c handlerChans, // should be ignored. id := cmd.Id() if id == nil { - return + return nil } resp.Id = &id } resp.Error = jsonErr } else { - resp = respondToAnyCmd(cmd, s, c) + // The first command must be the "authenticate" command if the + // connection is not already authenticated. + s.ws.Lock() + rc := s.ws.connections[c.n] + if _, ok := cmd.(*btcws.AuthenticateMsg); ok { + // Validate the provided credentials. + err := websocketAuthenticate(cmd, rc, s.authsha[:]) + if err != nil { + s.ws.Unlock() + return err + } + + // Generate an empty response to send for the successful + // authentication. + id := cmd.Id() + resp = &btcjson.Reply{ + Id: &id, + Result: nil, + Error: nil, + } + } else if !rc.authenticated { + rpcsLog.Warnf("Unauthenticated websocket message " + + "received") + s.ws.Unlock() + return ErrBadAuth + } + + s.ws.Unlock() } - // Once response has been processed, only send if the client - // is still connected. - select { - case <-c.disconnected: - return + // Find and run handler in new goroutine. + go func() { + if resp == nil { + resp = respondToAnyCmd(cmd, s, c) + } + select { + case <-c.disconnected: + return - default: - r <- resp - } + default: + r <- resp + } + }() + + return nil } // NotifyBlockConnected creates and marshalls a JSON message to notify