diff --git a/sockets.go b/sockets.go index e7bb547..c781dd2 100644 --- a/sockets.go +++ b/sockets.go @@ -43,6 +43,10 @@ import ( ) var ( + // ErrBadAuth represents an error where a request is denied due to + // a missing, incorrect, or duplicate authentication request. + ErrBadAuth = errors.New("bad auth") + // ErrConnRefused represents an error where a connection to another // process cannot be established. ErrConnRefused = errors.New("connection refused") @@ -67,8 +71,8 @@ type server struct { } type clientContext struct { - send chan []byte - disconnected chan struct{} // closed on disconnect + send chan []byte + quit chan struct{} // closed on disconnect } // parseListeners splits the list of listen addresses passed in addrs into @@ -214,20 +218,48 @@ func ParseRequest(msg []byte) (btcjson.Cmd, *btcjson.Error) { // ReplyToFrontend responds to a marshaled JSON-RPC request with a // marshaled JSON-RPC response for both standard and extension -// (websocket) clients. -func ReplyToFrontend(msg []byte, ws bool) []byte { +// (websocket) clients. The returned error is ErrBadAuth if a +// missing, incorrect, or duplicate authentication request is +// received. +func (s *server) ReplyToFrontend(msg []byte, ws, authenticated bool) ([]byte, error) { cmd, jsonErr := ParseRequest(msg) var id interface{} if cmd != nil { id = cmd.Id() } + + // If client is not already authenticated, the parsed request must + // be for authentication. + authCmd, ok := cmd.(*btcws.AuthenticateCmd) + if authenticated { + if ok { + // Duplicate auth request. + return nil, ErrBadAuth + } + } else { + if !ok { + // The first unauthenticated request must be an auth request. + return nil, ErrBadAuth + } + + // Check credentials. + login := authCmd.Username + ":" + authCmd.Passphrase + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(login)) + authSha := sha256.Sum256([]byte(auth)) + cmp := subtle.ConstantTimeCompare(authSha[:], s.authsha[:]) + if cmp != 1 { + return nil, ErrBadAuth + } + return nil, nil + } + if jsonErr != nil { response := btcjson.Reply{ Id: &id, Error: jsonErr, } mresponse, _ := json.Marshal(response) - return mresponse + return mresponse, nil } cReq := NewClientRequest(cmd, ws) @@ -248,7 +280,7 @@ func ReplyToFrontend(msg []byte, ws bool) []byte { mresponse, _ = json.Marshal(&response) } - return mresponse + return mresponse, nil } // ServeRPCRequest processes and replies to a JSON-RPC client request. @@ -258,7 +290,11 @@ func (s *server) ServeRPCRequest(w http.ResponseWriter, r *http.Request) { log.Errorf("RPCS: Error getting JSON message: %v", err) } - resp := ReplyToFrontend(body, false) + resp, err := s.ReplyToFrontend(body, false, true) + if err == ErrBadAuth { + http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) + return + } if _, err := w.Write(resp); err != nil { log.Warnf("RPCS: could not respond to RPC request: %v", err) } @@ -277,7 +313,7 @@ func clientResponseDuplicator() { case n := <-allClients: for cc := range clients { select { - case <-cc.disconnected: + case <-cc.quit: delete(clients, cc) case cc.send <- n: } @@ -297,76 +333,176 @@ func NotifyBtcdConnection(reply chan []byte) { } +// stringQueue manages a queue of strings, reading from in and sending +// the oldest unsent to out. This handler closes out and returns after +// in is closed and any queued items are sent. Any reads on quit result +// in immediate shutdown of the handler. +func stringQueue(in <-chan string, out chan<- string, quit <-chan struct{}) { + var q []string + var dequeue chan<- string + skipQueue := out + var next string +out: + for { + select { + case n, ok := <-in: + if !ok { + // Sender closed input channel. Nil channel + // and continue so the remaining queued + // items may be sent. If the queue is empty, + // break out of the loop. + in = nil + if dequeue == nil { + break out + } + continue + } + + // Either send to out immediately if skipQueue is + // non-nil (queue is empty) and reader is ready, + // or append to the queue and send later. + select { + case skipQueue <- n: + default: + q = append(q, n) + dequeue = out + skipQueue = nil + next = q[0] + } + + case dequeue <- next: + copy(q, q[1:]) + q[len(q)-1] = "" // avoid leak + q = q[:len(q)-1] + if len(q) == 0 { + // If the input chan was closed and nil'd, + // break out of the loop. + if in == nil { + break out + } + dequeue = nil + skipQueue = out + } else { + next = q[0] + } + + case <-quit: + break out + } + } + close(out) +} + // WSSendRecv is the handler for websocket client connections. It loops // forever (until disconnected), reading JSON-RPC requests and sending // sending responses and notifications. -func WSSendRecv(ws *websocket.Conn) { +func (s *server) WSSendRecv(ws *websocket.Conn, authenticated bool) { // Add client context so notifications duplicated to each // client are received by this client. + recvQuit := make(chan struct{}) + sendQuit := make(chan struct{}) cc := clientContext{ - send: make(chan []byte, 1), // buffer size is number of initial notifications - disconnected: make(chan struct{}), + send: make(chan []byte, 1), // buffer size is number of initial notifications + quit: make(chan struct{}), } + go func() { + select { + case <-recvQuit: + case <-sendQuit: + } + close(cc.quit) + }() NotifyBtcdConnection(cc.send) // TODO(jrick): clients should explicitly request this. addClient <- cc - defer close(cc.disconnected) // received passes all received messages from the currently connected // frontend to the for-select loop. It is closed when reading a // message from the websocket connection fails (presumably due to // a disconnected client). - received := make(chan []byte) + recvQueueIn := make(chan string) // Receive messages from websocket and send across jsonMsgs until // connection is lost go func() { for { - var m []byte + var m string if err := websocket.Message.Receive(ws, &m); err != nil { - // Log warning if the client did not disconnect. - if err != io.EOF { - log.Warnf("Cannot receive client websocket message: %v", - err) + select { + case <-sendQuit: + // Do not log error. + + default: + // Log warning if the client did not disconnect. + if err != io.EOF { + log.Warnf("Cannot receive client websocket message: %v", + err) + } } - close(received) + close(recvQueueIn) + close(recvQuit) return } - received <- m + recvQueueIn <- m } }() + // Manage queue of received messages for LIFO processing. + recvQueueOut := make(chan string) + go stringQueue(recvQueueIn, recvQueueOut, cc.quit) + + badAuth := make(chan struct{}) + go func() { + for m := range recvQueueOut { + resp, err := s.ReplyToFrontend([]byte(m), true, authenticated) + if err == ErrBadAuth { + select { + case badAuth <- struct{}{}: + case <-cc.quit: + } + return + } + + // Authentication passed. + authenticated = true + + select { + case cc.send <- resp: + case <-cc.quit: + } + } + close(cc.send) + }() + const deadline time.Duration = 2 * time.Second +out: for { select { - case m, ok := <-received: + case <-badAuth: + // Bad auth. Disconnect. + log.Warnf("Disconnecting improperly authorized websocket client") + ws.Close() + break out + + case m, ok := <-cc.send: if !ok { - // client disconnected. - return + // Nothing left to send. Return so the handler exits. + break out } - // Handle request here. - go func(m []byte) { - resp := ReplyToFrontend(m, true) - select { - case cc.send <- resp: - case <-cc.disconnected: - } - }(m) - - case m := <-cc.send: err := ws.SetWriteDeadline(time.Now().Add(deadline)) if err != nil { log.Errorf("Cannot set write deadline: %v", err) - return + break out } - err = websocket.Message.Send(ws, m) + err = websocket.Message.Send(ws, string(m)) if err != nil { log.Infof("Cannot complete client websocket send: %v", err) - return + break out } } } + close(sendQuit) } // NotifyNewBlockChainHeight notifies all frontends of a new @@ -396,17 +532,27 @@ func (s *server) Start() { httpServer := &http.Server{Handler: serveMux} serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if err := s.checkAuth(r); err != nil { + log.Warnf("Unauthorized client connection attempt") http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) return } s.ServeRPCRequest(w, r) }) serveMux.HandleFunc("/frontend", func(w http.ResponseWriter, r *http.Request) { - if err := s.checkAuth(r); err != nil { - http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) - return + authenticated := false + if err := s.checkAuth(r); err == nil { + authenticated = true } - websocket.Handler(WSSendRecv).ServeHTTP(w, r) + + // A new Server instance is created rather than just creating the + // handler closure since the default server will disconnect the + // client if the origin is unset. + wsServer := websocket.Server{ + Handler: websocket.Handler(func(ws *websocket.Conn) { + s.WSSendRecv(ws, authenticated) + }), + } + wsServer.ServeHTTP(w, r) }) for _, listener := range s.listeners { s.wg.Add(1) @@ -428,15 +574,13 @@ func (s *server) Start() { func (s *server) checkAuth(r *http.Request) error { authhdr := r.Header["Authorization"] if len(authhdr) <= 0 { - log.Infof("Frontend did not supply authentication.") - return errors.New("auth failure") + return ErrBadAuth } authsha := sha256.Sum256([]byte(authhdr[0])) cmp := subtle.ConstantTimeCompare(authsha[:], s.authsha[:]) if cmp != 1 { - log.Infof("Frontend did not supply correct authentication.") - return errors.New("auth failure") + return ErrBadAuth } return nil }