diff --git a/sockets.go b/sockets.go index 29d8afe..07c578e 100644 --- a/sockets.go +++ b/sockets.go @@ -47,6 +47,10 @@ var ( // a missing, incorrect, or duplicate authentication request. ErrBadAuth = errors.New("bad auth") + // ErrNoAuth represents an error where authentication could not succeed + // due to a missing Authorization HTTP header. + ErrNoAuth = errors.New("no auth") + // ErrConnRefused represents an error where a connection to another // process cannot be established. ErrConnRefused = errors.New("connection refused") @@ -403,6 +407,10 @@ out: // forever (until disconnected), reading JSON-RPC requests and sending // sending responses and notifications. func (s *server) WSSendRecv(ws *websocket.Conn, remoteAddr string, authenticated bool) { + // Clear the read deadline set before the websocket hijacked + // the connection. + ws.SetReadDeadline(time.Time{}) + // Add client context so notifications duplicated to each // client are received by this client. recvQuit := make(chan struct{}) @@ -546,7 +554,14 @@ func (s *server) Start() { log.Trace("Starting RPC server") serveMux := http.NewServeMux() - httpServer := &http.Server{Handler: serveMux} + const rpcAuthTimeoutSeconds = 10 + httpServer := &http.Server{ + Handler: serveMux, + + // Timeout connections which don't complete the initial + // handshake within the allowed timeframe. + ReadTimeout: time.Second * rpcAuthTimeoutSeconds, + } serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if err := s.checkAuth(r); err != nil { log.Warnf("Unauthorized client connection attempt") @@ -557,7 +572,15 @@ func (s *server) Start() { }) serveMux.HandleFunc("/frontend", func(w http.ResponseWriter, r *http.Request) { authenticated := false - if err := s.checkAuth(r); err == nil { + if err := s.checkAuth(r); err != nil { + // If auth was supplied but incorrect, rather than simply being + // missing, immediately terminate the connection. + if err != ErrNoAuth { + log.Warnf("Disconnecting improperly authorized websocket client") + http.Error(w, "401 Unauthorized.", http.StatusUnauthorized) + return + } + } else { authenticated = true } @@ -590,8 +613,8 @@ func (s *server) Start() { // This check is time-constant. func (s *server) checkAuth(r *http.Request) error { authhdr := r.Header["Authorization"] - if len(authhdr) <= 0 { - return ErrBadAuth + if len(authhdr) == 0 { + return ErrNoAuth } authsha := sha256.Sum256([]byte(authhdr[0]))