gorm -> orm
This commit is contained in:
commit
a5ec17a1d3
30
.gitignore
vendored
Normal file
30
.gitignore
vendored
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Created by .ignore support plugin (hsz.mobi)
|
||||||
|
### Go template
|
||||||
|
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||||
|
*.o
|
||||||
|
*.a
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Folders
|
||||||
|
_obj
|
||||||
|
_test
|
||||||
|
|
||||||
|
# Architecture specific extensions/prefixes
|
||||||
|
*.[568vq]
|
||||||
|
[568vq].out
|
||||||
|
|
||||||
|
*.cgo1.go
|
||||||
|
*.cgo2.c
|
||||||
|
_cgo_defun.c
|
||||||
|
_cgo_gotypes.go
|
||||||
|
_cgo_export.*
|
||||||
|
|
||||||
|
_testmain.go
|
||||||
|
|
||||||
|
*.exe
|
||||||
|
*.test
|
||||||
|
*.prof
|
||||||
|
|
||||||
|
.gitignore 제거할 예정
|
||||||
|
.idea/ 제거할 예정
|
||||||
|
|
21
inflection/LICENSE
Normal file
21
inflection/LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2015 - Jinzhu
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
39
inflection/README.md
Normal file
39
inflection/README.md
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
Inflection
|
||||||
|
=========
|
||||||
|
|
||||||
|
Inflection pluralizes and singularizes English nouns
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
inflection.Plural("person") => "people"
|
||||||
|
inflection.Plural("Person") => "People"
|
||||||
|
inflection.Plural("PERSON") => "PEOPLE"
|
||||||
|
inflection.Plural("bus") => "buses"
|
||||||
|
inflection.Plural("BUS") => "BUSES"
|
||||||
|
inflection.Plural("Bus") => "Buses"
|
||||||
|
|
||||||
|
inflection.Singularize("people") => "person"
|
||||||
|
inflection.Singularize("People") => "Person"
|
||||||
|
inflection.Singularize("PEOPLE") => "PERSON"
|
||||||
|
inflection.Singularize("buses") => "bus"
|
||||||
|
inflection.Singularize("BUSES") => "BUS"
|
||||||
|
inflection.Singularize("Buses") => "Bus"
|
||||||
|
|
||||||
|
inflection.Plural("FancyPerson") => "FancyPeople"
|
||||||
|
inflection.Singularize("FancyPeople") => "FancyPerson"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Register Rules
|
||||||
|
|
||||||
|
Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb)
|
||||||
|
|
||||||
|
If you want to register more rules, follow:
|
||||||
|
|
||||||
|
```
|
||||||
|
inflection.AddUncountable("fish")
|
||||||
|
inflection.AddIrregular("person", "people")
|
||||||
|
inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses"
|
||||||
|
inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS"
|
||||||
|
```
|
||||||
|
|
273
inflection/inflections.go
Normal file
273
inflection/inflections.go
Normal file
|
@ -0,0 +1,273 @@
|
||||||
|
/*
|
||||||
|
Package inflection pluralizes and singularizes English nouns.
|
||||||
|
|
||||||
|
inflection.Plural("person") => "people"
|
||||||
|
inflection.Plural("Person") => "People"
|
||||||
|
inflection.Plural("PERSON") => "PEOPLE"
|
||||||
|
|
||||||
|
inflection.Singularize("people") => "person"
|
||||||
|
inflection.Singularize("People") => "Person"
|
||||||
|
inflection.Singularize("PEOPLE") => "PERSON"
|
||||||
|
|
||||||
|
inflection.Plural("FancyPerson") => "FancydPeople"
|
||||||
|
inflection.Singularize("FancyPeople") => "FancydPerson"
|
||||||
|
|
||||||
|
Standard rules are from Rails's ActiveSupport (https://github.com/rails/rails/blob/master/activesupport/lib/active_support/inflections.rb)
|
||||||
|
|
||||||
|
If you want to register more rules, follow:
|
||||||
|
|
||||||
|
inflection.AddUncountable("fish")
|
||||||
|
inflection.AddIrregular("person", "people")
|
||||||
|
inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" / "Bus" => "Buses"
|
||||||
|
inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS"
|
||||||
|
*/
|
||||||
|
package inflection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type inflection struct {
|
||||||
|
regexp *regexp.Regexp
|
||||||
|
replace string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular is a regexp find replace inflection
|
||||||
|
type Regular struct {
|
||||||
|
find string
|
||||||
|
replace string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Irregular is a hard replace inflection,
|
||||||
|
// containing both singular and plural forms
|
||||||
|
type Irregular struct {
|
||||||
|
singular string
|
||||||
|
plural string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegularSlice is a slice of Regular inflections
|
||||||
|
type RegularSlice []Regular
|
||||||
|
|
||||||
|
// IrregularSlice is a slice of Irregular inflections
|
||||||
|
type IrregularSlice []Irregular
|
||||||
|
|
||||||
|
var pluralInflections = RegularSlice{
|
||||||
|
{"([a-z])$", "${1}s"},
|
||||||
|
{"s$", "s"},
|
||||||
|
{"^(ax|test)is$", "${1}es"},
|
||||||
|
{"(octop|vir)us$", "${1}i"},
|
||||||
|
{"(octop|vir)i$", "${1}i"},
|
||||||
|
{"(alias|status)$", "${1}es"},
|
||||||
|
{"(bu)s$", "${1}ses"},
|
||||||
|
{"(buffal|tomat)o$", "${1}oes"},
|
||||||
|
{"([ti])um$", "${1}a"},
|
||||||
|
{"([ti])a$", "${1}a"},
|
||||||
|
{"sis$", "ses"},
|
||||||
|
{"(?:([^f])fe|([lr])f)$", "${1}${2}ves"},
|
||||||
|
{"(hive)$", "${1}s"},
|
||||||
|
{"([^aeiouy]|qu)y$", "${1}ies"},
|
||||||
|
{"(x|ch|ss|sh)$", "${1}es"},
|
||||||
|
{"(matr|vert|ind)(?:ix|ex)$", "${1}ices"},
|
||||||
|
{"^(m|l)ouse$", "${1}ice"},
|
||||||
|
{"^(m|l)ice$", "${1}ice"},
|
||||||
|
{"^(ox)$", "${1}en"},
|
||||||
|
{"^(oxen)$", "${1}"},
|
||||||
|
{"(quiz)$", "${1}zes"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var singularInflections = RegularSlice{
|
||||||
|
{"s$", ""},
|
||||||
|
{"(ss)$", "${1}"},
|
||||||
|
{"(n)ews$", "${1}ews"},
|
||||||
|
{"([ti])a$", "${1}um"},
|
||||||
|
{"((a)naly|(b)a|(d)iagno|(p)arenthe|(p)rogno|(s)ynop|(t)he)(sis|ses)$", "${1}sis"},
|
||||||
|
{"(^analy)(sis|ses)$", "${1}sis"},
|
||||||
|
{"([^f])ves$", "${1}fe"},
|
||||||
|
{"(hive)s$", "${1}"},
|
||||||
|
{"(tive)s$", "${1}"},
|
||||||
|
{"([lr])ves$", "${1}f"},
|
||||||
|
{"([^aeiouy]|qu)ies$", "${1}y"},
|
||||||
|
{"(s)eries$", "${1}eries"},
|
||||||
|
{"(m)ovies$", "${1}ovie"},
|
||||||
|
{"(c)ookies$", "${1}ookie"},
|
||||||
|
{"(x|ch|ss|sh)es$", "${1}"},
|
||||||
|
{"^(m|l)ice$", "${1}ouse"},
|
||||||
|
{"(bus)(es)?$", "${1}"},
|
||||||
|
{"(o)es$", "${1}"},
|
||||||
|
{"(shoe)s$", "${1}"},
|
||||||
|
{"(cris|test)(is|es)$", "${1}is"},
|
||||||
|
{"^(a)x[ie]s$", "${1}xis"},
|
||||||
|
{"(octop|vir)(us|i)$", "${1}us"},
|
||||||
|
{"(alias|status)(es)?$", "${1}"},
|
||||||
|
{"^(ox)en", "${1}"},
|
||||||
|
{"(vert|ind)ices$", "${1}ex"},
|
||||||
|
{"(matr)ices$", "${1}ix"},
|
||||||
|
{"(quiz)zes$", "${1}"},
|
||||||
|
{"(database)s$", "${1}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var irregularInflections = IrregularSlice{
|
||||||
|
{"person", "people"},
|
||||||
|
{"man", "men"},
|
||||||
|
{"child", "children"},
|
||||||
|
{"sex", "sexes"},
|
||||||
|
{"move", "moves"},
|
||||||
|
{"mombie", "mombies"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var uncountableInflections = []string{"equipment", "information", "rice", "money", "species", "series", "fish", "sheep", "jeans", "police"}
|
||||||
|
|
||||||
|
var compiledPluralMaps []inflection
|
||||||
|
var compiledSingularMaps []inflection
|
||||||
|
|
||||||
|
func compile() {
|
||||||
|
compiledPluralMaps = []inflection{}
|
||||||
|
compiledSingularMaps = []inflection{}
|
||||||
|
for _, uncountable := range uncountableInflections {
|
||||||
|
inf := inflection{
|
||||||
|
regexp: regexp.MustCompile("^(?i)(" + uncountable + ")$"),
|
||||||
|
replace: "${1}",
|
||||||
|
}
|
||||||
|
compiledPluralMaps = append(compiledPluralMaps, inf)
|
||||||
|
compiledSingularMaps = append(compiledSingularMaps, inf)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range irregularInflections {
|
||||||
|
infs := []inflection{
|
||||||
|
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.singular) + "$"), replace: strings.ToUpper(value.plural)},
|
||||||
|
inflection{regexp: regexp.MustCompile(strings.Title(value.singular) + "$"), replace: strings.Title(value.plural)},
|
||||||
|
inflection{regexp: regexp.MustCompile(value.singular + "$"), replace: value.plural},
|
||||||
|
}
|
||||||
|
compiledPluralMaps = append(compiledPluralMaps, infs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range irregularInflections {
|
||||||
|
infs := []inflection{
|
||||||
|
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.plural) + "$"), replace: strings.ToUpper(value.singular)},
|
||||||
|
inflection{regexp: regexp.MustCompile(strings.Title(value.plural) + "$"), replace: strings.Title(value.singular)},
|
||||||
|
inflection{regexp: regexp.MustCompile(value.plural + "$"), replace: value.singular},
|
||||||
|
}
|
||||||
|
compiledSingularMaps = append(compiledSingularMaps, infs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := len(pluralInflections) - 1; i >= 0; i-- {
|
||||||
|
value := pluralInflections[i]
|
||||||
|
infs := []inflection{
|
||||||
|
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)},
|
||||||
|
inflection{regexp: regexp.MustCompile(value.find), replace: value.replace},
|
||||||
|
inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace},
|
||||||
|
}
|
||||||
|
compiledPluralMaps = append(compiledPluralMaps, infs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := len(singularInflections) - 1; i >= 0; i-- {
|
||||||
|
value := singularInflections[i]
|
||||||
|
infs := []inflection{
|
||||||
|
inflection{regexp: regexp.MustCompile(strings.ToUpper(value.find)), replace: strings.ToUpper(value.replace)},
|
||||||
|
inflection{regexp: regexp.MustCompile(value.find), replace: value.replace},
|
||||||
|
inflection{regexp: regexp.MustCompile("(?i)" + value.find), replace: value.replace},
|
||||||
|
}
|
||||||
|
compiledSingularMaps = append(compiledSingularMaps, infs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPlural adds a plural inflection
|
||||||
|
func AddPlural(find, replace string) {
|
||||||
|
pluralInflections = append(pluralInflections, Regular{find, replace})
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSingular adds a singular inflection
|
||||||
|
func AddSingular(find, replace string) {
|
||||||
|
singularInflections = append(singularInflections, Regular{find, replace})
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddIrregular adds an irregular inflection
|
||||||
|
func AddIrregular(singular, plural string) {
|
||||||
|
irregularInflections = append(irregularInflections, Irregular{singular, plural})
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddUncountable adds an uncountable inflection
|
||||||
|
func AddUncountable(values ...string) {
|
||||||
|
uncountableInflections = append(uncountableInflections, values...)
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPlural retrieves the plural inflection values
|
||||||
|
func GetPlural() RegularSlice {
|
||||||
|
plurals := make(RegularSlice, len(pluralInflections))
|
||||||
|
copy(plurals, pluralInflections)
|
||||||
|
return plurals
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSingular retrieves the singular inflection values
|
||||||
|
func GetSingular() RegularSlice {
|
||||||
|
singulars := make(RegularSlice, len(singularInflections))
|
||||||
|
copy(singulars, singularInflections)
|
||||||
|
return singulars
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIrregular retrieves the irregular inflection values
|
||||||
|
func GetIrregular() IrregularSlice {
|
||||||
|
irregular := make(IrregularSlice, len(irregularInflections))
|
||||||
|
copy(irregular, irregularInflections)
|
||||||
|
return irregular
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUncountable retrieves the uncountable inflection values
|
||||||
|
func GetUncountable() []string {
|
||||||
|
uncountables := make([]string, len(uncountableInflections))
|
||||||
|
copy(uncountables, uncountableInflections)
|
||||||
|
return uncountables
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPlural sets the plural inflections slice
|
||||||
|
func SetPlural(inflections RegularSlice) {
|
||||||
|
pluralInflections = inflections
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSingular sets the singular inflections slice
|
||||||
|
func SetSingular(inflections RegularSlice) {
|
||||||
|
singularInflections = inflections
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIrregular sets the irregular inflections slice
|
||||||
|
func SetIrregular(inflections IrregularSlice) {
|
||||||
|
irregularInflections = inflections
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUncountable sets the uncountable inflections slice
|
||||||
|
func SetUncountable(inflections []string) {
|
||||||
|
uncountableInflections = inflections
|
||||||
|
compile()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plural converts a word to its plural form
|
||||||
|
func Plural(str string) string {
|
||||||
|
for _, inflection := range compiledPluralMaps {
|
||||||
|
if inflection.regexp.MatchString(str) {
|
||||||
|
return inflection.regexp.ReplaceAllString(str, inflection.replace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
// Singular converts a word to its singular form
|
||||||
|
func Singular(str string) string {
|
||||||
|
for _, inflection := range compiledSingularMaps {
|
||||||
|
if inflection.regexp.MatchString(str) {
|
||||||
|
return inflection.regexp.ReplaceAllString(str, inflection.replace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
213
inflection/inflections_test.go
Normal file
213
inflection/inflections_test.go
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
package inflection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var inflections = map[string]string{
|
||||||
|
"star": "stars",
|
||||||
|
"STAR": "STARS",
|
||||||
|
"Star": "Stars",
|
||||||
|
"bus": "buses",
|
||||||
|
"fish": "fish",
|
||||||
|
"mouse": "mice",
|
||||||
|
"query": "queries",
|
||||||
|
"ability": "abilities",
|
||||||
|
"agency": "agencies",
|
||||||
|
"movie": "movies",
|
||||||
|
"archive": "archives",
|
||||||
|
"index": "indices",
|
||||||
|
"wife": "wives",
|
||||||
|
"safe": "saves",
|
||||||
|
"half": "halves",
|
||||||
|
"move": "moves",
|
||||||
|
"salesperson": "salespeople",
|
||||||
|
"person": "people",
|
||||||
|
"spokesman": "spokesmen",
|
||||||
|
"man": "men",
|
||||||
|
"woman": "women",
|
||||||
|
"basis": "bases",
|
||||||
|
"diagnosis": "diagnoses",
|
||||||
|
"diagnosis_a": "diagnosis_as",
|
||||||
|
"datum": "data",
|
||||||
|
"medium": "media",
|
||||||
|
"stadium": "stadia",
|
||||||
|
"analysis": "analyses",
|
||||||
|
"node_child": "node_children",
|
||||||
|
"child": "children",
|
||||||
|
"experience": "experiences",
|
||||||
|
"day": "days",
|
||||||
|
"comment": "comments",
|
||||||
|
"foobar": "foobars",
|
||||||
|
"newsletter": "newsletters",
|
||||||
|
"old_news": "old_news",
|
||||||
|
"news": "news",
|
||||||
|
"series": "series",
|
||||||
|
"species": "species",
|
||||||
|
"quiz": "quizzes",
|
||||||
|
"perspective": "perspectives",
|
||||||
|
"ox": "oxen",
|
||||||
|
"photo": "photos",
|
||||||
|
"buffalo": "buffaloes",
|
||||||
|
"tomato": "tomatoes",
|
||||||
|
"dwarf": "dwarves",
|
||||||
|
"elf": "elves",
|
||||||
|
"information": "information",
|
||||||
|
"equipment": "equipment",
|
||||||
|
"criterion": "criteria",
|
||||||
|
}
|
||||||
|
|
||||||
|
// storage is used to restore the state of the global variables
|
||||||
|
// on each test execution, to ensure no global state pollution
|
||||||
|
type storage struct {
|
||||||
|
singulars RegularSlice
|
||||||
|
plurals RegularSlice
|
||||||
|
irregulars IrregularSlice
|
||||||
|
uncountables []string
|
||||||
|
}
|
||||||
|
|
||||||
|
var backup = storage{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
AddIrregular("criterion", "criteria")
|
||||||
|
copy(backup.singulars, singularInflections)
|
||||||
|
copy(backup.plurals, pluralInflections)
|
||||||
|
copy(backup.irregulars, irregularInflections)
|
||||||
|
copy(backup.uncountables, uncountableInflections)
|
||||||
|
}
|
||||||
|
|
||||||
|
func restore() {
|
||||||
|
copy(singularInflections, backup.singulars)
|
||||||
|
copy(pluralInflections, backup.plurals)
|
||||||
|
copy(irregularInflections, backup.irregulars)
|
||||||
|
copy(uncountableInflections, backup.uncountables)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlural(t *testing.T) {
|
||||||
|
for key, value := range inflections {
|
||||||
|
if v := Plural(strings.ToUpper(key)); v != strings.ToUpper(value) {
|
||||||
|
t.Errorf("%v's plural should be %v, but got %v", strings.ToUpper(key), strings.ToUpper(value), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := Plural(strings.Title(key)); v != strings.Title(value) {
|
||||||
|
t.Errorf("%v's plural should be %v, but got %v", strings.Title(key), strings.Title(value), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := Plural(key); v != value {
|
||||||
|
t.Errorf("%v's plural should be %v, but got %v", key, value, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingular(t *testing.T) {
|
||||||
|
for key, value := range inflections {
|
||||||
|
if v := Singular(strings.ToUpper(value)); v != strings.ToUpper(key) {
|
||||||
|
t.Errorf("%v's singular should be %v, but got %v", strings.ToUpper(value), strings.ToUpper(key), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := Singular(strings.Title(value)); v != strings.Title(key) {
|
||||||
|
t.Errorf("%v's singular should be %v, but got %v", strings.Title(value), strings.Title(key), v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v := Singular(value); v != key {
|
||||||
|
t.Errorf("%v's singular should be %v, but got %v", value, key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddPlural(t *testing.T) {
|
||||||
|
defer restore()
|
||||||
|
ln := len(pluralInflections)
|
||||||
|
AddPlural("", "")
|
||||||
|
if ln+1 != len(pluralInflections) {
|
||||||
|
t.Errorf("Expected len %d, got %d", ln+1, len(pluralInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddSingular(t *testing.T) {
|
||||||
|
defer restore()
|
||||||
|
ln := len(singularInflections)
|
||||||
|
AddSingular("", "")
|
||||||
|
if ln+1 != len(singularInflections) {
|
||||||
|
t.Errorf("Expected len %d, got %d", ln+1, len(singularInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddIrregular(t *testing.T) {
|
||||||
|
defer restore()
|
||||||
|
ln := len(irregularInflections)
|
||||||
|
AddIrregular("", "")
|
||||||
|
if ln+1 != len(irregularInflections) {
|
||||||
|
t.Errorf("Expected len %d, got %d", ln+1, len(irregularInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddUncountable(t *testing.T) {
|
||||||
|
defer restore()
|
||||||
|
ln := len(uncountableInflections)
|
||||||
|
AddUncountable("", "")
|
||||||
|
if ln+2 != len(uncountableInflections) {
|
||||||
|
t.Errorf("Expected len %d, got %d", ln+2, len(uncountableInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPlural(t *testing.T) {
|
||||||
|
plurals := GetPlural()
|
||||||
|
if len(plurals) != len(pluralInflections) {
|
||||||
|
t.Errorf("Expected len %d, got %d", len(plurals), len(pluralInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSingular(t *testing.T) {
|
||||||
|
singular := GetSingular()
|
||||||
|
if len(singular) != len(singularInflections) {
|
||||||
|
t.Errorf("Expected len %d, got %d", len(singular), len(singularInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetIrregular(t *testing.T) {
|
||||||
|
irregular := GetIrregular()
|
||||||
|
if len(irregular) != len(irregularInflections) {
|
||||||
|
t.Errorf("Expected len %d, got %d", len(irregular), len(irregularInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUncountable(t *testing.T) {
|
||||||
|
uncountables := GetUncountable()
|
||||||
|
if len(uncountables) != len(uncountableInflections) {
|
||||||
|
t.Errorf("Expected len %d, got %d", len(uncountables), len(uncountableInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetPlural(t *testing.T) {
|
||||||
|
defer restore()
|
||||||
|
SetPlural(RegularSlice{{}, {}})
|
||||||
|
if len(pluralInflections) != 2 {
|
||||||
|
t.Errorf("Expected len 2, got %d", len(pluralInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSingular(t *testing.T) {
|
||||||
|
defer restore()
|
||||||
|
SetSingular(RegularSlice{{}, {}})
|
||||||
|
if len(singularInflections) != 2 {
|
||||||
|
t.Errorf("Expected len 2, got %d", len(singularInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetIrregular(t *testing.T) {
|
||||||
|
defer restore()
|
||||||
|
SetIrregular(IrregularSlice{{}, {}})
|
||||||
|
if len(irregularInflections) != 2 {
|
||||||
|
t.Errorf("Expected len 2, got %d", len(irregularInflections))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetUncountable(t *testing.T) {
|
||||||
|
defer restore()
|
||||||
|
SetUncountable([]string{"", ""})
|
||||||
|
if len(uncountableInflections) != 2 {
|
||||||
|
t.Errorf("Expected len 2, got %d", len(uncountableInflections))
|
||||||
|
}
|
||||||
|
}
|
3
now/Guardfile
Normal file
3
now/Guardfile
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
guard 'gotest' do
|
||||||
|
watch(%r{\.go$})
|
||||||
|
end
|
104
now/README.md
Normal file
104
now/README.md
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
## Now
|
||||||
|
|
||||||
|
Now is a time toolkit for golang
|
||||||
|
|
||||||
|
#### Why the project named `Now`?
|
||||||
|
|
||||||
|
```go
|
||||||
|
now.BeginningOfDay()
|
||||||
|
```
|
||||||
|
`now` is quite readable, aha?
|
||||||
|
|
||||||
|
#### But `now` is so common I can't search the project with my favorite search engine
|
||||||
|
|
||||||
|
* Star it in github [https://github.com/jinzhu/now](https://github.com/jinzhu/now)
|
||||||
|
* Search it with [http://godoc.org](http://godoc.org)
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```
|
||||||
|
go get -u github.com/jinzhu/now
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/jinzhu/now"
|
||||||
|
|
||||||
|
time.Now() // 2013-11-18 17:51:49.123456789 Mon
|
||||||
|
|
||||||
|
now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon
|
||||||
|
now.BeginningOfHour() // 2013-11-18 17:00:00 Mon
|
||||||
|
now.BeginningOfDay() // 2013-11-18 00:00:00 Mon
|
||||||
|
now.BeginningOfWeek() // 2013-11-17 00:00:00 Sun
|
||||||
|
now.FirstDayMonday = true // Set Monday as first day, default is Sunday
|
||||||
|
now.BeginningOfWeek() // 2013-11-18 00:00:00 Mon
|
||||||
|
now.BeginningOfMonth() // 2013-11-01 00:00:00 Fri
|
||||||
|
now.BeginningOfQuarter() // 2013-10-01 00:00:00 Tue
|
||||||
|
now.BeginningOfYear() // 2013-01-01 00:00:00 Tue
|
||||||
|
|
||||||
|
now.EndOfMinute() // 2013-11-18 17:51:59.999999999 Mon
|
||||||
|
now.EndOfHour() // 2013-11-18 17:59:59.999999999 Mon
|
||||||
|
now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon
|
||||||
|
now.EndOfWeek() // 2013-11-23 23:59:59.999999999 Sat
|
||||||
|
now.FirstDayMonday = true // Set Monday as first day, default is Sunday
|
||||||
|
now.EndOfWeek() // 2013-11-24 23:59:59.999999999 Sun
|
||||||
|
now.EndOfMonth() // 2013-11-30 23:59:59.999999999 Sat
|
||||||
|
now.EndOfQuarter() // 2013-12-31 23:59:59.999999999 Tue
|
||||||
|
now.EndOfYear() // 2013-12-31 23:59:59.999999999 Tue
|
||||||
|
|
||||||
|
|
||||||
|
// Use another time
|
||||||
|
t := time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.Now().Location())
|
||||||
|
now.New(t).EndOfMonth() // 2013-02-28 23:59:59.999999999 Thu
|
||||||
|
|
||||||
|
|
||||||
|
// Don't want be bothered with the First Day setting, Use Monday, Sunday
|
||||||
|
now.Monday() // 2013-11-18 00:00:00 Mon
|
||||||
|
now.Sunday() // 2013-11-24 00:00:00 Sun (Next Sunday)
|
||||||
|
now.EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of next Sunday)
|
||||||
|
|
||||||
|
t := time.Date(2013, 11, 24, 17, 51, 49, 123456789, time.Now().Location()) // 2013-11-24 17:51:49.123456789 Sun
|
||||||
|
now.New(t).Monday() // 2013-11-18 00:00:00 Sun (Last Monday if today is Sunday)
|
||||||
|
now.New(t).Sunday() // 2013-11-24 00:00:00 Sun (Beginning Of Today if today is Sunday)
|
||||||
|
now.New(t).EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of Today if today is Sunday)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Parse String
|
||||||
|
|
||||||
|
```go
|
||||||
|
time.Now() // 2013-11-18 17:51:49.123456789 Mon
|
||||||
|
|
||||||
|
// Parse(string) (time.Time, error)
|
||||||
|
t, err := now.Parse("12:20") // 2013-11-18 12:20:00, nil
|
||||||
|
t, err := now.Parse("1999-12-12 12:20") // 1999-12-12 12:20:00, nil
|
||||||
|
t, err := now.Parse("99:99") // 2013-11-18 12:20:00, Can't parse string as time: 99:99
|
||||||
|
|
||||||
|
// MustParse(string) time.Time
|
||||||
|
now.MustParse("2013-01-13") // 2013-01-13 00:00:00
|
||||||
|
now.MustParse("02-17") // 2013-02-17 00:00:00
|
||||||
|
now.MustParse("2-17") // 2013-02-17 00:00:00
|
||||||
|
now.MustParse("8") // 2013-11-18 08:00:00
|
||||||
|
now.MustParse("2002-10-12 22:14") // 2002-10-12 22:14:00
|
||||||
|
now.MustParse("99:99") // panic: Can't parse string as time: 99:99
|
||||||
|
```
|
||||||
|
|
||||||
|
Extend `now` to support more formats is quite easy, just update `TimeFormats` variable with `time.Format` like time layout
|
||||||
|
|
||||||
|
```go
|
||||||
|
now.TimeFormats = append(now.TimeFormats, "02 Jan 2006 15:04")
|
||||||
|
```
|
||||||
|
|
||||||
|
Please send me pull requests if you want a format to be supported officially
|
||||||
|
|
||||||
|
# Author
|
||||||
|
|
||||||
|
**jinzhu**
|
||||||
|
|
||||||
|
* <http://github.com/jinzhu>
|
||||||
|
* <wosmvp@gmail.com>
|
||||||
|
* <http://twitter.com/zhangjinzhu>
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Released under the [MIT License](http://www.opensource.org/licenses/MIT).
|
103
now/main.go
Normal file
103
now/main.go
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
// Package now is a time toolkit for golang.
|
||||||
|
//
|
||||||
|
// More details README here: https://github.com/jinzhu/now
|
||||||
|
//
|
||||||
|
// import "github.com/jinzhu/now"
|
||||||
|
//
|
||||||
|
// now.BeginningOfMinute() // 2013-11-18 17:51:00 Mon
|
||||||
|
// now.BeginningOfDay() // 2013-11-18 00:00:00 Mon
|
||||||
|
// now.EndOfDay() // 2013-11-18 23:59:59.999999999 Mon
|
||||||
|
package now
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
var FirstDayMonday bool
|
||||||
|
var TimeFormats = []string{"1/2/2006", "1/2/2006 15:4:5", "2006-1-2 15:4:5", "2006-1-2 15:4", "2006-1-2", "1-2", "15:4:5", "15:4", "15", "15:4:5 Jan 2, 2006 MST", "2006-01-02 15:04:05.999999999 -0700 MST"}
|
||||||
|
|
||||||
|
type Now struct {
|
||||||
|
time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(t time.Time) *Now {
|
||||||
|
return &Now{t}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BeginningOfMinute() time.Time {
|
||||||
|
return New(time.Now()).BeginningOfMinute()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BeginningOfHour() time.Time {
|
||||||
|
return New(time.Now()).BeginningOfHour()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BeginningOfDay() time.Time {
|
||||||
|
return New(time.Now()).BeginningOfDay()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BeginningOfWeek() time.Time {
|
||||||
|
return New(time.Now()).BeginningOfWeek()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BeginningOfMonth() time.Time {
|
||||||
|
return New(time.Now()).BeginningOfMonth()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BeginningOfQuarter() time.Time {
|
||||||
|
return New(time.Now()).BeginningOfQuarter()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BeginningOfYear() time.Time {
|
||||||
|
return New(time.Now()).BeginningOfYear()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndOfMinute() time.Time {
|
||||||
|
return New(time.Now()).EndOfMinute()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndOfHour() time.Time {
|
||||||
|
return New(time.Now()).EndOfHour()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndOfDay() time.Time {
|
||||||
|
return New(time.Now()).EndOfDay()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndOfWeek() time.Time {
|
||||||
|
return New(time.Now()).EndOfWeek()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndOfMonth() time.Time {
|
||||||
|
return New(time.Now()).EndOfMonth()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndOfQuarter() time.Time {
|
||||||
|
return New(time.Now()).EndOfQuarter()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndOfYear() time.Time {
|
||||||
|
return New(time.Now()).EndOfYear()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Monday() time.Time {
|
||||||
|
return New(time.Now()).Monday()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sunday() time.Time {
|
||||||
|
return New(time.Now()).Sunday()
|
||||||
|
}
|
||||||
|
|
||||||
|
func EndOfSunday() time.Time {
|
||||||
|
return New(time.Now()).EndOfSunday()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Parse(strs ...string) (time.Time, error) {
|
||||||
|
return New(time.Now()).Parse(strs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustParse(strs ...string) time.Time {
|
||||||
|
return New(time.Now()).MustParse(strs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Between(time1, time2 string) bool {
|
||||||
|
return New(time.Now()).Between(time1, time2)
|
||||||
|
}
|
182
now/now.go
Normal file
182
now/now.go
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
package now
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (now *Now) BeginningOfMinute() time.Time {
|
||||||
|
return now.Truncate(time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) BeginningOfHour() time.Time {
|
||||||
|
return now.Truncate(time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) BeginningOfDay() time.Time {
|
||||||
|
d := time.Duration(-now.Hour()) * time.Hour
|
||||||
|
return now.BeginningOfHour().Add(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) BeginningOfWeek() time.Time {
|
||||||
|
t := now.BeginningOfDay()
|
||||||
|
weekday := int(t.Weekday())
|
||||||
|
if FirstDayMonday {
|
||||||
|
if weekday == 0 {
|
||||||
|
weekday = 7
|
||||||
|
}
|
||||||
|
weekday = weekday - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
d := time.Duration(-weekday) * 24 * time.Hour
|
||||||
|
return t.Add(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) BeginningOfMonth() time.Time {
|
||||||
|
t := now.BeginningOfDay()
|
||||||
|
d := time.Duration(-int(t.Day())+1) * 24 * time.Hour
|
||||||
|
return t.Add(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) BeginningOfQuarter() time.Time {
|
||||||
|
month := now.BeginningOfMonth()
|
||||||
|
offset := (int(month.Month()) - 1) % 3
|
||||||
|
return month.AddDate(0, -offset, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) BeginningOfYear() time.Time {
|
||||||
|
t := now.BeginningOfDay()
|
||||||
|
d := time.Duration(-int(t.YearDay())+1) * 24 * time.Hour
|
||||||
|
return t.Truncate(time.Hour).Add(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) EndOfMinute() time.Time {
|
||||||
|
return now.BeginningOfMinute().Add(time.Minute - time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) EndOfHour() time.Time {
|
||||||
|
return now.BeginningOfHour().Add(time.Hour - time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) EndOfDay() time.Time {
|
||||||
|
return now.BeginningOfDay().Add(24*time.Hour - time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) EndOfWeek() time.Time {
|
||||||
|
return now.BeginningOfWeek().AddDate(0, 0, 7).Add(-time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) EndOfMonth() time.Time {
|
||||||
|
return now.BeginningOfMonth().AddDate(0, 1, 0).Add(-time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) EndOfQuarter() time.Time {
|
||||||
|
return now.BeginningOfQuarter().AddDate(0, 3, 0).Add(-time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) EndOfYear() time.Time {
|
||||||
|
return now.BeginningOfYear().AddDate(1, 0, 0).Add(-time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) Monday() time.Time {
|
||||||
|
t := now.BeginningOfDay()
|
||||||
|
weekday := int(t.Weekday())
|
||||||
|
if weekday == 0 {
|
||||||
|
weekday = 7
|
||||||
|
}
|
||||||
|
d := time.Duration(-weekday+1) * 24 * time.Hour
|
||||||
|
return t.Truncate(time.Hour).Add(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) Sunday() time.Time {
|
||||||
|
t := now.BeginningOfDay()
|
||||||
|
weekday := int(t.Weekday())
|
||||||
|
if weekday == 0 {
|
||||||
|
return t
|
||||||
|
} else {
|
||||||
|
d := time.Duration(7-weekday) * 24 * time.Hour
|
||||||
|
return t.Truncate(time.Hour).Add(d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) EndOfSunday() time.Time {
|
||||||
|
return now.Sunday().Add(24*time.Hour - time.Nanosecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseWithFormat(str string) (t time.Time, err error) {
|
||||||
|
for _, format := range TimeFormats {
|
||||||
|
t, err = time.Parse(format, str)
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = errors.New("Can't parse string as time: " + str)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) Parse(strs ...string) (t time.Time, err error) {
|
||||||
|
var setCurrentTime bool
|
||||||
|
parseTime := []int{}
|
||||||
|
currentTime := []int{now.Second(), now.Minute(), now.Hour(), now.Day(), int(now.Month()), now.Year()}
|
||||||
|
currentLocation := now.Location()
|
||||||
|
|
||||||
|
for _, str := range strs {
|
||||||
|
onlyTime := regexp.MustCompile(`^\s*\d+(:\d+)*\s*$`).MatchString(str) // match 15:04:05, 15
|
||||||
|
|
||||||
|
t, err = parseWithFormat(str)
|
||||||
|
location := t.Location()
|
||||||
|
if location.String() == "UTC" {
|
||||||
|
location = currentLocation
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
parseTime = []int{t.Second(), t.Minute(), t.Hour(), t.Day(), int(t.Month()), t.Year()}
|
||||||
|
onlyTime = onlyTime && (parseTime[3] == 1) && (parseTime[4] == 1)
|
||||||
|
|
||||||
|
for i, v := range parseTime {
|
||||||
|
// Don't reset hour, minute, second if it is a time only string
|
||||||
|
if onlyTime && i <= 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill up missed information with current time
|
||||||
|
if v == 0 {
|
||||||
|
if setCurrentTime {
|
||||||
|
parseTime[i] = currentTime[i]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
setCurrentTime = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default day and month is 1, fill up it if missing it
|
||||||
|
if onlyTime {
|
||||||
|
if i == 3 || i == 4 {
|
||||||
|
parseTime[i] = currentTime[i]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parseTime) > 0 {
|
||||||
|
t = time.Date(parseTime[5], time.Month(parseTime[4]), parseTime[3], parseTime[2], parseTime[1], parseTime[0], 0, location)
|
||||||
|
currentTime = []int{t.Second(), t.Minute(), t.Hour(), t.Day(), int(t.Month()), t.Year()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) MustParse(strs ...string) (t time.Time) {
|
||||||
|
t, err := now.Parse(strs...)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (now *Now) Between(time1, time2 string) bool {
|
||||||
|
restime := now.MustParse(time1)
|
||||||
|
restime2 := now.MustParse(time2)
|
||||||
|
return now.After(restime) && now.Before(restime2)
|
||||||
|
}
|
272
now/now_test.go
Normal file
272
now/now_test.go
Normal file
|
@ -0,0 +1,272 @@
|
||||||
|
package now
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var format = "2006-01-02 15:04:05.999999999"
|
||||||
|
|
||||||
|
func TestBeginningOf(t *testing.T) {
|
||||||
|
n := time.Date(2013, 11, 18, 17, 51, 49, 123456789, time.UTC)
|
||||||
|
|
||||||
|
if New(n).BeginningOfMinute().Format(format) != "2013-11-18 17:51:00" {
|
||||||
|
t.Errorf("BeginningOfMinute")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).BeginningOfHour().Format(format) != "2013-11-18 17:00:00" {
|
||||||
|
t.Errorf("BeginningOfHour")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).BeginningOfDay().Format(format) != "2013-11-18 00:00:00" {
|
||||||
|
t.Errorf("BeginningOfDay")
|
||||||
|
}
|
||||||
|
|
||||||
|
location, _ := time.LoadLocation("Japan")
|
||||||
|
beginningOfDay := time.Date(2015, 05, 01, 0, 0, 0, 0, location)
|
||||||
|
if New(beginningOfDay).BeginningOfDay().Format(format) != "2015-05-01 00:00:00" {
|
||||||
|
t.Errorf("BeginningOfDay")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).BeginningOfWeek().Format(format) != "2013-11-17 00:00:00" {
|
||||||
|
t.Errorf("BeginningOfWeek")
|
||||||
|
}
|
||||||
|
|
||||||
|
FirstDayMonday = true
|
||||||
|
if New(n).BeginningOfWeek().Format(format) != "2013-11-18 00:00:00" {
|
||||||
|
t.Errorf("BeginningOfWeek, FirstDayMonday")
|
||||||
|
}
|
||||||
|
FirstDayMonday = false
|
||||||
|
|
||||||
|
if New(n).BeginningOfMonth().Format(format) != "2013-11-01 00:00:00" {
|
||||||
|
t.Errorf("BeginningOfMonth")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).BeginningOfQuarter().Format(format) != "2013-10-01 00:00:00" {
|
||||||
|
t.Error("BeginningOfQuarter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n.AddDate(0, -1, 0)).BeginningOfQuarter().Format(format) != "2013-10-01 00:00:00" {
|
||||||
|
t.Error("BeginningOfQuarter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n.AddDate(0, 1, 0)).BeginningOfQuarter().Format(format) != "2013-10-01 00:00:00" {
|
||||||
|
t.Error("BeginningOfQuarter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).BeginningOfYear().Format(format) != "2013-01-01 00:00:00" {
|
||||||
|
t.Errorf("BeginningOfYear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEndOf(t *testing.T) {
|
||||||
|
n := time.Date(2013, 11, 18, 17, 51, 49, 123456789, time.UTC)
|
||||||
|
|
||||||
|
if New(n).EndOfMinute().Format(format) != "2013-11-18 17:51:59.999999999" {
|
||||||
|
t.Errorf("EndOfMinute")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).EndOfHour().Format(format) != "2013-11-18 17:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfHour")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).EndOfDay().Format(format) != "2013-11-18 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfDay")
|
||||||
|
}
|
||||||
|
|
||||||
|
FirstDayMonday = true
|
||||||
|
if New(n).EndOfWeek().Format(format) != "2013-11-24 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfWeek, FirstDayMonday")
|
||||||
|
}
|
||||||
|
|
||||||
|
FirstDayMonday = false
|
||||||
|
if New(n).EndOfWeek().Format(format) != "2013-11-23 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfWeek")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).EndOfMonth().Format(format) != "2013-11-30 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfMonth")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).EndOfQuarter().Format(format) != "2013-12-31 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfQuarter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n.AddDate(0, -1, 0)).EndOfQuarter().Format(format) != "2013-12-31 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfQuarter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n.AddDate(0, 1, 0)).EndOfQuarter().Format(format) != "2013-12-31 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfQuarter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).EndOfYear().Format(format) != "2013-12-31 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfYear")
|
||||||
|
}
|
||||||
|
|
||||||
|
n1 := time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)
|
||||||
|
if New(n1).EndOfMonth().Format(format) != "2013-02-28 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfMonth for 2013/02")
|
||||||
|
}
|
||||||
|
|
||||||
|
n2 := time.Date(1900, 02, 18, 17, 51, 49, 123456789, time.UTC)
|
||||||
|
if New(n2).EndOfMonth().Format(format) != "1900-02-28 23:59:59.999999999" {
|
||||||
|
t.Errorf("EndOfMonth")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMondayAndSunday(t *testing.T) {
|
||||||
|
n := time.Date(2013, 11, 19, 17, 51, 49, 123456789, time.UTC)
|
||||||
|
n2 := time.Date(2013, 11, 24, 17, 51, 49, 123456789, time.UTC)
|
||||||
|
|
||||||
|
if New(n).Monday().Format(format) != "2013-11-18 00:00:00" {
|
||||||
|
t.Errorf("Monday")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n2).Monday().Format(format) != "2013-11-18 00:00:00" {
|
||||||
|
t.Errorf("Monday")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).Sunday().Format(format) != "2013-11-24 00:00:00" {
|
||||||
|
t.Errorf("Sunday")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n2).Sunday().Format(format) != "2013-11-24 00:00:00" {
|
||||||
|
t.Errorf("Sunday")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).EndOfSunday().Format(format) != "2013-11-24 23:59:59.999999999" {
|
||||||
|
t.Errorf("Sunday")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).BeginningOfWeek().Format(format) != "2013-11-17 00:00:00" {
|
||||||
|
t.Errorf("BeginningOfWeek, FirstDayMonday")
|
||||||
|
}
|
||||||
|
|
||||||
|
FirstDayMonday = true
|
||||||
|
if New(n).BeginningOfWeek().Format(format) != "2013-11-18 00:00:00" {
|
||||||
|
t.Errorf("BeginningOfWeek, FirstDayMonday")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
n := time.Date(2013, 11, 18, 17, 51, 49, 123456789, time.UTC)
|
||||||
|
if New(n).MustParse("10-12").Format(format) != "2013-10-12 00:00:00" {
|
||||||
|
t.Errorf("Parse 10-12")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("2013-12-19 23:28:09.999999999 +0800 CST").Format(format) != "2013-12-19 23:28:09" {
|
||||||
|
t.Errorf("Parse two strings 2013-12-19 23:28:09.999999999 +0800 CST")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("2002-10-12 22:14").Format(format) != "2002-10-12 22:14:00" {
|
||||||
|
t.Errorf("Parse 2002-10-12 22:14")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("2002-10-12 2:4").Format(format) != "2002-10-12 02:04:00" {
|
||||||
|
t.Errorf("Parse 2002-10-12 2:4")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("2002-10-12 02:04").Format(format) != "2002-10-12 02:04:00" {
|
||||||
|
t.Errorf("Parse 2002-10-12 02:04")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("2002-10-12 22:14:56").Format(format) != "2002-10-12 22:14:56" {
|
||||||
|
t.Errorf("Parse 2002-10-12 22:14:56")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("2002-10-12").Format(format) != "2002-10-12 00:00:00" {
|
||||||
|
t.Errorf("Parse 2002-10-12")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("18").Format(format) != "2013-11-18 18:00:00" {
|
||||||
|
t.Errorf("Parse 18 as hour")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("18:20").Format(format) != "2013-11-18 18:20:00" {
|
||||||
|
t.Errorf("Parse 18:20")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("00:01").Format(format) != "2013-11-18 00:01:00" {
|
||||||
|
t.Errorf("Parse 00:01")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("18:20:39").Format(format) != "2013-11-18 18:20:39" {
|
||||||
|
t.Errorf("Parse 18:20:39")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("18:20:39", "2011-01-01").Format(format) != "2011-01-01 18:20:39" {
|
||||||
|
t.Errorf("Parse two strings 18:20:39, 2011-01-01")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("2011-1-1", "18:20:39").Format(format) != "2011-01-01 18:20:39" {
|
||||||
|
t.Errorf("Parse two strings 2011-01-01, 18:20:39")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("2011-01-01", "18").Format(format) != "2011-01-01 18:00:00" {
|
||||||
|
t.Errorf("Parse two strings 2011-01-01, 18")
|
||||||
|
}
|
||||||
|
|
||||||
|
TimeFormats = append(TimeFormats, "02 Jan 15:04")
|
||||||
|
if New(n).MustParse("04 Feb 12:09").Format(format) != "2013-02-04 12:09:00" {
|
||||||
|
t.Errorf("Parse 04 Feb 12:09 with specified format")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("23:28:9 Dec 19, 2013 PST").Format(format) != "2013-12-19 23:28:09" {
|
||||||
|
t.Errorf("Parse 23:28:9 Dec 19, 2013 PST")
|
||||||
|
}
|
||||||
|
|
||||||
|
if New(n).MustParse("23:28:9 Dec 19, 2013 PST").Location().String() != "PST" {
|
||||||
|
t.Errorf("Parse 23:28:9 Dec 19, 2013 PST shouldn't lose time zone")
|
||||||
|
}
|
||||||
|
|
||||||
|
n2 := New(n).MustParse("23:28:9 Dec 19, 2013 PST")
|
||||||
|
if New(n2).MustParse("10:20").Location().String() != "PST" {
|
||||||
|
t.Errorf("Parse 10:20 shouldn't change time zone")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBetween(t *testing.T) {
|
||||||
|
tm := time.Date(2015, 06, 30, 17, 51, 49, 123456789, time.Now().Location())
|
||||||
|
if !New(tm).Between("23:28:9 Dec 19, 2013 PST", "23:28:9 Dec 19, 2015 PST") {
|
||||||
|
t.Errorf("Between")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !New(tm).Between("2015-05-12 12:20", "2015-06-30 17:51:50") {
|
||||||
|
t.Errorf("Between")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Example() {
|
||||||
|
time.Now() // 2013-11-18 17:51:49.123456789 Mon
|
||||||
|
|
||||||
|
BeginningOfMinute() // 2013-11-18 17:51:00 Mon
|
||||||
|
BeginningOfHour() // 2013-11-18 17:00:00 Mon
|
||||||
|
BeginningOfDay() // 2013-11-18 00:00:00 Mon
|
||||||
|
BeginningOfWeek() // 2013-11-17 00:00:00 Sun
|
||||||
|
|
||||||
|
FirstDayMonday = true // Set Monday as first day
|
||||||
|
BeginningOfWeek() // 2013-11-18 00:00:00 Mon
|
||||||
|
BeginningOfMonth() // 2013-11-01 00:00:00 Fri
|
||||||
|
BeginningOfQuarter() // 2013-10-01 00:00:00 Tue
|
||||||
|
BeginningOfYear() // 2013-01-01 00:00:00 Tue
|
||||||
|
|
||||||
|
EndOfMinute() // 2013-11-18 17:51:59.999999999 Mon
|
||||||
|
EndOfHour() // 2013-11-18 17:59:59.999999999 Mon
|
||||||
|
EndOfDay() // 2013-11-18 23:59:59.999999999 Mon
|
||||||
|
EndOfWeek() // 2013-11-23 23:59:59.999999999 Sat
|
||||||
|
|
||||||
|
FirstDayMonday = true // Set Monday as first day
|
||||||
|
EndOfWeek() // 2013-11-24 23:59:59.999999999 Sun
|
||||||
|
EndOfMonth() // 2013-11-30 23:59:59.999999999 Sat
|
||||||
|
EndOfQuarter() // 2013-12-31 23:59:59.999999999 Tue
|
||||||
|
EndOfYear() // 2013-12-31 23:59:59.999999999 Tue
|
||||||
|
|
||||||
|
// Use another time
|
||||||
|
t := time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC)
|
||||||
|
New(t).EndOfMonth() // 2013-02-28 23:59:59.999999999 Thu
|
||||||
|
|
||||||
|
Monday() // 2013-11-18 00:00:00 Mon
|
||||||
|
Sunday() // 2013-11-24 00:00:00 Sun
|
||||||
|
EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun
|
||||||
|
}
|
11
orm/.codeclimate.yml
Normal file
11
orm/.codeclimate.yml
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
---
|
||||||
|
engines:
|
||||||
|
gofmt:
|
||||||
|
enabled: true
|
||||||
|
govet:
|
||||||
|
enabled: true
|
||||||
|
golint:
|
||||||
|
enabled: true
|
||||||
|
ratings:
|
||||||
|
paths:
|
||||||
|
- "**.go"
|
2
orm/.gitignore
vendored
Normal file
2
orm/.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
documents
|
||||||
|
_book
|
52
orm/CONTRIBUTING.md
Normal file
52
orm/CONTRIBUTING.md
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
# How to Contribute
|
||||||
|
|
||||||
|
## Bug Report
|
||||||
|
|
||||||
|
- Do a search on GitHub under Issues in case it has already been reported
|
||||||
|
- Submit __executable script__ or failing test pull request that could demonstrates the issue is *MUST HAVE*
|
||||||
|
|
||||||
|
## Feature Request
|
||||||
|
|
||||||
|
- Feature request with pull request is welcome
|
||||||
|
- Or it won't be implemented until I (other developers) find it is helpful for my (their) daily work
|
||||||
|
|
||||||
|
## Pull Request
|
||||||
|
|
||||||
|
- Prefer single commit pull request, that make the git history can be a bit easier to follow.
|
||||||
|
- New features need to be covered with tests to make sure your code works as expected, and won't be broken by others in future
|
||||||
|
|
||||||
|
## Contributing to Documentation
|
||||||
|
|
||||||
|
- You are welcome ;)
|
||||||
|
- You can help improve the README by making them more coherent, consistent or readable, and add more godoc documents to make people easier to follow.
|
||||||
|
- Blogs & Usage Guides & PPT also welcome, please add them to https://github.com/jinzhu/gorm/wiki/Guides
|
||||||
|
|
||||||
|
### Executable script template
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var db *gorm.DB
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
db, err = gorm.Open("sqlite3", "test.db")
|
||||||
|
// db, err = gorm.Open("postgres", "user=username dbname=password sslmode=disable")
|
||||||
|
// db, err = gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
db.LogMode(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Your code
|
||||||
|
}
|
||||||
|
```
|
21
orm/License
Normal file
21
orm/License
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
THE SOFTWARE.
|
46
orm/README.md
Normal file
46
orm/README.md
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# GORM
|
||||||
|
|
||||||
|
The fantastic ORM library for Golang, aims to be developer friendly.
|
||||||
|
|
||||||
|
[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||||
|
[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921)
|
||||||
|
[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm)
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
* Full-Featured ORM (almost)
|
||||||
|
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism)
|
||||||
|
* Callbacks (Before/After Create/Save/Update/Delete/Find)
|
||||||
|
* Preloading (eager loading)
|
||||||
|
* Transactions
|
||||||
|
* Composite Primary Key
|
||||||
|
* SQL Builder
|
||||||
|
* Auto Migrations
|
||||||
|
* Logger
|
||||||
|
* Extendable, write Plugins based on GORM callbacks
|
||||||
|
* Every feature comes with tests
|
||||||
|
* Developer Friendly
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
* GORM Guides [jinzhu.github.com/gorm](http://jinzhu.github.io/gorm)
|
||||||
|
|
||||||
|
## Upgrading To V1.0
|
||||||
|
|
||||||
|
* [CHANGELOG](http://jinzhu.github.io/gorm/changelog.html)
|
||||||
|
|
||||||
|
# Author
|
||||||
|
|
||||||
|
**jinzhu**
|
||||||
|
|
||||||
|
* <http://github.com/jinzhu>
|
||||||
|
* <wosmvp@gmail.com>
|
||||||
|
* <http://twitter.com/zhangjinzhu>
|
||||||
|
|
||||||
|
# Contributors
|
||||||
|
|
||||||
|
https://github.com/jinzhu/gorm/graphs/contributors
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License).
|
374
orm/association.go
Normal file
374
orm/association.go
Normal file
|
@ -0,0 +1,374 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"github.com/revel/modules/db/app"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Association Mode contains some helper methods to handle relationship things easily.
|
||||||
|
type Association struct {
|
||||||
|
Error error
|
||||||
|
scope *Scope
|
||||||
|
column string
|
||||||
|
field *Field
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find find out all related associations
|
||||||
|
func (association *Association) Find(value interface{}) *Association {
|
||||||
|
association.scope.related(value, association.column)
|
||||||
|
return association.setErr(association.scope.db.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append append new associations for many2many, has_many, replace current association for has_one, belongs_to
|
||||||
|
func (association *Association) Append(values ...interface{}) *Association {
|
||||||
|
if association.Error != nil {
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship := association.field.Relationship; relationship.Kind == "has_one" {
|
||||||
|
return association.Replace(values...)
|
||||||
|
}
|
||||||
|
return association.saveAssociations(values...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace replace current associations with new one
|
||||||
|
func (association *Association) Replace(values ...interface{}) *Association {
|
||||||
|
if association.Error != nil {
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
relationship = association.field.Relationship
|
||||||
|
scope = association.scope
|
||||||
|
field = association.field.Field
|
||||||
|
newDB = scope.NewDB()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Append new values
|
||||||
|
association.field.Set(reflect.Zero(association.field.Field.Type()))
|
||||||
|
association.saveAssociations(values...)
|
||||||
|
|
||||||
|
// Belongs To
|
||||||
|
if relationship.Kind == "belongs_to" {
|
||||||
|
// Set foreign key to be null when clearing value (length equals 0)
|
||||||
|
if len(values) == 0 {
|
||||||
|
// Set foreign key to be nil
|
||||||
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
foreignKeyMap[foreignKey] = nil
|
||||||
|
}
|
||||||
|
association.setErr(newDB.Model(scope.Value).UpdateColumn(foreignKeyMap).Error)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Polymorphic Relations
|
||||||
|
if relationship.PolymorphicDBName != "" {
|
||||||
|
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete Relations except new created
|
||||||
|
if len(values) > 0 {
|
||||||
|
var associationForeignFieldNames, associationForeignDBNames []string
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
// if many to many relations, get association fields name from association foreign keys
|
||||||
|
associationScope := scope.New(reflect.New(field.Type()).Interface())
|
||||||
|
for idx, dbName := range relationship.AssociationForeignFieldNames {
|
||||||
|
if field, ok := associationScope.FieldByName(dbName); ok {
|
||||||
|
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||||
|
associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If has one/many relations, use primary keys
|
||||||
|
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||||
|
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||||
|
associationForeignDBNames = append(associationForeignDBNames, field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface())
|
||||||
|
|
||||||
|
if len(newPrimaryKeys) > 0 {
|
||||||
|
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys))
|
||||||
|
newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
// if many to many relations, delete related relations from join table
|
||||||
|
var sourceForeignFieldNames []string
|
||||||
|
|
||||||
|
for _, dbName := range relationship.ForeignFieldNames {
|
||||||
|
if field, ok := scope.FieldByName(dbName); ok {
|
||||||
|
sourceForeignFieldNames = append(sourceForeignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 {
|
||||||
|
newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...)
|
||||||
|
|
||||||
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||||
|
}
|
||||||
|
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||||
|
// has_one or has_many relations, set foreign key to be nil (TODO or delete them?)
|
||||||
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
|
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
foreignKeyMap[foreignKey] = nil
|
||||||
|
if field, ok := scope.FieldByName(relationship.AssociationForeignFieldNames[idx]); ok {
|
||||||
|
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||||
|
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete remove relationship between source & passed arguments, but won't delete those arguments
|
||||||
|
func (association *Association) Delete(values ...interface{}) *Association {
|
||||||
|
if association.Error != nil {
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
relationship = association.field.Relationship
|
||||||
|
scope = association.scope
|
||||||
|
field = association.field.Field
|
||||||
|
newDB = scope.NewDB()
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(values) == 0 {
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
var deletingResourcePrimaryFieldNames, deletingResourcePrimaryDBNames []string
|
||||||
|
for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() {
|
||||||
|
deletingResourcePrimaryFieldNames = append(deletingResourcePrimaryFieldNames, field.Name)
|
||||||
|
deletingResourcePrimaryDBNames = append(deletingResourcePrimaryDBNames, field.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
deletingPrimaryKeys := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, values...)
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
// source value's foreign keys
|
||||||
|
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||||
|
newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get association's foreign fields name
|
||||||
|
var associationScope = scope.New(reflect.New(field.Type()).Interface())
|
||||||
|
var associationForeignFieldNames []string
|
||||||
|
for _, associationDBName := range relationship.AssociationForeignFieldNames {
|
||||||
|
if field, ok := associationScope.FieldByName(associationDBName); ok {
|
||||||
|
associationForeignFieldNames = append(associationForeignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// association value's foreign keys
|
||||||
|
deletingPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, values...)
|
||||||
|
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys))
|
||||||
|
newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...)
|
||||||
|
|
||||||
|
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship))
|
||||||
|
} else {
|
||||||
|
var foreignKeyMap = map[string]interface{}{}
|
||||||
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
foreignKeyMap[foreignKey] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.Kind == "belongs_to" {
|
||||||
|
// find with deleting relation's foreign keys
|
||||||
|
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, values...)
|
||||||
|
newDB = newDB.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
|
toQueryValues(primaryKeys)...,
|
||||||
|
)
|
||||||
|
|
||||||
|
// set foreign key to be null if there are some records affected
|
||||||
|
modelValue := reflect.New(scope.GetModelStruct().ModelType).Interface()
|
||||||
|
if results := newDB.Model(modelValue).UpdateColumn(foreignKeyMap); results.Error == nil {
|
||||||
|
if results.RowsAffected > 0 {
|
||||||
|
scope.updatedAttrsWithValues(foreignKeyMap)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
association.setErr(results.Error)
|
||||||
|
}
|
||||||
|
} else if relationship.Kind == "has_one" || relationship.Kind == "has_many" {
|
||||||
|
// find all relations
|
||||||
|
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||||
|
newDB = newDB.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
|
toQueryValues(primaryKeys)...,
|
||||||
|
)
|
||||||
|
|
||||||
|
// only include those deleting relations
|
||||||
|
newDB = newDB.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, deletingResourcePrimaryDBNames), toQueryMarks(deletingPrimaryKeys)),
|
||||||
|
toQueryValues(deletingPrimaryKeys)...,
|
||||||
|
)
|
||||||
|
|
||||||
|
// set matched relation's foreign key to be null
|
||||||
|
fieldValue := reflect.New(association.field.Field.Type()).Interface()
|
||||||
|
association.setErr(newDB.Model(fieldValue).UpdateColumn(foreignKeyMap).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove deleted records from source's field
|
||||||
|
if association.Error == nil {
|
||||||
|
if field.Kind() == reflect.Slice {
|
||||||
|
leftValues := reflect.Zero(field.Type())
|
||||||
|
|
||||||
|
for i := 0; i < field.Len(); i++ {
|
||||||
|
reflectValue := field.Index(i)
|
||||||
|
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, reflectValue.Interface())[0]
|
||||||
|
var isDeleted = false
|
||||||
|
for _, pk := range deletingPrimaryKeys {
|
||||||
|
if equalAsString(primaryKey, pk) {
|
||||||
|
isDeleted = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isDeleted {
|
||||||
|
leftValues = reflect.Append(leftValues, reflectValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
association.field.Set(leftValues)
|
||||||
|
} else if field.Kind() == reflect.Struct {
|
||||||
|
primaryKey := scope.getColumnAsArray(deletingResourcePrimaryFieldNames, field.Interface())[0]
|
||||||
|
for _, pk := range deletingPrimaryKeys {
|
||||||
|
if equalAsString(primaryKey, pk) {
|
||||||
|
association.field.Set(reflect.Zero(field.Type()))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear remove relationship between source & current associations, won't delete those associations
|
||||||
|
func (association *Association) Clear() *Association {
|
||||||
|
return association.Replace()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count return the count of current associations
|
||||||
|
func (association *Association) Count() int {
|
||||||
|
var (
|
||||||
|
count = 0
|
||||||
|
relationship = association.field.Relationship
|
||||||
|
scope = association.scope
|
||||||
|
fieldValue = association.field.Field.Interface()
|
||||||
|
query = scope.DB()
|
||||||
|
)
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
|
||||||
|
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||||
|
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
|
||||||
|
query = query.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
|
toQueryValues(primaryKeys)...,
|
||||||
|
)
|
||||||
|
} else if relationship.Kind == "belongs_to" {
|
||||||
|
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
|
||||||
|
query = query.Where(
|
||||||
|
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
|
||||||
|
toQueryValues(primaryKeys)...,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.PolymorphicType != "" {
|
||||||
|
query = query.Where(
|
||||||
|
fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)),
|
||||||
|
relationship.PolymorphicValue,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
query.Model(fieldValue).Count(&count)
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveAssociations save passed values as associations
|
||||||
|
func (association *Association) saveAssociations(values ...interface{}) *Association {
|
||||||
|
var (
|
||||||
|
scope = association.scope
|
||||||
|
field = association.field
|
||||||
|
relationship = field.Relationship
|
||||||
|
)
|
||||||
|
|
||||||
|
saveAssociation := func(reflectValue reflect.Value) {
|
||||||
|
// value has to been pointer
|
||||||
|
if reflectValue.Kind() != reflect.Ptr {
|
||||||
|
reflectPtr := reflect.New(reflectValue.Type())
|
||||||
|
reflectPtr.Elem().Set(reflectValue)
|
||||||
|
reflectValue = reflectPtr
|
||||||
|
}
|
||||||
|
|
||||||
|
// value has to been saved for many2many
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
if scope.New(reflectValue.Interface()).PrimaryKeyZero() {
|
||||||
|
association.setErr(scope.NewDB().Save(reflectValue.Interface()).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign Fields
|
||||||
|
var fieldType = field.Field.Type()
|
||||||
|
var setFieldBackToValue, setSliceFieldBackToValue bool
|
||||||
|
if reflectValue.Type().AssignableTo(fieldType) {
|
||||||
|
field.Set(reflectValue)
|
||||||
|
} else if reflectValue.Type().Elem().AssignableTo(fieldType) {
|
||||||
|
// if field's type is struct, then need to set value back to argument after save
|
||||||
|
setFieldBackToValue = true
|
||||||
|
field.Set(reflectValue.Elem())
|
||||||
|
} else if fieldType.Kind() == reflect.Slice {
|
||||||
|
if reflectValue.Type().AssignableTo(fieldType.Elem()) {
|
||||||
|
field.Set(reflect.Append(field.Field, reflectValue))
|
||||||
|
} else if reflectValue.Type().Elem().AssignableTo(fieldType.Elem()) {
|
||||||
|
// if field's type is slice of struct, then need to set value back to argument after save
|
||||||
|
setSliceFieldBackToValue = true
|
||||||
|
field.Set(reflect.Append(field.Field, reflectValue.Elem()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.Kind == "many_to_many" {
|
||||||
|
association.setErr(relationship.JoinTableHandler.Add(relationship.JoinTableHandler, scope.NewDB(), scope.Value, reflectValue.Interface()))
|
||||||
|
} else {
|
||||||
|
association.setErr(scope.NewDB().Select(field.Name).Save(scope.Value).Error)
|
||||||
|
|
||||||
|
if setFieldBackToValue {
|
||||||
|
reflectValue.Elem().Set(field.Field)
|
||||||
|
} else if setSliceFieldBackToValue {
|
||||||
|
reflectValue.Elem().Set(field.Field.Index(field.Field.Len() - 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
reflectValue := reflect.ValueOf(value)
|
||||||
|
indirectReflectValue := reflect.Indirect(reflectValue)
|
||||||
|
if indirectReflectValue.Kind() == reflect.Struct {
|
||||||
|
saveAssociation(reflectValue)
|
||||||
|
} else if indirectReflectValue.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < indirectReflectValue.Len(); i++ {
|
||||||
|
saveAssociation(indirectReflectValue.Index(i))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
association.setErr(errors.New("invalid value type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return association
|
||||||
|
}
|
||||||
|
|
||||||
|
func (association *Association) setErr(err error) *Association {
|
||||||
|
if err != nil {
|
||||||
|
association.Error = err
|
||||||
|
}
|
||||||
|
return association
|
||||||
|
}
|
907
orm/association_test.go
Normal file
907
orm/association_test.go
Normal file
|
@ -0,0 +1,907 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBelongsTo(t *testing.T) {
|
||||||
|
post := Post{
|
||||||
|
Title: "post belongs to",
|
||||||
|
Body: "body belongs to",
|
||||||
|
Category: Category{Name: "Category 1"},
|
||||||
|
MainCategory: Category{Name: "Main Category 1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&post).Error; err != nil {
|
||||||
|
t.Error("Got errors when save post", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if post.Category.ID == 0 || post.MainCategory.ID == 0 {
|
||||||
|
t.Errorf("Category's primary key should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if post.CategoryId.Int64 == 0 || post.MainCategoryId == 0 {
|
||||||
|
t.Errorf("post's foreign key should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query
|
||||||
|
var category1 Category
|
||||||
|
DB.Model(&post).Association("Category").Find(&category1)
|
||||||
|
if category1.Name != "Category 1" {
|
||||||
|
t.Errorf("Query belongs to relations with Association")
|
||||||
|
}
|
||||||
|
|
||||||
|
var mainCategory1 Category
|
||||||
|
DB.Model(&post).Association("MainCategory").Find(&mainCategory1)
|
||||||
|
if mainCategory1.Name != "Main Category 1" {
|
||||||
|
t.Errorf("Query belongs to relations with Association")
|
||||||
|
}
|
||||||
|
|
||||||
|
var category11 Category
|
||||||
|
DB.Model(&post).Related(&category11)
|
||||||
|
if category11.Name != "Category 1" {
|
||||||
|
t.Errorf("Query belongs to relations with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&post).Association("Category").Count() != 1 {
|
||||||
|
t.Errorf("Post's category count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&post).Association("MainCategory").Count() != 1 {
|
||||||
|
t.Errorf("Post's main category count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
var category2 = Category{
|
||||||
|
Name: "Category 2",
|
||||||
|
}
|
||||||
|
DB.Model(&post).Association("Category").Append(&category2)
|
||||||
|
|
||||||
|
if category2.ID == 0 {
|
||||||
|
t.Errorf("Category should has ID when created with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
var category21 Category
|
||||||
|
DB.Model(&post).Related(&category21)
|
||||||
|
|
||||||
|
if category21.Name != "Category 2" {
|
||||||
|
t.Errorf("Category should be updated with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&post).Association("Category").Count() != 1 {
|
||||||
|
t.Errorf("Post's category count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
var category3 = Category{
|
||||||
|
Name: "Category 3",
|
||||||
|
}
|
||||||
|
DB.Model(&post).Association("Category").Replace(&category3)
|
||||||
|
|
||||||
|
if category3.ID == 0 {
|
||||||
|
t.Errorf("Category should has ID when created with Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
var category31 Category
|
||||||
|
DB.Model(&post).Related(&category31)
|
||||||
|
if category31.Name != "Category 3" {
|
||||||
|
t.Errorf("Category should be updated with Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&post).Association("Category").Count() != 1 {
|
||||||
|
t.Errorf("Post's category count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
DB.Model(&post).Association("Category").Delete(&category2)
|
||||||
|
if DB.Model(&post).Related(&Category{}).RecordNotFound() {
|
||||||
|
t.Errorf("Should not delete any category when Delete a unrelated Category")
|
||||||
|
}
|
||||||
|
|
||||||
|
if post.Category.Name == "" {
|
||||||
|
t.Errorf("Post's category should not be reseted when Delete a unrelated Category")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&post).Association("Category").Delete(&category3)
|
||||||
|
|
||||||
|
if post.Category.Name != "" {
|
||||||
|
t.Errorf("Post's category should be reseted after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
var category41 Category
|
||||||
|
DB.Model(&post).Related(&category41)
|
||||||
|
if category41.Name != "" {
|
||||||
|
t.Errorf("Category should be deleted with Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
|
||||||
|
t.Errorf("Post's category count should be 0 after Delete, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&post).Association("Category").Append(&Category{
|
||||||
|
Name: "Category 2",
|
||||||
|
})
|
||||||
|
|
||||||
|
if DB.Model(&post).Related(&Category{}).RecordNotFound() {
|
||||||
|
t.Errorf("Should find category after append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if post.Category.Name == "" {
|
||||||
|
t.Errorf("Post's category should has value after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&post).Association("Category").Clear()
|
||||||
|
|
||||||
|
if post.Category.Name != "" {
|
||||||
|
t.Errorf("Post's category should be cleared after Clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.Model(&post).Related(&Category{}).RecordNotFound() {
|
||||||
|
t.Errorf("Should not find any category after Clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
|
||||||
|
t.Errorf("Post's category count should be 0 after Clear, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Association mode with soft delete
|
||||||
|
category6 := Category{
|
||||||
|
Name: "Category 6",
|
||||||
|
}
|
||||||
|
DB.Model(&post).Association("Category").Append(&category6)
|
||||||
|
|
||||||
|
if count := DB.Model(&post).Association("Category").Count(); count != 1 {
|
||||||
|
t.Errorf("Post's category count should be 1 after Append, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Delete(&category6)
|
||||||
|
|
||||||
|
if count := DB.Model(&post).Association("Category").Count(); count != 0 {
|
||||||
|
t.Errorf("Post's category count should be 0 after the category has been deleted, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&post).Association("Category").Find(&Category{}).Error; err == nil {
|
||||||
|
t.Errorf("Post's category is not findable after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Unscoped().Model(&post).Association("Category").Count(); count != 1 {
|
||||||
|
t.Errorf("Post's category count should be 1 when query with Unscoped, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Unscoped().Model(&post).Association("Category").Find(&Category{}).Error; err != nil {
|
||||||
|
t.Errorf("Post's category should be findable when query with Unscoped, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBelongsToOverrideForeignKey1(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profile Profile `gorm:"ForeignKey:ProfileRefer"`
|
||||||
|
ProfileRefer int
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "belongs_to" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileRefer"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBelongsToOverrideForeignKey2(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Refer string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profile Profile `gorm:"ForeignKey:ProfileID;AssociationForeignKey:Refer"`
|
||||||
|
ProfileID int
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "belongs_to" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"ProfileID"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasOne(t *testing.T) {
|
||||||
|
user := User{
|
||||||
|
Name: "has one",
|
||||||
|
CreditCard: CreditCard{Number: "411111111111"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Error("Got errors when save user", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.CreditCard.UserId.Int64 == 0 {
|
||||||
|
t.Errorf("CreditCard's foreign key should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query
|
||||||
|
var creditCard1 CreditCard
|
||||||
|
DB.Model(&user).Related(&creditCard1)
|
||||||
|
|
||||||
|
if creditCard1.Number != "411111111111" {
|
||||||
|
t.Errorf("Query has one relations with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
var creditCard11 CreditCard
|
||||||
|
DB.Model(&user).Association("CreditCard").Find(&creditCard11)
|
||||||
|
|
||||||
|
if creditCard11.Number != "411111111111" {
|
||||||
|
t.Errorf("Query has one relations with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||||
|
t.Errorf("User's credit card count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
var creditcard2 = CreditCard{
|
||||||
|
Number: "411111111112",
|
||||||
|
}
|
||||||
|
DB.Model(&user).Association("CreditCard").Append(&creditcard2)
|
||||||
|
|
||||||
|
if creditcard2.ID == 0 {
|
||||||
|
t.Errorf("Creditcard should has ID when created with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
var creditcard21 CreditCard
|
||||||
|
DB.Model(&user).Related(&creditcard21)
|
||||||
|
if creditcard21.Number != "411111111112" {
|
||||||
|
t.Errorf("CreditCard should be updated with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||||
|
t.Errorf("User's credit card count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
var creditcard3 = CreditCard{
|
||||||
|
Number: "411111111113",
|
||||||
|
}
|
||||||
|
DB.Model(&user).Association("CreditCard").Replace(&creditcard3)
|
||||||
|
|
||||||
|
if creditcard3.ID == 0 {
|
||||||
|
t.Errorf("Creditcard should has ID when created with Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
var creditcard31 CreditCard
|
||||||
|
DB.Model(&user).Related(&creditcard31)
|
||||||
|
if creditcard31.Number != "411111111113" {
|
||||||
|
t.Errorf("CreditCard should be updated with Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||||
|
t.Errorf("User's credit card count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
DB.Model(&user).Association("CreditCard").Delete(&creditcard2)
|
||||||
|
var creditcard4 CreditCard
|
||||||
|
DB.Model(&user).Related(&creditcard4)
|
||||||
|
if creditcard4.Number != "411111111113" {
|
||||||
|
t.Errorf("Should not delete credit card when Delete a unrelated CreditCard")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||||
|
t.Errorf("User's credit card count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&user).Association("CreditCard").Delete(&creditcard3)
|
||||||
|
if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
|
||||||
|
t.Errorf("Should delete credit card with Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("CreditCard").Count() != 0 {
|
||||||
|
t.Errorf("User's credit card count should be 0 after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
var creditcard5 = CreditCard{
|
||||||
|
Number: "411111111115",
|
||||||
|
}
|
||||||
|
DB.Model(&user).Association("CreditCard").Append(&creditcard5)
|
||||||
|
|
||||||
|
if DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
|
||||||
|
t.Errorf("Should added credit card with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("CreditCard").Count() != 1 {
|
||||||
|
t.Errorf("User's credit card count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&user).Association("CreditCard").Clear()
|
||||||
|
if !DB.Model(&user).Related(&CreditCard{}).RecordNotFound() {
|
||||||
|
t.Errorf("Credit card should be deleted with Clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("CreditCard").Count() != 0 {
|
||||||
|
t.Errorf("User's credit card count should be 0 after Clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Association mode with soft delete
|
||||||
|
var creditcard6 = CreditCard{
|
||||||
|
Number: "411111111116",
|
||||||
|
}
|
||||||
|
DB.Model(&user).Association("CreditCard").Append(&creditcard6)
|
||||||
|
|
||||||
|
if count := DB.Model(&user).Association("CreditCard").Count(); count != 1 {
|
||||||
|
t.Errorf("User's credit card count should be 1 after Append, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Delete(&creditcard6)
|
||||||
|
|
||||||
|
if count := DB.Model(&user).Association("CreditCard").Count(); count != 0 {
|
||||||
|
t.Errorf("User's credit card count should be 0 after credit card deleted, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err == nil {
|
||||||
|
t.Errorf("User's creditcard is not findable after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Unscoped().Model(&user).Association("CreditCard").Count(); count != 1 {
|
||||||
|
t.Errorf("User's credit card count should be 1 when query with Unscoped, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Unscoped().Model(&user).Association("CreditCard").Find(&CreditCard{}).Error; err != nil {
|
||||||
|
t.Errorf("User's creditcard should be findable when query with Unscoped, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasOneOverrideForeignKey1(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserRefer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profile Profile `gorm:"ForeignKey:UserRefer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "has_one" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasOneOverrideForeignKey2(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserID uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Refer string
|
||||||
|
Profile Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "has_one" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasMany(t *testing.T) {
|
||||||
|
post := Post{
|
||||||
|
Title: "post has many",
|
||||||
|
Body: "body has many",
|
||||||
|
Comments: []*Comment{{Content: "Comment 1"}, {Content: "Comment 2"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&post).Error; err != nil {
|
||||||
|
t.Error("Got errors when save post", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, comment := range post.Comments {
|
||||||
|
if comment.PostId == 0 {
|
||||||
|
t.Errorf("comment's PostID should be updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var compareComments = func(comments []Comment, contents []string) bool {
|
||||||
|
var commentContents []string
|
||||||
|
for _, comment := range comments {
|
||||||
|
commentContents = append(commentContents, comment.Content)
|
||||||
|
}
|
||||||
|
sort.Strings(commentContents)
|
||||||
|
sort.Strings(contents)
|
||||||
|
return reflect.DeepEqual(commentContents, contents)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query
|
||||||
|
if DB.First(&Comment{}, "content = ?", "Comment 1").Error != nil {
|
||||||
|
t.Errorf("Comment 1 should be saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
var comments1 []Comment
|
||||||
|
DB.Model(&post).Association("Comments").Find(&comments1)
|
||||||
|
if !compareComments(comments1, []string{"Comment 1", "Comment 2"}) {
|
||||||
|
t.Errorf("Query has many relations with Association")
|
||||||
|
}
|
||||||
|
|
||||||
|
var comments11 []Comment
|
||||||
|
DB.Model(&post).Related(&comments11)
|
||||||
|
if !compareComments(comments11, []string{"Comment 1", "Comment 2"}) {
|
||||||
|
t.Errorf("Query has many relations with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&post).Association("Comments").Count() != 2 {
|
||||||
|
t.Errorf("Post's comments count should be 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
DB.Model(&post).Association("Comments").Append(&Comment{Content: "Comment 3"})
|
||||||
|
|
||||||
|
var comments2 []Comment
|
||||||
|
DB.Model(&post).Related(&comments2)
|
||||||
|
if !compareComments(comments2, []string{"Comment 1", "Comment 2", "Comment 3"}) {
|
||||||
|
t.Errorf("Append new record to has many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&post).Association("Comments").Count() != 3 {
|
||||||
|
t.Errorf("Post's comments count should be 3 after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
DB.Model(&post).Association("Comments").Delete(comments11)
|
||||||
|
|
||||||
|
var comments3 []Comment
|
||||||
|
DB.Model(&post).Related(&comments3)
|
||||||
|
if !compareComments(comments3, []string{"Comment 3"}) {
|
||||||
|
t.Errorf("Delete an existing resource for has many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&post).Association("Comments").Count() != 1 {
|
||||||
|
t.Errorf("Post's comments count should be 1 after Delete 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
DB.Model(&Post{Id: 999}).Association("Comments").Replace()
|
||||||
|
|
||||||
|
var comments4 []Comment
|
||||||
|
DB.Model(&post).Related(&comments4)
|
||||||
|
if len(comments4) == 0 {
|
||||||
|
t.Errorf("Replace for other resource should not clear all comments")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&post).Association("Comments").Replace(&Comment{Content: "Comment 4"}, &Comment{Content: "Comment 5"})
|
||||||
|
|
||||||
|
var comments41 []Comment
|
||||||
|
DB.Model(&post).Related(&comments41)
|
||||||
|
if !compareComments(comments41, []string{"Comment 4", "Comment 5"}) {
|
||||||
|
t.Errorf("Replace has many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&Post{Id: 999}).Association("Comments").Clear()
|
||||||
|
|
||||||
|
var comments5 []Comment
|
||||||
|
DB.Model(&post).Related(&comments5)
|
||||||
|
if len(comments5) == 0 {
|
||||||
|
t.Errorf("Clear should not clear all comments")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&post).Association("Comments").Clear()
|
||||||
|
|
||||||
|
var comments51 []Comment
|
||||||
|
DB.Model(&post).Related(&comments51)
|
||||||
|
if len(comments51) != 0 {
|
||||||
|
t.Errorf("Clear has many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Association mode with soft delete
|
||||||
|
var comment6 = Comment{
|
||||||
|
Content: "comment 6",
|
||||||
|
}
|
||||||
|
DB.Model(&post).Association("Comments").Append(&comment6)
|
||||||
|
|
||||||
|
if count := DB.Model(&post).Association("Comments").Count(); count != 1 {
|
||||||
|
t.Errorf("post's comments count should be 1 after Append, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Delete(&comment6)
|
||||||
|
|
||||||
|
if count := DB.Model(&post).Association("Comments").Count(); count != 0 {
|
||||||
|
t.Errorf("post's comments count should be 0 after comment been deleted, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var comments6 []Comment
|
||||||
|
if DB.Model(&post).Association("Comments").Find(&comments6); len(comments6) != 0 {
|
||||||
|
t.Errorf("post's comments count should be 0 when find with Find, but got %v", len(comments6))
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Unscoped().Model(&post).Association("Comments").Count(); count != 1 {
|
||||||
|
t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var comments61 []Comment
|
||||||
|
if DB.Unscoped().Model(&post).Association("Comments").Find(&comments61); len(comments61) != 1 {
|
||||||
|
t.Errorf("post's comments count should be 1 when query with Unscoped, but got %v", len(comments61))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasManyOverrideForeignKey1(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserRefer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profile []Profile `gorm:"ForeignKey:UserRefer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "has_many" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserRefer"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"ID"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasManyOverrideForeignKey2(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
UserID uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Refer string
|
||||||
|
Profile []Profile `gorm:"ForeignKey:UserID;AssociationForeignKey:Refer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if relation, ok := DB.NewScope(&User{}).FieldByName("Profile"); ok {
|
||||||
|
if relation.Relationship.Kind != "has_many" ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.ForeignFieldNames, []string{"UserID"}) ||
|
||||||
|
!reflect.DeepEqual(relation.Relationship.AssociationForeignFieldNames, []string{"Refer"}) {
|
||||||
|
t.Errorf("Override belongs to foreign key with tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManyToMany(t *testing.T) {
|
||||||
|
DB.Raw("delete from languages")
|
||||||
|
var languages = []Language{{Name: "ZH"}, {Name: "EN"}}
|
||||||
|
user := User{Name: "Many2Many", Languages: languages}
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
// Query
|
||||||
|
var newLanguages []Language
|
||||||
|
DB.Model(&user).Related(&newLanguages, "Languages")
|
||||||
|
if len(newLanguages) != len([]string{"ZH", "EN"}) {
|
||||||
|
t.Errorf("Query many to many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&user).Association("Languages").Find(&newLanguages)
|
||||||
|
if len(newLanguages) != len([]string{"ZH", "EN"}) {
|
||||||
|
t.Errorf("Should be able to find many to many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("Languages").Count() != len([]string{"ZH", "EN"}) {
|
||||||
|
t.Errorf("Count should return correct result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
DB.Model(&user).Association("Languages").Append(&Language{Name: "DE"})
|
||||||
|
if DB.Where("name = ?", "DE").First(&Language{}).RecordNotFound() {
|
||||||
|
t.Errorf("New record should be saved when append")
|
||||||
|
}
|
||||||
|
|
||||||
|
languageA := Language{Name: "AA"}
|
||||||
|
DB.Save(&languageA)
|
||||||
|
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&languageA)
|
||||||
|
|
||||||
|
languageC := Language{Name: "CC"}
|
||||||
|
DB.Save(&languageC)
|
||||||
|
DB.Model(&user).Association("Languages").Append(&[]Language{{Name: "BB"}, languageC})
|
||||||
|
|
||||||
|
DB.Model(&User{Id: user.Id}).Association("Languages").Append(&[]Language{{Name: "DD"}, {Name: "EE"}})
|
||||||
|
|
||||||
|
totalLanguages := []string{"ZH", "EN", "DE", "AA", "BB", "CC", "DD", "EE"}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages) {
|
||||||
|
t.Errorf("All appended languages should be saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
user.Languages = []Language{}
|
||||||
|
DB.Model(&user).Association("Languages").Find(&user.Languages)
|
||||||
|
|
||||||
|
var language Language
|
||||||
|
DB.Where("name = ?", "EE").First(&language)
|
||||||
|
DB.Model(&user).Association("Languages").Delete(language, &language)
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
|
||||||
|
t.Errorf("Relations should be deleted with Delete")
|
||||||
|
}
|
||||||
|
if DB.Where("name = ?", "EE").First(&Language{}).RecordNotFound() {
|
||||||
|
t.Errorf("Language EE should not be deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
|
||||||
|
|
||||||
|
user2 := User{Name: "Many2Many_User2", Languages: languages}
|
||||||
|
DB.Save(&user2)
|
||||||
|
|
||||||
|
DB.Model(&user).Association("Languages").Delete(languages, &languages)
|
||||||
|
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-3 || len(user.Languages) != len(totalLanguages)-3 {
|
||||||
|
t.Errorf("Relations should be deleted with Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user2).Association("Languages").Count() == 0 {
|
||||||
|
t.Errorf("Other user's relations should not be deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
var languageB Language
|
||||||
|
DB.Where("name = ?", "BB").First(&languageB)
|
||||||
|
DB.Model(&user).Association("Languages").Replace(languageB)
|
||||||
|
if len(user.Languages) != 1 || DB.Model(&user).Association("Languages").Count() != 1 {
|
||||||
|
t.Errorf("Relations should be replaced")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&user).Association("Languages").Replace()
|
||||||
|
if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
|
||||||
|
t.Errorf("Relations should be replaced with empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&user).Association("Languages").Replace(&[]Language{{Name: "FF"}, {Name: "JJ"}})
|
||||||
|
if len(user.Languages) != 2 || DB.Model(&user).Association("Languages").Count() != len([]string{"FF", "JJ"}) {
|
||||||
|
t.Errorf("Relations should be replaced")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&user).Association("Languages").Clear()
|
||||||
|
if len(user.Languages) != 0 || DB.Model(&user).Association("Languages").Count() != 0 {
|
||||||
|
t.Errorf("Relations should be cleared")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Association mode with soft delete
|
||||||
|
var language6 = Language{
|
||||||
|
Name: "language 6",
|
||||||
|
}
|
||||||
|
DB.Model(&user).Association("Languages").Append(&language6)
|
||||||
|
|
||||||
|
if count := DB.Model(&user).Association("Languages").Count(); count != 1 {
|
||||||
|
t.Errorf("user's languages count should be 1 after Append, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Delete(&language6)
|
||||||
|
|
||||||
|
if count := DB.Model(&user).Association("Languages").Count(); count != 0 {
|
||||||
|
t.Errorf("user's languages count should be 0 after language been deleted, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var languages6 []Language
|
||||||
|
if DB.Model(&user).Association("Languages").Find(&languages6); len(languages6) != 0 {
|
||||||
|
t.Errorf("user's languages count should be 0 when find with Find, but got %v", len(languages6))
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Unscoped().Model(&user).Association("Languages").Count(); count != 1 {
|
||||||
|
t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
var languages61 []Language
|
||||||
|
if DB.Unscoped().Model(&user).Association("Languages").Find(&languages61); len(languages61) != 1 {
|
||||||
|
t.Errorf("user's languages count should be 1 when query with Unscoped, but got %v", len(languages61))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelated(t *testing.T) {
|
||||||
|
user := User{
|
||||||
|
Name: "jinzhu",
|
||||||
|
BillingAddress: Address{Address1: "Billing Address - Address 1"},
|
||||||
|
ShippingAddress: Address{Address1: "Shipping Address - Address 1"},
|
||||||
|
Emails: []Email{{Email: "jinzhu@example.com"}, {Email: "jinzhu-2@example@example.com"}},
|
||||||
|
CreditCard: CreditCard{Number: "1234567890"},
|
||||||
|
Company: Company{Name: "company1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&user).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when saving user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.CreditCard.ID == 0 {
|
||||||
|
t.Errorf("After user save, credit card should have id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.BillingAddress.ID == 0 {
|
||||||
|
t.Errorf("After user save, billing address should have id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Emails[0].Id == 0 {
|
||||||
|
t.Errorf("After user save, billing address should have id")
|
||||||
|
}
|
||||||
|
|
||||||
|
var emails []Email
|
||||||
|
DB.Model(&user).Related(&emails)
|
||||||
|
if len(emails) != 2 {
|
||||||
|
t.Errorf("Should have two emails")
|
||||||
|
}
|
||||||
|
|
||||||
|
var emails2 []Email
|
||||||
|
DB.Model(&user).Where("email = ?", "jinzhu@example.com").Related(&emails2)
|
||||||
|
if len(emails2) != 1 {
|
||||||
|
t.Errorf("Should have two emails")
|
||||||
|
}
|
||||||
|
|
||||||
|
var emails3 []*Email
|
||||||
|
DB.Model(&user).Related(&emails3)
|
||||||
|
if len(emails3) != 2 {
|
||||||
|
t.Errorf("Should have two emails")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user1 User
|
||||||
|
DB.Model(&user).Related(&user1.Emails)
|
||||||
|
if len(user1.Emails) != 2 {
|
||||||
|
t.Errorf("Should have only one email match related condition")
|
||||||
|
}
|
||||||
|
|
||||||
|
var address1 Address
|
||||||
|
DB.Model(&user).Related(&address1, "BillingAddressId")
|
||||||
|
if address1.Address1 != "Billing Address - Address 1" {
|
||||||
|
t.Errorf("Should get billing address from user correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
user1 = User{}
|
||||||
|
DB.Model(&address1).Related(&user1, "BillingAddressId")
|
||||||
|
if DB.NewRecord(user1) {
|
||||||
|
t.Errorf("Should get user from address correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
DB.Model(&emails[0]).Related(&user2)
|
||||||
|
if user2.Id != user.Id || user2.Name != user.Name {
|
||||||
|
t.Errorf("Should get user from email correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
var creditcard CreditCard
|
||||||
|
var user3 User
|
||||||
|
DB.First(&creditcard, "number = ?", "1234567890")
|
||||||
|
DB.Model(&creditcard).Related(&user3)
|
||||||
|
if user3.Id != user.Id || user3.Name != user.Name {
|
||||||
|
t.Errorf("Should get user from credit card correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.Model(&CreditCard{}).Related(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("RecordNotFound for Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
var company Company
|
||||||
|
if DB.Model(&user).Related(&company, "Company").RecordNotFound() || company.Name != "company1" {
|
||||||
|
t.Errorf("RecordNotFound for Related")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForeignKey(t *testing.T) {
|
||||||
|
for _, structField := range DB.NewScope(&User{}).GetStructFields() {
|
||||||
|
for _, foreignKey := range []string{"BillingAddressID", "ShippingAddressId", "CompanyID"} {
|
||||||
|
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||||
|
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, structField := range DB.NewScope(&Email{}).GetStructFields() {
|
||||||
|
for _, foreignKey := range []string{"UserId"} {
|
||||||
|
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||||
|
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, structField := range DB.NewScope(&Post{}).GetStructFields() {
|
||||||
|
for _, foreignKey := range []string{"CategoryId", "MainCategoryId"} {
|
||||||
|
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||||
|
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, structField := range DB.NewScope(&Comment{}).GetStructFields() {
|
||||||
|
for _, foreignKey := range []string{"PostId"} {
|
||||||
|
if structField.Name == foreignKey && !structField.IsForeignKey {
|
||||||
|
t.Errorf(fmt.Sprintf("%v should be foreign key", foreignKey))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testForeignKey(t *testing.T, source interface{}, sourceFieldName string, target interface{}, targetFieldName string) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" {
|
||||||
|
// sqlite does not support ADD CONSTRAINT in ALTER TABLE
|
||||||
|
return
|
||||||
|
}
|
||||||
|
targetScope := DB.NewScope(target)
|
||||||
|
targetTableName := targetScope.TableName()
|
||||||
|
modelScope := DB.NewScope(source)
|
||||||
|
modelField, ok := modelScope.FieldByName(sourceFieldName)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", sourceFieldName))
|
||||||
|
}
|
||||||
|
targetField, ok := targetScope.FieldByName(targetFieldName)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf(fmt.Sprintf("Failed to get field by name: %v", targetFieldName))
|
||||||
|
}
|
||||||
|
dest := fmt.Sprintf("%v(%v)", targetTableName, targetField.DBName)
|
||||||
|
err := DB.Model(source).AddForeignKey(modelField.DBName, dest, "CASCADE", "CASCADE").Error
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(fmt.Sprintf("Failed to create foreign key: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLongForeignKey(t *testing.T) {
|
||||||
|
testForeignKey(t, &NotSoLongTableName{}, "ReallyLongThingID", &ReallyLongTableNameToTestMySQLNameLengthLimit{}, "ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLongForeignKeyWithShortDest(t *testing.T) {
|
||||||
|
testForeignKey(t, &ReallyLongThingThatReferencesShort{}, "ShortID", &Short{}, "ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasManyChildrenWithOneStruct(t *testing.T) {
|
||||||
|
category := Category{
|
||||||
|
Name: "main",
|
||||||
|
Categories: []Category{
|
||||||
|
{Name: "sub1"},
|
||||||
|
{Name: "sub2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(&category)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSkipSaveAssociation(t *testing.T) {
|
||||||
|
type Company struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
CompanyID uint
|
||||||
|
Company Company `gorm:"save_associations:false"`
|
||||||
|
}
|
||||||
|
DB.AutoMigrate(&Company{}, &User{})
|
||||||
|
|
||||||
|
DB.Save(&User{Name: "jinzhu", Company: Company{Name: "skip_save_association"}})
|
||||||
|
|
||||||
|
if !DB.Where("name = ?", "skip_save_association").First(&Company{}).RecordNotFound() {
|
||||||
|
t.Errorf("Company skip_save_association should not been saved")
|
||||||
|
}
|
||||||
|
}
|
237
orm/callback.go
Normal file
237
orm/callback.go
Normal file
|
@ -0,0 +1,237 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultCallback default callbacks defined by gorm
|
||||||
|
var DefaultCallback = &Callback{}
|
||||||
|
|
||||||
|
// Callback is a struct that contains all CURD callbacks
|
||||||
|
// Field `creates` contains callbacks will be call when creating object
|
||||||
|
// Field `updates` contains callbacks will be call when updating object
|
||||||
|
// Field `deletes` contains callbacks will be call when deleting object
|
||||||
|
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
|
||||||
|
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
|
||||||
|
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
|
||||||
|
type Callback struct {
|
||||||
|
creates []*func(scope *Scope)
|
||||||
|
updates []*func(scope *Scope)
|
||||||
|
deletes []*func(scope *Scope)
|
||||||
|
queries []*func(scope *Scope)
|
||||||
|
rowQueries []*func(scope *Scope)
|
||||||
|
processors []*CallbackProcessor
|
||||||
|
}
|
||||||
|
|
||||||
|
// CallbackProcessor contains callback informations
|
||||||
|
type CallbackProcessor struct {
|
||||||
|
name string // current callback's name
|
||||||
|
before string // register current callback before a callback
|
||||||
|
after string // register current callback after a callback
|
||||||
|
replace bool // replace callbacks with same name
|
||||||
|
remove bool // delete callbacks with same name
|
||||||
|
kind string // callback type: create, update, delete, query, row_query
|
||||||
|
processor *func(scope *Scope) // callback handler
|
||||||
|
parent *Callback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Callback) clone() *Callback {
|
||||||
|
return &Callback{
|
||||||
|
creates: c.creates,
|
||||||
|
updates: c.updates,
|
||||||
|
deletes: c.deletes,
|
||||||
|
queries: c.queries,
|
||||||
|
rowQueries: c.rowQueries,
|
||||||
|
processors: c.processors,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create could be used to register callbacks for creating object
|
||||||
|
// db.Callback().Create().After("gorm:create").Register("plugin:run_after_create", func(*Scope) {
|
||||||
|
// // business logic
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// // set error if some thing wrong happened, will rollback the creating
|
||||||
|
// scope.Err(errors.New("error"))
|
||||||
|
// })
|
||||||
|
func (c *Callback) Create() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "create", parent: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update could be used to register callbacks for updating object, refer `Create` for usage
|
||||||
|
func (c *Callback) Update() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "update", parent: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete could be used to register callbacks for deleting object, refer `Create` for usage
|
||||||
|
func (c *Callback) Delete() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "delete", parent: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`...
|
||||||
|
// Refer `Create` for usage
|
||||||
|
func (c *Callback) Query() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "query", parent: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage
|
||||||
|
func (c *Callback) RowQuery() *CallbackProcessor {
|
||||||
|
return &CallbackProcessor{kind: "row_query", parent: c}
|
||||||
|
}
|
||||||
|
|
||||||
|
// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
|
||||||
|
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
|
||||||
|
cp.after = callbackName
|
||||||
|
return cp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Before insert a new callback before callback `callbackName`, refer `Callbacks.Create`
|
||||||
|
func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor {
|
||||||
|
cp.before = callbackName
|
||||||
|
return cp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register a new callback, refer `Callbacks.Create`
|
||||||
|
func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) {
|
||||||
|
cp.name = callbackName
|
||||||
|
cp.processor = &callback
|
||||||
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
|
cp.parent.reorder()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove a registered callback
|
||||||
|
// db.Callback().Create().Remove("gorm:update_time_stamp_when_create")
|
||||||
|
func (cp *CallbackProcessor) Remove(callbackName string) {
|
||||||
|
fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||||
|
cp.name = callbackName
|
||||||
|
cp.remove = true
|
||||||
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
|
cp.parent.reorder()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace a registered callback with new callback
|
||||||
|
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||||
|
// scope.SetColumn("Created", now)
|
||||||
|
// scope.SetColumn("Updated", now)
|
||||||
|
// })
|
||||||
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
|
fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())
|
||||||
|
cp.name = callbackName
|
||||||
|
cp.processor = &callback
|
||||||
|
cp.replace = true
|
||||||
|
cp.parent.processors = append(cp.parent.processors, cp)
|
||||||
|
cp.parent.reorder()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get registered callback
|
||||||
|
// db.Callback().Create().Get("gorm:create")
|
||||||
|
func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) {
|
||||||
|
for _, p := range cp.parent.processors {
|
||||||
|
if p.name == callbackName && p.kind == cp.kind && !cp.remove {
|
||||||
|
return *p.processor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRIndex get right index from string slice
|
||||||
|
func getRIndex(strs []string, str string) int {
|
||||||
|
for i := len(strs) - 1; i >= 0; i-- {
|
||||||
|
if strs[i] == str {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortProcessors sort callback processors based on its before, after, remove, replace
|
||||||
|
func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {
|
||||||
|
var (
|
||||||
|
allNames, sortedNames []string
|
||||||
|
sortCallbackProcessor func(c *CallbackProcessor)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, cp := range cps {
|
||||||
|
// show warning message the callback name already exists
|
||||||
|
if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove {
|
||||||
|
fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum())
|
||||||
|
}
|
||||||
|
allNames = append(allNames, cp.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
sortCallbackProcessor = func(c *CallbackProcessor) {
|
||||||
|
if getRIndex(sortedNames, c.name) == -1 { // if not sorted
|
||||||
|
if c.before != "" { // if defined before callback
|
||||||
|
if index := getRIndex(sortedNames, c.before); index != -1 {
|
||||||
|
// if before callback already sorted, append current callback just after it
|
||||||
|
sortedNames = append(sortedNames[:index], append([]string{c.name}, sortedNames[index:]...)...)
|
||||||
|
} else if index := getRIndex(allNames, c.before); index != -1 {
|
||||||
|
// if before callback exists but haven't sorted, append current callback to last
|
||||||
|
sortedNames = append(sortedNames, c.name)
|
||||||
|
sortCallbackProcessor(cps[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.after != "" { // if defined after callback
|
||||||
|
if index := getRIndex(sortedNames, c.after); index != -1 {
|
||||||
|
// if after callback already sorted, append current callback just before it
|
||||||
|
sortedNames = append(sortedNames[:index+1], append([]string{c.name}, sortedNames[index+1:]...)...)
|
||||||
|
} else if index := getRIndex(allNames, c.after); index != -1 {
|
||||||
|
// if after callback exists but haven't sorted
|
||||||
|
cp := cps[index]
|
||||||
|
// set after callback's before callback to current callback
|
||||||
|
if cp.before == "" {
|
||||||
|
cp.before = c.name
|
||||||
|
}
|
||||||
|
sortCallbackProcessor(cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if current callback haven't been sorted, append it to last
|
||||||
|
if getRIndex(sortedNames, c.name) == -1 {
|
||||||
|
sortedNames = append(sortedNames, c.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cp := range cps {
|
||||||
|
sortCallbackProcessor(cp)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sortedFuncs []*func(scope *Scope)
|
||||||
|
for _, name := range sortedNames {
|
||||||
|
if index := getRIndex(allNames, name); !cps[index].remove {
|
||||||
|
sortedFuncs = append(sortedFuncs, cps[index].processor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sortedFuncs
|
||||||
|
}
|
||||||
|
|
||||||
|
// reorder all registered processors, and reset CURD callbacks
|
||||||
|
func (c *Callback) reorder() {
|
||||||
|
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
|
||||||
|
|
||||||
|
for _, processor := range c.processors {
|
||||||
|
if processor.name != "" {
|
||||||
|
switch processor.kind {
|
||||||
|
case "create":
|
||||||
|
creates = append(creates, processor)
|
||||||
|
case "update":
|
||||||
|
updates = append(updates, processor)
|
||||||
|
case "delete":
|
||||||
|
deletes = append(deletes, processor)
|
||||||
|
case "query":
|
||||||
|
queries = append(queries, processor)
|
||||||
|
case "row_query":
|
||||||
|
rowQueries = append(rowQueries, processor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.creates = sortProcessors(creates)
|
||||||
|
c.updates = sortProcessors(updates)
|
||||||
|
c.deletes = sortProcessors(deletes)
|
||||||
|
c.queries = sortProcessors(queries)
|
||||||
|
c.rowQueries = sortProcessors(rowQueries)
|
||||||
|
}
|
149
orm/callback_create.go
Normal file
149
orm/callback_create.go
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Define callbacks for creating
|
||||||
|
func init() {
|
||||||
|
DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:create", createCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
|
||||||
|
DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
|
||||||
|
func beforeCreateCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeSave")
|
||||||
|
}
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeCreate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
|
||||||
|
func updateTimeStampForCreateCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
now := NowFunc()
|
||||||
|
scope.SetColumn("CreatedAt", now)
|
||||||
|
scope.SetColumn("UpdatedAt", now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createCallback the callback used to insert data into database
|
||||||
|
func createCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
|
var (
|
||||||
|
columns, placeholders []string
|
||||||
|
blankColumnsWithDefaultValue []string
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if scope.changeableField(field) {
|
||||||
|
if field.IsNormal {
|
||||||
|
if field.IsBlank && field.HasDefaultValue {
|
||||||
|
blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
|
||||||
|
scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
|
||||||
|
} else if !field.IsPrimaryKey || !field.IsBlank {
|
||||||
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
|
placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
|
||||||
|
}
|
||||||
|
} else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
|
||||||
|
for _, foreignKey := range field.Relationship.ForeignDBNames {
|
||||||
|
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||||
|
columns = append(columns, scope.Quote(foreignField.DBName))
|
||||||
|
placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
returningColumn = "*"
|
||||||
|
quotedTableName = scope.QuotedTableName()
|
||||||
|
primaryField = scope.PrimaryField()
|
||||||
|
extraOption string
|
||||||
|
)
|
||||||
|
|
||||||
|
if str, ok := scope.Get("gorm:insert_option"); ok {
|
||||||
|
extraOption = fmt.Sprint(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
if primaryField != nil {
|
||||||
|
returningColumn = scope.Quote(primaryField.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||||
|
|
||||||
|
if len(columns) == 0 {
|
||||||
|
scope.Raw(fmt.Sprintf(
|
||||||
|
"INSERT INTO %v DEFAULT VALUES%v%v",
|
||||||
|
quotedTableName,
|
||||||
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
scope.Raw(fmt.Sprintf(
|
||||||
|
"INSERT INTO %v (%v) VALUES (%v)%v%v",
|
||||||
|
scope.QuotedTableName(),
|
||||||
|
strings.Join(columns, ","),
|
||||||
|
strings.Join(placeholders, ","),
|
||||||
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute create sql
|
||||||
|
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
||||||
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
|
// set rows affected count
|
||||||
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
// set primary value to primary field
|
||||||
|
if primaryField != nil && primaryField.IsBlank {
|
||||||
|
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||||
|
scope.Err(primaryField.Set(primaryValue))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||||
|
primaryField.IsBlank = false
|
||||||
|
scope.db.RowsAffected = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
|
||||||
|
func forceReloadAfterCreateCallback(scope *Scope) {
|
||||||
|
if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
|
||||||
|
db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if field.IsPrimaryKey && !field.IsBlank {
|
||||||
|
db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db.Scan(scope.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
|
||||||
|
func afterCreateCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterCreate")
|
||||||
|
}
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterSave")
|
||||||
|
}
|
||||||
|
}
|
53
orm/callback_delete.go
Normal file
53
orm/callback_delete.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// Define callbacks for deleting
|
||||||
|
func init() {
|
||||||
|
DefaultCallback.Delete().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||||
|
DefaultCallback.Delete().Register("gorm:before_delete", beforeDeleteCallback)
|
||||||
|
DefaultCallback.Delete().Register("gorm:delete", deleteCallback)
|
||||||
|
DefaultCallback.Delete().Register("gorm:after_delete", afterDeleteCallback)
|
||||||
|
DefaultCallback.Delete().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// beforeDeleteCallback will invoke `BeforeDelete` method before deleting
|
||||||
|
func beforeDeleteCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeDelete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteCallback used to delete data from database or set deleted_at to current time (when using with soft delete)
|
||||||
|
func deleteCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
var extraOption string
|
||||||
|
if str, ok := scope.Get("gorm:delete_option"); ok {
|
||||||
|
extraOption = fmt.Sprint(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") {
|
||||||
|
scope.Raw(fmt.Sprintf(
|
||||||
|
"UPDATE %v SET deleted_at=%v%v%v",
|
||||||
|
scope.QuotedTableName(),
|
||||||
|
scope.AddToVars(NowFunc()),
|
||||||
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
)).Exec()
|
||||||
|
} else {
|
||||||
|
scope.Raw(fmt.Sprintf(
|
||||||
|
"DELETE FROM %v%v%v",
|
||||||
|
scope.QuotedTableName(),
|
||||||
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
)).Exec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// afterDeleteCallback will invoke `AfterDelete` method after deleting
|
||||||
|
func afterDeleteCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterDelete")
|
||||||
|
}
|
||||||
|
}
|
93
orm/callback_query.go
Normal file
93
orm/callback_query.go
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Define callbacks for querying
|
||||||
|
func init() {
|
||||||
|
DefaultCallback.Query().Register("gorm:query", queryCallback)
|
||||||
|
DefaultCallback.Query().Register("gorm:preload", preloadCallback)
|
||||||
|
DefaultCallback.Query().Register("gorm:after_query", afterQueryCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// queryCallback used to query data from database
|
||||||
|
func queryCallback(scope *Scope) {
|
||||||
|
defer scope.trace(NowFunc())
|
||||||
|
|
||||||
|
var (
|
||||||
|
isSlice, isPtr bool
|
||||||
|
resultType reflect.Type
|
||||||
|
results = scope.IndirectValue()
|
||||||
|
)
|
||||||
|
|
||||||
|
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||||
|
if primaryField := scope.PrimaryField(); primaryField != nil {
|
||||||
|
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryField.DBName), orderBy))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := scope.Get("gorm:query_destination"); ok {
|
||||||
|
results = reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind := results.Kind(); kind == reflect.Slice {
|
||||||
|
isSlice = true
|
||||||
|
resultType = results.Type().Elem()
|
||||||
|
results.Set(reflect.MakeSlice(results.Type(), 0, 0))
|
||||||
|
|
||||||
|
if resultType.Kind() == reflect.Ptr {
|
||||||
|
isPtr = true
|
||||||
|
resultType = resultType.Elem()
|
||||||
|
}
|
||||||
|
} else if kind != reflect.Struct {
|
||||||
|
scope.Err(errors.New("unsupported destination, should be slice or struct"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.prepareQuerySQL()
|
||||||
|
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.db.RowsAffected = 0
|
||||||
|
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||||
|
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
columns, _ := rows.Columns()
|
||||||
|
for rows.Next() {
|
||||||
|
scope.db.RowsAffected++
|
||||||
|
|
||||||
|
elem := results
|
||||||
|
if isSlice {
|
||||||
|
elem = reflect.New(resultType).Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())
|
||||||
|
|
||||||
|
if isSlice {
|
||||||
|
if isPtr {
|
||||||
|
results.Set(reflect.Append(results, elem.Addr()))
|
||||||
|
} else {
|
||||||
|
results.Set(reflect.Append(results, elem))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if scope.db.RowsAffected == 0 && !isSlice {
|
||||||
|
scope.Err(ErrRecordNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// afterQueryCallback will invoke `AfterFind` method after querying
|
||||||
|
func afterQueryCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterFind")
|
||||||
|
}
|
||||||
|
}
|
346
orm/callback_query_preload.go
Normal file
346
orm/callback_query_preload.go
Normal file
|
@ -0,0 +1,346 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// preloadCallback used to preload associations
|
||||||
|
func preloadCallback(scope *Scope) {
|
||||||
|
if scope.Search.preload == nil || scope.HasError() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
preloadedMap = map[string]bool{}
|
||||||
|
fields = scope.Fields()
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, preload := range scope.Search.preload {
|
||||||
|
var (
|
||||||
|
preloadFields = strings.Split(preload.schema, ".")
|
||||||
|
currentScope = scope
|
||||||
|
currentFields = fields
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, preloadField := range preloadFields {
|
||||||
|
var currentPreloadConditions []interface{}
|
||||||
|
|
||||||
|
if currentScope == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// if not preloaded
|
||||||
|
if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
|
||||||
|
|
||||||
|
// assign search conditions to last preload
|
||||||
|
if idx == len(preloadFields)-1 {
|
||||||
|
currentPreloadConditions = preload.conditions
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range currentFields {
|
||||||
|
if field.Name != preloadField || field.Relationship == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch field.Relationship.Kind {
|
||||||
|
case "has_one":
|
||||||
|
currentScope.handleHasOnePreload(field, currentPreloadConditions)
|
||||||
|
case "has_many":
|
||||||
|
currentScope.handleHasManyPreload(field, currentPreloadConditions)
|
||||||
|
case "belongs_to":
|
||||||
|
currentScope.handleBelongsToPreload(field, currentPreloadConditions)
|
||||||
|
case "many_to_many":
|
||||||
|
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
|
||||||
|
default:
|
||||||
|
scope.Err(errors.New("unsupported relation"))
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadedMap[preloadKey] = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if !preloadedMap[preloadKey] {
|
||||||
|
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// preload next level
|
||||||
|
if idx < len(preloadFields)-1 {
|
||||||
|
currentScope = currentScope.getColumnAsScope(preloadField)
|
||||||
|
if currentScope != nil {
|
||||||
|
currentFields = currentScope.Fields()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
|
||||||
|
var (
|
||||||
|
preloadDB = scope.NewDB()
|
||||||
|
preloadConditions []interface{}
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, condition := range conditions {
|
||||||
|
if scopes, ok := condition.(func(*DB) *DB); ok {
|
||||||
|
preloadDB = scopes(preloadDB)
|
||||||
|
} else {
|
||||||
|
preloadConditions = append(preloadConditions, condition)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return preloadDB, preloadConditions
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHasOnePreload used to preload has one associations
|
||||||
|
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||||
|
relation := field.Relationship
|
||||||
|
|
||||||
|
// get relations's primary keys
|
||||||
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||||
|
if len(primaryKeys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// preload conditions
|
||||||
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||||
|
|
||||||
|
// find relations
|
||||||
|
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
||||||
|
values := toQueryValues(primaryKeys)
|
||||||
|
if relation.PolymorphicType != "" {
|
||||||
|
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||||
|
values = append(values, relation.PolymorphicValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := makeSlice(field.Struct.Type)
|
||||||
|
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
||||||
|
|
||||||
|
// assign find results
|
||||||
|
var (
|
||||||
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
)
|
||||||
|
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||||
|
if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) {
|
||||||
|
indirectValue.FieldByName(field.Name).Set(result)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
scope.Err(field.Set(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHasManyPreload used to preload has many associations
|
||||||
|
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||||
|
relation := field.Relationship
|
||||||
|
|
||||||
|
// get relations's primary keys
|
||||||
|
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
|
||||||
|
if len(primaryKeys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// preload conditions
|
||||||
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||||
|
|
||||||
|
// find relations
|
||||||
|
query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
|
||||||
|
values := toQueryValues(primaryKeys)
|
||||||
|
if relation.PolymorphicType != "" {
|
||||||
|
query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
|
||||||
|
values = append(values, relation.PolymorphicValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := makeSlice(field.Struct.Type)
|
||||||
|
scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
|
||||||
|
|
||||||
|
// assign find results
|
||||||
|
var (
|
||||||
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
)
|
||||||
|
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
preloadMap := make(map[string][]reflect.Value)
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
|
||||||
|
preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
|
objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
|
||||||
|
f := object.FieldByName(field.Name)
|
||||||
|
if results, ok := preloadMap[toString(objectRealValue)]; ok {
|
||||||
|
f.Set(reflect.Append(f, results...))
|
||||||
|
} else {
|
||||||
|
f.Set(reflect.MakeSlice(f.Type(), 0, 0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.Err(field.Set(resultsValue))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleBelongsToPreload used to preload belongs to associations
|
||||||
|
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||||
|
relation := field.Relationship
|
||||||
|
|
||||||
|
// preload conditions
|
||||||
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||||
|
|
||||||
|
// get relations's primary keys
|
||||||
|
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
|
||||||
|
if len(primaryKeys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// find relations
|
||||||
|
results := makeSlice(field.Struct.Type)
|
||||||
|
scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
|
||||||
|
|
||||||
|
// assign find results
|
||||||
|
var (
|
||||||
|
resultsValue = indirect(reflect.ValueOf(results))
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < resultsValue.Len(); i++ {
|
||||||
|
result := resultsValue.Index(i)
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
value := getValueFromFields(result, relation.AssociationForeignFieldNames)
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
|
if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
|
||||||
|
object.FieldByName(field.Name).Set(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scope.Err(field.Set(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleManyToManyPreload used to preload many to many associations
|
||||||
|
func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
|
||||||
|
var (
|
||||||
|
relation = field.Relationship
|
||||||
|
joinTableHandler = relation.JoinTableHandler
|
||||||
|
fieldType = field.Struct.Type.Elem()
|
||||||
|
foreignKeyValue interface{}
|
||||||
|
foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type()
|
||||||
|
linkHash = map[string][]reflect.Value{}
|
||||||
|
isPtr bool
|
||||||
|
)
|
||||||
|
|
||||||
|
if fieldType.Kind() == reflect.Ptr {
|
||||||
|
isPtr = true
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
var sourceKeys = []string{}
|
||||||
|
for _, key := range joinTableHandler.SourceForeignKeys() {
|
||||||
|
sourceKeys = append(sourceKeys, key.DBName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// preload conditions
|
||||||
|
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
|
||||||
|
|
||||||
|
// generate query with join table
|
||||||
|
newScope := scope.New(reflect.New(fieldType).Interface())
|
||||||
|
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value).Select("*")
|
||||||
|
preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
|
||||||
|
|
||||||
|
// preload inline conditions
|
||||||
|
if len(preloadConditions) > 0 {
|
||||||
|
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := preloadDB.Rows()
|
||||||
|
|
||||||
|
if scope.Err(err) != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
columns, _ := rows.Columns()
|
||||||
|
for rows.Next() {
|
||||||
|
var (
|
||||||
|
elem = reflect.New(fieldType).Elem()
|
||||||
|
fields = scope.New(elem.Addr().Interface()).Fields()
|
||||||
|
)
|
||||||
|
|
||||||
|
// register foreign keys in join tables
|
||||||
|
var joinTableFields []*Field
|
||||||
|
for _, sourceKey := range sourceKeys {
|
||||||
|
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.scan(rows, columns, append(fields, joinTableFields...))
|
||||||
|
|
||||||
|
var foreignKeys = make([]interface{}, len(sourceKeys))
|
||||||
|
// generate hashed forkey keys in join table
|
||||||
|
for idx, joinTableField := range joinTableFields {
|
||||||
|
if !joinTableField.Field.IsNil() {
|
||||||
|
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hashedSourceKeys := toString(foreignKeys)
|
||||||
|
|
||||||
|
if isPtr {
|
||||||
|
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
|
||||||
|
} else {
|
||||||
|
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign find results
|
||||||
|
var (
|
||||||
|
indirectScopeValue = scope.IndirectValue()
|
||||||
|
fieldsSourceMap = map[string][]reflect.Value{}
|
||||||
|
foreignFieldNames = []string{}
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, dbName := range relation.ForeignFieldNames {
|
||||||
|
if field, ok := scope.FieldByName(dbName); ok {
|
||||||
|
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if indirectScopeValue.Kind() == reflect.Slice {
|
||||||
|
for j := 0; j < indirectScopeValue.Len(); j++ {
|
||||||
|
object := indirect(indirectScopeValue.Index(j))
|
||||||
|
key := toString(getValueFromFields(object, foreignFieldNames))
|
||||||
|
fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
|
||||||
|
}
|
||||||
|
} else if indirectScopeValue.IsValid() {
|
||||||
|
key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
|
||||||
|
fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
|
||||||
|
}
|
||||||
|
for source, link := range linkHash {
|
||||||
|
for i, field := range fieldsSourceMap[source] {
|
||||||
|
//If not 0 this means Value is a pointer and we already added preloaded models to it
|
||||||
|
if fieldsSourceMap[source][i].Len() != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
99
orm/callback_save.go
Normal file
99
orm/callback_save.go
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
func beginTransactionCallback(scope *Scope) {
|
||||||
|
scope.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
|
func commitOrRollbackTransactionCallback(scope *Scope) {
|
||||||
|
scope.CommitOrRollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveFieldAsAssociation(scope *Scope, field *Field) (bool, *Relationship) {
|
||||||
|
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||||
|
if value, ok := field.TagSettings["SAVE_ASSOCIATIONS"]; !ok || (value != "false" && value != "skip") {
|
||||||
|
if relationship := field.Relationship; relationship != nil {
|
||||||
|
return true, relationship
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveBeforeAssociationsCallback(scope *Scope) {
|
||||||
|
if !scope.shouldSaveAssociations() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if ok, relationship := saveFieldAsAssociation(scope, field); ok && relationship.Kind == "belongs_to" {
|
||||||
|
fieldValue := field.Field.Addr().Interface()
|
||||||
|
scope.Err(scope.NewDB().Save(fieldValue).Error)
|
||||||
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
|
// set value's foreign key
|
||||||
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
|
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||||
|
if foreignField, ok := scope.New(fieldValue).FieldByName(associationForeignName); ok {
|
||||||
|
scope.Err(scope.SetColumn(fieldName, foreignField.Field.Interface()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveAfterAssociationsCallback(scope *Scope) {
|
||||||
|
if !scope.shouldSaveAssociations() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if ok, relationship := saveFieldAsAssociation(scope, field); ok &&
|
||||||
|
(relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") {
|
||||||
|
value := field.Field
|
||||||
|
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
for i := 0; i < value.Len(); i++ {
|
||||||
|
newDB := scope.NewDB()
|
||||||
|
elem := value.Index(i).Addr().Interface()
|
||||||
|
newScope := newDB.NewScope(elem)
|
||||||
|
|
||||||
|
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
|
||||||
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
|
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||||
|
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||||
|
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.PolymorphicType != "" {
|
||||||
|
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||||
|
}
|
||||||
|
|
||||||
|
scope.Err(newDB.Save(elem).Error)
|
||||||
|
|
||||||
|
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||||
|
scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
elem := value.Addr().Interface()
|
||||||
|
newScope := scope.New(elem)
|
||||||
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
|
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||||
|
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||||
|
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relationship.PolymorphicType != "" {
|
||||||
|
scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue))
|
||||||
|
}
|
||||||
|
scope.Err(scope.NewDB().Save(elem).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
112
orm/callback_system_test.go
Normal file
112
orm/callback_system_test.go
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
|
||||||
|
var names []string
|
||||||
|
for _, f := range funcs {
|
||||||
|
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
|
||||||
|
names = append(names, fnames[len(fnames)-1])
|
||||||
|
}
|
||||||
|
return reflect.DeepEqual(names, fnames)
|
||||||
|
}
|
||||||
|
|
||||||
|
func create(s *Scope) {}
|
||||||
|
func beforeCreate1(s *Scope) {}
|
||||||
|
func beforeCreate2(s *Scope) {}
|
||||||
|
func afterCreate1(s *Scope) {}
|
||||||
|
func afterCreate2(s *Scope) {}
|
||||||
|
|
||||||
|
func TestRegisterCallback(t *testing.T) {
|
||||||
|
var callback = &Callback{}
|
||||||
|
|
||||||
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
|
callback.Create().Register("before_create2", beforeCreate2)
|
||||||
|
callback.Create().Register("create", create)
|
||||||
|
callback.Create().Register("after_create1", afterCreate1)
|
||||||
|
callback.Create().Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
|
if !equalFuncs(callback.creates, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||||
|
t.Errorf("register callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterCallbackWithOrder(t *testing.T) {
|
||||||
|
var callback1 = &Callback{}
|
||||||
|
callback1.Create().Register("before_create1", beforeCreate1)
|
||||||
|
callback1.Create().Register("create", create)
|
||||||
|
callback1.Create().Register("after_create1", afterCreate1)
|
||||||
|
callback1.Create().Before("after_create1").Register("after_create2", afterCreate2)
|
||||||
|
if !equalFuncs(callback1.creates, []string{"beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||||
|
t.Errorf("register callback with order")
|
||||||
|
}
|
||||||
|
|
||||||
|
var callback2 = &Callback{}
|
||||||
|
|
||||||
|
callback2.Update().Register("create", create)
|
||||||
|
callback2.Update().Before("create").Register("before_create1", beforeCreate1)
|
||||||
|
callback2.Update().After("after_create2").Register("after_create1", afterCreate1)
|
||||||
|
callback2.Update().Before("before_create1").Register("before_create2", beforeCreate2)
|
||||||
|
callback2.Update().Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
|
if !equalFuncs(callback2.updates, []string{"beforeCreate2", "beforeCreate1", "create", "afterCreate2", "afterCreate1"}) {
|
||||||
|
t.Errorf("register callback with order")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterCallbackWithComplexOrder(t *testing.T) {
|
||||||
|
var callback1 = &Callback{}
|
||||||
|
|
||||||
|
callback1.Query().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
|
callback1.Query().Register("before_create1", beforeCreate1)
|
||||||
|
callback1.Query().Register("after_create1", afterCreate1)
|
||||||
|
|
||||||
|
if !equalFuncs(callback1.queries, []string{"beforeCreate1", "create", "afterCreate1"}) {
|
||||||
|
t.Errorf("register callback with order")
|
||||||
|
}
|
||||||
|
|
||||||
|
var callback2 = &Callback{}
|
||||||
|
|
||||||
|
callback2.Delete().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
|
callback2.Delete().Before("create").Register("before_create1", beforeCreate1)
|
||||||
|
callback2.Delete().After("before_create1").Register("before_create2", beforeCreate2)
|
||||||
|
callback2.Delete().Register("after_create1", afterCreate1)
|
||||||
|
callback2.Delete().After("after_create1").Register("after_create2", afterCreate2)
|
||||||
|
|
||||||
|
if !equalFuncs(callback2.deletes, []string{"beforeCreate1", "beforeCreate2", "create", "afterCreate1", "afterCreate2"}) {
|
||||||
|
t.Errorf("register callback with order")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func replaceCreate(s *Scope) {}
|
||||||
|
|
||||||
|
func TestReplaceCallback(t *testing.T) {
|
||||||
|
var callback = &Callback{}
|
||||||
|
|
||||||
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
|
callback.Create().Register("after_create1", afterCreate1)
|
||||||
|
callback.Create().Replace("create", replaceCreate)
|
||||||
|
|
||||||
|
if !equalFuncs(callback.creates, []string{"beforeCreate1", "replaceCreate", "afterCreate1"}) {
|
||||||
|
t.Errorf("replace callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveCallback(t *testing.T) {
|
||||||
|
var callback = &Callback{}
|
||||||
|
|
||||||
|
callback.Create().Before("after_create1").After("before_create1").Register("create", create)
|
||||||
|
callback.Create().Register("before_create1", beforeCreate1)
|
||||||
|
callback.Create().Register("after_create1", afterCreate1)
|
||||||
|
callback.Create().Remove("create")
|
||||||
|
|
||||||
|
if !equalFuncs(callback.creates, []string{"beforeCreate1", "afterCreate1"}) {
|
||||||
|
t.Errorf("remove callback")
|
||||||
|
}
|
||||||
|
}
|
104
orm/callback_update.go
Normal file
104
orm/callback_update.go
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Define callbacks for updating
|
||||||
|
func init() {
|
||||||
|
DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:update", updateCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
|
||||||
|
DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// assignUpdatingAttributesCallback assign updating attributes to model
|
||||||
|
func assignUpdatingAttributesCallback(scope *Scope) {
|
||||||
|
if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
|
||||||
|
if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
|
||||||
|
scope.InstanceSet("gorm:update_attrs", updateMaps)
|
||||||
|
} else {
|
||||||
|
scope.SkipLeft()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
|
||||||
|
func beforeUpdateCallback(scope *Scope) {
|
||||||
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeSave")
|
||||||
|
}
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("BeforeUpdate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
|
||||||
|
func updateTimeStampForUpdateCallback(scope *Scope) {
|
||||||
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
|
scope.SetColumn("UpdatedAt", NowFunc())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateCallback the callback used to update data to database
|
||||||
|
func updateCallback(scope *Scope) {
|
||||||
|
if !scope.HasError() {
|
||||||
|
var sqls []string
|
||||||
|
|
||||||
|
if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
|
||||||
|
for column, value := range updateAttrs.(map[string]interface{}) {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, field := range scope.Fields() {
|
||||||
|
if scope.changeableField(field) {
|
||||||
|
if !field.IsPrimaryKey && field.IsNormal {
|
||||||
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
|
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
|
for _, foreignKey := range relationship.ForeignDBNames {
|
||||||
|
if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
|
||||||
|
sqls = append(sqls,
|
||||||
|
fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var extraOption string
|
||||||
|
if str, ok := scope.Get("gorm:update_option"); ok {
|
||||||
|
extraOption = fmt.Sprint(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sqls) > 0 {
|
||||||
|
scope.Raw(fmt.Sprintf(
|
||||||
|
"UPDATE %v SET %v%v%v",
|
||||||
|
scope.QuotedTableName(),
|
||||||
|
strings.Join(sqls, ", "),
|
||||||
|
addExtraSpaceIfExist(scope.CombinedConditionSql()),
|
||||||
|
addExtraSpaceIfExist(extraOption),
|
||||||
|
)).Exec()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
|
||||||
|
func afterUpdateCallback(scope *Scope) {
|
||||||
|
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterUpdate")
|
||||||
|
}
|
||||||
|
if !scope.HasError() {
|
||||||
|
scope.CallMethod("AfterSave")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
177
orm/callbacks_test.go
Normal file
177
orm/callbacks_test.go
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Product) BeforeCreate() (err error) {
|
||||||
|
if s.Code == "Invalid" {
|
||||||
|
err = errors.New("invalid product")
|
||||||
|
}
|
||||||
|
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) BeforeUpdate() (err error) {
|
||||||
|
if s.Code == "dont_update" {
|
||||||
|
err = errors.New("can't update")
|
||||||
|
}
|
||||||
|
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) BeforeSave() (err error) {
|
||||||
|
if s.Code == "dont_save" {
|
||||||
|
err = errors.New("can't save")
|
||||||
|
}
|
||||||
|
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterFind() {
|
||||||
|
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterCreate(tx *gorm.DB) {
|
||||||
|
tx.Model(s).UpdateColumn(Product{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterUpdate() {
|
||||||
|
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterSave() (err error) {
|
||||||
|
if s.Code == "after_save_error" {
|
||||||
|
err = errors.New("can't save")
|
||||||
|
}
|
||||||
|
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) BeforeDelete() (err error) {
|
||||||
|
if s.Code == "dont_delete" {
|
||||||
|
err = errors.New("can't delete")
|
||||||
|
}
|
||||||
|
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) AfterDelete() (err error) {
|
||||||
|
if s.Code == "after_delete_error" {
|
||||||
|
err = errors.New("can't delete")
|
||||||
|
}
|
||||||
|
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Product) GetCallTimes() []int64 {
|
||||||
|
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunCallbacks(t *testing.T) {
|
||||||
|
p := Product{Code: "unique_code", Price: 100}
|
||||||
|
DB.Save(&p)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
|
||||||
|
t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where("Code = ?", "unique_code").First(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
|
||||||
|
t.Errorf("After callbacks values are not saved, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
p.Price = 200
|
||||||
|
DB.Save(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
|
||||||
|
t.Errorf("After update callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
var products []Product
|
||||||
|
DB.Find(&products, "code = ?", "unique_code")
|
||||||
|
if products[0].AfterFindCallTimes != 2 {
|
||||||
|
t.Errorf("AfterFind callbacks should work with slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where("Code = ?", "unique_code").First(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
|
||||||
|
t.Errorf("After update callbacks values are not saved, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Delete(&p)
|
||||||
|
if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
|
||||||
|
t.Errorf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("Code = ?", "unique_code").First(&p).Error == nil {
|
||||||
|
t.Errorf("Can't find a deleted record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbacksWithErrors(t *testing.T) {
|
||||||
|
p := Product{Code: "Invalid", Price: 100}
|
||||||
|
if DB.Save(&p).Error == nil {
|
||||||
|
t.Errorf("An error from before create callbacks happened when create with invalid value")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil {
|
||||||
|
t.Errorf("Should not save record that have errors")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil {
|
||||||
|
t.Errorf("An error from after create callbacks happened when create with invalid value")
|
||||||
|
}
|
||||||
|
|
||||||
|
p2 := Product{Code: "update_callback", Price: 100}
|
||||||
|
DB.Save(&p2)
|
||||||
|
|
||||||
|
p2.Code = "dont_update"
|
||||||
|
if DB.Save(&p2).Error == nil {
|
||||||
|
t.Errorf("An error from before update callbacks happened when update with invalid value")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil {
|
||||||
|
t.Errorf("Record Should not be updated due to errors happened in before update callback")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil {
|
||||||
|
t.Errorf("Record Should not be updated due to errors happened in before update callback")
|
||||||
|
}
|
||||||
|
|
||||||
|
p2.Code = "dont_save"
|
||||||
|
if DB.Save(&p2).Error == nil {
|
||||||
|
t.Errorf("An error from before save callbacks happened when update with invalid value")
|
||||||
|
}
|
||||||
|
|
||||||
|
p3 := Product{Code: "dont_delete", Price: 100}
|
||||||
|
DB.Save(&p3)
|
||||||
|
if DB.Delete(&p3).Error == nil {
|
||||||
|
t.Errorf("An error from before delete callbacks happened when delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
|
||||||
|
t.Errorf("An error from before delete callbacks happened")
|
||||||
|
}
|
||||||
|
|
||||||
|
p4 := Product{Code: "after_save_error", Price: 100}
|
||||||
|
DB.Save(&p4)
|
||||||
|
if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil {
|
||||||
|
t.Errorf("Record should be reverted if get an error in after save callback")
|
||||||
|
}
|
||||||
|
|
||||||
|
p5 := Product{Code: "after_delete_error", Price: 100}
|
||||||
|
DB.Save(&p5)
|
||||||
|
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||||
|
t.Errorf("Record should be found")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Delete(&p5)
|
||||||
|
if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil {
|
||||||
|
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
|
||||||
|
}
|
||||||
|
}
|
180
orm/create_test.go
Normal file
180
orm/create_test.go
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreate(t *testing.T) {
|
||||||
|
float := 35.03554004971999
|
||||||
|
now := time.Now()
|
||||||
|
user := User{Name: "CreateUser", Age: 18, Birthday: &now, UserNum: Num(111), PasswordHash: []byte{'f', 'a', 'k', '4'}, Latitude: float}
|
||||||
|
|
||||||
|
if !DB.NewRecord(user) || !DB.NewRecord(&user) {
|
||||||
|
t.Error("User should be new record before create")
|
||||||
|
}
|
||||||
|
|
||||||
|
if count := DB.Save(&user).RowsAffected; count != 1 {
|
||||||
|
t.Error("There should be one record be affected when create record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewRecord(user) || DB.NewRecord(&user) {
|
||||||
|
t.Error("User should not new record after save")
|
||||||
|
}
|
||||||
|
|
||||||
|
var newUser User
|
||||||
|
DB.First(&newUser, user.Id)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(newUser.PasswordHash, []byte{'f', 'a', 'k', '4'}) {
|
||||||
|
t.Errorf("User's PasswordHash should be saved ([]byte)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if newUser.Age != 18 {
|
||||||
|
t.Errorf("User's Age should be saved (int)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if newUser.UserNum != Num(111) {
|
||||||
|
t.Errorf("User's UserNum should be saved (custom type)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if newUser.Latitude != float {
|
||||||
|
t.Errorf("Float64 should not be changed after save")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.CreatedAt.IsZero() {
|
||||||
|
t.Errorf("Should have created_at after create")
|
||||||
|
}
|
||||||
|
|
||||||
|
if newUser.CreatedAt.IsZero() {
|
||||||
|
t.Errorf("Should have created_at after create")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(user).Update("name", "create_user_new_name")
|
||||||
|
DB.First(&user, user.Id)
|
||||||
|
if user.CreatedAt.Format(time.RFC3339Nano) != newUser.CreatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("CreatedAt should not be changed after update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateWithAutoIncrement(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
|
||||||
|
t.Skip("Skipping this because only postgres properly support auto_increment on a non-primary_key column")
|
||||||
|
}
|
||||||
|
user1 := User{}
|
||||||
|
user2 := User{}
|
||||||
|
|
||||||
|
DB.Create(&user1)
|
||||||
|
DB.Create(&user2)
|
||||||
|
|
||||||
|
if user2.Sequence-user1.Sequence != 1 {
|
||||||
|
t.Errorf("Auto increment should apply on Sequence")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateWithNoGORMPrimayKey(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" {
|
||||||
|
t.Skip("Skipping this because MSSQL will return identity only if the table has an Id column")
|
||||||
|
}
|
||||||
|
|
||||||
|
jt := JoinTable{From: 1, To: 2}
|
||||||
|
err := DB.Create(&jt).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
||||||
|
animal := Animal{Name: "Ferdinand"}
|
||||||
|
if DB.Save(&animal).Error != nil {
|
||||||
|
t.Errorf("No error should happen when create a record without std primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if animal.Counter == 0 {
|
||||||
|
t.Errorf("No std primary key should be filled value after create")
|
||||||
|
}
|
||||||
|
|
||||||
|
if animal.Name != "Ferdinand" {
|
||||||
|
t.Errorf("Default value should be overrided")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test create with default value not overrided
|
||||||
|
an := Animal{From: "nerdz"}
|
||||||
|
|
||||||
|
if DB.Save(&an).Error != nil {
|
||||||
|
t.Errorf("No error should happen when create an record without std primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// We must fetch the value again, to have the default fields updated
|
||||||
|
// (We can't do this in the update statements, since sql default can be expressions
|
||||||
|
// And be different from the fields' type (eg. a time.Time fields has a default value of "now()"
|
||||||
|
DB.Model(Animal{}).Where(&Animal{Counter: an.Counter}).First(&an)
|
||||||
|
|
||||||
|
if an.Name != "galeone" {
|
||||||
|
t.Errorf("Default value should fill the field. But got %v", an.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnonymousScanner(t *testing.T) {
|
||||||
|
user := User{Name: "anonymous_scanner", Role: Role{Name: "admin"}}
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
DB.First(&user2, "name = ?", "anonymous_scanner")
|
||||||
|
if user2.Role.Name != "admin" {
|
||||||
|
t.Errorf("Should be able to get anonymous scanner")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user2.IsAdmin() {
|
||||||
|
t.Errorf("Should be able to get anonymous scanner")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnonymousField(t *testing.T) {
|
||||||
|
user := User{Name: "anonymous_field", Company: Company{Name: "company"}}
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
var user2 User
|
||||||
|
DB.First(&user2, "name = ?", "anonymous_field")
|
||||||
|
DB.Model(&user2).Related(&user2.Company)
|
||||||
|
if user2.Company.Name != "company" {
|
||||||
|
t.Errorf("Should be able to get anonymous field")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithCreate(t *testing.T) {
|
||||||
|
user := getPreparedUser("select_user", "select_with_create")
|
||||||
|
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
|
||||||
|
|
||||||
|
var queryuser User
|
||||||
|
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
|
||||||
|
|
||||||
|
if queryuser.Name != user.Name || queryuser.Age == user.Age {
|
||||||
|
t.Errorf("Should only create users with name column")
|
||||||
|
}
|
||||||
|
|
||||||
|
if queryuser.BillingAddressID.Int64 == 0 || queryuser.ShippingAddressId != 0 ||
|
||||||
|
queryuser.CreditCard.ID == 0 || len(queryuser.Emails) == 0 {
|
||||||
|
t.Errorf("Should only create selected relationships")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOmitWithCreate(t *testing.T) {
|
||||||
|
user := getPreparedUser("omit_user", "omit_with_create")
|
||||||
|
DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Create(user)
|
||||||
|
|
||||||
|
var queryuser User
|
||||||
|
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryuser, user.Id)
|
||||||
|
|
||||||
|
if queryuser.Name == user.Name || queryuser.Age != user.Age {
|
||||||
|
t.Errorf("Should only create users with age column")
|
||||||
|
}
|
||||||
|
|
||||||
|
if queryuser.BillingAddressID.Int64 != 0 || queryuser.ShippingAddressId == 0 ||
|
||||||
|
queryuser.CreditCard.ID != 0 || len(queryuser.Emails) != 0 {
|
||||||
|
t.Errorf("Should not create omited relationships")
|
||||||
|
}
|
||||||
|
}
|
281
orm/customize_column_test.go
Normal file
281
orm/customize_column_test.go
Normal file
|
@ -0,0 +1,281 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CustomizeColumn struct {
|
||||||
|
ID int64 `gorm:"column:mapped_id; primary_key:yes"`
|
||||||
|
Name string `gorm:"column:mapped_name"`
|
||||||
|
Date *time.Time `gorm:"column:mapped_time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure an ignored field does not interfere with another field's custom
|
||||||
|
// column name that matches the ignored field.
|
||||||
|
type CustomColumnAndIgnoredFieldClash struct {
|
||||||
|
Body string `sql:"-"`
|
||||||
|
RawBody string `gorm:"column:body"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomizeColumn(t *testing.T) {
|
||||||
|
col := "mapped_name"
|
||||||
|
DB.DropTable(&CustomizeColumn{})
|
||||||
|
DB.AutoMigrate(&CustomizeColumn{})
|
||||||
|
|
||||||
|
scope := DB.NewScope(&CustomizeColumn{})
|
||||||
|
if !scope.Dialect().HasColumn(scope.TableName(), col) {
|
||||||
|
t.Errorf("CustomizeColumn should have column %s", col)
|
||||||
|
}
|
||||||
|
|
||||||
|
col = "mapped_id"
|
||||||
|
if scope.PrimaryKey() != col {
|
||||||
|
t.Errorf("CustomizeColumn should have primary key %s, but got %q", col, scope.PrimaryKey())
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "foo"
|
||||||
|
now := time.Now()
|
||||||
|
cc := CustomizeColumn{ID: 666, Name: expected, Date: &now}
|
||||||
|
|
||||||
|
if count := DB.Create(&cc).RowsAffected; count != 1 {
|
||||||
|
t.Error("There should be one record be affected when create record")
|
||||||
|
}
|
||||||
|
|
||||||
|
var cc1 CustomizeColumn
|
||||||
|
DB.First(&cc1, 666)
|
||||||
|
|
||||||
|
if cc1.Name != expected {
|
||||||
|
t.Errorf("Failed to query CustomizeColumn")
|
||||||
|
}
|
||||||
|
|
||||||
|
cc.Name = "bar"
|
||||||
|
DB.Save(&cc)
|
||||||
|
|
||||||
|
var cc2 CustomizeColumn
|
||||||
|
DB.First(&cc2, 666)
|
||||||
|
if cc2.Name != "bar" {
|
||||||
|
t.Errorf("Failed to query CustomizeColumn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCustomColumnAndIgnoredFieldClash(t *testing.T) {
|
||||||
|
DB.DropTable(&CustomColumnAndIgnoredFieldClash{})
|
||||||
|
if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}).Error; err != nil {
|
||||||
|
t.Errorf("Should not raise error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CustomizePerson struct {
|
||||||
|
IdPerson string `gorm:"column:idPerson;primary_key:true"`
|
||||||
|
Accounts []CustomizeAccount `gorm:"many2many:PersonAccount;associationforeignkey:idAccount;foreignkey:idPerson"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CustomizeAccount struct {
|
||||||
|
IdAccount string `gorm:"column:idAccount;primary_key:true"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManyToManyWithCustomizedColumn(t *testing.T) {
|
||||||
|
DB.DropTable(&CustomizePerson{}, &CustomizeAccount{}, "PersonAccount")
|
||||||
|
DB.AutoMigrate(&CustomizePerson{}, &CustomizeAccount{})
|
||||||
|
|
||||||
|
account := CustomizeAccount{IdAccount: "account", Name: "id1"}
|
||||||
|
person := CustomizePerson{
|
||||||
|
IdPerson: "person",
|
||||||
|
Accounts: []CustomizeAccount{account},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&account).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&person).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var person1 CustomizePerson
|
||||||
|
scope := DB.NewScope(nil)
|
||||||
|
if err := DB.Preload("Accounts").First(&person1, scope.Quote("idPerson")+" = ?", person.IdPerson).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when preloading customized column many2many relations, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(person1.Accounts) != 1 || person1.Accounts[0].IdAccount != "account" {
|
||||||
|
t.Errorf("should preload correct accounts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CustomizeUser struct {
|
||||||
|
gorm.Model
|
||||||
|
Email string `sql:"column:email_address"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CustomizeInvitation struct {
|
||||||
|
gorm.Model
|
||||||
|
Address string `sql:"column:invitation"`
|
||||||
|
Person *CustomizeUser `gorm:"foreignkey:Email;associationforeignkey:invitation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOneToOneWithCustomizedColumn(t *testing.T) {
|
||||||
|
DB.DropTable(&CustomizeUser{}, &CustomizeInvitation{})
|
||||||
|
DB.AutoMigrate(&CustomizeUser{}, &CustomizeInvitation{})
|
||||||
|
|
||||||
|
user := CustomizeUser{
|
||||||
|
Email: "hello@example.com",
|
||||||
|
}
|
||||||
|
invitation := CustomizeInvitation{
|
||||||
|
Address: "hello@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&user)
|
||||||
|
DB.Create(&invitation)
|
||||||
|
|
||||||
|
var invitation2 CustomizeInvitation
|
||||||
|
if err := DB.Preload("Person").Find(&invitation2, invitation.ID).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if invitation2.Person.Email != user.Email {
|
||||||
|
t.Errorf("Should preload one to one relation with customize foreign keys")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PromotionDiscount struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Coupons []*PromotionCoupon `gorm:"ForeignKey:discount_id"`
|
||||||
|
Rule *PromotionRule `gorm:"ForeignKey:discount_id"`
|
||||||
|
Benefits []PromotionBenefit `gorm:"ForeignKey:promotion_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PromotionBenefit struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
PromotionID uint
|
||||||
|
Discount PromotionDiscount `gorm:"ForeignKey:promotion_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PromotionCoupon struct {
|
||||||
|
gorm.Model
|
||||||
|
Code string
|
||||||
|
DiscountID uint
|
||||||
|
Discount PromotionDiscount
|
||||||
|
}
|
||||||
|
|
||||||
|
type PromotionRule struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Begin *time.Time
|
||||||
|
End *time.Time
|
||||||
|
DiscountID uint
|
||||||
|
Discount *PromotionDiscount
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOneToManyWithCustomizedColumn(t *testing.T) {
|
||||||
|
DB.DropTable(&PromotionDiscount{}, &PromotionCoupon{})
|
||||||
|
DB.AutoMigrate(&PromotionDiscount{}, &PromotionCoupon{})
|
||||||
|
|
||||||
|
discount := PromotionDiscount{
|
||||||
|
Name: "Happy New Year",
|
||||||
|
Coupons: []*PromotionCoupon{
|
||||||
|
{Code: "newyear1"},
|
||||||
|
{Code: "newyear2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&discount).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var discount1 PromotionDiscount
|
||||||
|
if err := DB.Preload("Coupons").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(discount.Coupons) != 2 {
|
||||||
|
t.Errorf("should find two coupons")
|
||||||
|
}
|
||||||
|
|
||||||
|
var coupon PromotionCoupon
|
||||||
|
if err := DB.Preload("Discount").First(&coupon, "code = ?", "newyear1").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if coupon.Discount.Name != "Happy New Year" {
|
||||||
|
t.Errorf("should preload discount from coupon")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasOneWithPartialCustomizedColumn(t *testing.T) {
|
||||||
|
DB.DropTable(&PromotionDiscount{}, &PromotionRule{})
|
||||||
|
DB.AutoMigrate(&PromotionDiscount{}, &PromotionRule{})
|
||||||
|
|
||||||
|
var begin = time.Now()
|
||||||
|
var end = time.Now().Add(24 * time.Hour)
|
||||||
|
discount := PromotionDiscount{
|
||||||
|
Name: "Happy New Year 2",
|
||||||
|
Rule: &PromotionRule{
|
||||||
|
Name: "time_limited",
|
||||||
|
Begin: &begin,
|
||||||
|
End: &end,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&discount).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var discount1 PromotionDiscount
|
||||||
|
if err := DB.Preload("Rule").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if discount.Rule.Begin.Format(time.RFC3339Nano) != begin.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("Should be able to preload Rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
var rule PromotionRule
|
||||||
|
if err := DB.Preload("Discount").First(&rule, "name = ?", "time_limited").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Discount.Name != "Happy New Year 2" {
|
||||||
|
t.Errorf("should preload discount from rule")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBelongsToWithPartialCustomizedColumn(t *testing.T) {
|
||||||
|
DB.DropTable(&PromotionDiscount{}, &PromotionBenefit{})
|
||||||
|
DB.AutoMigrate(&PromotionDiscount{}, &PromotionBenefit{})
|
||||||
|
|
||||||
|
discount := PromotionDiscount{
|
||||||
|
Name: "Happy New Year 3",
|
||||||
|
Benefits: []PromotionBenefit{
|
||||||
|
{Name: "free cod"},
|
||||||
|
{Name: "free shipping"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Create(&discount).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var discount1 PromotionDiscount
|
||||||
|
if err := DB.Preload("Benefits").First(&discount1, "id = ?", discount.ID).Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(discount.Benefits) != 2 {
|
||||||
|
t.Errorf("should find two benefits")
|
||||||
|
}
|
||||||
|
|
||||||
|
var benefit PromotionBenefit
|
||||||
|
if err := DB.Preload("Discount").First(&benefit, "name = ?", "free cod").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if benefit.Discount.Name != "Happy New Year 3" {
|
||||||
|
t.Errorf("should preload discount from coupon")
|
||||||
|
}
|
||||||
|
}
|
68
orm/delete_test.go
Normal file
68
orm/delete_test.go
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDelete(t *testing.T) {
|
||||||
|
user1, user2 := User{Name: "delete1"}, User{Name: "delete2"}
|
||||||
|
DB.Save(&user1)
|
||||||
|
DB.Save(&user2)
|
||||||
|
|
||||||
|
if err := DB.Delete(&user1).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when delete a record, err=%s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("User can't be found after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("Other users that not deleted should be found-able")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInlineDelete(t *testing.T) {
|
||||||
|
user1, user2 := User{Name: "inline_delete1"}, User{Name: "inline_delete2"}
|
||||||
|
DB.Save(&user1)
|
||||||
|
DB.Save(&user2)
|
||||||
|
|
||||||
|
if DB.Delete(&User{}, user1.Id).Error != nil {
|
||||||
|
t.Errorf("No error should happen when delete a record")
|
||||||
|
} else if !DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("User can't be found after delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Errorf("No error should happen when delete a record, err=%s", err)
|
||||||
|
} else if !DB.Where("name = ?", user2.Name).First(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("User can't be found after delete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoftDelete(t *testing.T) {
|
||||||
|
type User struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
DeletedAt *time.Time
|
||||||
|
}
|
||||||
|
DB.AutoMigrate(&User{})
|
||||||
|
|
||||||
|
user := User{Name: "soft_delete"}
|
||||||
|
DB.Save(&user)
|
||||||
|
DB.Delete(&user)
|
||||||
|
|
||||||
|
if DB.First(&User{}, "name = ?", user.Name).Error == nil {
|
||||||
|
t.Errorf("Can't find a soft deleted record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Errorf("Should be able to find soft deleted record with Unscoped, but err=%s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Unscoped().Delete(&user)
|
||||||
|
if !DB.Unscoped().First(&User{}, "name = ?", user.Name).RecordNotFound() {
|
||||||
|
t.Errorf("Can't find permanently deleted record")
|
||||||
|
}
|
||||||
|
}
|
106
orm/dialect.go
Normal file
106
orm/dialect.go
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialect interface contains behaviors that differ across SQL database
|
||||||
|
type Dialect interface {
|
||||||
|
// GetName get dialect's name
|
||||||
|
GetName() string
|
||||||
|
|
||||||
|
// SetDB set db for dialect
|
||||||
|
SetDB(db *sql.DB)
|
||||||
|
|
||||||
|
// BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1
|
||||||
|
BindVar(i int) string
|
||||||
|
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||||
|
Quote(key string) string
|
||||||
|
// DataTypeOf return data's sql type
|
||||||
|
DataTypeOf(field *StructField) string
|
||||||
|
|
||||||
|
// HasIndex check has index or not
|
||||||
|
HasIndex(tableName string, indexName string) bool
|
||||||
|
// HasForeignKey check has foreign key or not
|
||||||
|
HasForeignKey(tableName string, foreignKeyName string) bool
|
||||||
|
// RemoveIndex remove index
|
||||||
|
RemoveIndex(tableName string, indexName string) error
|
||||||
|
// HasTable check has table or not
|
||||||
|
HasTable(tableName string) bool
|
||||||
|
// HasColumn check has column or not
|
||||||
|
HasColumn(tableName string, columnName string) bool
|
||||||
|
|
||||||
|
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||||
|
LimitAndOffsetSQL(limit, offset interface{}) string
|
||||||
|
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||||
|
SelectFromDummyTable() string
|
||||||
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
|
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||||
|
|
||||||
|
// BuildForeignKeyName returns a foreign key name for the given table, field and reference
|
||||||
|
BuildForeignKeyName(tableName, field, dest string) string
|
||||||
|
|
||||||
|
// CurrentDatabase return current database name
|
||||||
|
CurrentDatabase() string
|
||||||
|
}
|
||||||
|
|
||||||
|
var dialectsMap = map[string]Dialect{}
|
||||||
|
|
||||||
|
func newDialect(name string, db *sql.DB) Dialect {
|
||||||
|
if value, ok := dialectsMap[name]; ok {
|
||||||
|
dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect)
|
||||||
|
dialect.SetDB(db)
|
||||||
|
return dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("`%v` is not officially supported, running under compatibility mode.\n", name)
|
||||||
|
commontDialect := &commonDialect{}
|
||||||
|
commontDialect.SetDB(db)
|
||||||
|
return commontDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterDialect register new dialect
|
||||||
|
func RegisterDialect(name string, dialect Dialect) {
|
||||||
|
dialectsMap[name] = dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFieldStructForDialect parse field struct for dialect
|
||||||
|
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||||
|
// Get redirected field type
|
||||||
|
var reflectType = field.Struct.Type
|
||||||
|
for reflectType.Kind() == reflect.Ptr {
|
||||||
|
reflectType = reflectType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get redirected field value
|
||||||
|
fieldValue = reflect.Indirect(reflect.New(reflectType))
|
||||||
|
|
||||||
|
// Get scanner's real value
|
||||||
|
var getScannerValue func(reflect.Value)
|
||||||
|
getScannerValue = func(value reflect.Value) {
|
||||||
|
fieldValue = value
|
||||||
|
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
|
||||||
|
getScannerValue(fieldValue.Field(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
getScannerValue(fieldValue)
|
||||||
|
|
||||||
|
// Default Size
|
||||||
|
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||||
|
size, _ = strconv.Atoi(num)
|
||||||
|
} else {
|
||||||
|
size = 255
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default type from tag setting
|
||||||
|
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||||
|
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||||
|
additionalType = additionalType + " DEFAULT " + value
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
|
||||||
|
}
|
152
orm/dialect_common.go
Normal file
152
orm/dialect_common.go
Normal file
|
@ -0,0 +1,152 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultForeignKeyNamer contains the default foreign key name generator method
|
||||||
|
type DefaultForeignKeyNamer struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
type commonDialect struct {
|
||||||
|
db *sql.DB
|
||||||
|
DefaultForeignKeyNamer
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("common", &commonDialect{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) GetName() string {
|
||||||
|
return "common"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *commonDialect) SetDB(db *sql.DB) {
|
||||||
|
s.db = db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) BindVar(i int) string {
|
||||||
|
return "$$" // ?
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) Quote(key string) string {
|
||||||
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "BOOLEAN"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||||
|
sqlType = "INTEGER AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "INTEGER"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||||
|
sqlType = "BIGINT AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "BIGINT"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "FLOAT"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "VARCHAR(65532)"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "TIMESTAMP"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("BINARY(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "BINARY(65532)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.CurrentDatabase(), tableName, indexName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.CurrentDatabase(), tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s commonDialect) CurrentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||||
|
if limit != nil {
|
||||||
|
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit > 0 {
|
||||||
|
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if offset != nil {
|
||||||
|
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset > 0 {
|
||||||
|
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) SelectFromDummyTable() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (DefaultForeignKeyNamer) BuildForeignKeyName(tableName, field, dest string) string {
|
||||||
|
keyName := fmt.Sprintf("%s_%s_%s_foreign", tableName, field, dest)
|
||||||
|
keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_")
|
||||||
|
return keyName
|
||||||
|
}
|
135
orm/dialect_h2.go
Normal file
135
orm/dialect_h2.go
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type h2 struct {
|
||||||
|
commonDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("h2", &h2{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h2) GetName() string {
|
||||||
|
return "h2"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h2) BindVar(i int) string {
|
||||||
|
return fmt.Sprintf("$%v", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h2) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "boolean"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "serial"
|
||||||
|
} else {
|
||||||
|
sqlType = "integer"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "bigserial"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "numeric"
|
||||||
|
case reflect.String:
|
||||||
|
if _, ok := field.TagSettings["SIZE"]; !ok {
|
||||||
|
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
|
||||||
|
}
|
||||||
|
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "timestamp with time zone"
|
||||||
|
}
|
||||||
|
case reflect.Map:
|
||||||
|
if dataValue.Type().Name() == "Hstore" {
|
||||||
|
sqlType = "hstore"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if isByteArrayOrSlice(dataValue) {
|
||||||
|
sqlType = "bytea"
|
||||||
|
} else if isUUID(dataValue) {
|
||||||
|
sqlType = "uuid"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for h2", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s h2) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s h2) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s h2) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s h2) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s h2) CurrentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s h2) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||||
|
//return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
|
return fmt.Sprintf("")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h2) SupportLastInsertID() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
//
|
||||||
|
//func isByteArrayOrSlice(value reflect.Value) bool {
|
||||||
|
// return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func isUUID(value reflect.Value) bool {
|
||||||
|
// if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
// typename := value.Type().Name()
|
||||||
|
// lower := strings.ToLower(typename)
|
||||||
|
// return "uuid" == lower || "guid" == lower
|
||||||
|
//}
|
146
orm/dialect_mysql.go
Normal file
146
orm/dialect_mysql.go
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mysql struct {
|
||||||
|
commonDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("mysql", &mysql{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mysql) GetName() string {
|
||||||
|
return "mysql"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mysql) Quote(key string) string {
|
||||||
|
return fmt.Sprintf("`%s`", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Data Type for MySQL Dialect
|
||||||
|
func (mysql) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
// MySQL allows only one auto increment column per table, and it must
|
||||||
|
// be a KEY column.
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok {
|
||||||
|
if _, ok = field.TagSettings["INDEX"]; !ok && !field.IsPrimaryKey {
|
||||||
|
delete(field.TagSettings, "AUTO_INCREMENT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "boolean"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "int AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "int"
|
||||||
|
}
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "int unsigned AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "int unsigned"
|
||||||
|
}
|
||||||
|
case reflect.Int64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "bigint AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "bigint unsigned AUTO_INCREMENT"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint unsigned"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "double"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "longtext"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
if _, ok := field.TagSettings["NOT NULL"]; ok {
|
||||||
|
sqlType = "timestamp"
|
||||||
|
} else {
|
||||||
|
sqlType = "timestamp NULL"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varbinary(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "longblob"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mysql) RemoveIndex(tableName string, indexName string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.CurrentDatabase(), tableName, foreignKeyName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mysql) CurrentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mysql) SelectFromDummyTable() string {
|
||||||
|
return "FROM DUAL"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mysql) BuildForeignKeyName(tableName, field, dest string) string {
|
||||||
|
keyName := s.commonDialect.BuildForeignKeyName(tableName, field, dest)
|
||||||
|
if utf8.RuneCountInString(keyName) <= 64 {
|
||||||
|
return keyName
|
||||||
|
}
|
||||||
|
h := sha1.New()
|
||||||
|
h.Write([]byte(keyName))
|
||||||
|
bs := h.Sum(nil)
|
||||||
|
|
||||||
|
// sha1 is 40 digits, keep first 24 characters of destination
|
||||||
|
destRunes := []rune(regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(dest, "_"))
|
||||||
|
if len(destRunes) > 24 {
|
||||||
|
destRunes = destRunes[:24]
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s%x", string(destRunes), bs)
|
||||||
|
}
|
135
orm/dialect_postgres.go
Normal file
135
orm/dialect_postgres.go
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
type postgres struct {
|
||||||
|
commonDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("postgres", &postgres{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (postgres) GetName() string {
|
||||||
|
return "postgres"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (postgres) BindVar(i int) string {
|
||||||
|
return fmt.Sprintf("$%v", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (postgres) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "boolean"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "serial"
|
||||||
|
} else {
|
||||||
|
sqlType = "integer"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "bigserial"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "numeric"
|
||||||
|
case reflect.String:
|
||||||
|
if _, ok := field.TagSettings["SIZE"]; !ok {
|
||||||
|
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
|
||||||
|
}
|
||||||
|
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "timestamp with time zone"
|
||||||
|
}
|
||||||
|
case reflect.Map:
|
||||||
|
if dataValue.Type().Name() == "Hstore" {
|
||||||
|
sqlType = "hstore"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if isByteArrayOrSlice(dataValue) {
|
||||||
|
sqlType = "bytea"
|
||||||
|
} else if isUUID(dataValue) {
|
||||||
|
sqlType = "uuid"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) CurrentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||||
|
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (postgres) SupportLastInsertID() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isByteArrayOrSlice(value reflect.Value) bool {
|
||||||
|
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUUID(value reflect.Value) bool {
|
||||||
|
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
typename := value.Type().Name()
|
||||||
|
lower := strings.ToLower(typename)
|
||||||
|
return "uuid" == lower || "guid" == lower
|
||||||
|
}
|
108
orm/dialect_sqlite3.go
Normal file
108
orm/dialect_sqlite3.go
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sqlite3 struct {
|
||||||
|
commonDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterDialect("sqlite", &sqlite3{})
|
||||||
|
RegisterDialect("sqlite3", &sqlite3{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sqlite3) GetName() string {
|
||||||
|
return "sqlite3"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Data Type for Sqlite Dialect
|
||||||
|
func (sqlite3) DataTypeOf(field *StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "bool"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "integer primary key autoincrement"
|
||||||
|
} else {
|
||||||
|
sqlType = "integer"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
|
||||||
|
sqlType = "integer primary key autoincrement"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "real"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "datetime"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
sqlType = "blob"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) CurrentDatabase() (name string) {
|
||||||
|
var (
|
||||||
|
ifaces = make([]interface{}, 3)
|
||||||
|
pointers = make([]*string, 3)
|
||||||
|
i int
|
||||||
|
)
|
||||||
|
for i = 0; i < 3; i++ {
|
||||||
|
ifaces[i] = &pointers[i]
|
||||||
|
}
|
||||||
|
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pointers[1] != nil {
|
||||||
|
name = *pointers[1]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
54
orm/dialects/h2/h2.go
Normal file
54
orm/dialects/h2/h2.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package h2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/lib/pq/hstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Hstore map[string]*string
|
||||||
|
|
||||||
|
// Value get value of Hstore
|
||||||
|
func (h Hstore) Value() (driver.Value, error) {
|
||||||
|
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
||||||
|
if len(h) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range h {
|
||||||
|
var s sql.NullString
|
||||||
|
if value != nil {
|
||||||
|
s.String = *value
|
||||||
|
s.Valid = true
|
||||||
|
}
|
||||||
|
hstore.Map[key] = s
|
||||||
|
}
|
||||||
|
return hstore.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan scan value into Hstore
|
||||||
|
func (h *Hstore) Scan(value interface{}) error {
|
||||||
|
hstore := hstore.Hstore{}
|
||||||
|
|
||||||
|
if err := hstore.Scan(value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hstore.Map) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
*h = Hstore{}
|
||||||
|
for k := range hstore.Map {
|
||||||
|
if hstore.Map[k].Valid {
|
||||||
|
s := hstore.Map[k].String
|
||||||
|
(*h)[k] = &s
|
||||||
|
} else {
|
||||||
|
(*h)[k] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
151
orm/dialects/mssql/mssql.go
Normal file
151
orm/dialects/mssql/mssql.go
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
package mssql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/denisenkom/go-mssqldb"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setIdentityInsert(scope *gorm.Scope) {
|
||||||
|
if scope.Dialect().GetName() == "mssql" {
|
||||||
|
scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
|
||||||
|
gorm.RegisterDialect("mssql", &mssql{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type mssql struct {
|
||||||
|
db *sql.DB
|
||||||
|
gorm.DefaultForeignKeyNamer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) GetName() string {
|
||||||
|
return "mssql"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mssql) SetDB(db *sql.DB) {
|
||||||
|
s.db = db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) BindVar(i int) string {
|
||||||
|
return "$$" // ?
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) Quote(key string) string {
|
||||||
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) DataTypeOf(field *gorm.StructField) string {
|
||||||
|
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "bit"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "int IDENTITY(1,1)"
|
||||||
|
} else {
|
||||||
|
sqlType = "int"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
|
||||||
|
sqlType = "bigint IDENTITY(1,1)"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "float"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("nvarchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "datetime2"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasIndex(tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) RemoveIndex(tableName string, indexName string) error {
|
||||||
|
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasTable(tableName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.CurrentDatabase()).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) HasColumn(tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.CurrentDatabase(), tableName, columnName).Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s mssql) CurrentDatabase() (name string) {
|
||||||
|
s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||||
|
if limit != nil {
|
||||||
|
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit > 0 {
|
||||||
|
sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if offset != nil {
|
||||||
|
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset > 0 {
|
||||||
|
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) SelectFromDummyTable() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
|
return ""
|
||||||
|
}
|
3
orm/dialects/mysql/mysql.go
Normal file
3
orm/dialects/mysql/mysql.go
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
package mysql
|
||||||
|
|
||||||
|
import _ "github.com/go-sql-driver/mysql"
|
54
orm/dialects/postgres/postgres.go
Normal file
54
orm/dialects/postgres/postgres.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/lib/pq/hstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Hstore map[string]*string
|
||||||
|
|
||||||
|
// Value get value of Hstore
|
||||||
|
func (h Hstore) Value() (driver.Value, error) {
|
||||||
|
hstore := hstore.Hstore{Map: map[string]sql.NullString{}}
|
||||||
|
if len(h) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range h {
|
||||||
|
var s sql.NullString
|
||||||
|
if value != nil {
|
||||||
|
s.String = *value
|
||||||
|
s.Valid = true
|
||||||
|
}
|
||||||
|
hstore.Map[key] = s
|
||||||
|
}
|
||||||
|
return hstore.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan scan value into Hstore
|
||||||
|
func (h *Hstore) Scan(value interface{}) error {
|
||||||
|
hstore := hstore.Hstore{}
|
||||||
|
|
||||||
|
if err := hstore.Scan(value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(hstore.Map) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
*h = Hstore{}
|
||||||
|
for k := range hstore.Map {
|
||||||
|
if hstore.Map[k].Valid {
|
||||||
|
s := hstore.Map[k].String
|
||||||
|
(*h)[k] = &s
|
||||||
|
} else {
|
||||||
|
(*h)[k] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
3
orm/dialects/sqlite/sqlite.go
Normal file
3
orm/dialects/sqlite/sqlite.go
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import _ "github.com/mattn/go-sqlite3"
|
66
orm/embedded_struct_test.go
Normal file
66
orm/embedded_struct_test.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
type BasePost struct {
|
||||||
|
Id int64
|
||||||
|
Title string
|
||||||
|
URL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Author struct {
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
type HNPost struct {
|
||||||
|
BasePost
|
||||||
|
Author `gorm:"embedded_prefix:user_"` // Embedded struct
|
||||||
|
Upvotes int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type EngadgetPost struct {
|
||||||
|
BasePost BasePost `gorm:"embedded"`
|
||||||
|
Author Author `gorm:"embedded;embedded_prefix:author_"` // Embedded struct
|
||||||
|
ImageUrl string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrefixColumnNameForEmbeddedStruct(t *testing.T) {
|
||||||
|
dialect := DB.NewScope(&EngadgetPost{}).Dialect()
|
||||||
|
if !dialect.HasColumn(DB.NewScope(&EngadgetPost{}).TableName(), "author_name") || !dialect.HasColumn(DB.NewScope(&EngadgetPost{}).TableName(), "author_email") {
|
||||||
|
t.Errorf("should has prefix for embedded columns")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dialect.HasColumn(DB.NewScope(&HNPost{}).TableName(), "user_name") || !dialect.HasColumn(DB.NewScope(&HNPost{}).TableName(), "user_email") {
|
||||||
|
t.Errorf("should has prefix for embedded columns")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAndQueryEmbeddedStruct(t *testing.T) {
|
||||||
|
DB.Save(&HNPost{BasePost: BasePost{Title: "news"}})
|
||||||
|
DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}})
|
||||||
|
var news HNPost
|
||||||
|
if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||||
|
} else if news.Title != "hn_news" {
|
||||||
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}})
|
||||||
|
var egNews EngadgetPost
|
||||||
|
if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil {
|
||||||
|
t.Errorf("no error should happen when query with embedded struct, but got %v", err)
|
||||||
|
} else if egNews.BasePost.Title != "engadget_news" {
|
||||||
|
t.Errorf("embedded struct's value should be scanned correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope(&HNPost{}).PrimaryField() == nil {
|
||||||
|
t.Errorf("primary key with embedded struct should works")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range DB.NewScope(&HNPost{}).Fields() {
|
||||||
|
if field.Name == "BasePost" {
|
||||||
|
t.Errorf("scope Fields should not contain embedded struct")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
56
orm/errors.go
Normal file
56
orm/errors.go
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
|
||||||
|
ErrRecordNotFound = errors.New("record not found")
|
||||||
|
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
|
||||||
|
ErrInvalidSQL = errors.New("invalid SQL")
|
||||||
|
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||||
|
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||||
|
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
|
||||||
|
ErrCantStartTransaction = errors.New("can't start transaction")
|
||||||
|
// ErrUnaddressable unaddressable value
|
||||||
|
ErrUnaddressable = errors.New("using unaddressable value")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Errors contains all happened errors
|
||||||
|
type Errors []error
|
||||||
|
|
||||||
|
// GetErrors gets all happened errors
|
||||||
|
func (errs Errors) GetErrors() []error {
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds an error
|
||||||
|
func (errs Errors) Add(newErrors ...error) Errors {
|
||||||
|
for _, err := range newErrors {
|
||||||
|
if errors, ok := err.(Errors); ok {
|
||||||
|
errs = errs.Add(errors...)
|
||||||
|
} else {
|
||||||
|
ok = true
|
||||||
|
for _, e := range errs {
|
||||||
|
if err == e {
|
||||||
|
ok = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error format happened errors
|
||||||
|
func (errs Errors) Error() string {
|
||||||
|
var errors = []string{}
|
||||||
|
for _, e := range errs {
|
||||||
|
errors = append(errors, e.Error())
|
||||||
|
}
|
||||||
|
return strings.Join(errors, "; ")
|
||||||
|
}
|
20
orm/errors_test.go
Normal file
20
orm/errors_test.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrorsCanBeUsedOutsideGorm(t *testing.T) {
|
||||||
|
errs := []error{errors.New("First"), errors.New("Second")}
|
||||||
|
|
||||||
|
gErrs := gorm.Errors(errs)
|
||||||
|
gErrs = gErrs.Add(errors.New("Third"))
|
||||||
|
gErrs = gErrs.Add(gErrs)
|
||||||
|
|
||||||
|
if gErrs.Error() != "First; Second; Third" {
|
||||||
|
t.Fatalf("Gave wrong error, got %s", gErrs.Error())
|
||||||
|
}
|
||||||
|
}
|
58
orm/field.go
Normal file
58
orm/field.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Field model field definition
|
||||||
|
type Field struct {
|
||||||
|
*StructField
|
||||||
|
IsBlank bool
|
||||||
|
Field reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set set a value to the field
|
||||||
|
func (field *Field) Set(value interface{}) (err error) {
|
||||||
|
if !field.Field.IsValid() {
|
||||||
|
return errors.New("field value not valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !field.Field.CanAddr() {
|
||||||
|
return ErrUnaddressable
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue, ok := value.(reflect.Value)
|
||||||
|
if !ok {
|
||||||
|
reflectValue = reflect.ValueOf(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldValue := field.Field
|
||||||
|
if reflectValue.IsValid() {
|
||||||
|
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||||
|
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||||
|
} else {
|
||||||
|
if fieldValue.Kind() == reflect.Ptr {
|
||||||
|
if fieldValue.IsNil() {
|
||||||
|
fieldValue.Set(reflect.New(field.Struct.Type.Elem()))
|
||||||
|
}
|
||||||
|
fieldValue = fieldValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
|
||||||
|
fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
|
||||||
|
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
|
||||||
|
err = scanner.Scan(reflectValue.Interface())
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
field.Field.Set(reflect.Zero(field.Field.Type()))
|
||||||
|
}
|
||||||
|
|
||||||
|
field.IsBlank = isBlank(field.Field)
|
||||||
|
return err
|
||||||
|
}
|
49
orm/field_test.go
Normal file
49
orm/field_test.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CalculateField struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Children []CalculateFieldChild
|
||||||
|
Category CalculateFieldCategory
|
||||||
|
EmbeddedField
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedField struct {
|
||||||
|
EmbeddedName string `sql:"NOT NULL;DEFAULT:'hello'"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CalculateFieldChild struct {
|
||||||
|
gorm.Model
|
||||||
|
CalculateFieldID uint
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CalculateFieldCategory struct {
|
||||||
|
gorm.Model
|
||||||
|
CalculateFieldID uint
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateField(t *testing.T) {
|
||||||
|
var field CalculateField
|
||||||
|
var scope = DB.NewScope(&field)
|
||||||
|
if field, ok := scope.FieldByName("Children"); !ok || field.Relationship == nil {
|
||||||
|
t.Errorf("Should calculate fields correctly for the first time")
|
||||||
|
}
|
||||||
|
|
||||||
|
if field, ok := scope.FieldByName("Category"); !ok || field.Relationship == nil {
|
||||||
|
t.Errorf("Should calculate fields correctly for the first time")
|
||||||
|
}
|
||||||
|
|
||||||
|
if field, ok := scope.FieldByName("embedded_name"); !ok {
|
||||||
|
t.Errorf("should find embedded field")
|
||||||
|
} else if _, ok := field.TagSettings["NOT NULL"]; !ok {
|
||||||
|
t.Errorf("should find embedded field's tag settings")
|
||||||
|
}
|
||||||
|
}
|
19
orm/interface.go
Normal file
19
orm/interface.go
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import "database/sql"
|
||||||
|
|
||||||
|
type sqlCommon interface {
|
||||||
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
|
Prepare(query string) (*sql.Stmt, error)
|
||||||
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
QueryRow(query string, args ...interface{}) *sql.Row
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlDb interface {
|
||||||
|
Begin() (*sql.Tx, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlTx interface {
|
||||||
|
Commit() error
|
||||||
|
Rollback() error
|
||||||
|
}
|
204
orm/join_table_handler.go
Normal file
204
orm/join_table_handler.go
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JoinTableHandlerInterface is an interface for how to handle many2many relations
|
||||||
|
type JoinTableHandlerInterface interface {
|
||||||
|
// initialize join table handler
|
||||||
|
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||||
|
// Table return join table's table name
|
||||||
|
Table(db *DB) string
|
||||||
|
// Add create relationship in join table for source and destination
|
||||||
|
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||||
|
// Delete delete relationship in join table for sources
|
||||||
|
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||||
|
// JoinWith query with `Join` conditions
|
||||||
|
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||||
|
// SourceForeignKeys return source foreign keys
|
||||||
|
SourceForeignKeys() []JoinTableForeignKey
|
||||||
|
// DestinationForeignKeys return destination foreign keys
|
||||||
|
DestinationForeignKeys() []JoinTableForeignKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinTableForeignKey join table foreign key struct
|
||||||
|
type JoinTableForeignKey struct {
|
||||||
|
DBName string
|
||||||
|
AssociationDBName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinTableSource is a struct that contains model type and foreign keys
|
||||||
|
type JoinTableSource struct {
|
||||||
|
ModelType reflect.Type
|
||||||
|
ForeignKeys []JoinTableForeignKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinTableHandler default join table handler
|
||||||
|
type JoinTableHandler struct {
|
||||||
|
TableName string `sql:"-"`
|
||||||
|
Source JoinTableSource `sql:"-"`
|
||||||
|
Destination JoinTableSource `sql:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SourceForeignKeys return source foreign keys
|
||||||
|
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
||||||
|
return s.Source.ForeignKeys
|
||||||
|
}
|
||||||
|
|
||||||
|
// DestinationForeignKeys return destination foreign keys
|
||||||
|
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
||||||
|
return s.Destination.ForeignKeys
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup initialize a default join table handler
|
||||||
|
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||||
|
s.TableName = tableName
|
||||||
|
|
||||||
|
s.Source = JoinTableSource{ModelType: source}
|
||||||
|
for idx, dbName := range relationship.ForeignFieldNames {
|
||||||
|
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
|
||||||
|
DBName: relationship.ForeignDBNames[idx],
|
||||||
|
AssociationDBName: dbName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Destination = JoinTableSource{ModelType: destination}
|
||||||
|
for idx, dbName := range relationship.AssociationForeignFieldNames {
|
||||||
|
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
|
||||||
|
DBName: relationship.AssociationForeignDBNames[idx],
|
||||||
|
AssociationDBName: dbName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table return join table's table name
|
||||||
|
func (s JoinTableHandler) Table(db *DB) string {
|
||||||
|
return s.TableName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||||
|
values := map[string]interface{}{}
|
||||||
|
|
||||||
|
for _, source := range sources {
|
||||||
|
scope := db.NewScope(source)
|
||||||
|
modelType := scope.GetModelStruct().ModelType
|
||||||
|
|
||||||
|
if s.Source.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||||
|
values[foreignKey.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if s.Destination.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
|
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||||
|
values[foreignKey.DBName] = field.Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add create relationship in join table for source and destination
|
||||||
|
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
|
||||||
|
scope := db.NewScope("")
|
||||||
|
searchMap := s.getSearchMap(db, source, destination)
|
||||||
|
|
||||||
|
var assignColumns, binVars, conditions []string
|
||||||
|
var values []interface{}
|
||||||
|
for key, value := range searchMap {
|
||||||
|
assignColumns = append(assignColumns, scope.Quote(key))
|
||||||
|
binVars = append(binVars, `?`)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
quotedTable := scope.Quote(handler.Table(db))
|
||||||
|
sql := fmt.Sprintf(
|
||||||
|
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
|
||||||
|
quotedTable,
|
||||||
|
strings.Join(assignColumns, ","),
|
||||||
|
strings.Join(binVars, ","),
|
||||||
|
scope.Dialect().SelectFromDummyTable(),
|
||||||
|
quotedTable,
|
||||||
|
strings.Join(conditions, " AND "),
|
||||||
|
)
|
||||||
|
|
||||||
|
return db.Exec(sql, values...).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete delete relationship in join table for sources
|
||||||
|
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
||||||
|
var (
|
||||||
|
scope = db.NewScope(nil)
|
||||||
|
conditions []string
|
||||||
|
values []interface{}
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value := range s.getSearchMap(db, sources...) {
|
||||||
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinWith query with `Join` conditions
|
||||||
|
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||||
|
var (
|
||||||
|
scope = db.NewScope(source)
|
||||||
|
tableName = handler.Table(db)
|
||||||
|
quotedTableName = scope.Quote(tableName)
|
||||||
|
joinConditions []string
|
||||||
|
values []interface{}
|
||||||
|
)
|
||||||
|
|
||||||
|
if s.Source.ModelType == scope.GetModelStruct().ModelType {
|
||||||
|
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var foreignDBNames []string
|
||||||
|
var foreignFieldNames []string
|
||||||
|
|
||||||
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
foreignDBNames = append(foreignDBNames, foreignKey.DBName)
|
||||||
|
if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
|
||||||
|
foreignFieldNames = append(foreignFieldNames, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||||
|
|
||||||
|
var condString string
|
||||||
|
if len(foreignFieldValues) > 0 {
|
||||||
|
var quotedForeignDBNames []string
|
||||||
|
for _, dbName := range foreignDBNames {
|
||||||
|
quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
|
||||||
|
}
|
||||||
|
|
||||||
|
condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
|
||||||
|
|
||||||
|
keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
|
||||||
|
values = append(values, toQueryValues(keys))
|
||||||
|
} else {
|
||||||
|
condString = fmt.Sprintf("1 <> 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
|
||||||
|
Where(condString, toQueryValues(foreignFieldValues)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.Error = errors.New("wrong source type for join table handler")
|
||||||
|
return db
|
||||||
|
}
|
72
orm/join_table_test.go
Normal file
72
orm/join_table_test.go
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Person struct {
|
||||||
|
Id int
|
||||||
|
Name string
|
||||||
|
Addresses []*Address `gorm:"many2many:person_addresses;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PersonAddress struct {
|
||||||
|
gorm.JoinTableHandler
|
||||||
|
PersonID int
|
||||||
|
AddressID int
|
||||||
|
DeletedAt *time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
|
||||||
|
return db.Where(map[string]interface{}{
|
||||||
|
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
|
||||||
|
"address_id": db.NewScope(associationValue).PrimaryKeyValue(),
|
||||||
|
}).Assign(map[string]interface{}{
|
||||||
|
"person_id": foreignValue,
|
||||||
|
"address_id": associationValue,
|
||||||
|
"deleted_at": gorm.Expr("NULL"),
|
||||||
|
}).FirstOrCreate(&PersonAddress{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
|
||||||
|
return db.Delete(&PersonAddress{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
|
||||||
|
table := pa.Table(db)
|
||||||
|
return db.Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinTable(t *testing.T) {
|
||||||
|
DB.Exec("drop table person_addresses;")
|
||||||
|
DB.AutoMigrate(&Person{})
|
||||||
|
DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{})
|
||||||
|
|
||||||
|
address1 := &Address{Address1: "address 1"}
|
||||||
|
address2 := &Address{Address1: "address 2"}
|
||||||
|
person := &Person{Name: "person", Addresses: []*Address{address1, address2}}
|
||||||
|
DB.Save(person)
|
||||||
|
|
||||||
|
DB.Model(person).Association("Addresses").Delete(address1)
|
||||||
|
|
||||||
|
if DB.Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 1 {
|
||||||
|
t.Errorf("Should found one address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(person).Association("Addresses").Count() != 1 {
|
||||||
|
t.Errorf("Should found one address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.Id).RowsAffected != 2 {
|
||||||
|
t.Errorf("Found two addresses with Unscoped")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(person).Association("Addresses").Clear(); DB.Model(person).Association("Addresses").Count() != 0 {
|
||||||
|
t.Errorf("Should deleted all addresses")
|
||||||
|
}
|
||||||
|
}
|
99
orm/logger.go
Normal file
99
orm/logger.go
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)}
|
||||||
|
sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`)
|
||||||
|
)
|
||||||
|
|
||||||
|
type logger interface {
|
||||||
|
Print(v ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogWriter log writer interface
|
||||||
|
type LogWriter interface {
|
||||||
|
Println(v ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger default logger
|
||||||
|
type Logger struct {
|
||||||
|
LogWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print format & print log
|
||||||
|
func (logger Logger) Print(values ...interface{}) {
|
||||||
|
if len(values) > 1 {
|
||||||
|
level := values[0]
|
||||||
|
currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m"
|
||||||
|
source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1])
|
||||||
|
messages := []interface{}{source, currentTime}
|
||||||
|
|
||||||
|
if level == "sql" {
|
||||||
|
// duration
|
||||||
|
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||||
|
// sql
|
||||||
|
var sql string
|
||||||
|
var formattedValues []string
|
||||||
|
|
||||||
|
for _, value := range values[4].([]interface{}) {
|
||||||
|
indirectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
if indirectValue.IsValid() {
|
||||||
|
value = indirectValue.Interface()
|
||||||
|
if t, ok := value.(time.Time); ok {
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339)))
|
||||||
|
} else if b, ok := value.([]byte); ok {
|
||||||
|
if str := string(b); isPrintable(str) {
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str))
|
||||||
|
} else {
|
||||||
|
formattedValues = append(formattedValues, "'<binary>'")
|
||||||
|
}
|
||||||
|
} else if r, ok := value.(driver.Valuer); ok {
|
||||||
|
if value, err := r.Value(); err == nil && value != nil {
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||||
|
} else {
|
||||||
|
formattedValues = append(formattedValues, "NULL")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var formattedValuesLength = len(formattedValues)
|
||||||
|
for index, value := range sqlRegexp.Split(values[3].(string), -1) {
|
||||||
|
sql += value
|
||||||
|
if index < formattedValuesLength {
|
||||||
|
sql += formattedValues[index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, sql)
|
||||||
|
} else {
|
||||||
|
messages = append(messages, "\033[31;1m")
|
||||||
|
messages = append(messages, values[2:]...)
|
||||||
|
messages = append(messages, "\033[0m")
|
||||||
|
}
|
||||||
|
logger.Println(messages...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPrintable(s string) bool {
|
||||||
|
for _, r := range s {
|
||||||
|
if !unicode.IsPrint(r) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
715
orm/main.go
Normal file
715
orm/main.go
Normal file
|
@ -0,0 +1,715 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DB contains information for current db connection
|
||||||
|
type DB struct {
|
||||||
|
Value interface{}
|
||||||
|
Error error
|
||||||
|
RowsAffected int64
|
||||||
|
callbacks *Callback
|
||||||
|
db sqlCommon
|
||||||
|
parent *DB
|
||||||
|
search *search
|
||||||
|
logMode int
|
||||||
|
logger logger
|
||||||
|
dialect Dialect
|
||||||
|
singularTable bool
|
||||||
|
source string
|
||||||
|
values map[string]interface{}
|
||||||
|
joinTableHandlers map[string]JoinTableHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open initialize a new db connection, need to import driver first, e.g:
|
||||||
|
//
|
||||||
|
// import _ "github.com/go-sql-driver/mysql"
|
||||||
|
// func main() {
|
||||||
|
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
|
||||||
|
// }
|
||||||
|
// GORM has wrapped some drivers, for easier to remember driver's import path, so you could import the mysql driver with
|
||||||
|
// import _ "github.com/jinzhu/gorm/dialects/mysql"
|
||||||
|
// // import _ "github.com/jinzhu/gorm/dialects/postgres"
|
||||||
|
// // import _ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
|
// // import _ "github.com/jinzhu/gorm/dialects/mssql"
|
||||||
|
func Open(dialect string, args ...interface{}) (*DB, error) {
|
||||||
|
var db DB
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if len(args) == 0 {
|
||||||
|
err = errors.New("invalid database source")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var source string
|
||||||
|
var dbSQL sqlCommon
|
||||||
|
|
||||||
|
switch value := args[0].(type) {
|
||||||
|
case string:
|
||||||
|
var driver = dialect
|
||||||
|
if len(args) == 1 {
|
||||||
|
source = value
|
||||||
|
} else if len(args) >= 2 {
|
||||||
|
driver = value
|
||||||
|
source = args[1].(string)
|
||||||
|
}
|
||||||
|
dbSQL, err = sql.Open(driver, source)
|
||||||
|
case sqlCommon:
|
||||||
|
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
||||||
|
dbSQL = value
|
||||||
|
}
|
||||||
|
|
||||||
|
db = DB{
|
||||||
|
dialect: newDialect(dialect, dbSQL.(*sql.DB)),
|
||||||
|
logger: defaultLogger,
|
||||||
|
callbacks: DefaultCallback,
|
||||||
|
source: source,
|
||||||
|
values: map[string]interface{}{},
|
||||||
|
db: dbSQL,
|
||||||
|
}
|
||||||
|
db.parent = &db
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
err = db.DB().Ping() // Send a ping to make sure the database connection is alive.
|
||||||
|
if err != nil {
|
||||||
|
db.DB().Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &db, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close close current db connection
|
||||||
|
func (s *DB) Close() error {
|
||||||
|
return s.parent.db.(*sql.DB).Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DB get `*sql.DB` from current connection
|
||||||
|
func (s *DB) DB() *sql.DB {
|
||||||
|
return s.db.(*sql.DB)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dialect get dialect
|
||||||
|
func (s *DB) Dialect() Dialect {
|
||||||
|
return s.parent.dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
// New clone a new db connection without search conditions
|
||||||
|
func (s *DB) New() *DB {
|
||||||
|
clone := s.clone()
|
||||||
|
clone.search = nil
|
||||||
|
clone.Value = nil
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScope create a scope for current operation
|
||||||
|
func (s *DB) NewScope(value interface{}) *Scope {
|
||||||
|
dbClone := s.clone()
|
||||||
|
dbClone.Value = value
|
||||||
|
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code.
|
||||||
|
func (s *DB) CommonDB() sqlCommon {
|
||||||
|
return s.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
|
||||||
|
// db.Callback().Create().Register("update_created_at", updateCreated)
|
||||||
|
// Refer https://jinzhu.github.io/gorm/development.html#callbacks
|
||||||
|
func (s *DB) Callback() *Callback {
|
||||||
|
s.parent.callbacks = s.parent.callbacks.clone()
|
||||||
|
return s.parent.callbacks
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogger replace default logger
|
||||||
|
func (s *DB) SetLogger(log logger) {
|
||||||
|
s.logger = log
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
|
||||||
|
func (s *DB) LogMode(enable bool) *DB {
|
||||||
|
if enable {
|
||||||
|
s.logMode = 2
|
||||||
|
} else {
|
||||||
|
s.logMode = 1
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// SingularTable use singular table by default
|
||||||
|
func (s *DB) SingularTable(enable bool) {
|
||||||
|
modelStructsMap = newModelStructsMap()
|
||||||
|
s.parent.singularTable = enable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query
|
||||||
|
func (s *DB) Where(query interface{}, args ...interface{}) *DB {
|
||||||
|
return s.clone().search.Where(query, args...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or filter records that match before conditions or this one, similar to `Where`
|
||||||
|
func (s *DB) Or(query interface{}, args ...interface{}) *DB {
|
||||||
|
return s.clone().search.Or(query, args...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not filter records that don't match current conditions, similar to `Where`
|
||||||
|
func (s *DB) Not(query interface{}, args ...interface{}) *DB {
|
||||||
|
return s.clone().search.Not(query, args...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit specify the number of records to be retrieved
|
||||||
|
func (s *DB) Limit(limit interface{}) *DB {
|
||||||
|
return s.clone().search.Limit(limit).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset specify the number of records to skip before starting to return the records
|
||||||
|
func (s *DB) Offset(offset interface{}) *DB {
|
||||||
|
return s.clone().search.Offset(offset).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions
|
||||||
|
// db.Order("name DESC")
|
||||||
|
// db.Order("name DESC", true) // reorder
|
||||||
|
// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression
|
||||||
|
func (s *DB) Order(value interface{}, reorder ...bool) *DB {
|
||||||
|
return s.clone().search.Order(value, reorder...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select specify fields that you want to retrieve from database when querying, by default, will select all fields;
|
||||||
|
// When creating/updating, specify fields that you want to save to database
|
||||||
|
func (s *DB) Select(query interface{}, args ...interface{}) *DB {
|
||||||
|
return s.clone().search.Select(query, args...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Omit specify fields that you want to ignore when saving to database for creating, updating
|
||||||
|
func (s *DB) Omit(columns ...string) *DB {
|
||||||
|
return s.clone().search.Omit(columns...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group specify the group method on the find
|
||||||
|
func (s *DB) Group(query string) *DB {
|
||||||
|
return s.clone().search.Group(query).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Having specify HAVING conditions for GROUP BY
|
||||||
|
func (s *DB) Having(query string, values ...interface{}) *DB {
|
||||||
|
return s.clone().search.Having(query, values...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Joins specify Joins conditions
|
||||||
|
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||||
|
func (s *DB) Joins(query string, args ...interface{}) *DB {
|
||||||
|
return s.clone().search.Joins(query, args...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scopes pass current database connection to arguments `func(*DB) *DB`, which could be used to add conditions dynamically
|
||||||
|
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||||
|
// return db.Where("amount > ?", 1000)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||||
|
// return func (db *gorm.DB) *gorm.DB {
|
||||||
|
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||||
|
// Refer https://jinzhu.github.io/gorm/curd.html#scopes
|
||||||
|
func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB {
|
||||||
|
for _, f := range funcs {
|
||||||
|
s = f(s)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete
|
||||||
|
func (s *DB) Unscoped() *DB {
|
||||||
|
return s.clone().search.unscoped().db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||||
|
func (s *DB) Attrs(attrs ...interface{}) *DB {
|
||||||
|
return s.clone().search.Attrs(attrs...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||||
|
func (s *DB) Assign(attrs ...interface{}) *DB {
|
||||||
|
return s.clone().search.Assign(attrs...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// First find first record that match given conditions, order by primary key
|
||||||
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
|
newScope := s.clone().NewScope(out)
|
||||||
|
newScope.Search.Limit(1)
|
||||||
|
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||||
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last find last record that match given conditions, order by primary key
|
||||||
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
|
newScope := s.clone().NewScope(out)
|
||||||
|
newScope.Search.Limit(1)
|
||||||
|
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
||||||
|
inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find find records that match given conditions
|
||||||
|
func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||||
|
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan scan value to a struct
|
||||||
|
func (s *DB) Scan(dest interface{}) *DB {
|
||||||
|
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row return `*sql.Row` with given conditions
|
||||||
|
func (s *DB) Row() *sql.Row {
|
||||||
|
return s.NewScope(s.Value).row()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rows return `*sql.Rows` with given conditions
|
||||||
|
func (s *DB) Rows() (*sql.Rows, error) {
|
||||||
|
return s.NewScope(s.Value).rows()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanRows scan `*sql.Rows` to give struct
|
||||||
|
func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error {
|
||||||
|
var (
|
||||||
|
clone = s.clone()
|
||||||
|
scope = clone.NewScope(result)
|
||||||
|
columns, err = rows.Columns()
|
||||||
|
)
|
||||||
|
|
||||||
|
if clone.AddError(err) == nil {
|
||||||
|
scope.scan(rows, columns, scope.Fields())
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pluck used to query single column from a model as a map
|
||||||
|
// var ages []int64
|
||||||
|
// db.Find(&users).Pluck("age", &ages)
|
||||||
|
func (s *DB) Pluck(column string, value interface{}) *DB {
|
||||||
|
return s.NewScope(s.Value).pluck(column, value).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count get how many records for a model
|
||||||
|
func (s *DB) Count(value interface{}) *DB {
|
||||||
|
return s.NewScope(s.Value).count(value).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Related get related associations
|
||||||
|
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||||
|
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions)
|
||||||
|
// https://jinzhu.github.io/gorm/curd.html#firstorinit
|
||||||
|
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||||
|
c := s.clone()
|
||||||
|
if result := c.First(out, where...); result.Error != nil {
|
||||||
|
if !result.RecordNotFound() {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
c.NewScope(out).inlineCondition(where...).initialize()
|
||||||
|
} else {
|
||||||
|
c.NewScope(out).updatedAttrsWithValues(c.search.assignAttrs)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions)
|
||||||
|
// https://jinzhu.github.io/gorm/curd.html#firstorcreate
|
||||||
|
func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||||
|
c := s.clone()
|
||||||
|
if result := s.First(out, where...); result.Error != nil {
|
||||||
|
if !result.RecordNotFound() {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db
|
||||||
|
} else if len(c.search.assignAttrs) > 0 {
|
||||||
|
return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||||
|
func (s *DB) Update(attrs ...interface{}) *DB {
|
||||||
|
return s.Updates(toSearchableMap(attrs...), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||||
|
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
|
||||||
|
return s.clone().NewScope(s.Value).
|
||||||
|
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
|
||||||
|
InstanceSet("gorm:update_interface", values).
|
||||||
|
callCallbacks(s.parent.callbacks.updates).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||||
|
func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
||||||
|
return s.UpdateColumns(toSearchableMap(attrs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update
|
||||||
|
func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||||
|
return s.clone().NewScope(s.Value).
|
||||||
|
Set("gorm:update_column", true).
|
||||||
|
Set("gorm:save_associations", false).
|
||||||
|
InstanceSet("gorm:update_interface", values).
|
||||||
|
callCallbacks(s.parent.callbacks.updates).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||||
|
func (s *DB) Save(value interface{}) *DB {
|
||||||
|
scope := s.clone().NewScope(value)
|
||||||
|
if !scope.PrimaryKeyZero() {
|
||||||
|
newDB := scope.callCallbacks(s.parent.callbacks.updates).db
|
||||||
|
if newDB.Error == nil && newDB.RowsAffected == 0 {
|
||||||
|
return s.New().FirstOrCreate(value)
|
||||||
|
}
|
||||||
|
return newDB
|
||||||
|
}
|
||||||
|
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create insert the value into database
|
||||||
|
func (s *DB) Create(value interface{}) *DB {
|
||||||
|
scope := s.clone().NewScope(value)
|
||||||
|
return scope.callCallbacks(s.parent.callbacks.creates).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
||||||
|
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
|
||||||
|
return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw use raw sql as conditions, won't run it unless invoked by other methods
|
||||||
|
// db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)
|
||||||
|
func (s *DB) Raw(sql string, values ...interface{}) *DB {
|
||||||
|
return s.clone().search.Raw(true).Where(sql, values...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec execute raw sql
|
||||||
|
func (s *DB) Exec(sql string, values ...interface{}) *DB {
|
||||||
|
scope := s.clone().NewScope(nil)
|
||||||
|
generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values})
|
||||||
|
generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")")
|
||||||
|
scope.Raw(generatedSQL)
|
||||||
|
return scope.Exec().db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model specify the model you would like to run db operations
|
||||||
|
// // update all users's name to `hello`
|
||||||
|
// db.Model(&User{}).Update("name", "hello")
|
||||||
|
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||||
|
// db.Model(&user).Update("name", "hello")
|
||||||
|
func (s *DB) Model(value interface{}) *DB {
|
||||||
|
c := s.clone()
|
||||||
|
c.Value = value
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table specify the table you would like to run db operations
|
||||||
|
func (s *DB) Table(name string) *DB {
|
||||||
|
clone := s.clone()
|
||||||
|
clone.search.Table(name)
|
||||||
|
clone.Value = nil
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug start debug mode
|
||||||
|
func (s *DB) Debug() *DB {
|
||||||
|
return s.clone().LogMode(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin begin a transaction
|
||||||
|
func (s *DB) Begin() *DB {
|
||||||
|
c := s.clone()
|
||||||
|
if db, ok := c.db.(sqlDb); ok {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
c.db = interface{}(tx).(sqlCommon)
|
||||||
|
c.AddError(err)
|
||||||
|
} else {
|
||||||
|
c.AddError(ErrCantStartTransaction)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit commit a transaction
|
||||||
|
func (s *DB) Commit() *DB {
|
||||||
|
if db, ok := s.db.(sqlTx); ok {
|
||||||
|
s.AddError(db.Commit())
|
||||||
|
} else {
|
||||||
|
s.AddError(ErrInvalidTransaction)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rollback rollback a transaction
|
||||||
|
func (s *DB) Rollback() *DB {
|
||||||
|
if db, ok := s.db.(sqlTx); ok {
|
||||||
|
s.AddError(db.Rollback())
|
||||||
|
} else {
|
||||||
|
s.AddError(ErrInvalidTransaction)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecord check if value's primary key is blank
|
||||||
|
func (s *DB) NewRecord(value interface{}) bool {
|
||||||
|
return s.clone().NewScope(value).PrimaryKeyZero()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordNotFound check if returning ErrRecordNotFound error
|
||||||
|
func (s *DB) RecordNotFound() bool {
|
||||||
|
for _, err := range s.GetErrors() {
|
||||||
|
if err == ErrRecordNotFound {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTable create table for models
|
||||||
|
func (s *DB) CreateTable(models ...interface{}) *DB {
|
||||||
|
db := s.Unscoped()
|
||||||
|
for _, model := range models {
|
||||||
|
db = db.NewScope(model).createTable().db
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropTable drop table for models
|
||||||
|
func (s *DB) DropTable(values ...interface{}) *DB {
|
||||||
|
db := s.clone()
|
||||||
|
for _, value := range values {
|
||||||
|
if tableName, ok := value.(string); ok {
|
||||||
|
db = db.Table(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
db = db.NewScope(value).dropTable().db
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropTableIfExists drop table if it is exist
|
||||||
|
func (s *DB) DropTableIfExists(values ...interface{}) *DB {
|
||||||
|
db := s.clone()
|
||||||
|
for _, value := range values {
|
||||||
|
if s.HasTable(value) {
|
||||||
|
db.AddError(s.DropTable(value).Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasTable check has table or not
|
||||||
|
func (s *DB) HasTable(value interface{}) bool {
|
||||||
|
var (
|
||||||
|
scope = s.clone().NewScope(value)
|
||||||
|
tableName string
|
||||||
|
)
|
||||||
|
|
||||||
|
if name, ok := value.(string); ok {
|
||||||
|
tableName = name
|
||||||
|
} else {
|
||||||
|
tableName = scope.TableName()
|
||||||
|
}
|
||||||
|
|
||||||
|
has := scope.Dialect().HasTable(tableName)
|
||||||
|
s.AddError(scope.db.Error)
|
||||||
|
return has
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutoMigrate run auto migration for given models, will only add missing fields, won't delete/change current data
|
||||||
|
func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
||||||
|
db := s.Unscoped()
|
||||||
|
for _, value := range values {
|
||||||
|
db = db.NewScope(value).autoMigrate().db
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyColumn modify column to type
|
||||||
|
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||||
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
scope.modifyColumn(column, typ)
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropColumn drop a column
|
||||||
|
func (s *DB) DropColumn(column string) *DB {
|
||||||
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
scope.dropColumn(column)
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddIndex add index for columns with given name
|
||||||
|
func (s *DB) AddIndex(indexName string, columns ...string) *DB {
|
||||||
|
scope := s.Unscoped().NewScope(s.Value)
|
||||||
|
scope.addIndex(false, indexName, columns...)
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddUniqueIndex add unique index for columns with given name
|
||||||
|
func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB {
|
||||||
|
scope := s.Unscoped().NewScope(s.Value)
|
||||||
|
scope.addIndex(true, indexName, columns...)
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveIndex remove index with name
|
||||||
|
func (s *DB) RemoveIndex(indexName string) *DB {
|
||||||
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
scope.removeIndex(indexName)
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddForeignKey Add foreign key to the given scope, e.g:
|
||||||
|
// db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||||
|
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
||||||
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
||||||
|
return scope.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode
|
||||||
|
func (s *DB) Association(column string) *Association {
|
||||||
|
var err error
|
||||||
|
scope := s.clone().NewScope(s.Value)
|
||||||
|
|
||||||
|
if primaryField := scope.PrimaryField(); primaryField.IsBlank {
|
||||||
|
err = errors.New("primary key can't be nil")
|
||||||
|
} else {
|
||||||
|
if field, ok := scope.FieldByName(column); ok {
|
||||||
|
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
||||||
|
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||||
|
} else {
|
||||||
|
return &Association{scope: scope, column: column, field: field}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Association{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preload preload associations with given conditions
|
||||||
|
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||||
|
func (s *DB) Preload(column string, conditions ...interface{}) *DB {
|
||||||
|
return s.clone().search.Preload(column, conditions...).db
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set set setting by name, which could be used in callbacks, will clone a new db, and update its setting
|
||||||
|
func (s *DB) Set(name string, value interface{}) *DB {
|
||||||
|
return s.clone().InstantSet(name, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstantSet instant set setting, will affect current db
|
||||||
|
func (s *DB) InstantSet(name string, value interface{}) *DB {
|
||||||
|
s.values[name] = value
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get get setting by name
|
||||||
|
func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||||
|
value, ok = s.values[name]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetJoinTableHandler set a model's join table handler for a relation
|
||||||
|
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
||||||
|
scope := s.NewScope(source)
|
||||||
|
for _, field := range scope.GetModelStruct().StructFields {
|
||||||
|
if field.Name == column || field.DBName == column {
|
||||||
|
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||||
|
source := (&Scope{Value: source}).GetModelStruct().ModelType
|
||||||
|
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||||
|
handler.Setup(field.Relationship, many2many, source, destination)
|
||||||
|
field.Relationship.JoinTableHandler = handler
|
||||||
|
if table := handler.Table(s); scope.Dialect().HasTable(table) {
|
||||||
|
s.Table(table).AutoMigrate(handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddError add error to the db
|
||||||
|
func (s *DB) AddError(err error) error {
|
||||||
|
if err != nil {
|
||||||
|
if err != ErrRecordNotFound {
|
||||||
|
if s.logMode == 0 {
|
||||||
|
go s.print(fileWithLineNum(), err)
|
||||||
|
} else {
|
||||||
|
s.log(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errors := Errors(s.GetErrors())
|
||||||
|
errors.Add(err)
|
||||||
|
if len(errors) > 1 {
|
||||||
|
err = errors
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Error = err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetErrors get happened errors from the db
|
||||||
|
func (s *DB) GetErrors() (errors []error) {
|
||||||
|
if errs, ok := s.Error.(Errors); ok {
|
||||||
|
return errs
|
||||||
|
} else if s.Error != nil {
|
||||||
|
return []error{s.Error}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Private Methods For *gorm.DB
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
func (s *DB) clone() *DB {
|
||||||
|
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
|
||||||
|
|
||||||
|
for key, value := range s.values {
|
||||||
|
db.values[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.search == nil {
|
||||||
|
db.search = &search{limit: -1, offset: -1}
|
||||||
|
} else {
|
||||||
|
db.search = s.search.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
db.search.db = &db
|
||||||
|
return &db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) print(v ...interface{}) {
|
||||||
|
s.logger.(logger).Print(v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) log(v ...interface{}) {
|
||||||
|
if s != nil && s.logMode == 2 {
|
||||||
|
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
|
||||||
|
if s.logMode == 2 {
|
||||||
|
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars)
|
||||||
|
}
|
||||||
|
}
|
820
orm/main_test.go
Normal file
820
orm/main_test.go
Normal file
|
@ -0,0 +1,820 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/erikstmartin/go-testdb"
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/mssql"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||||
|
"github.com/jinzhu/gorm/dialects/postgres"
|
||||||
|
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||||
|
"github.com/jinzhu/now"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
DB *gorm.DB
|
||||||
|
t1, t2, t3, t4, t5 time.Time
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if DB, err = OpenTestConnection(); err != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
runMigration()
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenTestConnection() (db *gorm.DB, err error) {
|
||||||
|
switch os.Getenv("GORM_DIALECT") {
|
||||||
|
case "mysql":
|
||||||
|
// CREATE USER 'gorm'@'localhost' IDENTIFIED BY 'gorm';
|
||||||
|
// CREATE DATABASE gorm;
|
||||||
|
// GRANT ALL ON gorm.* TO 'gorm'@'localhost';
|
||||||
|
fmt.Println("testing mysql...")
|
||||||
|
dbhost := os.Getenv("GORM_DBADDRESS")
|
||||||
|
if dbhost != "" {
|
||||||
|
dbhost = fmt.Sprintf("tcp(%v)", dbhost)
|
||||||
|
}
|
||||||
|
db, err = gorm.Open("mysql", fmt.Sprintf("gorm:gorm@%v/gorm?charset=utf8&parseTime=True", dbhost))
|
||||||
|
case "postgres":
|
||||||
|
fmt.Println("testing postgres...")
|
||||||
|
dbhost := os.Getenv("GORM_DBHOST")
|
||||||
|
if dbhost != "" {
|
||||||
|
dbhost = fmt.Sprintf("host=%v ", dbhost)
|
||||||
|
}
|
||||||
|
db, err = gorm.Open("postgres", fmt.Sprintf("%vuser=gorm password=gorm DB.name=gorm sslmode=disable", dbhost))
|
||||||
|
case "foundation":
|
||||||
|
fmt.Println("testing foundation...")
|
||||||
|
db, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
|
||||||
|
case "mssql":
|
||||||
|
fmt.Println("testing mssql...")
|
||||||
|
db, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433")
|
||||||
|
default:
|
||||||
|
fmt.Println("testing sqlite3...")
|
||||||
|
db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
|
||||||
|
// db.SetLogger(log.New(os.Stdout, "\r\n", 0))
|
||||||
|
if os.Getenv("DEBUG") == "true" {
|
||||||
|
db.LogMode(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.DB().SetMaxIdleConns(10)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringPrimaryKey(t *testing.T) {
|
||||||
|
type UUIDStruct struct {
|
||||||
|
ID string `gorm:"primary_key"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
DB.DropTable(&UUIDStruct{})
|
||||||
|
DB.AutoMigrate(&UUIDStruct{})
|
||||||
|
|
||||||
|
data := UUIDStruct{ID: "uuid", Name: "hello"}
|
||||||
|
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello" {
|
||||||
|
t.Errorf("string primary key should not be populated")
|
||||||
|
}
|
||||||
|
|
||||||
|
data = UUIDStruct{ID: "uuid", Name: "hello world"}
|
||||||
|
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello world" {
|
||||||
|
t.Errorf("string primary key should not be populated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||||
|
var columns []string
|
||||||
|
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||||
|
t.Errorf("Should got error with invalid SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||||
|
t.Errorf("Should got error with invalid SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
|
||||||
|
t.Errorf("Should got error with invalid SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
var count1, count2 int64
|
||||||
|
DB.Model(&User{}).Count(&count1)
|
||||||
|
if count1 <= 0 {
|
||||||
|
t.Errorf("Should find some users")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
|
||||||
|
t.Errorf("Should got error with invalid SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&User{}).Count(&count2)
|
||||||
|
if count1 != count2 {
|
||||||
|
t.Errorf("No user should not be deleted by invalid SQL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetTable(t *testing.T) {
|
||||||
|
DB.Create(getPreparedUser("pluck_user1", "pluck_user"))
|
||||||
|
DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
|
||||||
|
DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
|
||||||
|
|
||||||
|
if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
|
||||||
|
t.Error("No errors should happen if set table for pluck", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
if DB.Table("users").Find(&[]User{}).Error != nil {
|
||||||
|
t.Errorf("No errors should happen if set table for find")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Table("invalid_table").Find(&users).Error == nil {
|
||||||
|
t.Errorf("Should got error when table is set to an invalid table")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Exec("drop table deleted_users;")
|
||||||
|
if DB.Table("deleted_users").CreateTable(&User{}).Error != nil {
|
||||||
|
t.Errorf("Create table with specified table")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Table("deleted_users").Save(&User{Name: "DeletedUser"})
|
||||||
|
|
||||||
|
var deletedUsers []User
|
||||||
|
DB.Table("deleted_users").Find(&deletedUsers)
|
||||||
|
if len(deletedUsers) != 1 {
|
||||||
|
t.Errorf("Query from specified table")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(getPreparedUser("normal_user", "reset_table"))
|
||||||
|
DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
|
||||||
|
var user1, user2, user3 User
|
||||||
|
DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
|
||||||
|
if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") {
|
||||||
|
t.Errorf("unset specified table with blank string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Order struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cart struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Cart) TableName() string {
|
||||||
|
return "shopping_cart"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasTable(t *testing.T) {
|
||||||
|
type Foo struct {
|
||||||
|
Id int
|
||||||
|
Stuff string
|
||||||
|
}
|
||||||
|
DB.DropTable(&Foo{})
|
||||||
|
|
||||||
|
// Table should not exist at this point, HasTable should return false
|
||||||
|
if ok := DB.HasTable("foos"); ok {
|
||||||
|
t.Errorf("Table should not exist, but does")
|
||||||
|
}
|
||||||
|
if ok := DB.HasTable(&Foo{}); ok {
|
||||||
|
t.Errorf("Table should not exist, but does")
|
||||||
|
}
|
||||||
|
|
||||||
|
// We create the table
|
||||||
|
if err := DB.CreateTable(&Foo{}).Error; err != nil {
|
||||||
|
t.Errorf("Table should be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
// And now it should exits, and HasTable should return true
|
||||||
|
if ok := DB.HasTable("foos"); !ok {
|
||||||
|
t.Errorf("Table should exist, but HasTable informs it does not")
|
||||||
|
}
|
||||||
|
if ok := DB.HasTable(&Foo{}); !ok {
|
||||||
|
t.Errorf("Table should exist, but HasTable informs it does not")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTableName(t *testing.T) {
|
||||||
|
DB := DB.Model("")
|
||||||
|
if DB.NewScope(Order{}).TableName() != "orders" {
|
||||||
|
t.Errorf("Order's table name should be orders")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope(&Order{}).TableName() != "orders" {
|
||||||
|
t.Errorf("&Order's table name should be orders")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope([]Order{}).TableName() != "orders" {
|
||||||
|
t.Errorf("[]Order's table name should be orders")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope(&[]Order{}).TableName() != "orders" {
|
||||||
|
t.Errorf("&[]Order's table name should be orders")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.SingularTable(true)
|
||||||
|
if DB.NewScope(Order{}).TableName() != "order" {
|
||||||
|
t.Errorf("Order's singular table name should be order")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope(&Order{}).TableName() != "order" {
|
||||||
|
t.Errorf("&Order's singular table name should be order")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope([]Order{}).TableName() != "order" {
|
||||||
|
t.Errorf("[]Order's singular table name should be order")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope(&[]Order{}).TableName() != "order" {
|
||||||
|
t.Errorf("&[]Order's singular table name should be order")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope(&Cart{}).TableName() != "shopping_cart" {
|
||||||
|
t.Errorf("&Cart's singular table name should be shopping_cart")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope(Cart{}).TableName() != "shopping_cart" {
|
||||||
|
t.Errorf("Cart's singular table name should be shopping_cart")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" {
|
||||||
|
t.Errorf("&[]Cart's singular table name should be shopping_cart")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.NewScope([]Cart{}).TableName() != "shopping_cart" {
|
||||||
|
t.Errorf("[]Cart's singular table name should be shopping_cart")
|
||||||
|
}
|
||||||
|
DB.SingularTable(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNullValues(t *testing.T) {
|
||||||
|
DB.DropTable(&NullValue{})
|
||||||
|
DB.AutoMigrate(&NullValue{})
|
||||||
|
|
||||||
|
if err := DB.Save(&NullValue{
|
||||||
|
Name: sql.NullString{String: "hello", Valid: true},
|
||||||
|
Gender: &sql.NullString{String: "M", Valid: true},
|
||||||
|
Age: sql.NullInt64{Int64: 18, Valid: true},
|
||||||
|
Male: sql.NullBool{Bool: true, Valid: true},
|
||||||
|
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
|
||||||
|
AddedAt: NullTime{Time: time.Now(), Valid: true},
|
||||||
|
}).Error; err != nil {
|
||||||
|
t.Errorf("Not error should raise when test null value")
|
||||||
|
}
|
||||||
|
|
||||||
|
var nv NullValue
|
||||||
|
DB.First(&nv, "name = ?", "hello")
|
||||||
|
|
||||||
|
if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
|
||||||
|
t.Errorf("Should be able to fetch null value")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&NullValue{
|
||||||
|
Name: sql.NullString{String: "hello-2", Valid: true},
|
||||||
|
Gender: &sql.NullString{String: "F", Valid: true},
|
||||||
|
Age: sql.NullInt64{Int64: 18, Valid: false},
|
||||||
|
Male: sql.NullBool{Bool: true, Valid: true},
|
||||||
|
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
|
||||||
|
AddedAt: NullTime{Time: time.Now(), Valid: false},
|
||||||
|
}).Error; err != nil {
|
||||||
|
t.Errorf("Not error should raise when test null value")
|
||||||
|
}
|
||||||
|
|
||||||
|
var nv2 NullValue
|
||||||
|
DB.First(&nv2, "name = ?", "hello-2")
|
||||||
|
if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
|
||||||
|
t.Errorf("Should be able to fetch null value")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&NullValue{
|
||||||
|
Name: sql.NullString{String: "hello-3", Valid: false},
|
||||||
|
Gender: &sql.NullString{String: "M", Valid: true},
|
||||||
|
Age: sql.NullInt64{Int64: 18, Valid: false},
|
||||||
|
Male: sql.NullBool{Bool: true, Valid: true},
|
||||||
|
Height: sql.NullFloat64{Float64: 100.11, Valid: true},
|
||||||
|
AddedAt: NullTime{Time: time.Now(), Valid: false},
|
||||||
|
}).Error; err == nil {
|
||||||
|
t.Errorf("Can't save because of name can't be null")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNullValuesWithFirstOrCreate(t *testing.T) {
|
||||||
|
var nv1 = NullValue{
|
||||||
|
Name: sql.NullString{String: "first_or_create", Valid: true},
|
||||||
|
Gender: &sql.NullString{String: "M", Valid: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
var nv2 NullValue
|
||||||
|
result := DB.Where(nv1).FirstOrCreate(&nv2)
|
||||||
|
|
||||||
|
if result.RowsAffected != 1 {
|
||||||
|
t.Errorf("RowsAffected should be 1 after create some record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
t.Errorf("Should not raise any error, but got %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" {
|
||||||
|
t.Errorf("first or create with nullvalues")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil {
|
||||||
|
t.Errorf("Should not raise any error, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nv2.Age.Int64 != 18 {
|
||||||
|
t.Errorf("should update age to 18")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransaction(t *testing.T) {
|
||||||
|
tx := DB.Begin()
|
||||||
|
u := User{Name: "transcation"}
|
||||||
|
if err := tx.Save(&u).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
|
||||||
|
t.Errorf("Should return the underlying sql.Tx")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx.Rollback()
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||||
|
t.Errorf("Should not find record after rollback")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx2 := DB.Begin()
|
||||||
|
u2 := User{Name: "transcation-2"}
|
||||||
|
if err := tx2.Save(&u2).Error; err != nil {
|
||||||
|
t.Errorf("No error should raise")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||||
|
t.Errorf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
tx2.Commit()
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
|
||||||
|
t.Errorf("Should be able to find committed record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRow(t *testing.T) {
|
||||||
|
user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "RowUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row()
|
||||||
|
var age int64
|
||||||
|
row.Scan(&age)
|
||||||
|
if age != 10 {
|
||||||
|
t.Errorf("Scan with Row")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRows(t *testing.T) {
|
||||||
|
user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "RowsUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "RowsUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Not error should happen, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
var age int64
|
||||||
|
rows.Scan(&name, &age)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("Should found two records")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScanRows(t *testing.T) {
|
||||||
|
user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Not error should happen, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []Result
|
||||||
|
for rows.Next() {
|
||||||
|
var result Result
|
||||||
|
if err := DB.ScanRows(rows, &result); err != nil {
|
||||||
|
t.Errorf("should get no error, but got %v", err)
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
|
||||||
|
t.Errorf("Should find expected results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScan(t *testing.T) {
|
||||||
|
user1 := User{Name: "ScanUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "ScanUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "ScanUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
|
||||||
|
var res result
|
||||||
|
DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res)
|
||||||
|
if res.Name != user3.Name {
|
||||||
|
t.Errorf("Scan into struct should work")
|
||||||
|
}
|
||||||
|
|
||||||
|
var doubleAgeRes result
|
||||||
|
DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes)
|
||||||
|
if doubleAgeRes.Age != res.Age*2 {
|
||||||
|
t.Errorf("Scan double age as age")
|
||||||
|
}
|
||||||
|
|
||||||
|
var ress []result
|
||||||
|
DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress)
|
||||||
|
if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
|
||||||
|
t.Errorf("Scan into struct map")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRaw(t *testing.T) {
|
||||||
|
user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
var ress []result
|
||||||
|
DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress)
|
||||||
|
if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
|
||||||
|
t.Errorf("Raw with scan")
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows()
|
||||||
|
count := 0
|
||||||
|
for rows.Next() {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("Raw with Rows should find one record with name 3")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
|
||||||
|
if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
|
||||||
|
t.Error("Raw sql to update records")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroup(t *testing.T) {
|
||||||
|
rows, err := DB.Select("name").Table("users").Group("name").Rows()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
rows.Scan(&name)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("Should not raise any error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoins(t *testing.T) {
|
||||||
|
var user = User{
|
||||||
|
Name: "joins",
|
||||||
|
CreditCard: CreditCard{Number: "411111111111"},
|
||||||
|
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
|
||||||
|
}
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
var users1 []User
|
||||||
|
DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1)
|
||||||
|
if len(users1) != 2 {
|
||||||
|
t.Errorf("should find two users using left join")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users2 []User
|
||||||
|
DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2)
|
||||||
|
if len(users2) != 1 {
|
||||||
|
t.Errorf("should find one users using left join with conditions")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users3 []User
|
||||||
|
DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3)
|
||||||
|
if len(users3) != 1 {
|
||||||
|
t.Errorf("should find one users using multiple left join conditions")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users4 []User
|
||||||
|
DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4)
|
||||||
|
if len(users4) != 0 {
|
||||||
|
t.Errorf("should find no user when searching with unexisting credit card")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users5 []User
|
||||||
|
db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id: 1}).Where(Email{Id: 1}).Not(Email{Id: 10}).First(&users5)
|
||||||
|
if db5.Error != nil {
|
||||||
|
t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinsWithSelect(t *testing.T) {
|
||||||
|
type result struct {
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
user := User{
|
||||||
|
Name: "joins_with_select",
|
||||||
|
Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
|
||||||
|
}
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
var results []result
|
||||||
|
DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results)
|
||||||
|
if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" {
|
||||||
|
t.Errorf("Should find all two emails with Join select")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHaving(t *testing.T) {
|
||||||
|
rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows()
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
var total int64
|
||||||
|
rows.Scan(&name, &total)
|
||||||
|
|
||||||
|
if name == "2" && total != 1 {
|
||||||
|
t.Errorf("Should have one user having name 2")
|
||||||
|
}
|
||||||
|
if name == "3" && total != 2 {
|
||||||
|
t.Errorf("Should have two users having name 3")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("Should not raise any error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialectHasTzSupport() bool {
|
||||||
|
// NB: mssql and FoundationDB do not support time zones.
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimeWithZone(t *testing.T) {
|
||||||
|
var format = "2006-01-02 15:04:05 -0700"
|
||||||
|
var times []time.Time
|
||||||
|
GMT8, _ := time.LoadLocation("Asia/Shanghai")
|
||||||
|
times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8))
|
||||||
|
times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC))
|
||||||
|
|
||||||
|
for index, vtime := range times {
|
||||||
|
name := "time_with_zone_" + strconv.Itoa(index)
|
||||||
|
user := User{Name: name, Birthday: &vtime}
|
||||||
|
|
||||||
|
if !DialectHasTzSupport() {
|
||||||
|
// If our driver dialect doesn't support TZ's, just use UTC for everything here.
|
||||||
|
utcBirthday := user.Birthday.UTC()
|
||||||
|
user.Birthday = &utcBirthday
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(&user)
|
||||||
|
expectedBirthday := "2013-02-18 17:51:49 +0000"
|
||||||
|
foundBirthday := user.Birthday.UTC().Format(format)
|
||||||
|
if foundBirthday != expectedBirthday {
|
||||||
|
t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
|
||||||
|
}
|
||||||
|
|
||||||
|
var findUser, findUser2, findUser3 User
|
||||||
|
DB.First(&findUser, "name = ?", name)
|
||||||
|
foundBirthday = findUser.Birthday.UTC().Format(format)
|
||||||
|
if foundBirthday != expectedBirthday {
|
||||||
|
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
|
||||||
|
t.Errorf("User should be found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() {
|
||||||
|
t.Errorf("User should not be found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHstore(t *testing.T) {
|
||||||
|
type Details struct {
|
||||||
|
Id int64
|
||||||
|
Bulk postgres.Hstore
|
||||||
|
}
|
||||||
|
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil {
|
||||||
|
fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m")
|
||||||
|
panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Exec("drop table details")
|
||||||
|
|
||||||
|
if err := DB.CreateTable(&Details{}).Error; err != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait"
|
||||||
|
bulk := map[string]*string{
|
||||||
|
"bankAccountId": &bankAccountId,
|
||||||
|
"phoneNumber": &phoneNumber,
|
||||||
|
"opinion": &opinion,
|
||||||
|
}
|
||||||
|
d := Details{Bulk: bulk}
|
||||||
|
DB.Save(&d)
|
||||||
|
|
||||||
|
var d2 Details
|
||||||
|
if err := DB.First(&d2).Error; err != nil {
|
||||||
|
t.Errorf("Got error when tried to fetch details: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k := range bulk {
|
||||||
|
if r, ok := d2.Bulk[k]; ok {
|
||||||
|
if res, _ := bulk[k]; *res != *r {
|
||||||
|
t.Errorf("Details should be equal")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("Details should be existed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetAndGet(t *testing.T) {
|
||||||
|
if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
|
||||||
|
t.Errorf("Should be able to get setting after set")
|
||||||
|
} else {
|
||||||
|
if value.(string) != "world" {
|
||||||
|
t.Errorf("Setted value should not be changed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := DB.Get("non_existing"); ok {
|
||||||
|
t.Errorf("Get non existing key should return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompatibilityMode(t *testing.T) {
|
||||||
|
DB, _ := gorm.Open("testdb", "")
|
||||||
|
testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
|
||||||
|
columns := []string{"id", "name", "age"}
|
||||||
|
result := `
|
||||||
|
1,Tim,20
|
||||||
|
2,Joe,25
|
||||||
|
3,Bob,30
|
||||||
|
`
|
||||||
|
return testdb.RowsFromCSVString(columns, result), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
DB.Find(&users)
|
||||||
|
if (users[0].Name != "Tim") || len(users) != 3 {
|
||||||
|
t.Errorf("Unexcepted result returned")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenExistingDB(t *testing.T) {
|
||||||
|
DB.Save(&User{Name: "jnfeinstein"})
|
||||||
|
dialect := os.Getenv("GORM_DIALECT")
|
||||||
|
|
||||||
|
db, err := gorm.Open(dialect, DB.DB())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Should have wrapped the existing DB connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
|
||||||
|
t.Errorf("Should have found existing record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDdlErrors(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if err = DB.Close(); err != nil {
|
||||||
|
t.Errorf("Closing DDL test db connection err=%s", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// Reopen DB connection.
|
||||||
|
if DB, err = OpenTestConnection(); err != nil {
|
||||||
|
t.Fatalf("Failed re-opening db connection: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := DB.Find(&User{}).Error; err == nil {
|
||||||
|
t.Errorf("Expected operation on closed db to produce an error, but err was nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenWithOneParameter(t *testing.T) {
|
||||||
|
db, err := gorm.Open("dialect")
|
||||||
|
if db != nil {
|
||||||
|
t.Error("Open with one parameter returned non nil for db")
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Open with one parameter returned err as nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGorm(b *testing.B) {
|
||||||
|
b.N = 2000
|
||||||
|
for x := 0; x < b.N; x++ {
|
||||||
|
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||||
|
now := time.Now()
|
||||||
|
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: &now}
|
||||||
|
// Insert
|
||||||
|
DB.Save(&email)
|
||||||
|
// Query
|
||||||
|
DB.First(&BigEmail{}, "email = ?", e)
|
||||||
|
// Update
|
||||||
|
DB.Model(&email).UpdateColumn("email", "new-"+e)
|
||||||
|
// Delete
|
||||||
|
DB.Delete(&email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRawSql(b *testing.B) {
|
||||||
|
DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable")
|
||||||
|
DB.SetMaxIdleConns(10)
|
||||||
|
insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
|
||||||
|
querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
|
||||||
|
updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
|
||||||
|
deleteSql := "DELETE FROM orders WHERE id = $1"
|
||||||
|
|
||||||
|
b.N = 2000
|
||||||
|
for x := 0; x < b.N; x++ {
|
||||||
|
var id int64
|
||||||
|
e := strconv.Itoa(x) + "benchmark@example.org"
|
||||||
|
now := time.Now()
|
||||||
|
email := BigEmail{Email: e, UserAgent: "pc", RegisteredAt: &now}
|
||||||
|
// Insert
|
||||||
|
DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
|
||||||
|
// Query
|
||||||
|
rows, _ := DB.Query(querySql, email.Email)
|
||||||
|
rows.Close()
|
||||||
|
// Update
|
||||||
|
DB.Exec(updateSql, "new-"+e, time.Now(), id)
|
||||||
|
// Delete
|
||||||
|
DB.Exec(deleteSql, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTime(str string) *time.Time {
|
||||||
|
t := now.MustParse(str)
|
||||||
|
return &t
|
||||||
|
}
|
438
orm/migration_test.go
Normal file
438
orm/migration_test.go
Normal file
|
@ -0,0 +1,438 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Id int64
|
||||||
|
Age int64
|
||||||
|
UserNum Num
|
||||||
|
Name string `sql:"size:255"`
|
||||||
|
Email string
|
||||||
|
Birthday *time.Time // Time
|
||||||
|
CreatedAt time.Time // CreatedAt: Time of record is created, will be insert automatically
|
||||||
|
UpdatedAt time.Time // UpdatedAt: Time of record is updated, will be updated automatically
|
||||||
|
Emails []Email // Embedded structs
|
||||||
|
BillingAddress Address // Embedded struct
|
||||||
|
BillingAddressID sql.NullInt64 // Embedded struct's foreign key
|
||||||
|
ShippingAddress Address // Embedded struct
|
||||||
|
ShippingAddressId int64 // Embedded struct's foreign key
|
||||||
|
CreditCard CreditCard
|
||||||
|
Latitude float64
|
||||||
|
Languages []Language `gorm:"many2many:user_languages;"`
|
||||||
|
CompanyID *int
|
||||||
|
Company Company
|
||||||
|
Role
|
||||||
|
PasswordHash []byte
|
||||||
|
Sequence uint `gorm:"AUTO_INCREMENT"`
|
||||||
|
IgnoreMe int64 `sql:"-"`
|
||||||
|
IgnoreStringSlice []string `sql:"-"`
|
||||||
|
Ignored struct{ Name string } `sql:"-"`
|
||||||
|
IgnoredPointer *User `sql:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NotSoLongTableName struct {
|
||||||
|
Id int64
|
||||||
|
ReallyLongThingID int64
|
||||||
|
ReallyLongThing ReallyLongTableNameToTestMySQLNameLengthLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReallyLongTableNameToTestMySQLNameLengthLimit struct {
|
||||||
|
Id int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReallyLongThingThatReferencesShort struct {
|
||||||
|
Id int64
|
||||||
|
ShortID int64
|
||||||
|
Short Short
|
||||||
|
}
|
||||||
|
|
||||||
|
type Short struct {
|
||||||
|
Id int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreditCard struct {
|
||||||
|
ID int8
|
||||||
|
Number string
|
||||||
|
UserId sql.NullInt64
|
||||||
|
CreatedAt time.Time `sql:"not null"`
|
||||||
|
UpdatedAt time.Time
|
||||||
|
DeletedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Email struct {
|
||||||
|
Id int16
|
||||||
|
UserId int
|
||||||
|
Email string `sql:"type:varchar(100);"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Address struct {
|
||||||
|
ID int
|
||||||
|
Address1 string
|
||||||
|
Address2 string
|
||||||
|
Post string
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
DeletedAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type Language struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Users []User `gorm:"many2many:user_languages;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Id int64
|
||||||
|
Code string
|
||||||
|
Price int64
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
AfterFindCallTimes int64
|
||||||
|
BeforeCreateCallTimes int64
|
||||||
|
AfterCreateCallTimes int64
|
||||||
|
BeforeUpdateCallTimes int64
|
||||||
|
AfterUpdateCallTimes int64
|
||||||
|
BeforeSaveCallTimes int64
|
||||||
|
AfterSaveCallTimes int64
|
||||||
|
BeforeDeleteCallTimes int64
|
||||||
|
AfterDeleteCallTimes int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Company struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
Owner *User `sql:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Role struct {
|
||||||
|
Name string `gorm:"size:256"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role *Role) Scan(value interface{}) error {
|
||||||
|
if b, ok := value.([]uint8); ok {
|
||||||
|
role.Name = string(b)
|
||||||
|
} else {
|
||||||
|
role.Name = value.(string)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role Role) Value() (driver.Value, error) {
|
||||||
|
return role.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (role Role) IsAdmin() bool {
|
||||||
|
return role.Name == "admin"
|
||||||
|
}
|
||||||
|
|
||||||
|
type Num int64
|
||||||
|
|
||||||
|
func (i *Num) Scan(src interface{}) error {
|
||||||
|
switch s := src.(type) {
|
||||||
|
case []byte:
|
||||||
|
case int64:
|
||||||
|
*i = Num(s)
|
||||||
|
default:
|
||||||
|
return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Animal struct {
|
||||||
|
Counter uint64 `gorm:"primary_key:yes"`
|
||||||
|
Name string `sql:"DEFAULT:'galeone'"`
|
||||||
|
From string //test reserved sql keyword as field name
|
||||||
|
Age time.Time `sql:"DEFAULT:current_timestamp"`
|
||||||
|
unexported string // unexported value
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinTable struct {
|
||||||
|
From uint64
|
||||||
|
To uint64
|
||||||
|
Time time.Time `sql:"default: null"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
Id int64
|
||||||
|
CategoryId sql.NullInt64
|
||||||
|
MainCategoryId int64
|
||||||
|
Title string
|
||||||
|
Body string
|
||||||
|
Comments []*Comment
|
||||||
|
Category Category
|
||||||
|
MainCategory Category
|
||||||
|
}
|
||||||
|
|
||||||
|
type Category struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
|
||||||
|
Categories []Category
|
||||||
|
CategoryID *uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type Comment struct {
|
||||||
|
gorm.Model
|
||||||
|
PostId int64
|
||||||
|
Content string
|
||||||
|
Post Post
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scanner
|
||||||
|
type NullValue struct {
|
||||||
|
Id int64
|
||||||
|
Name sql.NullString `sql:"not null"`
|
||||||
|
Gender *sql.NullString `sql:"not null"`
|
||||||
|
Age sql.NullInt64
|
||||||
|
Male sql.NullBool
|
||||||
|
Height sql.NullFloat64
|
||||||
|
AddedAt NullTime
|
||||||
|
}
|
||||||
|
|
||||||
|
type NullTime struct {
|
||||||
|
Time time.Time
|
||||||
|
Valid bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nt *NullTime) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
nt.Valid = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
nt.Time, nt.Valid = value.(time.Time), true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nt NullTime) Value() (driver.Value, error) {
|
||||||
|
if !nt.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nt.Time, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPreparedUser(name string, role string) *User {
|
||||||
|
var company Company
|
||||||
|
DB.Where(Company{Name: role}).FirstOrCreate(&company)
|
||||||
|
|
||||||
|
return &User{
|
||||||
|
Name: name,
|
||||||
|
Age: 20,
|
||||||
|
Role: Role{role},
|
||||||
|
BillingAddress: Address{Address1: fmt.Sprintf("Billing Address %v", name)},
|
||||||
|
ShippingAddress: Address{Address1: fmt.Sprintf("Shipping Address %v", name)},
|
||||||
|
CreditCard: CreditCard{Number: fmt.Sprintf("123456%v", name)},
|
||||||
|
Emails: []Email{
|
||||||
|
{Email: fmt.Sprintf("user_%v@example1.com", name)}, {Email: fmt.Sprintf("user_%v@example2.com", name)},
|
||||||
|
},
|
||||||
|
Company: company,
|
||||||
|
Languages: []Language{
|
||||||
|
{Name: fmt.Sprintf("lang_1_%v", name)},
|
||||||
|
{Name: fmt.Sprintf("lang_2_%v", name)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runMigration() {
|
||||||
|
if err := DB.DropTableIfExists(&User{}).Error; err != nil {
|
||||||
|
fmt.Printf("Got error when try to delete table users, %+v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range []string{"animals", "user_languages"} {
|
||||||
|
DB.Exec(fmt.Sprintf("drop table %v;", table))
|
||||||
|
}
|
||||||
|
|
||||||
|
values := []interface{}{&Short{}, &ReallyLongThingThatReferencesShort{}, &ReallyLongTableNameToTestMySQLNameLengthLimit{}, &NotSoLongTableName{}, &Product{}, &Email{}, &Address{}, &CreditCard{}, &Company{}, &Role{}, &Language{}, &HNPost{}, &EngadgetPost{}, &Animal{}, &User{}, &JoinTable{}, &Post{}, &Category{}, &Comment{}, &Cat{}, &Dog{}, &Hamster{}, &Toy{}, &ElementWithIgnoredField{}}
|
||||||
|
for _, value := range values {
|
||||||
|
DB.DropTable(value)
|
||||||
|
}
|
||||||
|
if err := DB.AutoMigrate(values...).Error; err != nil {
|
||||||
|
panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIndexes(t *testing.T) {
|
||||||
|
if err := DB.Model(&Email{}).AddIndex("idx_email_email", "email").Error; err != nil {
|
||||||
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
scope := DB.NewScope(&Email{})
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
|
t.Errorf("Email should have index idx_email_email")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email").Error; err != nil {
|
||||||
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email") {
|
||||||
|
t.Errorf("Email's index idx_email_email should be deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&Email{}).AddIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
|
||||||
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
|
||||||
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&Email{}).AddUniqueIndex("idx_email_email_and_user_id", "user_id", "email").Error; err != nil {
|
||||||
|
t.Errorf("Got error when tried to create index: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
|
t.Errorf("Email should have index idx_email_email_and_user_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.comiii"}, {Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error == nil {
|
||||||
|
t.Errorf("Should get to create duplicate record when having unique index")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user = User{Name: "sample_user"}
|
||||||
|
DB.Save(&user)
|
||||||
|
if DB.Model(&user).Association("Emails").Append(Email{Email: "not-1duplicated@gmail.com"}, Email{Email: "not-duplicated2@gmail.com"}).Error != nil {
|
||||||
|
t.Errorf("Should get no error when append two emails for user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&user).Association("Emails").Append(Email{Email: "duplicated@gmail.com"}, Email{Email: "duplicated@gmail.com"}).Error == nil {
|
||||||
|
t.Errorf("Should get no duplicated email error when insert duplicated emails for a user")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Model(&Email{}).RemoveIndex("idx_email_email_and_user_id").Error; err != nil {
|
||||||
|
t.Errorf("Got error when tried to remove index: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if scope.Dialect().HasIndex(scope.TableName(), "idx_email_email_and_user_id") {
|
||||||
|
t.Errorf("Email's index idx_email_email_and_user_id should be deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Save(&User{Name: "unique_indexes", Emails: []Email{{Email: "user1@example.com"}, {Email: "user1@example.com"}}}).Error != nil {
|
||||||
|
t.Errorf("Should be able to create duplicated emails after remove unique index")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BigEmail struct {
|
||||||
|
Id int64
|
||||||
|
UserId int64
|
||||||
|
Email string `sql:"index:idx_email_agent"`
|
||||||
|
UserAgent string `sql:"index:idx_email_agent"`
|
||||||
|
RegisteredAt *time.Time `sql:"unique_index"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEmail) TableName() string {
|
||||||
|
return "emails"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAutoMigration(t *testing.T) {
|
||||||
|
DB.AutoMigrate(&Address{})
|
||||||
|
if err := DB.Table("emails").AutoMigrate(&BigEmail{}).Error; err != nil {
|
||||||
|
t.Errorf("Auto Migrate should not raise any error")
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
DB.Save(&BigEmail{Email: "jinzhu@example.org", UserAgent: "pc", RegisteredAt: &now})
|
||||||
|
|
||||||
|
scope := DB.NewScope(&BigEmail{})
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_email_agent") {
|
||||||
|
t.Errorf("Failed to create index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_emails_registered_at") {
|
||||||
|
t.Errorf("Failed to create index")
|
||||||
|
}
|
||||||
|
|
||||||
|
var bigemail BigEmail
|
||||||
|
DB.First(&bigemail, "user_agent = ?", "pc")
|
||||||
|
if bigemail.Email != "jinzhu@example.org" || bigemail.UserAgent != "pc" || bigemail.RegisteredAt.IsZero() {
|
||||||
|
t.Error("Big Emails should be saved and fetched correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MultipleIndexes struct {
|
||||||
|
ID int64
|
||||||
|
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`
|
||||||
|
Name string `sql:"unique_index:uix_multipleindexes_user_name"`
|
||||||
|
Email string `sql:"unique_index:,uix_multipleindexes_user_email"`
|
||||||
|
Other string `sql:"index:,idx_multipleindexes_user_other"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleIndexes(t *testing.T) {
|
||||||
|
if err := DB.DropTableIfExists(&MultipleIndexes{}).Error; err != nil {
|
||||||
|
fmt.Printf("Got error when try to delete table multiple_indexes, %+v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.AutoMigrate(&MultipleIndexes{})
|
||||||
|
if err := DB.AutoMigrate(&BigEmail{}).Error; err != nil {
|
||||||
|
t.Errorf("Auto Migrate should not raise any error")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(&MultipleIndexes{UserID: 1, Name: "jinzhu", Email: "jinzhu@example.org", Other: "foo"})
|
||||||
|
|
||||||
|
scope := DB.NewScope(&MultipleIndexes{})
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_name") {
|
||||||
|
t.Errorf("Failed to create index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multipleindexes_user_email") {
|
||||||
|
t.Errorf("Failed to create index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "uix_multiple_indexes_email") {
|
||||||
|
t.Errorf("Failed to create index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multipleindexes_user_other") {
|
||||||
|
t.Errorf("Failed to create index")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !scope.Dialect().HasIndex(scope.TableName(), "idx_multiple_indexes_other") {
|
||||||
|
t.Errorf("Failed to create index")
|
||||||
|
}
|
||||||
|
|
||||||
|
var mutipleIndexes MultipleIndexes
|
||||||
|
DB.First(&mutipleIndexes, "name = ?", "jinzhu")
|
||||||
|
if mutipleIndexes.Email != "jinzhu@example.org" || mutipleIndexes.Name != "jinzhu" {
|
||||||
|
t.Error("MutipleIndexes should be saved and fetched correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check unique constraints
|
||||||
|
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
|
||||||
|
t.Error("MultipleIndexes unique index failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&MultipleIndexes{UserID: 1, Name: "name1", Email: "foo@example.org", Other: "foo"}).Error; err != nil {
|
||||||
|
t.Error("MultipleIndexes unique index failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "jinzhu@example.org", Other: "foo"}).Error; err == nil {
|
||||||
|
t.Error("MultipleIndexes unique index failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&MultipleIndexes{UserID: 2, Name: "name1", Email: "foo2@example.org", Other: "foo"}).Error; err != nil {
|
||||||
|
t.Error("MultipleIndexes unique index failed")
|
||||||
|
}
|
||||||
|
}
|
14
orm/model.go
Normal file
14
orm/model.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// Model base model definition, including fields `ID`, `CreatedAt`, `UpdatedAt`, `DeletedAt`, which could be embedded in your models
|
||||||
|
// type User struct {
|
||||||
|
// gorm.Model
|
||||||
|
// }
|
||||||
|
type Model struct {
|
||||||
|
ID uint `gorm:"primary_key"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
DeletedAt *time.Time `sql:"index"`
|
||||||
|
}
|
575
orm/model_struct.go
Normal file
575
orm/model_struct.go
Normal file
|
@ -0,0 +1,575 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"go/ast"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/inflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultTableNameHandler default table name handler
|
||||||
|
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||||
|
return defaultTableName
|
||||||
|
}
|
||||||
|
|
||||||
|
type safeModelStructsMap struct {
|
||||||
|
m map[reflect.Type]*ModelStruct
|
||||||
|
l *sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) {
|
||||||
|
s.l.Lock()
|
||||||
|
defer s.l.Unlock()
|
||||||
|
s.m[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct {
|
||||||
|
s.l.RLock()
|
||||||
|
defer s.l.RUnlock()
|
||||||
|
return s.m[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModelStructsMap() *safeModelStructsMap {
|
||||||
|
return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)}
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelStructsMap = newModelStructsMap()
|
||||||
|
|
||||||
|
// ModelStruct model definition
|
||||||
|
type ModelStruct struct {
|
||||||
|
PrimaryFields []*StructField
|
||||||
|
StructFields []*StructField
|
||||||
|
ModelType reflect.Type
|
||||||
|
defaultTableName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName get model's table name
|
||||||
|
func (s *ModelStruct) TableName(db *DB) string {
|
||||||
|
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StructField model field's struct definition
|
||||||
|
type StructField struct {
|
||||||
|
DBName string
|
||||||
|
Name string
|
||||||
|
Names []string
|
||||||
|
IsPrimaryKey bool
|
||||||
|
IsNormal bool
|
||||||
|
IsIgnored bool
|
||||||
|
IsScanner bool
|
||||||
|
HasDefaultValue bool
|
||||||
|
Tag reflect.StructTag
|
||||||
|
TagSettings map[string]string
|
||||||
|
Struct reflect.StructField
|
||||||
|
IsForeignKey bool
|
||||||
|
Relationship *Relationship
|
||||||
|
}
|
||||||
|
|
||||||
|
func (structField *StructField) clone() *StructField {
|
||||||
|
clone := &StructField{
|
||||||
|
DBName: structField.DBName,
|
||||||
|
Name: structField.Name,
|
||||||
|
Names: structField.Names,
|
||||||
|
IsPrimaryKey: structField.IsPrimaryKey,
|
||||||
|
IsNormal: structField.IsNormal,
|
||||||
|
IsIgnored: structField.IsIgnored,
|
||||||
|
IsScanner: structField.IsScanner,
|
||||||
|
HasDefaultValue: structField.HasDefaultValue,
|
||||||
|
Tag: structField.Tag,
|
||||||
|
TagSettings: map[string]string{},
|
||||||
|
Struct: structField.Struct,
|
||||||
|
IsForeignKey: structField.IsForeignKey,
|
||||||
|
Relationship: structField.Relationship,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range structField.TagSettings {
|
||||||
|
clone.TagSettings[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
|
||||||
|
// Relationship described the relationship between models
|
||||||
|
type Relationship struct {
|
||||||
|
Kind string
|
||||||
|
PolymorphicType string
|
||||||
|
PolymorphicDBName string
|
||||||
|
PolymorphicValue string
|
||||||
|
ForeignFieldNames []string
|
||||||
|
ForeignDBNames []string
|
||||||
|
AssociationForeignFieldNames []string
|
||||||
|
AssociationForeignDBNames []string
|
||||||
|
JoinTableHandler JoinTableHandlerInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func getForeignField(column string, fields []*StructField) *StructField {
|
||||||
|
for _, field := range fields {
|
||||||
|
if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelStruct get value's model struct, relationships based on struct and tag definition
|
||||||
|
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
|
var modelStruct ModelStruct
|
||||||
|
// Scope value can't be nil
|
||||||
|
if scope.Value == nil {
|
||||||
|
return &modelStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectType := reflect.ValueOf(scope.Value).Type()
|
||||||
|
for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
|
||||||
|
reflectType = reflectType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scope value need to be a struct
|
||||||
|
if reflectType.Kind() != reflect.Struct {
|
||||||
|
return &modelStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Cached model struct
|
||||||
|
if value := modelStructsMap.Get(reflectType); value != nil {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
modelStruct.ModelType = reflectType
|
||||||
|
|
||||||
|
// Set default table name
|
||||||
|
if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok {
|
||||||
|
modelStruct.defaultTableName = tabler.TableName()
|
||||||
|
} else {
|
||||||
|
tableName := ToDBName(reflectType.Name())
|
||||||
|
if scope.db == nil || !scope.db.parent.singularTable {
|
||||||
|
tableName = inflection.Plural(tableName)
|
||||||
|
}
|
||||||
|
modelStruct.defaultTableName = tableName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all fields
|
||||||
|
for i := 0; i < reflectType.NumField(); i++ {
|
||||||
|
if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||||
|
field := &StructField{
|
||||||
|
Struct: fieldStruct,
|
||||||
|
Name: fieldStruct.Name,
|
||||||
|
Names: []string{fieldStruct.Name},
|
||||||
|
Tag: fieldStruct.Tag,
|
||||||
|
TagSettings: parseTagSetting(fieldStruct.Tag),
|
||||||
|
}
|
||||||
|
|
||||||
|
// is ignored field
|
||||||
|
if _, ok := field.TagSettings["-"]; ok {
|
||||||
|
field.IsIgnored = true
|
||||||
|
} else {
|
||||||
|
if _, ok := field.TagSettings["PRIMARY_KEY"]; ok {
|
||||||
|
field.IsPrimaryKey = true
|
||||||
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := field.TagSettings["DEFAULT"]; ok {
|
||||||
|
field.HasDefaultValue = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok && !field.IsPrimaryKey {
|
||||||
|
field.HasDefaultValue = true
|
||||||
|
}
|
||||||
|
|
||||||
|
indirectType := fieldStruct.Type
|
||||||
|
for indirectType.Kind() == reflect.Ptr {
|
||||||
|
indirectType = indirectType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldValue := reflect.New(indirectType).Interface()
|
||||||
|
if _, isScanner := fieldValue.(sql.Scanner); isScanner {
|
||||||
|
// is scanner
|
||||||
|
field.IsScanner, field.IsNormal = true, true
|
||||||
|
if indirectType.Kind() == reflect.Struct {
|
||||||
|
for i := 0; i < indirectType.NumField(); i++ {
|
||||||
|
for key, value := range parseTagSetting(indirectType.Field(i).Tag) {
|
||||||
|
field.TagSettings[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if _, isTime := fieldValue.(*time.Time); isTime {
|
||||||
|
// is time
|
||||||
|
field.IsNormal = true
|
||||||
|
} else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
||||||
|
// is embedded struct
|
||||||
|
for _, subField := range scope.New(fieldValue).GetStructFields() {
|
||||||
|
subField = subField.clone()
|
||||||
|
subField.Names = append([]string{fieldStruct.Name}, subField.Names...)
|
||||||
|
if prefix, ok := field.TagSettings["EMBEDDED_PREFIX"]; ok {
|
||||||
|
subField.DBName = prefix + subField.DBName
|
||||||
|
}
|
||||||
|
if subField.IsPrimaryKey {
|
||||||
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField)
|
||||||
|
}
|
||||||
|
modelStruct.StructFields = append(modelStruct.StructFields, subField)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
// build relationships
|
||||||
|
switch indirectType.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
defer func(field *StructField) {
|
||||||
|
var (
|
||||||
|
relationship = &Relationship{}
|
||||||
|
toScope = scope.New(reflect.New(field.Struct.Type).Interface())
|
||||||
|
foreignKeys []string
|
||||||
|
associationForeignKeys []string
|
||||||
|
elemType = field.Struct.Type
|
||||||
|
)
|
||||||
|
|
||||||
|
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
||||||
|
foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||||
|
associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||||
|
relationship.Kind = "many_to_many"
|
||||||
|
|
||||||
|
// if no foreign keys defined with tag
|
||||||
|
if len(foreignKeys) == 0 {
|
||||||
|
for _, field := range modelStruct.PrimaryFields {
|
||||||
|
foreignKeys = append(foreignKeys, field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, foreignKey := range foreignKeys {
|
||||||
|
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
|
// source foreign keys (db names)
|
||||||
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName)
|
||||||
|
// join table foreign keys for source
|
||||||
|
joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName
|
||||||
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no association foreign keys defined with tag
|
||||||
|
if len(associationForeignKeys) == 0 {
|
||||||
|
for _, field := range toScope.PrimaryFields() {
|
||||||
|
associationForeignKeys = append(associationForeignKeys, field.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range associationForeignKeys {
|
||||||
|
if field, ok := toScope.FieldByName(name); ok {
|
||||||
|
// association foreign keys (db names)
|
||||||
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
||||||
|
// join table foreign keys for association
|
||||||
|
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
||||||
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
joinTableHandler := JoinTableHandler{}
|
||||||
|
joinTableHandler.Setup(relationship, many2many, reflectType, elemType)
|
||||||
|
relationship.JoinTableHandler = &joinTableHandler
|
||||||
|
field.Relationship = relationship
|
||||||
|
} else {
|
||||||
|
// User has many comments, associationType is User, comment use UserID as foreign key
|
||||||
|
var associationType = reflectType.Name()
|
||||||
|
var toFields = toScope.GetStructFields()
|
||||||
|
relationship.Kind = "has_many"
|
||||||
|
|
||||||
|
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||||
|
// Dog has many toys, tag polymorphic is Owner, then associationType is Owner
|
||||||
|
// Toy use OwnerID, OwnerType ('dogs') as foreign key
|
||||||
|
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
||||||
|
associationType = polymorphic
|
||||||
|
relationship.PolymorphicType = polymorphicType.Name
|
||||||
|
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||||
|
// if Dog has multiple set of toys set name of the set (instead of default 'dogs')
|
||||||
|
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
|
||||||
|
relationship.PolymorphicValue = value
|
||||||
|
} else {
|
||||||
|
relationship.PolymorphicValue = scope.TableName()
|
||||||
|
}
|
||||||
|
polymorphicType.IsForeignKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if no foreign keys defined with tag
|
||||||
|
if len(foreignKeys) == 0 {
|
||||||
|
// if no association foreign keys defined with tag
|
||||||
|
if len(associationForeignKeys) == 0 {
|
||||||
|
for _, field := range modelStruct.PrimaryFields {
|
||||||
|
foreignKeys = append(foreignKeys, associationType+field.Name)
|
||||||
|
associationForeignKeys = append(associationForeignKeys, field.Name)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// generate foreign keys from defined association foreign keys
|
||||||
|
for _, scopeFieldName := range associationForeignKeys {
|
||||||
|
if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil {
|
||||||
|
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
|
||||||
|
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// generate association foreign keys from foreign keys
|
||||||
|
if len(associationForeignKeys) == 0 {
|
||||||
|
for _, foreignKey := range foreignKeys {
|
||||||
|
if strings.HasPrefix(foreignKey, associationType) {
|
||||||
|
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
||||||
|
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||||
|
associationForeignKeys = []string{scope.PrimaryKey()}
|
||||||
|
}
|
||||||
|
} else if len(foreignKeys) != len(associationForeignKeys) {
|
||||||
|
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, foreignKey := range foreignKeys {
|
||||||
|
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
||||||
|
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
|
||||||
|
// source foreign keys
|
||||||
|
foreignField.IsForeignKey = true
|
||||||
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
||||||
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
||||||
|
|
||||||
|
// association foreign keys
|
||||||
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||||
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
|
field.Relationship = relationship
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
field.IsNormal = true
|
||||||
|
}
|
||||||
|
}(field)
|
||||||
|
case reflect.Struct:
|
||||||
|
defer func(field *StructField) {
|
||||||
|
var (
|
||||||
|
// user has one profile, associationType is User, profile use UserID as foreign key
|
||||||
|
// user belongs to profile, associationType is Profile, user use ProfileID as foreign key
|
||||||
|
associationType = reflectType.Name()
|
||||||
|
relationship = &Relationship{}
|
||||||
|
toScope = scope.New(reflect.New(field.Struct.Type).Interface())
|
||||||
|
toFields = toScope.GetStructFields()
|
||||||
|
tagForeignKeys []string
|
||||||
|
tagAssociationForeignKeys []string
|
||||||
|
)
|
||||||
|
|
||||||
|
if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" {
|
||||||
|
tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||||
|
tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||||
|
// Cat has one toy, tag polymorphic is Owner, then associationType is Owner
|
||||||
|
// Toy use OwnerID, OwnerType ('cats') as foreign key
|
||||||
|
if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil {
|
||||||
|
associationType = polymorphic
|
||||||
|
relationship.PolymorphicType = polymorphicType.Name
|
||||||
|
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||||
|
// if Cat has several different types of toys set name for each (instead of default 'cats')
|
||||||
|
if value, ok := field.TagSettings["POLYMORPHIC_VALUE"]; ok {
|
||||||
|
relationship.PolymorphicValue = value
|
||||||
|
} else {
|
||||||
|
relationship.PolymorphicValue = scope.TableName()
|
||||||
|
}
|
||||||
|
polymorphicType.IsForeignKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has One
|
||||||
|
{
|
||||||
|
var foreignKeys = tagForeignKeys
|
||||||
|
var associationForeignKeys = tagAssociationForeignKeys
|
||||||
|
// if no foreign keys defined with tag
|
||||||
|
if len(foreignKeys) == 0 {
|
||||||
|
// if no association foreign keys defined with tag
|
||||||
|
if len(associationForeignKeys) == 0 {
|
||||||
|
for _, primaryField := range modelStruct.PrimaryFields {
|
||||||
|
foreignKeys = append(foreignKeys, associationType+primaryField.Name)
|
||||||
|
associationForeignKeys = append(associationForeignKeys, primaryField.Name)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// generate foreign keys form association foreign keys
|
||||||
|
for _, associationForeignKey := range tagAssociationForeignKeys {
|
||||||
|
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
|
foreignKeys = append(foreignKeys, associationType+foreignField.Name)
|
||||||
|
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// generate association foreign keys from foreign keys
|
||||||
|
if len(associationForeignKeys) == 0 {
|
||||||
|
for _, foreignKey := range foreignKeys {
|
||||||
|
if strings.HasPrefix(foreignKey, associationType) {
|
||||||
|
associationForeignKey := strings.TrimPrefix(foreignKey, associationType)
|
||||||
|
if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||||
|
associationForeignKeys = []string{scope.PrimaryKey()}
|
||||||
|
}
|
||||||
|
} else if len(foreignKeys) != len(associationForeignKeys) {
|
||||||
|
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, foreignKey := range foreignKeys {
|
||||||
|
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
|
||||||
|
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
|
||||||
|
foreignField.IsForeignKey = true
|
||||||
|
// source foreign keys
|
||||||
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
|
||||||
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
|
||||||
|
|
||||||
|
// association foreign keys
|
||||||
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||||
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
|
relationship.Kind = "has_one"
|
||||||
|
field.Relationship = relationship
|
||||||
|
} else {
|
||||||
|
var foreignKeys = tagForeignKeys
|
||||||
|
var associationForeignKeys = tagAssociationForeignKeys
|
||||||
|
|
||||||
|
if len(foreignKeys) == 0 {
|
||||||
|
// generate foreign keys & association foreign keys
|
||||||
|
if len(associationForeignKeys) == 0 {
|
||||||
|
for _, primaryField := range toScope.PrimaryFields() {
|
||||||
|
foreignKeys = append(foreignKeys, field.Name+primaryField.Name)
|
||||||
|
associationForeignKeys = append(associationForeignKeys, primaryField.Name)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// generate foreign keys with association foreign keys
|
||||||
|
for _, associationForeignKey := range associationForeignKeys {
|
||||||
|
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
|
||||||
|
foreignKeys = append(foreignKeys, field.Name+foreignField.Name)
|
||||||
|
associationForeignKeys = append(associationForeignKeys, foreignField.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// generate foreign keys & association foreign keys
|
||||||
|
if len(associationForeignKeys) == 0 {
|
||||||
|
for _, foreignKey := range foreignKeys {
|
||||||
|
if strings.HasPrefix(foreignKey, field.Name) {
|
||||||
|
associationForeignKey := strings.TrimPrefix(foreignKey, field.Name)
|
||||||
|
if foreignField := getForeignField(associationForeignKey, toFields); foreignField != nil {
|
||||||
|
associationForeignKeys = append(associationForeignKeys, associationForeignKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 {
|
||||||
|
associationForeignKeys = []string{toScope.PrimaryKey()}
|
||||||
|
}
|
||||||
|
} else if len(foreignKeys) != len(associationForeignKeys) {
|
||||||
|
scope.Err(errors.New("invalid foreign keys, should have same length"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, foreignKey := range foreignKeys {
|
||||||
|
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
|
||||||
|
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
|
||||||
|
foreignField.IsForeignKey = true
|
||||||
|
|
||||||
|
// association foreign keys
|
||||||
|
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
|
||||||
|
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
|
||||||
|
|
||||||
|
// source foreign keys
|
||||||
|
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||||
|
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(relationship.ForeignFieldNames) != 0 {
|
||||||
|
relationship.Kind = "belongs_to"
|
||||||
|
field.Relationship = relationship
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(field)
|
||||||
|
default:
|
||||||
|
field.IsNormal = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Even it is ignored, also possible to decode db value into the field
|
||||||
|
if value, ok := field.TagSettings["COLUMN"]; ok {
|
||||||
|
field.DBName = value
|
||||||
|
} else {
|
||||||
|
field.DBName = ToDBName(fieldStruct.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelStruct.PrimaryFields) == 0 {
|
||||||
|
if field := getForeignField("id", modelStruct.StructFields); field != nil {
|
||||||
|
field.IsPrimaryKey = true
|
||||||
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelStructsMap.Set(reflectType, &modelStruct)
|
||||||
|
|
||||||
|
return &modelStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStructFields get model's field structs
|
||||||
|
func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||||
|
return scope.GetModelStruct().StructFields
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
||||||
|
setting := map[string]string{}
|
||||||
|
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
||||||
|
tags := strings.Split(str, ";")
|
||||||
|
for _, value := range tags {
|
||||||
|
v := strings.Split(value, ":")
|
||||||
|
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||||
|
if len(v) >= 2 {
|
||||||
|
setting[k] = strings.Join(v[1:], ":")
|
||||||
|
} else {
|
||||||
|
setting[k] = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return setting
|
||||||
|
}
|
381
orm/multi_primary_keys_test.go
Normal file
381
orm/multi_primary_keys_test.go
Normal file
|
@ -0,0 +1,381 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Blog struct {
|
||||||
|
ID uint `gorm:"primary_key"`
|
||||||
|
Locale string `gorm:"primary_key"`
|
||||||
|
Subject string
|
||||||
|
Body string
|
||||||
|
Tags []Tag `gorm:"many2many:blog_tags;"`
|
||||||
|
SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;AssociationForeignKey:id"`
|
||||||
|
LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;AssociationForeignKey:id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tag struct {
|
||||||
|
ID uint `gorm:"primary_key"`
|
||||||
|
Locale string `gorm:"primary_key"`
|
||||||
|
Value string
|
||||||
|
Blogs []*Blog `gorm:"many2many:blogs_tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareTags(tags []Tag, contents []string) bool {
|
||||||
|
var tagContents []string
|
||||||
|
for _, tag := range tags {
|
||||||
|
tagContents = append(tagContents, tag.Value)
|
||||||
|
}
|
||||||
|
sort.Strings(tagContents)
|
||||||
|
sort.Strings(contents)
|
||||||
|
return reflect.DeepEqual(tagContents, contents)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
|
||||||
|
DB.DropTable(&Blog{}, &Tag{})
|
||||||
|
DB.DropTable("blog_tags")
|
||||||
|
DB.CreateTable(&Blog{}, &Tag{})
|
||||||
|
blog := Blog{
|
||||||
|
Locale: "ZH",
|
||||||
|
Subject: "subject",
|
||||||
|
Body: "body",
|
||||||
|
Tags: []Tag{
|
||||||
|
{Locale: "ZH", Value: "tag1"},
|
||||||
|
{Locale: "ZH", Value: "tag2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(&blog)
|
||||||
|
if !compareTags(blog.Tags, []string{"tag1", "tag2"}) {
|
||||||
|
t.Errorf("Blog should has two tags")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
|
||||||
|
DB.Model(&blog).Association("Tags").Append([]*Tag{tag3})
|
||||||
|
if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Blog should has three tags after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("Tags").Count() != 3 {
|
||||||
|
t.Errorf("Blog should has three tags after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tags []Tag
|
||||||
|
DB.Model(&blog).Related(&tags, "Tags")
|
||||||
|
if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Should find 3 tags with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
var blog1 Blog
|
||||||
|
DB.Preload("Tags").Find(&blog1)
|
||||||
|
if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Preload many2many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
|
||||||
|
var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
|
||||||
|
DB.Model(&blog).Association("Tags").Replace(tag5, tag6)
|
||||||
|
var tags2 []Tag
|
||||||
|
DB.Model(&blog).Related(&tags2, "Tags")
|
||||||
|
if !compareTags(tags2, []string{"tag5", "tag6"}) {
|
||||||
|
t.Errorf("Should find 2 tags after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("Tags").Count() != 2 {
|
||||||
|
t.Errorf("Blog should has three tags after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
DB.Model(&blog).Association("Tags").Delete(tag5)
|
||||||
|
var tags3 []Tag
|
||||||
|
DB.Model(&blog).Related(&tags3, "Tags")
|
||||||
|
if !compareTags(tags3, []string{"tag6"}) {
|
||||||
|
t.Errorf("Should find 1 tags after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("Tags").Count() != 1 {
|
||||||
|
t.Errorf("Blog should has three tags after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog).Association("Tags").Delete(tag3)
|
||||||
|
var tags4 []Tag
|
||||||
|
DB.Model(&blog).Related(&tags4, "Tags")
|
||||||
|
if !compareTags(tags4, []string{"tag6"}) {
|
||||||
|
t.Errorf("Tag should not be deleted when Delete with a unrelated tag")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&blog).Association("Tags").Clear()
|
||||||
|
if DB.Model(&blog).Association("Tags").Count() != 0 {
|
||||||
|
t.Errorf("All tags should be cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManyToManyWithCustomizedForeignKeys(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
|
||||||
|
DB.DropTable(&Blog{}, &Tag{})
|
||||||
|
DB.DropTable("shared_blog_tags")
|
||||||
|
DB.CreateTable(&Blog{}, &Tag{})
|
||||||
|
blog := Blog{
|
||||||
|
Locale: "ZH",
|
||||||
|
Subject: "subject",
|
||||||
|
Body: "body",
|
||||||
|
SharedTags: []Tag{
|
||||||
|
{Locale: "ZH", Value: "tag1"},
|
||||||
|
{Locale: "ZH", Value: "tag2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
DB.Save(&blog)
|
||||||
|
|
||||||
|
blog2 := Blog{
|
||||||
|
ID: blog.ID,
|
||||||
|
Locale: "EN",
|
||||||
|
}
|
||||||
|
DB.Create(&blog2)
|
||||||
|
|
||||||
|
if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) {
|
||||||
|
t.Errorf("Blog should has two tags")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
|
||||||
|
DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3})
|
||||||
|
if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Blog should has three tags after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("SharedTags").Count() != 3 {
|
||||||
|
t.Errorf("Blog should has three tags after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog2).Association("SharedTags").Count() != 3 {
|
||||||
|
t.Errorf("Blog should has three tags after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tags []Tag
|
||||||
|
DB.Model(&blog).Related(&tags, "SharedTags")
|
||||||
|
if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Should find 3 tags with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog2).Related(&tags, "SharedTags")
|
||||||
|
if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Should find 3 tags with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
var blog1 Blog
|
||||||
|
DB.Preload("SharedTags").Find(&blog1)
|
||||||
|
if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Preload many2many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tag4 = &Tag{Locale: "ZH", Value: "tag4"}
|
||||||
|
DB.Model(&blog2).Association("SharedTags").Append(tag4)
|
||||||
|
|
||||||
|
DB.Model(&blog).Related(&tags, "SharedTags")
|
||||||
|
if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) {
|
||||||
|
t.Errorf("Should find 3 tags with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog2).Related(&tags, "SharedTags")
|
||||||
|
if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) {
|
||||||
|
t.Errorf("Should find 3 tags with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
|
||||||
|
var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
|
||||||
|
DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6)
|
||||||
|
var tags2 []Tag
|
||||||
|
DB.Model(&blog).Related(&tags2, "SharedTags")
|
||||||
|
if !compareTags(tags2, []string{"tag5", "tag6"}) {
|
||||||
|
t.Errorf("Should find 2 tags after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog2).Related(&tags2, "SharedTags")
|
||||||
|
if !compareTags(tags2, []string{"tag5", "tag6"}) {
|
||||||
|
t.Errorf("Should find 2 tags after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("SharedTags").Count() != 2 {
|
||||||
|
t.Errorf("Blog should has three tags after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
DB.Model(&blog).Association("SharedTags").Delete(tag5)
|
||||||
|
var tags3 []Tag
|
||||||
|
DB.Model(&blog).Related(&tags3, "SharedTags")
|
||||||
|
if !compareTags(tags3, []string{"tag6"}) {
|
||||||
|
t.Errorf("Should find 1 tags after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("SharedTags").Count() != 1 {
|
||||||
|
t.Errorf("Blog should has three tags after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog2).Association("SharedTags").Delete(tag3)
|
||||||
|
var tags4 []Tag
|
||||||
|
DB.Model(&blog).Related(&tags4, "SharedTags")
|
||||||
|
if !compareTags(tags4, []string{"tag6"}) {
|
||||||
|
t.Errorf("Tag should not be deleted when Delete with a unrelated tag")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&blog2).Association("SharedTags").Clear()
|
||||||
|
if DB.Model(&blog).Association("SharedTags").Count() != 0 {
|
||||||
|
t.Errorf("All tags should be cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) {
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
|
||||||
|
DB.DropTable(&Blog{}, &Tag{})
|
||||||
|
DB.DropTable("locale_blog_tags")
|
||||||
|
DB.CreateTable(&Blog{}, &Tag{})
|
||||||
|
blog := Blog{
|
||||||
|
Locale: "ZH",
|
||||||
|
Subject: "subject",
|
||||||
|
Body: "body",
|
||||||
|
LocaleTags: []Tag{
|
||||||
|
{Locale: "ZH", Value: "tag1"},
|
||||||
|
{Locale: "ZH", Value: "tag2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
DB.Save(&blog)
|
||||||
|
|
||||||
|
blog2 := Blog{
|
||||||
|
ID: blog.ID,
|
||||||
|
Locale: "EN",
|
||||||
|
}
|
||||||
|
DB.Create(&blog2)
|
||||||
|
|
||||||
|
// Append
|
||||||
|
var tag3 = &Tag{Locale: "ZH", Value: "tag3"}
|
||||||
|
DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3})
|
||||||
|
if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Blog should has three tags after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
|
||||||
|
t.Errorf("Blog should has three tags after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
|
||||||
|
t.Errorf("EN Blog should has 0 tags after ZH Blog Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tags []Tag
|
||||||
|
DB.Model(&blog).Related(&tags, "LocaleTags")
|
||||||
|
if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Should find 3 tags with Related")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog2).Related(&tags, "LocaleTags")
|
||||||
|
if len(tags) != 0 {
|
||||||
|
t.Errorf("Should find 0 tags with Related for EN Blog")
|
||||||
|
}
|
||||||
|
|
||||||
|
var blog1 Blog
|
||||||
|
DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID)
|
||||||
|
if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Preload many2many relations")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tag4 = &Tag{Locale: "ZH", Value: "tag4"}
|
||||||
|
DB.Model(&blog2).Association("LocaleTags").Append(tag4)
|
||||||
|
|
||||||
|
DB.Model(&blog).Related(&tags, "LocaleTags")
|
||||||
|
if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("Should find 3 tags with Related for EN Blog")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog2).Related(&tags, "LocaleTags")
|
||||||
|
if !compareTags(tags, []string{"tag4"}) {
|
||||||
|
t.Errorf("Should find 1 tags with Related for EN Blog")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
var tag5 = &Tag{Locale: "ZH", Value: "tag5"}
|
||||||
|
var tag6 = &Tag{Locale: "ZH", Value: "tag6"}
|
||||||
|
DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6)
|
||||||
|
|
||||||
|
var tags2 []Tag
|
||||||
|
DB.Model(&blog).Related(&tags2, "LocaleTags")
|
||||||
|
if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("CN Blog's tags should not be changed after EN Blog Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
var blog11 Blog
|
||||||
|
DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale)
|
||||||
|
if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) {
|
||||||
|
t.Errorf("CN Blog's tags should not be changed after EN Blog Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog2).Related(&tags2, "LocaleTags")
|
||||||
|
if !compareTags(tags2, []string{"tag5", "tag6"}) {
|
||||||
|
t.Errorf("Should find 2 tags after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
var blog21 Blog
|
||||||
|
DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale)
|
||||||
|
if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) {
|
||||||
|
t.Errorf("EN Blog's tags should be changed after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
|
||||||
|
t.Errorf("ZH Blog should has three tags after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog2).Association("LocaleTags").Count() != 2 {
|
||||||
|
t.Errorf("EN Blog should has two tags after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
DB.Model(&blog).Association("LocaleTags").Delete(tag5)
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
|
||||||
|
t.Errorf("ZH Blog should has three tags after Delete with EN's tag")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog2).Association("LocaleTags").Count() != 2 {
|
||||||
|
t.Errorf("EN Blog should has two tags after ZH Blog Delete with EN's tag")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog2).Association("LocaleTags").Delete(tag5)
|
||||||
|
|
||||||
|
if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
|
||||||
|
t.Errorf("ZH Blog should has three tags after EN Blog Delete with EN's tag")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog2).Association("LocaleTags").Count() != 1 {
|
||||||
|
t.Errorf("EN Blog should has 1 tags after EN Blog Delete with EN's tag")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&blog2).Association("LocaleTags").Clear()
|
||||||
|
if DB.Model(&blog).Association("LocaleTags").Count() != 3 {
|
||||||
|
t.Errorf("ZH Blog's tags should not be cleared when clear EN Blog's tags")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
|
||||||
|
t.Errorf("EN Blog's tags should be cleared when clear EN Blog's tags")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&blog).Association("LocaleTags").Clear()
|
||||||
|
if DB.Model(&blog).Association("LocaleTags").Count() != 0 {
|
||||||
|
t.Errorf("ZH Blog's tags should be cleared when clear ZH Blog's tags")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&blog2).Association("LocaleTags").Count() != 0 {
|
||||||
|
t.Errorf("EN Blog's tags should be cleared")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
84
orm/pointer_test.go
Normal file
84
orm/pointer_test.go
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
type PointerStruct struct {
|
||||||
|
ID int64
|
||||||
|
Name *string
|
||||||
|
Num *int
|
||||||
|
}
|
||||||
|
|
||||||
|
type NormalStruct struct {
|
||||||
|
ID int64
|
||||||
|
Name string
|
||||||
|
Num int
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPointerFields(t *testing.T) {
|
||||||
|
DB.DropTable(&PointerStruct{})
|
||||||
|
DB.AutoMigrate(&PointerStruct{})
|
||||||
|
var name = "pointer struct 1"
|
||||||
|
var num = 100
|
||||||
|
pointerStruct := PointerStruct{Name: &name, Num: &num}
|
||||||
|
if DB.Create(&pointerStruct).Error != nil {
|
||||||
|
t.Errorf("Failed to save pointer struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
var pointerStructResult PointerStruct
|
||||||
|
if err := DB.First(&pointerStructResult, "id = ?", pointerStruct.ID).Error; err != nil || *pointerStructResult.Name != name || *pointerStructResult.Num != num {
|
||||||
|
t.Errorf("Failed to query saved pointer struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
var tableName = DB.NewScope(&PointerStruct{}).TableName()
|
||||||
|
|
||||||
|
var normalStruct NormalStruct
|
||||||
|
DB.Table(tableName).First(&normalStruct)
|
||||||
|
if normalStruct.Name != name || normalStruct.Num != num {
|
||||||
|
t.Errorf("Failed to query saved Normal struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
var nilPointerStruct = PointerStruct{}
|
||||||
|
if err := DB.Create(&nilPointerStruct).Error; err != nil {
|
||||||
|
t.Error("Failed to save nil pointer struct", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pointerStruct2 PointerStruct
|
||||||
|
if err := DB.First(&pointerStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
|
||||||
|
t.Error("Failed to query saved nil pointer struct", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var normalStruct2 NormalStruct
|
||||||
|
if err := DB.Table(tableName).First(&normalStruct2, "id = ?", nilPointerStruct.ID).Error; err != nil {
|
||||||
|
t.Error("Failed to query saved nil pointer struct", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var partialNilPointerStruct1 = PointerStruct{Num: &num}
|
||||||
|
if err := DB.Create(&partialNilPointerStruct1).Error; err != nil {
|
||||||
|
t.Error("Failed to save partial nil pointer struct", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pointerStruct3 PointerStruct
|
||||||
|
if err := DB.First(&pointerStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || *pointerStruct3.Num != num {
|
||||||
|
t.Error("Failed to query saved partial nil pointer struct", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var normalStruct3 NormalStruct
|
||||||
|
if err := DB.Table(tableName).First(&normalStruct3, "id = ?", partialNilPointerStruct1.ID).Error; err != nil || normalStruct3.Num != num {
|
||||||
|
t.Error("Failed to query saved partial pointer struct", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var partialNilPointerStruct2 = PointerStruct{Name: &name}
|
||||||
|
if err := DB.Create(&partialNilPointerStruct2).Error; err != nil {
|
||||||
|
t.Error("Failed to save partial nil pointer struct", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var pointerStruct4 PointerStruct
|
||||||
|
if err := DB.First(&pointerStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || *pointerStruct4.Name != name {
|
||||||
|
t.Error("Failed to query saved partial nil pointer struct", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var normalStruct4 NormalStruct
|
||||||
|
if err := DB.Table(tableName).First(&normalStruct4, "id = ?", partialNilPointerStruct2.ID).Error; err != nil || normalStruct4.Name != name {
|
||||||
|
t.Error("Failed to query saved partial pointer struct", err)
|
||||||
|
}
|
||||||
|
}
|
366
orm/polymorphic_test.go
Normal file
366
orm/polymorphic_test.go
Normal file
|
@ -0,0 +1,366 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cat struct {
|
||||||
|
Id int
|
||||||
|
Name string
|
||||||
|
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Dog struct {
|
||||||
|
Id int
|
||||||
|
Name string
|
||||||
|
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Hamster struct {
|
||||||
|
Id int
|
||||||
|
Name string
|
||||||
|
PreferredToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_preferred"`
|
||||||
|
OtherToy Toy `gorm:"polymorphic:Owner;polymorphic_value:hamster_other"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Toy struct {
|
||||||
|
Id int
|
||||||
|
Name string
|
||||||
|
OwnerId int
|
||||||
|
OwnerType string
|
||||||
|
}
|
||||||
|
|
||||||
|
var compareToys = func(toys []Toy, contents []string) bool {
|
||||||
|
var toyContents []string
|
||||||
|
for _, toy := range toys {
|
||||||
|
toyContents = append(toyContents, toy.Name)
|
||||||
|
}
|
||||||
|
sort.Strings(toyContents)
|
||||||
|
sort.Strings(contents)
|
||||||
|
return reflect.DeepEqual(toyContents, contents)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolymorphic(t *testing.T) {
|
||||||
|
cat := Cat{Name: "Mr. Bigglesworth", Toy: Toy{Name: "cat toy"}}
|
||||||
|
dog := Dog{Name: "Pluto", Toys: []Toy{{Name: "dog toy 1"}, {Name: "dog toy 2"}}}
|
||||||
|
DB.Save(&cat).Save(&dog)
|
||||||
|
|
||||||
|
if DB.Model(&cat).Association("Toy").Count() != 1 {
|
||||||
|
t.Errorf("Cat's toys count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 2 {
|
||||||
|
t.Errorf("Dog's toys count should be 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query
|
||||||
|
var catToys []Toy
|
||||||
|
if DB.Model(&cat).Related(&catToys, "Toy").RecordNotFound() {
|
||||||
|
t.Errorf("Did not find any has one polymorphic association")
|
||||||
|
} else if len(catToys) != 1 {
|
||||||
|
t.Errorf("Should have found only one polymorphic has one association")
|
||||||
|
} else if catToys[0].Name != cat.Toy.Name {
|
||||||
|
t.Errorf("Should have found the proper has one polymorphic association")
|
||||||
|
}
|
||||||
|
|
||||||
|
var dogToys []Toy
|
||||||
|
if DB.Model(&dog).Related(&dogToys, "Toys").RecordNotFound() {
|
||||||
|
t.Errorf("Did not find any polymorphic has many associations")
|
||||||
|
} else if len(dogToys) != len(dog.Toys) {
|
||||||
|
t.Errorf("Should have found all polymorphic has many associations")
|
||||||
|
}
|
||||||
|
|
||||||
|
var catToy Toy
|
||||||
|
DB.Model(&cat).Association("Toy").Find(&catToy)
|
||||||
|
if catToy.Name != cat.Toy.Name {
|
||||||
|
t.Errorf("Should find has one polymorphic association")
|
||||||
|
}
|
||||||
|
|
||||||
|
var dogToys1 []Toy
|
||||||
|
DB.Model(&dog).Association("Toys").Find(&dogToys1)
|
||||||
|
if !compareToys(dogToys1, []string{"dog toy 1", "dog toy 2"}) {
|
||||||
|
t.Errorf("Should find has many polymorphic association")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
DB.Model(&cat).Association("Toy").Append(&Toy{
|
||||||
|
Name: "cat toy 2",
|
||||||
|
})
|
||||||
|
|
||||||
|
var catToy2 Toy
|
||||||
|
DB.Model(&cat).Association("Toy").Find(&catToy2)
|
||||||
|
if catToy2.Name != "cat toy 2" {
|
||||||
|
t.Errorf("Should update has one polymorphic association with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&cat).Association("Toy").Count() != 1 {
|
||||||
|
t.Errorf("Cat's toys count should be 1 after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 2 {
|
||||||
|
t.Errorf("Should return two polymorphic has many associations")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&dog).Association("Toys").Append(&Toy{
|
||||||
|
Name: "dog toy 3",
|
||||||
|
})
|
||||||
|
|
||||||
|
var dogToys2 []Toy
|
||||||
|
DB.Model(&dog).Association("Toys").Find(&dogToys2)
|
||||||
|
if !compareToys(dogToys2, []string{"dog toy 1", "dog toy 2", "dog toy 3"}) {
|
||||||
|
t.Errorf("Dog's toys should be updated with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 3 {
|
||||||
|
t.Errorf("Should return three polymorphic has many associations")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
DB.Model(&cat).Association("Toy").Replace(&Toy{
|
||||||
|
Name: "cat toy 3",
|
||||||
|
})
|
||||||
|
|
||||||
|
var catToy3 Toy
|
||||||
|
DB.Model(&cat).Association("Toy").Find(&catToy3)
|
||||||
|
if catToy3.Name != "cat toy 3" {
|
||||||
|
t.Errorf("Should update has one polymorphic association with Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&cat).Association("Toy").Count() != 1 {
|
||||||
|
t.Errorf("Cat's toys count should be 1 after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 3 {
|
||||||
|
t.Errorf("Should return three polymorphic has many associations")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&dog).Association("Toys").Replace(&Toy{
|
||||||
|
Name: "dog toy 4",
|
||||||
|
}, []Toy{
|
||||||
|
{Name: "dog toy 5"}, {Name: "dog toy 6"}, {Name: "dog toy 7"},
|
||||||
|
})
|
||||||
|
|
||||||
|
var dogToys3 []Toy
|
||||||
|
DB.Model(&dog).Association("Toys").Find(&dogToys3)
|
||||||
|
if !compareToys(dogToys3, []string{"dog toy 4", "dog toy 5", "dog toy 6", "dog toy 7"}) {
|
||||||
|
t.Errorf("Dog's toys should be updated with Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 4 {
|
||||||
|
t.Errorf("Should return three polymorphic has many associations")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
DB.Model(&cat).Association("Toy").Delete(&catToy2)
|
||||||
|
|
||||||
|
var catToy4 Toy
|
||||||
|
DB.Model(&cat).Association("Toy").Find(&catToy4)
|
||||||
|
if catToy4.Name != "cat toy 3" {
|
||||||
|
t.Errorf("Should not update has one polymorphic association when Delete a unrelated Toy")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&cat).Association("Toy").Count() != 1 {
|
||||||
|
t.Errorf("Cat's toys count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 4 {
|
||||||
|
t.Errorf("Dog's toys count should be 4")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&cat).Association("Toy").Delete(&catToy3)
|
||||||
|
|
||||||
|
if !DB.Model(&cat).Related(&Toy{}, "Toy").RecordNotFound() {
|
||||||
|
t.Errorf("Toy should be deleted with Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&cat).Association("Toy").Count() != 0 {
|
||||||
|
t.Errorf("Cat's toys count should be 0 after Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 4 {
|
||||||
|
t.Errorf("Dog's toys count should not be changed when delete cat's toy")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&dog).Association("Toys").Delete(&dogToys2)
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 4 {
|
||||||
|
t.Errorf("Dog's toys count should not be changed when delete unrelated toys")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&dog).Association("Toys").Delete(&dogToys3)
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 0 {
|
||||||
|
t.Errorf("Dog's toys count should be deleted with Delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&cat).Association("Toy").Append(&Toy{
|
||||||
|
Name: "cat toy 2",
|
||||||
|
})
|
||||||
|
|
||||||
|
if DB.Model(&cat).Association("Toy").Count() != 1 {
|
||||||
|
t.Errorf("Cat's toys should be added with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&cat).Association("Toy").Clear()
|
||||||
|
|
||||||
|
if DB.Model(&cat).Association("Toy").Count() != 0 {
|
||||||
|
t.Errorf("Cat's toys should be cleared with Clear")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&dog).Association("Toys").Append(&Toy{
|
||||||
|
Name: "dog toy 8",
|
||||||
|
})
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 1 {
|
||||||
|
t.Errorf("Dog's toys should be added with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&dog).Association("Toys").Clear()
|
||||||
|
|
||||||
|
if DB.Model(&dog).Association("Toys").Count() != 0 {
|
||||||
|
t.Errorf("Dog's toys should be cleared with Clear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNamedPolymorphic(t *testing.T) {
|
||||||
|
hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}}
|
||||||
|
DB.Save(&hamster)
|
||||||
|
|
||||||
|
hamster2 := Hamster{}
|
||||||
|
DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id)
|
||||||
|
if hamster2.PreferredToy.Id != hamster.PreferredToy.Id || hamster2.PreferredToy.Name != hamster.PreferredToy.Name {
|
||||||
|
t.Errorf("Hamster's preferred toy couldn't be preloaded")
|
||||||
|
}
|
||||||
|
if hamster2.OtherToy.Id != hamster.OtherToy.Id || hamster2.OtherToy.Name != hamster.OtherToy.Name {
|
||||||
|
t.Errorf("Hamster's other toy couldn't be preloaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear to omit Toy.Id in count
|
||||||
|
hamster2.PreferredToy = Toy{}
|
||||||
|
hamster2.OtherToy = Toy{}
|
||||||
|
|
||||||
|
if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
|
||||||
|
t.Errorf("Hamster's preferred toy count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
|
||||||
|
t.Errorf("Hamster's other toy count should be 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query
|
||||||
|
var hamsterToys []Toy
|
||||||
|
if DB.Model(&hamster).Related(&hamsterToys, "PreferredToy").RecordNotFound() {
|
||||||
|
t.Errorf("Did not find any has one polymorphic association")
|
||||||
|
} else if len(hamsterToys) != 1 {
|
||||||
|
t.Errorf("Should have found only one polymorphic has one association")
|
||||||
|
} else if hamsterToys[0].Name != hamster.PreferredToy.Name {
|
||||||
|
t.Errorf("Should have found the proper has one polymorphic association")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&hamster).Related(&hamsterToys, "OtherToy").RecordNotFound() {
|
||||||
|
t.Errorf("Did not find any has one polymorphic association")
|
||||||
|
} else if len(hamsterToys) != 1 {
|
||||||
|
t.Errorf("Should have found only one polymorphic has one association")
|
||||||
|
} else if hamsterToys[0].Name != hamster.OtherToy.Name {
|
||||||
|
t.Errorf("Should have found the proper has one polymorphic association")
|
||||||
|
}
|
||||||
|
|
||||||
|
hamsterToy := Toy{}
|
||||||
|
DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
|
||||||
|
if hamsterToy.Name != hamster.PreferredToy.Name {
|
||||||
|
t.Errorf("Should find has one polymorphic association")
|
||||||
|
}
|
||||||
|
hamsterToy = Toy{}
|
||||||
|
DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
|
||||||
|
if hamsterToy.Name != hamster.OtherToy.Name {
|
||||||
|
t.Errorf("Should find has one polymorphic association")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append
|
||||||
|
DB.Model(&hamster).Association("PreferredToy").Append(&Toy{
|
||||||
|
Name: "bike 2",
|
||||||
|
})
|
||||||
|
DB.Model(&hamster).Association("OtherToy").Append(&Toy{
|
||||||
|
Name: "treadmill 2",
|
||||||
|
})
|
||||||
|
|
||||||
|
hamsterToy = Toy{}
|
||||||
|
DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
|
||||||
|
if hamsterToy.Name != "bike 2" {
|
||||||
|
t.Errorf("Should update has one polymorphic association with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
hamsterToy = Toy{}
|
||||||
|
DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
|
||||||
|
if hamsterToy.Name != "treadmill 2" {
|
||||||
|
t.Errorf("Should update has one polymorphic association with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
|
||||||
|
t.Errorf("Hamster's toys count should be 1 after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
|
||||||
|
t.Errorf("Hamster's toys count should be 1 after Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace
|
||||||
|
DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{
|
||||||
|
Name: "bike 3",
|
||||||
|
})
|
||||||
|
DB.Model(&hamster).Association("OtherToy").Replace(&Toy{
|
||||||
|
Name: "treadmill 3",
|
||||||
|
})
|
||||||
|
|
||||||
|
hamsterToy = Toy{}
|
||||||
|
DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy)
|
||||||
|
if hamsterToy.Name != "bike 3" {
|
||||||
|
t.Errorf("Should update has one polymorphic association with Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
hamsterToy = Toy{}
|
||||||
|
DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy)
|
||||||
|
if hamsterToy.Name != "treadmill 3" {
|
||||||
|
t.Errorf("Should update has one polymorphic association with Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 {
|
||||||
|
t.Errorf("hamster's toys count should be 1 after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
|
||||||
|
t.Errorf("hamster's toys count should be 1 after Replace")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
DB.Model(&hamster).Association("PreferredToy").Append(&Toy{
|
||||||
|
Name: "bike 2",
|
||||||
|
})
|
||||||
|
DB.Model(&hamster).Association("OtherToy").Append(&Toy{
|
||||||
|
Name: "treadmill 2",
|
||||||
|
})
|
||||||
|
|
||||||
|
if DB.Model(&hamster).Association("PreferredToy").Count() != 1 {
|
||||||
|
t.Errorf("Hamster's toys should be added with Append")
|
||||||
|
}
|
||||||
|
if DB.Model(&hamster).Association("OtherToy").Count() != 1 {
|
||||||
|
t.Errorf("Hamster's toys should be added with Append")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&hamster).Association("PreferredToy").Clear()
|
||||||
|
|
||||||
|
if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 {
|
||||||
|
t.Errorf("Hamster's preferred toy should be cleared with Clear")
|
||||||
|
}
|
||||||
|
if DB.Model(&hamster2).Association("OtherToy").Count() != 1 {
|
||||||
|
t.Errorf("Hamster's other toy should be still available")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&hamster).Association("OtherToy").Clear()
|
||||||
|
if DB.Model(&hamster).Association("OtherToy").Count() != 0 {
|
||||||
|
t.Errorf("Hamster's other toy should be cleared with Clear")
|
||||||
|
}
|
||||||
|
}
|
1606
orm/preload_test.go
Normal file
1606
orm/preload_test.go
Normal file
File diff suppressed because it is too large
Load Diff
661
orm/query_test.go
Normal file
661
orm/query_test.go
Normal file
|
@ -0,0 +1,661 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFirstAndLast(t *testing.T) {
|
||||||
|
DB.Save(&User{Name: "user1", Emails: []Email{{Email: "user1@example.com"}}})
|
||||||
|
DB.Save(&User{Name: "user2", Emails: []Email{{Email: "user2@example.com"}}})
|
||||||
|
|
||||||
|
var user1, user2, user3, user4 User
|
||||||
|
DB.First(&user1)
|
||||||
|
DB.Order("id").Limit(1).Find(&user2)
|
||||||
|
|
||||||
|
DB.Last(&user3)
|
||||||
|
DB.Order("id desc").Limit(1).Find(&user4)
|
||||||
|
if user1.Id != user2.Id || user3.Id != user4.Id {
|
||||||
|
t.Errorf("First and Last should by order by primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
DB.First(&users)
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Find first record as slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
if DB.Joins("left join emails on emails.user_id = users.id").First(&user).Error != nil {
|
||||||
|
t.Errorf("Should not raise any error when order with Join table")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Email != "" {
|
||||||
|
t.Errorf("User's Email should be blank as no one set it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFirstAndLastWithNoStdPrimaryKey(t *testing.T) {
|
||||||
|
DB.Save(&Animal{Name: "animal1"})
|
||||||
|
DB.Save(&Animal{Name: "animal2"})
|
||||||
|
|
||||||
|
var animal1, animal2, animal3, animal4 Animal
|
||||||
|
DB.First(&animal1)
|
||||||
|
DB.Order("counter").Limit(1).Find(&animal2)
|
||||||
|
|
||||||
|
DB.Last(&animal3)
|
||||||
|
DB.Order("counter desc").Limit(1).Find(&animal4)
|
||||||
|
if animal1.Counter != animal2.Counter || animal3.Counter != animal4.Counter {
|
||||||
|
t.Errorf("First and Last should work correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFirstAndLastWithRaw(t *testing.T) {
|
||||||
|
user1 := User{Name: "user", Emails: []Email{{Email: "user1@example.com"}}}
|
||||||
|
user2 := User{Name: "user", Emails: []Email{{Email: "user2@example.com"}}}
|
||||||
|
DB.Save(&user1)
|
||||||
|
DB.Save(&user2)
|
||||||
|
|
||||||
|
var user3, user4 User
|
||||||
|
DB.Raw("select * from users WHERE name = ?", "user").First(&user3)
|
||||||
|
if user3.Id != user1.Id {
|
||||||
|
t.Errorf("Find first record with raw")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Raw("select * from users WHERE name = ?", "user").Last(&user4)
|
||||||
|
if user4.Id != user2.Id {
|
||||||
|
t.Errorf("Find last record with raw")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIntPrimaryKey(t *testing.T) {
|
||||||
|
var animal Animal
|
||||||
|
DB.First(&animal, uint64(1))
|
||||||
|
if animal.Counter != 1 {
|
||||||
|
t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(Animal{}).Where(Animal{Counter: uint64(2)}).Scan(&animal)
|
||||||
|
if animal.Counter != 2 {
|
||||||
|
t.Errorf("Fetch a record from with a non-int primary key should work, but failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
|
||||||
|
type AddressByZipCode struct {
|
||||||
|
ZipCode string `gorm:"primary_key"`
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.AutoMigrate(&AddressByZipCode{})
|
||||||
|
DB.Create(&AddressByZipCode{ZipCode: "00501", Address: "Holtsville"})
|
||||||
|
|
||||||
|
var address AddressByZipCode
|
||||||
|
DB.First(&address, "00501")
|
||||||
|
if address.ZipCode != "00501" {
|
||||||
|
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindAsSliceOfPointers(t *testing.T) {
|
||||||
|
DB.Save(&User{Name: "user"})
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
DB.Find(&users)
|
||||||
|
|
||||||
|
var userPointers []*User
|
||||||
|
DB.Find(&userPointers)
|
||||||
|
|
||||||
|
if len(users) == 0 || len(users) != len(userPointers) {
|
||||||
|
t.Errorf("Find slice of pointers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchWithPlainSQL(t *testing.T) {
|
||||||
|
user1 := User{Name: "PlainSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "PlainSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "PlainSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
scopedb := DB.Where("name LIKE ?", "%PlainSqlUser%")
|
||||||
|
|
||||||
|
if DB.Where("name = ?", user1.Name).First(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("Search with plain SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("name LIKE ?", "%"+user1.Name+"%").First(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("Search with plan SQL (regexp)")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
DB.Find(&users, "name LIKE ? and age > ?", "%PlainSqlUser%", 1)
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Errorf("Should found 2 users that age > 1, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users)
|
||||||
|
if len(users) != 3 {
|
||||||
|
t.Errorf("Should found 3 users that age >= 1, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
scopedb.Where("age <> ?", 20).Find(&users)
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Errorf("Should found 2 users age != 20, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Errorf("Should found 2 users, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users)
|
||||||
|
if len(users) != 3 {
|
||||||
|
t.Errorf("Should found 3 users, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where("id in (?)", user1.Id).Find(&users)
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Should found 1 users, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Where("id IN (?)", []string{}).Find(&users).Error; err != nil {
|
||||||
|
t.Error("no error should happen when query with empty slice, but got: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Not("id IN (?)", []string{}).Find(&users).Error; err != nil {
|
||||||
|
t.Error("no error should happen when query with empty slice, but got: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("name = ?", "none existing").Find(&[]User{}).RecordNotFound() {
|
||||||
|
t.Errorf("Should not get RecordNotFound error when looking for none existing records")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchWithStruct(t *testing.T) {
|
||||||
|
user1 := User{Name: "StructSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "StructSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "StructSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
if DB.Where(user1.Id).First(&User{}).RecordNotFound() {
|
||||||
|
t.Errorf("Search with primary key")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.First(&User{}, user1.Id).RecordNotFound() {
|
||||||
|
t.Errorf("Search with primary key as inline condition")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.First(&User{}, fmt.Sprintf("%v", user1.Id)).RecordNotFound() {
|
||||||
|
t.Errorf("Search with primary key as inline condition")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
DB.Where([]int64{user1.Id, user2.Id, user3.Id}).Find(&users)
|
||||||
|
if len(users) != 3 {
|
||||||
|
t.Errorf("Should found 3 users when search with primary keys, but got %v", len(users))
|
||||||
|
}
|
||||||
|
|
||||||
|
var user User
|
||||||
|
DB.First(&user, &User{Name: user1.Name})
|
||||||
|
if user.Id == 0 || user.Name != user1.Name {
|
||||||
|
t.Errorf("Search first record with inline pointer of struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.First(&user, User{Name: user1.Name})
|
||||||
|
if user.Id == 0 || user.Name != user.Name {
|
||||||
|
t.Errorf("Search first record with inline struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: user1.Name}).First(&user)
|
||||||
|
if user.Id == 0 || user.Name != user1.Name {
|
||||||
|
t.Errorf("Search first record with where struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Find(&users, &User{Name: user2.Name})
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Search all records with inline struct")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchWithMap(t *testing.T) {
|
||||||
|
companyID := 1
|
||||||
|
user1 := User{Name: "MapSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "MapSearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "MapSearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
user4 := User{Name: "MapSearchUser4", Age: 30, Birthday: parseTime("2020-1-1"), CompanyID: &companyID}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4)
|
||||||
|
|
||||||
|
var user User
|
||||||
|
DB.First(&user, map[string]interface{}{"name": user1.Name})
|
||||||
|
if user.Id == 0 || user.Name != user1.Name {
|
||||||
|
t.Errorf("Search first record with inline map")
|
||||||
|
}
|
||||||
|
|
||||||
|
user = User{}
|
||||||
|
DB.Where(map[string]interface{}{"name": user2.Name}).First(&user)
|
||||||
|
if user.Id == 0 || user.Name != user2.Name {
|
||||||
|
t.Errorf("Search first record with where map")
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
DB.Where(map[string]interface{}{"name": user3.Name}).Find(&users)
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Search all records with inline map")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Find(&users, map[string]interface{}{"name": user3.Name})
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Search all records with inline map")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": nil})
|
||||||
|
if len(users) != 0 {
|
||||||
|
t.Errorf("Search all records with inline map containing null value finding 0 records")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Find(&users, map[string]interface{}{"name": user1.Name, "company_id": nil})
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Search all records with inline map containing null value finding 1 record")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Find(&users, map[string]interface{}{"name": user4.Name, "company_id": companyID})
|
||||||
|
if len(users) != 1 {
|
||||||
|
t.Errorf("Search all records with inline multiple value map")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchWithEmptyChain(t *testing.T) {
|
||||||
|
user1 := User{Name: "ChainSearchUser1", Age: 1, Birthday: parseTime("2000-1-1")}
|
||||||
|
user2 := User{Name: "ChainearchUser2", Age: 10, Birthday: parseTime("2010-1-1")}
|
||||||
|
user3 := User{Name: "ChainearchUser3", Age: 20, Birthday: parseTime("2020-1-1")}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
if DB.Where("").Where("").First(&User{}).Error != nil {
|
||||||
|
t.Errorf("Should not raise any error if searching with empty strings")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where(&User{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
|
||||||
|
t.Errorf("Should not raise any error if searching with empty struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where(map[string]interface{}{}).Where("name = ?", user1.Name).First(&User{}).Error != nil {
|
||||||
|
t.Errorf("Should not raise any error if searching with empty map")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelect(t *testing.T) {
|
||||||
|
user1 := User{Name: "SelectUser1"}
|
||||||
|
DB.Save(&user1)
|
||||||
|
|
||||||
|
var user User
|
||||||
|
DB.Where("name = ?", user1.Name).Select("name").Find(&user)
|
||||||
|
if user.Id != 0 {
|
||||||
|
t.Errorf("Should not have ID because only selected name, %+v", user.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Name != user1.Name {
|
||||||
|
t.Errorf("Should have user Name when selected it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOrderAndPluck(t *testing.T) {
|
||||||
|
user1 := User{Name: "OrderPluckUser1", Age: 1}
|
||||||
|
user2 := User{Name: "OrderPluckUser2", Age: 10}
|
||||||
|
user3 := User{Name: "OrderPluckUser3", Age: 20}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
scopedb := DB.Model(&User{}).Where("name like ?", "%OrderPluckUser%")
|
||||||
|
|
||||||
|
var user User
|
||||||
|
scopedb.Order(gorm.Expr("name = ? DESC", "OrderPluckUser2")).First(&user)
|
||||||
|
if user.Name != "OrderPluckUser2" {
|
||||||
|
t.Errorf("Order with sql expression")
|
||||||
|
}
|
||||||
|
|
||||||
|
var ages []int64
|
||||||
|
scopedb.Order("age desc").Pluck("age", &ages)
|
||||||
|
if ages[0] != 20 {
|
||||||
|
t.Errorf("The first age should be 20 when order with age desc")
|
||||||
|
}
|
||||||
|
|
||||||
|
var ages1, ages2 []int64
|
||||||
|
scopedb.Order("age desc").Pluck("age", &ages1).Pluck("age", &ages2)
|
||||||
|
if !reflect.DeepEqual(ages1, ages2) {
|
||||||
|
t.Errorf("The first order is the primary order")
|
||||||
|
}
|
||||||
|
|
||||||
|
var ages3, ages4 []int64
|
||||||
|
scopedb.Model(&User{}).Order("age desc").Pluck("age", &ages3).Order("age", true).Pluck("age", &ages4)
|
||||||
|
if reflect.DeepEqual(ages3, ages4) {
|
||||||
|
t.Errorf("Reorder should work")
|
||||||
|
}
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
var ages5 []int64
|
||||||
|
scopedb.Model(User{}).Order("name").Order("age desc").Pluck("age", &ages5).Pluck("name", &names)
|
||||||
|
if names != nil && ages5 != nil {
|
||||||
|
if !(names[0] == user1.Name && names[1] == user2.Name && names[2] == user3.Name && ages5[2] == 20) {
|
||||||
|
t.Errorf("Order with multiple orders")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("Order with multiple orders")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(User{}).Select("name, age").Find(&[]User{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimit(t *testing.T) {
|
||||||
|
user1 := User{Name: "LimitUser1", Age: 1}
|
||||||
|
user2 := User{Name: "LimitUser2", Age: 10}
|
||||||
|
user3 := User{Name: "LimitUser3", Age: 20}
|
||||||
|
user4 := User{Name: "LimitUser4", Age: 10}
|
||||||
|
user5 := User{Name: "LimitUser5", Age: 20}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5)
|
||||||
|
|
||||||
|
var users1, users2, users3 []User
|
||||||
|
DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3)
|
||||||
|
|
||||||
|
if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 {
|
||||||
|
t.Errorf("Limit should works")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOffset(t *testing.T) {
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)})
|
||||||
|
}
|
||||||
|
var users1, users2, users3, users4 []User
|
||||||
|
DB.Limit(100).Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4)
|
||||||
|
|
||||||
|
if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) {
|
||||||
|
t.Errorf("Offset should work")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOr(t *testing.T) {
|
||||||
|
user1 := User{Name: "OrUser1", Age: 1}
|
||||||
|
user2 := User{Name: "OrUser2", Age: 10}
|
||||||
|
user3 := User{Name: "OrUser3", Age: 20}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
DB.Where("name = ?", user1.Name).Or("name = ?", user2.Name).Find(&users)
|
||||||
|
if len(users) != 2 {
|
||||||
|
t.Errorf("Find users with or")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCount(t *testing.T) {
|
||||||
|
user1 := User{Name: "CountUser1", Age: 1}
|
||||||
|
user2 := User{Name: "CountUser2", Age: 10}
|
||||||
|
user3 := User{Name: "CountUser3", Age: 20}
|
||||||
|
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
var count, count1, count2 int64
|
||||||
|
var users []User
|
||||||
|
|
||||||
|
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
|
||||||
|
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if count != int64(len(users)) {
|
||||||
|
t.Errorf("Count() method should get correct value")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in (?)", []string{user2.Name, user3.Name}).Count(&count2)
|
||||||
|
if count1 != 1 || count2 != 3 {
|
||||||
|
t.Errorf("Multiple count in chain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNot(t *testing.T) {
|
||||||
|
DB.Create(getPreparedUser("user1", "not"))
|
||||||
|
DB.Create(getPreparedUser("user2", "not"))
|
||||||
|
DB.Create(getPreparedUser("user3", "not"))
|
||||||
|
|
||||||
|
user4 := getPreparedUser("user4", "not")
|
||||||
|
user4.Company = Company{}
|
||||||
|
DB.Create(user4)
|
||||||
|
|
||||||
|
DB := DB.Where("role = ?", "not")
|
||||||
|
|
||||||
|
var users1, users2, users3, users4, users5, users6, users7, users8, users9 []User
|
||||||
|
if DB.Find(&users1).RowsAffected != 4 {
|
||||||
|
t.Errorf("should find 4 not users")
|
||||||
|
}
|
||||||
|
DB.Not(users1[0].Id).Find(&users2)
|
||||||
|
|
||||||
|
if len(users1)-len(users2) != 1 {
|
||||||
|
t.Errorf("Should ignore the first users with Not")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Not([]int{}).Find(&users3)
|
||||||
|
if len(users1)-len(users3) != 0 {
|
||||||
|
t.Errorf("Should find all users with a blank condition")
|
||||||
|
}
|
||||||
|
|
||||||
|
var name3Count int64
|
||||||
|
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
|
||||||
|
DB.Not("name", "user3").Find(&users4)
|
||||||
|
if len(users1)-len(users4) != int(name3Count) {
|
||||||
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Not("name = ?", "user3").Find(&users4)
|
||||||
|
if len(users1)-len(users4) != int(name3Count) {
|
||||||
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Not("name <> ?", "user3").Find(&users4)
|
||||||
|
if len(users4) != int(name3Count) {
|
||||||
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Not(User{Name: "user3"}).Find(&users5)
|
||||||
|
|
||||||
|
if len(users1)-len(users5) != int(name3Count) {
|
||||||
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
|
||||||
|
if len(users1)-len(users6) != int(name3Count) {
|
||||||
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
|
||||||
|
if len(users1)-len(users7) != 2 { // not user3 or user4
|
||||||
|
t.Errorf("Should find all user's name not equal to 3 who do not have company id")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Not("name", []string{"user3"}).Find(&users8)
|
||||||
|
if len(users1)-len(users8) != int(name3Count) {
|
||||||
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
|
}
|
||||||
|
|
||||||
|
var name2Count int64
|
||||||
|
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
|
||||||
|
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
|
||||||
|
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
|
||||||
|
t.Errorf("Should find all users's name not equal 3")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFillSmallerStruct(t *testing.T) {
|
||||||
|
user1 := User{Name: "SmallerUser", Age: 100}
|
||||||
|
DB.Save(&user1)
|
||||||
|
type SimpleUser struct {
|
||||||
|
Name string
|
||||||
|
Id int64
|
||||||
|
UpdatedAt time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var simpleUser SimpleUser
|
||||||
|
DB.Table("users").Where("name = ?", user1.Name).First(&simpleUser)
|
||||||
|
|
||||||
|
if simpleUser.Id == 0 || simpleUser.Name == "" {
|
||||||
|
t.Errorf("Should fill data correctly into smaller struct")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindOrInitialize(t *testing.T) {
|
||||||
|
var user1, user2, user3, user4, user5, user6 User
|
||||||
|
DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1)
|
||||||
|
if user1.Name != "find or init" || user1.Id != 0 || user1.Age != 33 {
|
||||||
|
t.Errorf("user should be initialized with search value")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2)
|
||||||
|
if user2.Name != "find or init" || user2.Id != 0 || user2.Age != 33 {
|
||||||
|
t.Errorf("user should be initialized with search value")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"})
|
||||||
|
if user3.Name != "find or init 2" || user3.Id != 0 {
|
||||||
|
t.Errorf("user should be initialized with inline search value")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4)
|
||||||
|
if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
|
||||||
|
t.Errorf("user should be initialized with search value and attrs")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4)
|
||||||
|
if user4.Name != "find or init" || user4.Id != 0 || user4.Age != 44 {
|
||||||
|
t.Errorf("user should be initialized with search value and assign attrs")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(&User{Name: "find or init", Age: 33})
|
||||||
|
DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5)
|
||||||
|
if user5.Name != "find or init" || user5.Id == 0 || user5.Age != 33 {
|
||||||
|
t.Errorf("user should be found and not initialized by Attrs")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6)
|
||||||
|
if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 33 {
|
||||||
|
t.Errorf("user should be found with FirstOrInit")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6)
|
||||||
|
if user6.Name != "find or init" || user6.Id == 0 || user6.Age != 44 {
|
||||||
|
t.Errorf("user should be found and updated with assigned attrs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindOrCreate(t *testing.T) {
|
||||||
|
var user1, user2, user3, user4, user5, user6, user7, user8 User
|
||||||
|
DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1)
|
||||||
|
if user1.Name != "find or create" || user1.Id == 0 || user1.Age != 33 {
|
||||||
|
t.Errorf("user should be created with search value")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2)
|
||||||
|
if user1.Id != user2.Id || user2.Name != "find or create" || user2.Id == 0 || user2.Age != 33 {
|
||||||
|
t.Errorf("user should be created with search value")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"})
|
||||||
|
if user3.Name != "find or create 2" || user3.Id == 0 {
|
||||||
|
t.Errorf("user should be created with inline search value")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4)
|
||||||
|
if user4.Name != "find or create 3" || user4.Id == 0 || user4.Age != 44 {
|
||||||
|
t.Errorf("user should be created with search value and attrs")
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedAt1 := user4.UpdatedAt
|
||||||
|
DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
|
||||||
|
if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("UpdateAt should be changed when update values with assign")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4)
|
||||||
|
if user4.Name != "find or create 4" || user4.Id == 0 || user4.Age != 44 {
|
||||||
|
t.Errorf("user should be created with search value and assigned attrs")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5)
|
||||||
|
if user5.Name != "find or create" || user5.Id == 0 || user5.Age != 33 {
|
||||||
|
t.Errorf("user should be found and not initialized by Attrs")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6)
|
||||||
|
if user6.Name != "find or create" || user6.Id == 0 || user6.Age != 44 {
|
||||||
|
t.Errorf("user should be found and updated with assigned attrs")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or create"}).Find(&user7)
|
||||||
|
if user7.Name != "find or create" || user7.Id == 0 || user7.Age != 44 {
|
||||||
|
t.Errorf("user should be found and updated with assigned attrs")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, CreditCard: CreditCard{Number: "1231231231"}, Emails: []Email{{Email: "jinzhu@assign_embedded_struct.com"}, {Email: "jinzhu-2@assign_embedded_struct.com"}}}).FirstOrCreate(&user8)
|
||||||
|
if DB.Where("email = ?", "jinzhu-2@assign_embedded_struct.com").First(&Email{}).RecordNotFound() {
|
||||||
|
t.Errorf("embedded struct email should be saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Where("email = ?", "1231231231").First(&CreditCard{}).RecordNotFound() {
|
||||||
|
t.Errorf("embedded struct credit card should be saved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithEscapedFieldName(t *testing.T) {
|
||||||
|
user1 := User{Name: "EscapedFieldNameUser", Age: 1}
|
||||||
|
user2 := User{Name: "EscapedFieldNameUser", Age: 10}
|
||||||
|
user3 := User{Name: "EscapedFieldNameUser", Age: 20}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
DB.Model(User{}).Where(&User{Name: "EscapedFieldNameUser"}).Pluck("\"name\"", &names)
|
||||||
|
|
||||||
|
if len(names) != 3 {
|
||||||
|
t.Errorf("Expected 3 name, but got: %d", len(names))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithVariables(t *testing.T) {
|
||||||
|
DB.Save(&User{Name: "jinzhu"})
|
||||||
|
|
||||||
|
rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Errorf("Should have returned at least one row")
|
||||||
|
} else {
|
||||||
|
columns, _ := rows.Columns()
|
||||||
|
if !reflect.DeepEqual(columns, []string{"fake"}) {
|
||||||
|
t.Errorf("Should only contains one column")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithArrayInput(t *testing.T) {
|
||||||
|
DB.Save(&User{Name: "jinzhu", Age: 42})
|
||||||
|
|
||||||
|
var user User
|
||||||
|
DB.Select([]string{"name", "age"}).Where("age = 42 AND name = 'jinzhu'").First(&user)
|
||||||
|
|
||||||
|
if user.Name != "jinzhu" || user.Age != 42 {
|
||||||
|
t.Errorf("Should have selected both age and name")
|
||||||
|
}
|
||||||
|
}
|
85
orm/scaner_test.go
Normal file
85
orm/scaner_test.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestScannableSlices(t *testing.T) {
|
||||||
|
if err := DB.AutoMigrate(&RecordWithSlice{}).Error; err != nil {
|
||||||
|
t.Errorf("Should create table with slice values correctly: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r1 := RecordWithSlice{
|
||||||
|
Strings: ExampleStringSlice{"a", "b", "c"},
|
||||||
|
Structs: ExampleStructSlice{
|
||||||
|
{"name1", "value1"},
|
||||||
|
{"name2", "value2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Save(&r1).Error; err != nil {
|
||||||
|
t.Errorf("Should save record with slice values")
|
||||||
|
}
|
||||||
|
|
||||||
|
var r2 RecordWithSlice
|
||||||
|
|
||||||
|
if err := DB.Find(&r2).Error; err != nil {
|
||||||
|
t.Errorf("Should fetch record with slice values")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r2.Strings) != 3 || r2.Strings[0] != "a" || r2.Strings[1] != "b" || r2.Strings[2] != "c" {
|
||||||
|
t.Errorf("Should have serialised and deserialised a string array")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(r2.Structs) != 2 || r2.Structs[0].Name != "name1" || r2.Structs[0].Value != "value1" || r2.Structs[1].Name != "name2" || r2.Structs[1].Value != "value2" {
|
||||||
|
t.Errorf("Should have serialised and deserialised a struct array")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RecordWithSlice struct {
|
||||||
|
ID uint64
|
||||||
|
Strings ExampleStringSlice `sql:"type:text"`
|
||||||
|
Structs ExampleStructSlice `sql:"type:text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExampleStringSlice []string
|
||||||
|
|
||||||
|
func (l ExampleStringSlice) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *ExampleStringSlice) Scan(input interface{}) error {
|
||||||
|
switch value := input.(type) {
|
||||||
|
case string:
|
||||||
|
return json.Unmarshal([]byte(value), l)
|
||||||
|
case []byte:
|
||||||
|
return json.Unmarshal(value, l)
|
||||||
|
default:
|
||||||
|
return errors.New("not supported")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExampleStruct struct {
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExampleStructSlice []ExampleStruct
|
||||||
|
|
||||||
|
func (l ExampleStructSlice) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *ExampleStructSlice) Scan(input interface{}) error {
|
||||||
|
switch value := input.(type) {
|
||||||
|
case string:
|
||||||
|
return json.Unmarshal([]byte(value), l)
|
||||||
|
case []byte:
|
||||||
|
return json.Unmarshal(value, l)
|
||||||
|
default:
|
||||||
|
return errors.New("not supported")
|
||||||
|
}
|
||||||
|
}
|
1282
orm/scope.go
Normal file
1282
orm/scope.go
Normal file
File diff suppressed because it is too large
Load Diff
43
orm/scope_test.go
Normal file
43
orm/scope_test.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
||||||
|
return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NameIn2And3(d *gorm.DB) *gorm.DB {
|
||||||
|
return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NameIn(names []string) func(d *gorm.DB) *gorm.DB {
|
||||||
|
return func(d *gorm.DB) *gorm.DB {
|
||||||
|
return d.Where("name in (?)", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScopes(t *testing.T) {
|
||||||
|
user1 := User{Name: "ScopeUser1", Age: 1}
|
||||||
|
user2 := User{Name: "ScopeUser2", Age: 1}
|
||||||
|
user3 := User{Name: "ScopeUser3", Age: 2}
|
||||||
|
DB.Save(&user1).Save(&user2).Save(&user3)
|
||||||
|
|
||||||
|
var users1, users2, users3 []User
|
||||||
|
DB.Scopes(NameIn1And2).Find(&users1)
|
||||||
|
if len(users1) != 2 {
|
||||||
|
t.Errorf("Should found two users's name in 1, 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2)
|
||||||
|
if len(users2) != 1 {
|
||||||
|
t.Errorf("Should found one user's name is 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Scopes(NameIn([]string{user1.Name, user3.Name})).Find(&users3)
|
||||||
|
if len(users3) != 2 {
|
||||||
|
t.Errorf("Should found two users's name in 1, 3")
|
||||||
|
}
|
||||||
|
}
|
147
orm/search.go
Normal file
147
orm/search.go
Normal file
|
@ -0,0 +1,147 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type search struct {
|
||||||
|
db *DB
|
||||||
|
whereConditions []map[string]interface{}
|
||||||
|
orConditions []map[string]interface{}
|
||||||
|
notConditions []map[string]interface{}
|
||||||
|
havingConditions []map[string]interface{}
|
||||||
|
joinConditions []map[string]interface{}
|
||||||
|
initAttrs []interface{}
|
||||||
|
assignAttrs []interface{}
|
||||||
|
selects map[string]interface{}
|
||||||
|
omits []string
|
||||||
|
orders []interface{}
|
||||||
|
preload []searchPreload
|
||||||
|
offset interface{}
|
||||||
|
limit interface{}
|
||||||
|
group string
|
||||||
|
tableName string
|
||||||
|
raw bool
|
||||||
|
Unscoped bool
|
||||||
|
countingQuery bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type searchPreload struct {
|
||||||
|
schema string
|
||||||
|
conditions []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) clone() *search {
|
||||||
|
clone := *s
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Where(query interface{}, values ...interface{}) *search {
|
||||||
|
s.whereConditions = append(s.whereConditions, map[string]interface{}{"query": query, "args": values})
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Not(query interface{}, values ...interface{}) *search {
|
||||||
|
s.notConditions = append(s.notConditions, map[string]interface{}{"query": query, "args": values})
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Or(query interface{}, values ...interface{}) *search {
|
||||||
|
s.orConditions = append(s.orConditions, map[string]interface{}{"query": query, "args": values})
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Attrs(attrs ...interface{}) *search {
|
||||||
|
s.initAttrs = append(s.initAttrs, toSearchableMap(attrs...))
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Assign(attrs ...interface{}) *search {
|
||||||
|
s.assignAttrs = append(s.assignAttrs, toSearchableMap(attrs...))
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Order(value interface{}, reorder ...bool) *search {
|
||||||
|
if len(reorder) > 0 && reorder[0] {
|
||||||
|
s.orders = []interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if value != nil {
|
||||||
|
s.orders = append(s.orders, value)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Select(query interface{}, args ...interface{}) *search {
|
||||||
|
s.selects = map[string]interface{}{"query": query, "args": args}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Omit(columns ...string) *search {
|
||||||
|
s.omits = columns
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Limit(limit interface{}) *search {
|
||||||
|
s.limit = limit
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Offset(offset interface{}) *search {
|
||||||
|
s.offset = offset
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Group(query string) *search {
|
||||||
|
s.group = s.getInterfaceAsSQL(query)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Having(query string, values ...interface{}) *search {
|
||||||
|
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Joins(query string, values ...interface{}) *search {
|
||||||
|
s.joinConditions = append(s.joinConditions, map[string]interface{}{"query": query, "args": values})
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Preload(schema string, values ...interface{}) *search {
|
||||||
|
var preloads []searchPreload
|
||||||
|
for _, preload := range s.preload {
|
||||||
|
if preload.schema != schema {
|
||||||
|
preloads = append(preloads, preload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
preloads = append(preloads, searchPreload{schema, values})
|
||||||
|
s.preload = preloads
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Raw(b bool) *search {
|
||||||
|
s.raw = b
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) unscoped() *search {
|
||||||
|
s.Unscoped = true
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) Table(name string) *search {
|
||||||
|
s.tableName = name
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *search) getInterfaceAsSQL(value interface{}) (str string) {
|
||||||
|
switch value.(type) {
|
||||||
|
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||||
|
str = fmt.Sprintf("%v", value)
|
||||||
|
default:
|
||||||
|
s.db.AddError(ErrInvalidSQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
if str == "-1" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
30
orm/search_test.go
Normal file
30
orm/search_test.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCloneSearch(t *testing.T) {
|
||||||
|
s := new(search)
|
||||||
|
s.Where("name = ?", "jinzhu").Order("name").Attrs("name", "jinzhu").Select("name, age")
|
||||||
|
|
||||||
|
s1 := s.clone()
|
||||||
|
s1.Where("age = ?", 20).Order("age").Attrs("email", "a@e.org").Select("email")
|
||||||
|
|
||||||
|
if reflect.DeepEqual(s.whereConditions, s1.whereConditions) {
|
||||||
|
t.Errorf("Where should be copied")
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.DeepEqual(s.orders, s1.orders) {
|
||||||
|
t.Errorf("Order should be copied")
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.DeepEqual(s.initAttrs, s1.initAttrs) {
|
||||||
|
t.Errorf("InitAttrs should be copied")
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.DeepEqual(s.Select, s1.Select) {
|
||||||
|
t.Errorf("selectStr should be copied")
|
||||||
|
}
|
||||||
|
}
|
5
orm/test_all.sh
Executable file
5
orm/test_all.sh
Executable file
|
@ -0,0 +1,5 @@
|
||||||
|
dialects=("postgres" "mysql" "sqlite")
|
||||||
|
|
||||||
|
for dialect in "${dialects[@]}" ; do
|
||||||
|
GORM_DIALECT=${dialect} go test
|
||||||
|
done
|
465
orm/update_test.go
Normal file
465
orm/update_test.go
Normal file
|
@ -0,0 +1,465 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdate(t *testing.T) {
|
||||||
|
product1 := Product{Code: "product1code"}
|
||||||
|
product2 := Product{Code: "product2code"}
|
||||||
|
|
||||||
|
DB.Save(&product1).Save(&product2).Update("code", "product2newcode")
|
||||||
|
|
||||||
|
if product2.Code != "product2newcode" {
|
||||||
|
t.Errorf("Record should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.First(&product1, product1.Id)
|
||||||
|
DB.First(&product2, product2.Id)
|
||||||
|
updatedAt1 := product1.UpdatedAt
|
||||||
|
|
||||||
|
if DB.First(&Product{}, "code = ?", product1.Code).RecordNotFound() {
|
||||||
|
t.Errorf("Product1 should not be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.First(&Product{}, "code = ?", "product2code").RecordNotFound() {
|
||||||
|
t.Errorf("Product2's code should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
|
||||||
|
t.Errorf("Product2's code should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Table("products").Where("code in (?)", []string{"product1code"}).Update("code", "product1newcode")
|
||||||
|
|
||||||
|
var product4 Product
|
||||||
|
DB.First(&product4, product1.Id)
|
||||||
|
if updatedAt1.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updatedAt should be updated if something changed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !DB.First(&Product{}, "code = 'product1code'").RecordNotFound() {
|
||||||
|
t.Errorf("Product1's code should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.First(&Product{}, "code = 'product1newcode'").RecordNotFound() {
|
||||||
|
t.Errorf("Product should not be changed to 789")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(product2).Update("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||||
|
t.Error("No error should raise when update with CamelCase")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.Model(&product2).UpdateColumn("CreatedAt", time.Now().Add(time.Hour)).Error != nil {
|
||||||
|
t.Error("No error should raise when update_column with CamelCase")
|
||||||
|
}
|
||||||
|
|
||||||
|
var products []Product
|
||||||
|
DB.Find(&products)
|
||||||
|
if count := DB.Model(Product{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(products)) {
|
||||||
|
t.Error("RowsAffected should be correct when do batch update")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.First(&product4, product4.Id)
|
||||||
|
updatedAt4 := product4.UpdatedAt
|
||||||
|
DB.Model(&product4).Update("price", gorm.Expr("price + ? - ?", 100, 50))
|
||||||
|
var product5 Product
|
||||||
|
DB.First(&product5, product4.Id)
|
||||||
|
if product5.Price != product4.Price+100-50 {
|
||||||
|
t.Errorf("Update with expression")
|
||||||
|
}
|
||||||
|
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("Update with expression should update UpdatedAt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
|
||||||
|
animal := Animal{Name: "Ferdinand"}
|
||||||
|
DB.Save(&animal)
|
||||||
|
updatedAt1 := animal.UpdatedAt
|
||||||
|
|
||||||
|
DB.Save(&animal).Update("name", "Francis")
|
||||||
|
|
||||||
|
if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updatedAt should not be updated if nothing changed")
|
||||||
|
}
|
||||||
|
|
||||||
|
var animals []Animal
|
||||||
|
DB.Find(&animals)
|
||||||
|
if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) {
|
||||||
|
t.Error("RowsAffected should be correct when do batch update")
|
||||||
|
}
|
||||||
|
|
||||||
|
animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone)
|
||||||
|
DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
|
||||||
|
DB.First(&animal, animal.Counter)
|
||||||
|
if animal.Name != "galeone" {
|
||||||
|
t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When changing a field with a default value, the change must occur
|
||||||
|
animal.Name = "amazing horse"
|
||||||
|
DB.Save(&animal)
|
||||||
|
DB.First(&animal, animal.Counter)
|
||||||
|
if animal.Name != "amazing horse" {
|
||||||
|
t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// When changing a field with a default value with blank value
|
||||||
|
animal.Name = ""
|
||||||
|
DB.Save(&animal)
|
||||||
|
DB.First(&animal, animal.Counter)
|
||||||
|
if animal.Name != "" {
|
||||||
|
t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdates(t *testing.T) {
|
||||||
|
product1 := Product{Code: "product1code", Price: 10}
|
||||||
|
product2 := Product{Code: "product2code", Price: 10}
|
||||||
|
DB.Save(&product1).Save(&product2)
|
||||||
|
DB.Model(&product1).Updates(map[string]interface{}{"code": "product1newcode", "price": 100})
|
||||||
|
if product1.Code != "product1newcode" || product1.Price != 100 {
|
||||||
|
t.Errorf("Record should be updated also with map")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.First(&product1, product1.Id)
|
||||||
|
DB.First(&product2, product2.Id)
|
||||||
|
updatedAt2 := product2.UpdatedAt
|
||||||
|
|
||||||
|
if DB.First(&Product{}, "code = ? and price = ?", product2.Code, product2.Price).RecordNotFound() {
|
||||||
|
t.Errorf("Product2 should not be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.First(&Product{}, "code = ?", "product1newcode").RecordNotFound() {
|
||||||
|
t.Errorf("Product1 should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Table("products").Where("code in (?)", []string{"product2code"}).Updates(Product{Code: "product2newcode"})
|
||||||
|
if !DB.First(&Product{}, "code = 'product2code'").RecordNotFound() {
|
||||||
|
t.Errorf("Product2's code should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
var product4 Product
|
||||||
|
DB.First(&product4, product2.Id)
|
||||||
|
if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updatedAt should be updated if something changed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if DB.First(&Product{}, "code = ?", "product2newcode").RecordNotFound() {
|
||||||
|
t.Errorf("product2's code should be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedAt4 := product4.UpdatedAt
|
||||||
|
DB.Model(&product4).Updates(map[string]interface{}{"price": gorm.Expr("price + ?", 100)})
|
||||||
|
var product5 Product
|
||||||
|
DB.First(&product5, product4.Id)
|
||||||
|
if product5.Price != product4.Price+100 {
|
||||||
|
t.Errorf("Updates with expression")
|
||||||
|
}
|
||||||
|
// product4's UpdatedAt will be reset when updating
|
||||||
|
if product4.UpdatedAt.Format(time.RFC3339Nano) == updatedAt4.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("Updates with expression should update UpdatedAt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateColumn(t *testing.T) {
|
||||||
|
product1 := Product{Code: "product1code", Price: 10}
|
||||||
|
product2 := Product{Code: "product2code", Price: 20}
|
||||||
|
DB.Save(&product1).Save(&product2).UpdateColumn(map[string]interface{}{"code": "product2newcode", "price": 100})
|
||||||
|
if product2.Code != "product2newcode" || product2.Price != 100 {
|
||||||
|
t.Errorf("product 2 should be updated with update column")
|
||||||
|
}
|
||||||
|
|
||||||
|
var product3 Product
|
||||||
|
DB.First(&product3, product1.Id)
|
||||||
|
if product3.Code != "product1code" || product3.Price != 10 {
|
||||||
|
t.Errorf("product 1 should not be updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.First(&product2, product2.Id)
|
||||||
|
updatedAt2 := product2.UpdatedAt
|
||||||
|
DB.Model(product2).UpdateColumn("code", "update_column_new")
|
||||||
|
var product4 Product
|
||||||
|
DB.First(&product4, product2.Id)
|
||||||
|
if updatedAt2.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("updatedAt should not be updated with update column")
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Model(&product4).UpdateColumn("price", gorm.Expr("price + 100 - 50"))
|
||||||
|
var product5 Product
|
||||||
|
DB.First(&product5, product4.Id)
|
||||||
|
if product5.Price != product4.Price+100-50 {
|
||||||
|
t.Errorf("UpdateColumn with expression")
|
||||||
|
}
|
||||||
|
if product5.UpdatedAt.Format(time.RFC3339Nano) != product4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
|
t.Errorf("UpdateColumn with expression should not update UpdatedAt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithUpdate(t *testing.T) {
|
||||||
|
user := getPreparedUser("select_user", "select_with_update")
|
||||||
|
DB.Create(user)
|
||||||
|
|
||||||
|
var reloadUser User
|
||||||
|
DB.First(&reloadUser, user.Id)
|
||||||
|
reloadUser.Name = "new_name"
|
||||||
|
reloadUser.Age = 50
|
||||||
|
reloadUser.BillingAddress = Address{Address1: "New Billing Address"}
|
||||||
|
reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"}
|
||||||
|
reloadUser.CreditCard = CreditCard{Number: "987654321"}
|
||||||
|
reloadUser.Emails = []Email{
|
||||||
|
{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
|
||||||
|
}
|
||||||
|
reloadUser.Company = Company{Name: "new company"}
|
||||||
|
|
||||||
|
DB.Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser)
|
||||||
|
|
||||||
|
var queryUser User
|
||||||
|
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
|
||||||
|
|
||||||
|
if queryUser.Name == user.Name || queryUser.Age != user.Age {
|
||||||
|
t.Errorf("Should only update users with name column")
|
||||||
|
}
|
||||||
|
|
||||||
|
if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 ||
|
||||||
|
queryUser.ShippingAddressId != user.ShippingAddressId ||
|
||||||
|
queryUser.CreditCard.ID == user.CreditCard.ID ||
|
||||||
|
len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id {
|
||||||
|
t.Errorf("Should only update selected relationships")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithUpdateWithMap(t *testing.T) {
|
||||||
|
user := getPreparedUser("select_user", "select_with_update_map")
|
||||||
|
DB.Create(user)
|
||||||
|
|
||||||
|
updateValues := map[string]interface{}{
|
||||||
|
"Name": "new_name",
|
||||||
|
"Age": 50,
|
||||||
|
"BillingAddress": Address{Address1: "New Billing Address"},
|
||||||
|
"ShippingAddress": Address{Address1: "New ShippingAddress Address"},
|
||||||
|
"CreditCard": CreditCard{Number: "987654321"},
|
||||||
|
"Emails": []Email{
|
||||||
|
{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
|
||||||
|
},
|
||||||
|
"Company": Company{Name: "new company"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var reloadUser User
|
||||||
|
DB.First(&reloadUser, user.Id)
|
||||||
|
DB.Model(&reloadUser).Select("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues)
|
||||||
|
|
||||||
|
var queryUser User
|
||||||
|
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
|
||||||
|
|
||||||
|
if queryUser.Name == user.Name || queryUser.Age != user.Age {
|
||||||
|
t.Errorf("Should only update users with name column")
|
||||||
|
}
|
||||||
|
|
||||||
|
if queryUser.BillingAddressID.Int64 == user.BillingAddressID.Int64 ||
|
||||||
|
queryUser.ShippingAddressId != user.ShippingAddressId ||
|
||||||
|
queryUser.CreditCard.ID == user.CreditCard.ID ||
|
||||||
|
len(queryUser.Emails) == len(user.Emails) || queryUser.Company.Id == user.Company.Id {
|
||||||
|
t.Errorf("Should only update selected relationships")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOmitWithUpdate(t *testing.T) {
|
||||||
|
user := getPreparedUser("omit_user", "omit_with_update")
|
||||||
|
DB.Create(user)
|
||||||
|
|
||||||
|
var reloadUser User
|
||||||
|
DB.First(&reloadUser, user.Id)
|
||||||
|
reloadUser.Name = "new_name"
|
||||||
|
reloadUser.Age = 50
|
||||||
|
reloadUser.BillingAddress = Address{Address1: "New Billing Address"}
|
||||||
|
reloadUser.ShippingAddress = Address{Address1: "New ShippingAddress Address"}
|
||||||
|
reloadUser.CreditCard = CreditCard{Number: "987654321"}
|
||||||
|
reloadUser.Emails = []Email{
|
||||||
|
{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
|
||||||
|
}
|
||||||
|
reloadUser.Company = Company{Name: "new company"}
|
||||||
|
|
||||||
|
DB.Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Save(&reloadUser)
|
||||||
|
|
||||||
|
var queryUser User
|
||||||
|
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
|
||||||
|
|
||||||
|
if queryUser.Name != user.Name || queryUser.Age == user.Age {
|
||||||
|
t.Errorf("Should only update users with name column")
|
||||||
|
}
|
||||||
|
|
||||||
|
if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 ||
|
||||||
|
queryUser.ShippingAddressId == user.ShippingAddressId ||
|
||||||
|
queryUser.CreditCard.ID != user.CreditCard.ID ||
|
||||||
|
len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
|
||||||
|
t.Errorf("Should only update relationships that not omited")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOmitWithUpdateWithMap(t *testing.T) {
|
||||||
|
user := getPreparedUser("select_user", "select_with_update_map")
|
||||||
|
DB.Create(user)
|
||||||
|
|
||||||
|
updateValues := map[string]interface{}{
|
||||||
|
"Name": "new_name",
|
||||||
|
"Age": 50,
|
||||||
|
"BillingAddress": Address{Address1: "New Billing Address"},
|
||||||
|
"ShippingAddress": Address{Address1: "New ShippingAddress Address"},
|
||||||
|
"CreditCard": CreditCard{Number: "987654321"},
|
||||||
|
"Emails": []Email{
|
||||||
|
{Email: "new_user_1@example1.com"}, {Email: "new_user_2@example2.com"}, {Email: "new_user_3@example2.com"},
|
||||||
|
},
|
||||||
|
"Company": Company{Name: "new company"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var reloadUser User
|
||||||
|
DB.First(&reloadUser, user.Id)
|
||||||
|
DB.Model(&reloadUser).Omit("Name", "BillingAddress", "CreditCard", "Company", "Emails").Update(updateValues)
|
||||||
|
|
||||||
|
var queryUser User
|
||||||
|
DB.Preload("BillingAddress").Preload("ShippingAddress").
|
||||||
|
Preload("CreditCard").Preload("Emails").Preload("Company").First(&queryUser, user.Id)
|
||||||
|
|
||||||
|
if queryUser.Name != user.Name || queryUser.Age == user.Age {
|
||||||
|
t.Errorf("Should only update users with name column")
|
||||||
|
}
|
||||||
|
|
||||||
|
if queryUser.BillingAddressID.Int64 != user.BillingAddressID.Int64 ||
|
||||||
|
queryUser.ShippingAddressId == user.ShippingAddressId ||
|
||||||
|
queryUser.CreditCard.ID != user.CreditCard.ID ||
|
||||||
|
len(queryUser.Emails) != len(user.Emails) || queryUser.Company.Id != user.Company.Id {
|
||||||
|
t.Errorf("Should only update relationships not omited")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectWithUpdateColumn(t *testing.T) {
|
||||||
|
user := getPreparedUser("select_user", "select_with_update_map")
|
||||||
|
DB.Create(user)
|
||||||
|
|
||||||
|
updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
|
||||||
|
|
||||||
|
var reloadUser User
|
||||||
|
DB.First(&reloadUser, user.Id)
|
||||||
|
DB.Model(&reloadUser).Select("Name").UpdateColumn(updateValues)
|
||||||
|
|
||||||
|
var queryUser User
|
||||||
|
DB.First(&queryUser, user.Id)
|
||||||
|
|
||||||
|
if queryUser.Name == user.Name || queryUser.Age != user.Age {
|
||||||
|
t.Errorf("Should only update users with name column")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOmitWithUpdateColumn(t *testing.T) {
|
||||||
|
user := getPreparedUser("select_user", "select_with_update_map")
|
||||||
|
DB.Create(user)
|
||||||
|
|
||||||
|
updateValues := map[string]interface{}{"Name": "new_name", "Age": 50}
|
||||||
|
|
||||||
|
var reloadUser User
|
||||||
|
DB.First(&reloadUser, user.Id)
|
||||||
|
DB.Model(&reloadUser).Omit("Name").UpdateColumn(updateValues)
|
||||||
|
|
||||||
|
var queryUser User
|
||||||
|
DB.First(&queryUser, user.Id)
|
||||||
|
|
||||||
|
if queryUser.Name != user.Name || queryUser.Age == user.Age {
|
||||||
|
t.Errorf("Should omit name column when update user")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateColumnsSkipsAssociations(t *testing.T) {
|
||||||
|
user := getPreparedUser("update_columns_user", "special_role")
|
||||||
|
user.Age = 99
|
||||||
|
address1 := "first street"
|
||||||
|
user.BillingAddress = Address{Address1: address1}
|
||||||
|
DB.Save(user)
|
||||||
|
|
||||||
|
// Update a single field of the user and verify that the changed address is not stored.
|
||||||
|
newAge := int64(100)
|
||||||
|
user.BillingAddress.Address1 = "second street"
|
||||||
|
db := DB.Model(user).UpdateColumns(User{Age: newAge})
|
||||||
|
if db.RowsAffected != 1 {
|
||||||
|
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that Age now=`newAge`.
|
||||||
|
freshUser := &User{Id: user.Id}
|
||||||
|
DB.First(freshUser)
|
||||||
|
if freshUser.Age != newAge {
|
||||||
|
t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that user's BillingAddress.Address1 is not changed and is still "first street".
|
||||||
|
DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID)
|
||||||
|
if freshUser.BillingAddress.Address1 != address1 {
|
||||||
|
t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatesWithBlankValues(t *testing.T) {
|
||||||
|
product := Product{Code: "product1", Price: 10}
|
||||||
|
DB.Save(&product)
|
||||||
|
|
||||||
|
DB.Model(&Product{Id: product.Id}).Updates(&Product{Price: 100})
|
||||||
|
|
||||||
|
var product1 Product
|
||||||
|
DB.First(&product1, product.Id)
|
||||||
|
|
||||||
|
if product1.Code != "product1" || product1.Price != 100 {
|
||||||
|
t.Errorf("product's code should not be updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ElementWithIgnoredField struct {
|
||||||
|
Id int64
|
||||||
|
Value string
|
||||||
|
IgnoredField int64 `sql:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ElementWithIgnoredField) TableName() string {
|
||||||
|
return "element_with_ignored_field"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatesTableWithIgnoredValues(t *testing.T) {
|
||||||
|
elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10}
|
||||||
|
DB.Save(&elem)
|
||||||
|
|
||||||
|
DB.Table(elem.TableName()).
|
||||||
|
Where("id = ?", elem.Id).
|
||||||
|
// DB.Model(&ElementWithIgnoredField{Id: elem.Id}).
|
||||||
|
Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100})
|
||||||
|
|
||||||
|
var elem1 ElementWithIgnoredField
|
||||||
|
err := DB.First(&elem1, elem.Id).Error
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error getting an element from database: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if elem1.IgnoredField != 0 {
|
||||||
|
t.Errorf("element's ignored field should not be updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateDecodeVirtualAttributes(t *testing.T) {
|
||||||
|
var user = User{
|
||||||
|
Name: "jinzhu",
|
||||||
|
IgnoreMe: 88,
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Save(&user)
|
||||||
|
|
||||||
|
DB.Model(&user).Updates(User{Name: "jinzhu2", IgnoreMe: 100})
|
||||||
|
|
||||||
|
if user.IgnoreMe != 100 {
|
||||||
|
t.Errorf("should decode virtual attributes to struct, so it could be used in callbacks")
|
||||||
|
}
|
||||||
|
}
|
264
orm/utils.go
Normal file
264
orm/utils.go
Normal file
|
@ -0,0 +1,264 @@
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NowFunc returns current time, this function is exported in order to be able
|
||||||
|
// to give the flexibility to the developer to customize it according to their
|
||||||
|
// needs, e.g:
|
||||||
|
// gorm.NowFunc = func() time.Time {
|
||||||
|
// return time.Now().UTC()
|
||||||
|
// }
|
||||||
|
var NowFunc = func() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copied from golint
|
||||||
|
var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||||
|
var commonInitialismsReplacer *strings.Replacer
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var commonInitialismsForReplacer []string
|
||||||
|
for _, initialism := range commonInitialisms {
|
||||||
|
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||||
|
}
|
||||||
|
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type safeMap struct {
|
||||||
|
m map[string]string
|
||||||
|
l *sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeMap) Set(key string, value string) {
|
||||||
|
s.l.Lock()
|
||||||
|
defer s.l.Unlock()
|
||||||
|
s.m[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeMap) Get(key string) string {
|
||||||
|
s.l.RLock()
|
||||||
|
defer s.l.RUnlock()
|
||||||
|
return s.m[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSafeMap() *safeMap {
|
||||||
|
return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
var smap = newSafeMap()
|
||||||
|
|
||||||
|
type strCase bool
|
||||||
|
|
||||||
|
const (
|
||||||
|
lower strCase = false
|
||||||
|
upper strCase = true
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToDBName convert string to db name
|
||||||
|
func ToDBName(name string) string {
|
||||||
|
if v := smap.Get(name); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
value = commonInitialismsReplacer.Replace(name)
|
||||||
|
buf = bytes.NewBufferString("")
|
||||||
|
lastCase, currCase, nextCase strCase
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, v := range value[:len(value)-1] {
|
||||||
|
nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z')
|
||||||
|
if i > 0 {
|
||||||
|
if currCase == upper {
|
||||||
|
if lastCase == upper && nextCase == upper {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
} else {
|
||||||
|
if value[i-1] != '_' && value[i+1] != '_' {
|
||||||
|
buf.WriteRune('_')
|
||||||
|
}
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
currCase = upper
|
||||||
|
buf.WriteRune(v)
|
||||||
|
}
|
||||||
|
lastCase = currCase
|
||||||
|
currCase = nextCase
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteByte(value[len(value)-1])
|
||||||
|
|
||||||
|
s := strings.ToLower(buf.String())
|
||||||
|
smap.Set(name, s)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQL expression
|
||||||
|
type expr struct {
|
||||||
|
expr string
|
||||||
|
args []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expr generate raw SQL expression, for example:
|
||||||
|
// DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100))
|
||||||
|
func Expr(expression string, args ...interface{}) *expr {
|
||||||
|
return &expr{expr: expression, args: args}
|
||||||
|
}
|
||||||
|
|
||||||
|
func indirect(reflectValue reflect.Value) reflect.Value {
|
||||||
|
for reflectValue.Kind() == reflect.Ptr {
|
||||||
|
reflectValue = reflectValue.Elem()
|
||||||
|
}
|
||||||
|
return reflectValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||||
|
var results []string
|
||||||
|
|
||||||
|
for _, primaryValue := range primaryValues {
|
||||||
|
var marks []string
|
||||||
|
for _,_ = range primaryValue {
|
||||||
|
marks = append(marks, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(marks) > 1 {
|
||||||
|
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||||
|
} else {
|
||||||
|
results = append(results, strings.Join(marks, ""))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(results, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryCondition(scope *Scope, columns []string) string {
|
||||||
|
var newColumns []string
|
||||||
|
for _, column := range columns {
|
||||||
|
newColumns = append(newColumns, scope.Quote(column))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columns) > 1 {
|
||||||
|
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||||
|
}
|
||||||
|
return strings.Join(newColumns, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func toQueryValues(values [][]interface{}) (results []interface{}) {
|
||||||
|
for _, value := range values {
|
||||||
|
for _, v := range value {
|
||||||
|
results = append(results, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileWithLineNum() string {
|
||||||
|
for i := 2; i < 15; i++ {
|
||||||
|
_, file, line, ok := runtime.Caller(i)
|
||||||
|
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
|
||||||
|
return fmt.Sprintf("%v:%v", file, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBlank(value reflect.Value) bool {
|
||||||
|
return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSearchableMap(attrs ...interface{}) (result interface{}) {
|
||||||
|
if len(attrs) > 1 {
|
||||||
|
if str, ok := attrs[0].(string); ok {
|
||||||
|
result = map[string]interface{}{str: attrs[1]}
|
||||||
|
}
|
||||||
|
} else if len(attrs) == 1 {
|
||||||
|
if attr, ok := attrs[0].(map[string]interface{}); ok {
|
||||||
|
result = attr
|
||||||
|
}
|
||||||
|
|
||||||
|
if attr, ok := attrs[0].(interface{}); ok {
|
||||||
|
result = attr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func equalAsString(a interface{}, b interface{}) bool {
|
||||||
|
return toString(a) == toString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toString(str interface{}) string {
|
||||||
|
if values, ok := str.([]interface{}); ok {
|
||||||
|
var results []string
|
||||||
|
for _, value := range values {
|
||||||
|
results = append(results, toString(value))
|
||||||
|
}
|
||||||
|
return strings.Join(results, "_")
|
||||||
|
} else if bytes, ok := str.([]byte); ok {
|
||||||
|
return string(bytes)
|
||||||
|
} else if reflectValue := reflect.Indirect(reflect.ValueOf(str)); reflectValue.IsValid() {
|
||||||
|
return fmt.Sprintf("%v", reflectValue.Interface())
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeSlice(elemType reflect.Type) interface{} {
|
||||||
|
if elemType.Kind() == reflect.Slice {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
sliceType := reflect.SliceOf(elemType)
|
||||||
|
slice := reflect.New(sliceType)
|
||||||
|
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||||
|
return slice.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
func strInSlice(a string, list []string) bool {
|
||||||
|
for _, b := range list {
|
||||||
|
if b == a {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// getValueFromFields return given fields's value
|
||||||
|
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
|
||||||
|
// If value is a nil pointer, Indirect returns a zero Value!
|
||||||
|
// Therefor we need to check for a zero value,
|
||||||
|
// as FieldByName could panic
|
||||||
|
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
|
||||||
|
for _, fieldName := range fieldNames {
|
||||||
|
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
|
||||||
|
result := fieldValue.Interface()
|
||||||
|
if r, ok := result.(driver.Valuer); ok {
|
||||||
|
result, _ = r.Value()
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func addExtraSpaceIfExist(str string) string {
|
||||||
|
if str != "" {
|
||||||
|
return " " + str
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
30
orm/utils_test.go
Normal file
30
orm/utils_test.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jinzhu/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToDBNameGenerateFriendlyName(t *testing.T) {
|
||||||
|
var maps = map[string]string{
|
||||||
|
"": "",
|
||||||
|
"ThisIsATest": "this_is_a_test",
|
||||||
|
"PFAndESI": "pf_and_esi",
|
||||||
|
"AbcAndJkl": "abc_and_jkl",
|
||||||
|
"EmployeeID": "employee_id",
|
||||||
|
"SKU_ID": "sku_id",
|
||||||
|
"HTTPAndSMTP": "http_and_smtp",
|
||||||
|
"HTTPServerHandlerForURLID": "http_server_handler_for_url_id",
|
||||||
|
"UUID": "uuid",
|
||||||
|
"HTTPURL": "http_url",
|
||||||
|
"HTTP_URL": "http_url",
|
||||||
|
"ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range maps {
|
||||||
|
if gorm.ToDBName(key) != value {
|
||||||
|
t.Errorf("%v ToDBName should equal %v, but got %v", key, value, gorm.ToDBName(key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
53
orm/wercker.yml
Normal file
53
orm/wercker.yml
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
# use the default golang container from Docker Hub
|
||||||
|
box: golang
|
||||||
|
|
||||||
|
services:
|
||||||
|
- id: mariadb:10.0
|
||||||
|
env:
|
||||||
|
MYSQL_DATABASE: gorm
|
||||||
|
MYSQL_USER: gorm
|
||||||
|
MYSQL_PASSWORD: gorm
|
||||||
|
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
|
||||||
|
- id: postgres
|
||||||
|
env:
|
||||||
|
POSTGRES_USER: gorm
|
||||||
|
POSTGRES_PASSWORD: gorm
|
||||||
|
POSTGRES_DB: gorm
|
||||||
|
|
||||||
|
# The steps that will be executed in the build pipeline
|
||||||
|
build:
|
||||||
|
# The steps that will be executed on build
|
||||||
|
steps:
|
||||||
|
# Sets the go workspace and places you package
|
||||||
|
# at the right place in the workspace tree
|
||||||
|
- setup-go-workspace
|
||||||
|
|
||||||
|
# Gets the dependencies
|
||||||
|
- script:
|
||||||
|
name: go get
|
||||||
|
code: |
|
||||||
|
cd $WERCKER_SOURCE_DIR
|
||||||
|
go version
|
||||||
|
go get -t ./...
|
||||||
|
|
||||||
|
# Build the project
|
||||||
|
- script:
|
||||||
|
name: go build
|
||||||
|
code: |
|
||||||
|
go build ./...
|
||||||
|
|
||||||
|
# Test the project
|
||||||
|
- script:
|
||||||
|
name: test sqlite
|
||||||
|
code: |
|
||||||
|
go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test mysql
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=mysql GORM_DBADDRESS=mariadb:3306 go test ./...
|
||||||
|
|
||||||
|
- script:
|
||||||
|
name: test postgres
|
||||||
|
code: |
|
||||||
|
GORM_DIALECT=postgres GORM_DBHOST=postgres go test ./...
|
Loading…
Reference in New Issue
Block a user