337 lines
8.6 KiB
Go
337 lines
8.6 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/lbryio/lbry.go/v2/extras/errors"
|
|
"github.com/lbryio/lbry.go/v2/extras/util"
|
|
"github.com/lbryio/lbry.go/v2/extras/validator"
|
|
v "github.com/lbryio/ozzo-validation"
|
|
|
|
"github.com/spf13/cast"
|
|
)
|
|
|
|
// ResponseHeaders are returned with each response
|
|
var ResponseHeaders map[string]string
|
|
|
|
// CorsDomains Allowed domains for CORS Policy
|
|
var CorsDomains []string
|
|
|
|
// CorsAllowLocalhost if true localhost connections are always allowed
|
|
var CorsAllowLocalhost bool
|
|
|
|
// Log allows logging of events and errors
|
|
var Log = func(*http.Request, *Response, error) {}
|
|
|
|
// http://choly.ca/post/go-json-marshalling/
|
|
type ResponseInfo struct {
|
|
Success bool `json:"success"`
|
|
Error *string `json:"error"`
|
|
Data interface{} `json:"data"`
|
|
Trace []string `json:"_trace,omitempty"`
|
|
}
|
|
|
|
// BuildJSONResponse allows implementers to control the json response form from the api
|
|
var BuildJSONResponse = func(response ResponseInfo) ([]byte, error) {
|
|
return json.MarshalIndent(&response, "", " ")
|
|
}
|
|
|
|
// TraceEnabled Attaches a trace field to the JSON response when enabled.
|
|
var TraceEnabled = false
|
|
|
|
// StatusError represents an error with an associated HTTP status code.
|
|
type StatusError struct {
|
|
Status int
|
|
Err error
|
|
}
|
|
|
|
func (se StatusError) Error() string { return se.Err.Error() }
|
|
|
|
// Response is returned by API handlers
|
|
type Response struct {
|
|
Status int
|
|
Data interface{}
|
|
RedirectURL string
|
|
Error error
|
|
}
|
|
|
|
// Handler handles API requests
|
|
type Handler func(r *http.Request) Response
|
|
|
|
func (h Handler) callHandlerSafely(r *http.Request) (rsp Response) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err, ok := r.(error)
|
|
if !ok {
|
|
err = errors.Err("%v", r)
|
|
}
|
|
rsp = Response{Error: errors.Wrap(err, 2)}
|
|
}
|
|
}()
|
|
|
|
return h(r)
|
|
}
|
|
|
|
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
// Set header settings
|
|
if ResponseHeaders != nil {
|
|
//Multiple readers, no writers is okay
|
|
for key, value := range ResponseHeaders {
|
|
w.Header().Set(key, value)
|
|
}
|
|
}
|
|
origin := r.Header.Get("origin")
|
|
for _, d := range CorsDomains {
|
|
if d == origin {
|
|
w.Header().Set("Access-Control-Allow-Origin", d)
|
|
vary := w.Header().Get("Vary")
|
|
if vary != "*" {
|
|
if vary != "" {
|
|
vary += ", "
|
|
}
|
|
vary += "Origin"
|
|
}
|
|
w.Header().Set("Vary", vary)
|
|
}
|
|
}
|
|
|
|
if CorsAllowLocalhost && strings.HasPrefix(origin, "http://localhost:") {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
vary := w.Header().Get("Vary")
|
|
if vary != "*" {
|
|
if vary != "" {
|
|
vary += ", "
|
|
}
|
|
vary += "Origin"
|
|
}
|
|
w.Header().Set("Vary", vary)
|
|
}
|
|
|
|
// Stop here if its a preflighted OPTIONS request
|
|
if r.Method == "OPTIONS" {
|
|
return
|
|
}
|
|
|
|
rsp := h.callHandlerSafely(r)
|
|
|
|
if rsp.Status == 0 {
|
|
if rsp.Error != nil {
|
|
ogErr := errors.Unwrap(rsp.Error)
|
|
if statusError, ok := ogErr.(StatusError); ok {
|
|
if statusError.Status == 0 {
|
|
statusError.Status = http.StatusInternalServerError
|
|
}
|
|
rsp.Status = statusError.Status
|
|
} else {
|
|
rsp.Status = http.StatusInternalServerError
|
|
}
|
|
} else if rsp.RedirectURL != "" {
|
|
rsp.Status = http.StatusFound
|
|
} else {
|
|
rsp.Status = http.StatusOK
|
|
}
|
|
}
|
|
|
|
success := rsp.Status < http.StatusBadRequest
|
|
if success {
|
|
Log(r, &rsp, nil)
|
|
} else {
|
|
Log(r, &rsp, rsp.Error)
|
|
}
|
|
|
|
// redirect
|
|
if rsp.Status >= http.StatusMultipleChoices && rsp.Status < http.StatusBadRequest {
|
|
http.Redirect(w, r, rsp.RedirectURL, rsp.Status)
|
|
return
|
|
} else if rsp.RedirectURL != "" {
|
|
Log(r, &rsp, errors.Base(
|
|
"status code %d does not indicate a redirect, but RedirectURL is non-empty '%s'",
|
|
rsp.Status, rsp.RedirectURL,
|
|
))
|
|
}
|
|
|
|
var errorString *string
|
|
if rsp.Error != nil {
|
|
errorStringRaw := rsp.Error.Error()
|
|
errorString = &errorStringRaw
|
|
}
|
|
|
|
var trace []string
|
|
if TraceEnabled && errors.HasTrace(rsp.Error) {
|
|
trace = getTraceFromError(rsp.Error)
|
|
}
|
|
|
|
jsonResponse, err := BuildJSONResponse(ResponseInfo{
|
|
Success: success,
|
|
Error: errorString,
|
|
Data: rsp.Data,
|
|
Trace: trace,
|
|
})
|
|
if err != nil {
|
|
Log(r, &rsp, errors.Prefix("Error encoding JSON response: ", err))
|
|
jsonResponse, err = BuildJSONResponse(ResponseInfo{
|
|
Success: false,
|
|
Error: util.PtrToString(err.Error()),
|
|
Data: nil,
|
|
Trace: getTraceFromError(err),
|
|
})
|
|
if err != nil {
|
|
Log(r, &rsp, errors.Prefix("Error encoding JSON response: ", err))
|
|
}
|
|
}
|
|
|
|
w.WriteHeader(rsp.Status)
|
|
_, _ = w.Write(jsonResponse)
|
|
}
|
|
|
|
func getTraceFromError(err error) []string {
|
|
trace := strings.Split(errors.Trace(err), "\n")
|
|
for index, element := range trace {
|
|
if strings.HasPrefix(element, "\t") {
|
|
trace[index] = strings.Replace(element, "\t", " ", 1)
|
|
}
|
|
}
|
|
return trace
|
|
}
|
|
|
|
// IgnoredFormFields are ignored by FormValues() when checking for extraneous fields
|
|
var IgnoredFormFields []string
|
|
|
|
func FormValues(r *http.Request, params interface{}, validationRules []*v.FieldRules) error {
|
|
ref := reflect.ValueOf(params)
|
|
if !ref.IsValid() || ref.Kind() != reflect.Ptr || ref.Elem().Kind() != reflect.Struct {
|
|
return errors.Err("'params' must be a pointer to a struct")
|
|
}
|
|
|
|
structType := ref.Elem().Type()
|
|
structValue := ref.Elem()
|
|
fields := map[string]bool{}
|
|
for i := 0; i < structType.NumField(); i++ {
|
|
fieldName := structType.Field(i).Name
|
|
formattedName := util.Underscore(fieldName)
|
|
jsonName, ok := structType.Field(i).Tag.Lookup("json")
|
|
if ok {
|
|
formattedName = jsonName
|
|
}
|
|
value := strings.TrimSpace(r.FormValue(formattedName))
|
|
|
|
// if param is not set at all, continue
|
|
// comes after call to r.FormValue so form values get parsed internally (if they arent already)
|
|
if len(r.Form[formattedName]) == 0 {
|
|
continue
|
|
}
|
|
|
|
fields[formattedName] = true
|
|
isPtr := false
|
|
var finalValue reflect.Value
|
|
|
|
structField := structValue.FieldByName(fieldName)
|
|
structFieldKind := structField.Kind()
|
|
if structFieldKind == reflect.Ptr {
|
|
isPtr = true
|
|
structFieldKind = structField.Type().Elem().Kind()
|
|
}
|
|
|
|
switch structFieldKind {
|
|
case reflect.String:
|
|
finalValue = reflect.ValueOf(value)
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
if value == "" {
|
|
continue
|
|
}
|
|
castVal, err := cast.ToInt64E(value)
|
|
if err != nil {
|
|
return errors.Err("%s: must be an integer", formattedName)
|
|
}
|
|
switch structFieldKind {
|
|
case reflect.Int:
|
|
finalValue = reflect.ValueOf(int(castVal))
|
|
case reflect.Int8:
|
|
finalValue = reflect.ValueOf(int8(castVal))
|
|
case reflect.Int16:
|
|
finalValue = reflect.ValueOf(int16(castVal))
|
|
case reflect.Int32:
|
|
finalValue = reflect.ValueOf(int32(castVal))
|
|
case reflect.Int64:
|
|
finalValue = reflect.ValueOf(castVal)
|
|
}
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
if value == "" {
|
|
continue
|
|
}
|
|
castVal, err := cast.ToUint64E(value)
|
|
if err != nil {
|
|
return errors.Err("%s: must be an unsigned integer", formattedName)
|
|
}
|
|
switch structFieldKind {
|
|
case reflect.Uint:
|
|
finalValue = reflect.ValueOf(uint(castVal))
|
|
case reflect.Uint8:
|
|
finalValue = reflect.ValueOf(uint8(castVal))
|
|
case reflect.Uint16:
|
|
finalValue = reflect.ValueOf(uint16(castVal))
|
|
case reflect.Uint32:
|
|
finalValue = reflect.ValueOf(uint32(castVal))
|
|
case reflect.Uint64:
|
|
finalValue = reflect.ValueOf(castVal)
|
|
}
|
|
case reflect.Bool:
|
|
if value == "" {
|
|
continue
|
|
}
|
|
if !validator.IsBoolString(value) {
|
|
return errors.Err("%s: must be one of the following values: %s",
|
|
formattedName, strings.Join(validator.GetBoolStringValues(), ", "))
|
|
}
|
|
finalValue = reflect.ValueOf(validator.IsTruthy(value))
|
|
|
|
case reflect.Float32, reflect.Float64:
|
|
if value == "" {
|
|
continue
|
|
}
|
|
castVal, err := cast.ToFloat64E(value)
|
|
if err != nil {
|
|
return errors.Err("%s: must be a floating point number", formattedName)
|
|
}
|
|
switch structFieldKind {
|
|
case reflect.Float32:
|
|
finalValue = reflect.ValueOf(float32(castVal))
|
|
case reflect.Float64:
|
|
finalValue = reflect.ValueOf(float64(castVal))
|
|
}
|
|
default:
|
|
return errors.Err("field %s is an unsupported type", fieldName)
|
|
}
|
|
|
|
if isPtr {
|
|
if structField.IsNil() {
|
|
structField.Set(reflect.New(structField.Type().Elem()))
|
|
}
|
|
structField.Elem().Set(finalValue)
|
|
} else {
|
|
structField.Set(finalValue)
|
|
}
|
|
}
|
|
|
|
var extraParams []string
|
|
for k := range r.Form {
|
|
if _, ok := fields[k]; !ok && !util.InSlice(k, IgnoredFormFields) {
|
|
extraParams = append(extraParams, k)
|
|
}
|
|
}
|
|
if len(extraParams) > 0 {
|
|
return errors.Err("Extraneous params: " + strings.Join(extraParams, ", "))
|
|
}
|
|
|
|
if len(validationRules) > 0 {
|
|
validationErr := v.ValidateStruct(params, validationRules...)
|
|
if validationErr != nil {
|
|
return errors.Err(validationErr)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|