diff --git a/server/server.go b/server/server.go index a165172..4f9ffad 100644 --- a/server/server.go +++ b/server/server.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net/http" + "strings" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -101,16 +102,25 @@ func getPostData(w http.ResponseWriter, req *http.Request, reqStruct PostRequest // people's large wallets and increase the limit than OOM for everybody and // decrease the limit. req.Body = http.MaxBytesReader(w, req.Body, maxBodySize) - err := json.NewDecoder(req.Body).Decode(&reqStruct) + decoder := json.NewDecoder(req.Body) + decoder.DisallowUnknownFields() + err := decoder.Decode(&reqStruct) switch { case err == nil: break case err.Error() == "http: request body too large": errorJson(w, http.StatusRequestEntityTooLarge, "") return false + case strings.HasPrefix(err.Error(), "json: unknown field"): + // The error is coming straight out of the json decoder. I think the prefix + // we check for determines what it is pretty reliably. I'd think it's safe + // to give back to the requesting client (unlike an arbitrary error + // message). + errorJson(w, http.StatusBadRequest, err.Error()) + return false default: - // Maybe we can suss out specific errors later. Need to study what errors - // come from Decode. + // Maybe we can suss out more specific errors later. Need to study what + // errors come from Decode. errorJson(w, http.StatusBadRequest, "Error parsing JSON") return false } diff --git a/server/server_test.go b/server/server_test.go index c72a581..5c26c95 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -417,6 +417,13 @@ func TestServerHelperGetPostDataErrors(t *testing.T) { expectedStatusCode: http.StatusBadRequest, expectedErrorString: http.StatusText(http.StatusBadRequest) + ": Request failed validation: TestReq Error", }, + { + name: "body JSON has unknown field", + method: http.MethodPost, + requestBody: `{"lol": "wut"}`, + expectedStatusCode: http.StatusBadRequest, + expectedErrorString: http.StatusText(http.StatusBadRequest) + `: json: unknown field "lol"`, + }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) {