wallet-sync-server/server/server.go

197 lines
5.7 KiB
Go
Raw Normal View History

package server
import (
"encoding/json"
2022-08-01 17:50:16 +02:00
"fmt"
"log"
"net/http"
"strings"
2022-07-23 01:49:30 +02:00
"github.com/prometheus/client_golang/prometheus/promhttp"
"lbryio/wallet-sync-server/auth"
"lbryio/wallet-sync-server/env"
"lbryio/wallet-sync-server/mail"
"lbryio/wallet-sync-server/server/paths"
"lbryio/wallet-sync-server/store"
)
2022-08-22 23:44:26 +02:00
const maxBodySize = 100000
type Server struct {
auth auth.AuthInterface
store store.StoreInterface
2022-07-24 00:13:56 +02:00
env env.EnvInterface
mail mail.MailInterface
2022-08-01 17:50:16 +02:00
port int
}
func Init(
auth auth.AuthInterface,
store store.StoreInterface,
2022-07-24 00:13:56 +02:00
env env.EnvInterface,
mail mail.MailInterface,
2022-08-01 17:50:16 +02:00
port int,
) *Server {
2022-08-01 17:50:16 +02:00
return &Server{auth, store, env, mail, port}
}
type ErrorResponse struct {
Error string `json:"error"`
}
func errorJson(w http.ResponseWriter, code int, extra string) {
errorStr := http.StatusText(code)
if extra != "" {
errorStr = errorStr + ": " + extra
}
authErrorJson, err := json.Marshal(ErrorResponse{Error: errorStr})
if err != nil {
// In case something really stupid happens
http.Error(w, `{"error": "error when JSON-encoding error message"}`, code)
}
http.Error(w, string(authErrorJson), code)
return
}
// Don't report any details to the user. Log it instead.
func internalServiceErrorJson(w http.ResponseWriter, serverErr error, errContext string) {
errorStr := http.StatusText(http.StatusInternalServerError)
authErrorJson, err := json.Marshal(ErrorResponse{Error: errorStr})
if err != nil {
// In case something really stupid happens
http.Error(w, `{"error": "error when JSON-encoding error message"}`, http.StatusInternalServerError)
log.Printf("error when JSON-encoding error message")
return
}
http.Error(w, string(authErrorJson), http.StatusInternalServerError)
log.Printf("%s: %+v\n", errContext, serverErr)
return
}
//////////////////
// Handler Helpers
//////////////////
// Cut down on code repetition. No need to return errors since it can all be
// handled here. Just return a bool to indicate success.
func requestOverhead(w http.ResponseWriter, req *http.Request, method string) bool {
if req.Method != method {
errorJson(w, http.StatusMethodNotAllowed, "")
return false
}
return true
}
// All structs representing incoming json request body should implement this
// The contents of `error` should be safe for an API response (public-facing)
type PostRequest interface {
validate() error
}
// Confirm it's a Post request, various overhead, decode the json, validate the struct
func getPostData(w http.ResponseWriter, req *http.Request, reqStruct PostRequest) bool {
if !requestOverhead(w, req, http.MethodPost) {
return false
}
2022-06-19 23:49:05 +02:00
// Make the limit 100k. Increase from there as needed. I'd rather block some
// people's large wallets and increase the limit than OOM for everybody and
// decrease the limit.
2022-08-22 23:44:26 +02:00
req.Body = http.MaxBytesReader(w, req.Body, maxBodySize)
decoder := json.NewDecoder(req.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&reqStruct)
2022-06-19 23:49:05 +02:00
switch {
case err == nil:
break
case err.Error() == "http: request body too large":
errorJson(w, http.StatusRequestEntityTooLarge, "")
return false
case strings.HasPrefix(err.Error(), "json: unknown field"):
// The error is coming straight out of the json decoder. I think the prefix
// we check for determines what it is pretty reliably. I'd think it's safe
// to give back to the requesting client (unlike an arbitrary error
// message).
errorJson(w, http.StatusBadRequest, err.Error())
return false
2022-06-19 23:49:05 +02:00
default:
// Maybe we can suss out more specific errors later. Need to study what
// errors come from Decode.
2022-06-19 23:49:05 +02:00
errorJson(w, http.StatusBadRequest, "Error parsing JSON")
return false
}
err = reqStruct.validate()
if err != nil {
errorJson(w, http.StatusBadRequest, "Request failed validation: "+err.Error())
return false
}
return true
}
// Confirm it's a Get request, various overhead
func getGetData(w http.ResponseWriter, req *http.Request) bool {
return requestOverhead(w, req, http.MethodGet)
}
// TODO - probably don't return all of authToken since we only need userId and
// deviceId.
func (s *Server) checkAuth(
w http.ResponseWriter,
token auth.AuthTokenString,
scope auth.AuthScope,
) *auth.AuthToken {
authToken, err := s.store.GetToken(token)
2022-07-26 22:36:57 +02:00
if err == store.ErrNoTokenForUserDevice {
errorJson(w, http.StatusUnauthorized, "Token Not Found")
return nil
}
if err != nil {
internalServiceErrorJson(w, err, "Error getting Token")
return nil
}
if !authToken.ScopeValid(scope) {
errorJson(w, http.StatusForbidden, "Scope")
return nil
}
return authToken
}
// TODO - both wallet and token requests should be PUT, not POST.
// PUT = "...creates a new resource or replaces a representation of the target resource with the request payload."
func (s *Server) unknownEndpoint(w http.ResponseWriter, req *http.Request) {
errorJson(w, http.StatusNotFound, "Unknown Endpoint")
return
}
func (s *Server) wrongApiVersion(w http.ResponseWriter, req *http.Request) {
2022-08-01 17:50:16 +02:00
errorJson(w, http.StatusNotFound, "Wrong API version. Current version is "+paths.ApiVersion+".")
return
}
func (s *Server) Serve() {
2022-08-01 17:50:16 +02:00
http.HandleFunc(paths.PathAuthToken, s.getAuthToken)
http.HandleFunc(paths.PathWallet, s.handleWallet)
http.HandleFunc(paths.PathRegister, s.register)
http.HandleFunc(paths.PathPassword, s.changePassword)
http.HandleFunc(paths.PathVerify, s.verify)
http.HandleFunc(paths.PathResendVerify, s.resendVerifyEmail)
http.HandleFunc(paths.PathClientSaltSeed, s.getClientSaltSeed)
2022-08-01 17:50:16 +02:00
http.HandleFunc(paths.PathUnknownEndpoint, s.unknownEndpoint)
http.HandleFunc(paths.PathWrongApiVersion, s.wrongApiVersion)
2022-08-01 17:50:16 +02:00
http.Handle(paths.PathPrometheus, promhttp.Handler())
2022-07-23 01:49:30 +02:00
2022-08-01 17:50:16 +02:00
log.Printf("Serving at localhost:%d\n", s.port)
http.ListenAndServe(fmt.Sprintf("localhost:%d", s.port), nil)
}