From b171fd42ba33a6139275c7227bac68d01b989d28 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Tue, 2 Aug 2016 12:55:40 +1000 Subject: [PATCH] Add smart quote helper --- strmangle/strmangle.go | 47 +++++++++++--------------- strmangle/strmangle_test.go | 67 ++++--------------------------------- 2 files changed, 26 insertions(+), 88 deletions(-) diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 917bcd5..55abd65 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -7,6 +7,7 @@ package strmangle import ( "fmt" "math" + "regexp" "strings" "github.com/jinzhu/inflection" @@ -15,37 +16,29 @@ import ( var ( idAlphabet = []byte("abcdefghijklmnopqrstuvwxyz") uppercaseWords = []string{"id", "uid", "uuid", "guid", "ssn", "tz"} + smartQuoteRgx = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(\."?[_a-z][_a-z0-9]*"?)*$`) ) -type state int -type smartStack []state - -const ( - stateNothing = iota - stateSubExpression - stateFunction - stateIdentifier -) - -func (stack *smartStack) push(s state) { - *stack = append(*stack, s) -} - -func (stack *smartStack) pop() state { - l := len(*stack) - if l == 0 { - return stateNothing - } - - v := (*stack)[l-1] - *stack = (*stack)[:l-1] - return v -} - // SmartQuote intelligently quotes identifiers in sql statements func SmartQuote(s string) string { - // split on comma, treat as individual thing - return s + if s == "null" { + return s + } + + if m := smartQuoteRgx.MatchString(s); m != true { + return s + } + + splits := strings.Split(s, ".") + for i, split := range splits { + if strings.HasPrefix(split, `"`) || strings.HasSuffix(split, `"`) { + continue + } + + splits[i] = fmt.Sprintf(`"%s"`, split) + } + + return strings.Join(splits, ".") } // Identifier is a base conversion from Base 10 integers to Base 26 diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index aba89b8..a379f2b 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -5,60 +5,6 @@ import ( "testing" ) -func TestStack(t *testing.T) { - t.Parallel() - - stack := smartStack{} - - if stack.pop() != stateNothing { - t.Errorf("Expected state nothing for empty stack") - } - - stack.push(stateFunction) - stack.push(stateSubExpression) - stack.push(stateNothing) - - if len(stack) != 3 { - t.Errorf("Expected 3 state on stack, got %d", len(stack)) - } - - if r := stack.pop(); r != stateNothing { - t.Errorf("Expected stateNothing, got %v", r) - } - if len(stack) != 2 { - t.Errorf("Expected 2 state on stack, got %d", len(stack)) - } - - if r := stack.pop(); r != stateSubExpression { - t.Errorf("Expected stateSubExpression, got %v", r) - } - if len(stack) != 1 { - t.Errorf("Expected 1 state on stack, got %d", len(stack)) - } - - stack.push(stateSubExpression) - if len(stack) != 2 { - t.Errorf("Expected 2 state on stack, got %d", len(stack)) - } - - if r := stack.pop(); r != stateSubExpression { - t.Errorf("Expected stateSubExpression, got %v", r) - } - if len(stack) != 1 { - t.Errorf("Expected 1 state on stack, got %d", len(stack)) - } - - if r := stack.pop(); r != stateFunction { - t.Errorf("Expected stateFunction, got %v", r) - } - if len(stack) != 0 { - t.Errorf("Expected 0 state on stack, got %d", len(stack)) - } - if stack.pop() != stateNothing { - t.Errorf("Expected state nothing for empty stack") - } -} - func TestSmartQuote(t *testing.T) { t.Parallel() @@ -66,15 +12,14 @@ func TestSmartQuote(t *testing.T) { In string Out string }{ - {In: `count(*) as thing, thing as stuff`, Out: `count(*) as thing, "thing" as stuff`}, - {In: `select (select 1) as thing, thing as stuff`, Out: `select (select 1) as thing, "thing" as stuff`}, - { - In: `select (select stuff as thing from thing where id=1 or name="thing") as stuff`, - Out: `select (select "stuff" as thing from thing where id=1 or name="thing") as stuff`, - }, {In: `thing`, Out: `"thing"`}, - {In: `thing, stuff`, Out: `"thing", "stuff"`}, + {In: `null`, Out: `null`}, + {In: `"thing"`, Out: `"thing"`}, {In: `*`, Out: `*`}, + {In: `thing.thing`, Out: `"thing"."thing"`}, + {In: `"thing"."thing"`, Out: `"thing"."thing"`}, + {In: `thing.thing.thing.thing`, Out: `"thing"."thing"."thing"."thing"`}, + {In: `thing."thing".thing."thing"`, Out: `"thing"."thing"."thing"."thing"`}, } for _, test := range tests {