pax_global_header00006660000000000000000000000064150152663170014520gustar00rootroot0000000000000052 comment=1ffb5a75de20776ce5d266d62436e23ec0cfe60e postgres-1.6.0/000077500000000000000000000000001501526631700133725ustar00rootroot00000000000000postgres-1.6.0/.github/000077500000000000000000000000001501526631700147325ustar00rootroot00000000000000postgres-1.6.0/.github/dependabot.yml000066400000000000000000000010201501526631700175530ustar00rootroot00000000000000# To get started with Dependabot version updates, you'll need to specify which # package ecosystems to update and where the package manifests are located. # Please see the documentation for all configuration options: # https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "daily" - package-ecosystem: "gomod" directory: "/" schedule: interval: "daily" postgres-1.6.0/.github/workflows/000077500000000000000000000000001501526631700167675ustar00rootroot00000000000000postgres-1.6.0/.github/workflows/tests.yml000066400000000000000000000010561501526631700206560ustar00rootroot00000000000000name: tests on: push: pull_request: permissions: contents: read jobs: run-tests: strategy: matrix: go: ['1.20'] platform: [ubuntu-latest] runs-on: ubuntu-latest steps: - name: Set up Go 1.x uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v4 # Run build of the application - name: Run build run: go build . - name: Run tests run: go test -race -count=1 -v ./... postgres-1.6.0/.gitignore000066400000000000000000000000061501526631700153560ustar00rootroot00000000000000.idea postgres-1.6.0/License000066400000000000000000000021111501526631700146720ustar00rootroot00000000000000The MIT License (MIT) Copyright (c) 2013-NOW Jinzhu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. postgres-1.6.0/README.md000066400000000000000000000014661501526631700146600ustar00rootroot00000000000000# GORM PostgreSQL Driver ## Quick Start ```go import ( "gorm.io/driver/postgres" "gorm.io/gorm" ) // https://github.com/jackc/pgx dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) ``` ## Configuration ```go import ( "gorm.io/driver/postgres" "gorm.io/gorm" ) db, err := gorm.Open(postgres.New(postgres.Config{ DSN: "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai", // data source name, refer https://github.com/jackc/pgx PreferSimpleProtocol: true, // disables implicit prepared statement usage. By default pgx automatically uses the extended protocol }), &gorm.Config{}) ``` Checkout [https://gorm.io](https://gorm.io) for details. postgres-1.6.0/error_translator.go000066400000000000000000000024051501526631700173240ustar00rootroot00000000000000package postgres import ( "encoding/json" "gorm.io/gorm" "github.com/jackc/pgx/v5/pgconn" ) // The error codes to map PostgreSQL errors to gorm errors, here is the PostgreSQL error codes reference https://www.postgresql.org/docs/current/errcodes-appendix.html. var errCodes = map[string]error{ "23505": gorm.ErrDuplicatedKey, "23503": gorm.ErrForeignKeyViolated, "42703": gorm.ErrInvalidField, "23514": gorm.ErrCheckConstraintViolated, } type ErrMessage struct { Code string Severity string Message string } // Translate it will translate the error to native gorm errors. // Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback. func (dialector Dialector) Translate(err error) error { if pgErr, ok := err.(*pgconn.PgError); ok { if translatedErr, found := errCodes[pgErr.Code]; found { return translatedErr } return err } parsedErr, marshalErr := json.Marshal(err) if marshalErr != nil { return err } var errMsg ErrMessage unmarshalErr := json.Unmarshal(parsedErr, &errMsg) if unmarshalErr != nil { return err } if translatedErr, found := errCodes[errMsg.Code]; found { return translatedErr } return err } postgres-1.6.0/error_translator_test.go000066400000000000000000000024701501526631700203650ustar00rootroot00000000000000package postgres import ( "errors" "testing" "github.com/jackc/pgx/v5/pgconn" "gorm.io/gorm" ) func TestDialector_Translate(t *testing.T) { type fields struct { Config *Config } type args struct { err error } tests := []struct { name string fields fields args args want error }{ { name: "it should return ErrDuplicatedKey error if the status code is 23505", args: args{err: &pgconn.PgError{Code: "23505"}}, want: gorm.ErrDuplicatedKey, }, { name: "it should return ErrForeignKeyViolated error if the status code is 23503", args: args{err: &pgconn.PgError{Code: "23503"}}, want: gorm.ErrForeignKeyViolated, }, { name: "it should return gorm.ErrInvalidField error if the status code is 42703", args: args{err: &pgconn.PgError{Code: "42703"}}, want: gorm.ErrInvalidField, }, { name: "it should return gorm.ErrCheckConstraintViolated error if the status code is 23514", args: args{err: &pgconn.PgError{Code: "23514"}}, want: gorm.ErrCheckConstraintViolated, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dialector := Dialector{ Config: tt.fields.Config, } if err := dialector.Translate(tt.args.err); !errors.Is(err, tt.want) { t.Errorf("Translate() expected error = %v, got error %v", err, tt.want) } }) } } postgres-1.6.0/go.mod000066400000000000000000000010441501526631700144770ustar00rootroot00000000000000module gorm.io/driver/postgres go 1.20 require ( github.com/jackc/pgx/v5 v5.6.0 gorm.io/gorm v1.25.10 ) require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/text v0.21.0 // indirect ) retract v1.5.5 // Published accidentally. postgres-1.6.0/go.sum000066400000000000000000000051711501526631700145310ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= postgres-1.6.0/migrator.go000066400000000000000000000660041501526631700155530ustar00rootroot00000000000000package postgres import ( "database/sql" "fmt" "github.com/jackc/pgx/v5" "regexp" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" ) // See https://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql // Here are some changes: // - use `LEFT JOIN` instead of `CROSS JOIN` // - exclude indexes used to support constraints (they are auto-generated) const indexSql = ` SELECT ct.relname AS table_name, ci.relname AS index_name, i.indisunique AS non_unique, i.indisprimary AS primary, a.attname AS column_name FROM pg_index i LEFT JOIN pg_class ct ON ct.oid = i.indrelid LEFT JOIN pg_class ci ON ci.oid = i.indexrelid LEFT JOIN pg_attribute a ON a.attrelid = ct.oid LEFT JOIN pg_constraint con ON con.conindid = i.indexrelid WHERE a.attnum = ANY(i.indkey) AND con.oid IS NULL AND ct.relkind = 'r' AND ct.relname = ? ` var typeAliasMap = map[string][]string{ "int": {"integer"}, "int2": {"smallint"}, "int4": {"integer"}, "int8": {"bigint"}, "smallint": {"int2"}, "integer": {"int4"}, "bigint": {"int8"}, "date": {"date"}, "decimal": {"numeric"}, "numeric": {"decimal"}, "timestamp": {"timestamp"}, "timestamptz": {"timestamp with time zone"}, "timestamp without time zone": {"timestamp"}, "timestamp with time zone": {"timestamptz"}, "bool": {"boolean"}, "boolean": {"bool"}, "serial2": {"smallserial"}, "serial4": {"serial"}, "serial8": {"bigserial"}, "varbit": {"bit varying"}, "char": {"character"}, "varchar": {"character varying"}, "float4": {"real"}, "float8": {"double precision"}, "time": {"time"}, "timetz": {"time with time zone"}, "time without time zone": {"time"}, "time with time zone": {"timetz"}, } type Migrator struct { migrator.Migrator } // select querys ignore dryrun func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) { queryTx := m.DB if m.DB.DryRun { queryTx = m.DB.Session(&gorm.Session{}) queryTx.DryRun = false } return queryTx.Raw(sql, values...) } func (m Migrator) CurrentDatabase() (name string) { m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name) return } func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) if opt.Expression != "" { str = opt.Expression } if opt.Collate != "" { str += " COLLATE " + opt.Collate } if opt.Sort != "" { str += " " + opt.Sort } results = append(results, clause.Expr{SQL: str}) } return } func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) return m.queryRaw( "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, ).Scan(&count).Error }) return count > 0 } func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} createIndexSQL := "CREATE " if idx.Class != "" { createIndexSQL += idx.Class + " " } createIndexSQL += "INDEX " hasConcurrentOption := strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" if hasConcurrentOption { createIndexSQL += "CONCURRENTLY " } createIndexSQL += "IF NOT EXISTS ? ON ?" if idx.Type != "" { createIndexSQL += " USING " + idx.Type + "(?)" } else { createIndexSQL += " ?" } if idx.Option != "" && !hasConcurrentOption { createIndexSQL += " " + idx.Option } if idx.Where != "" { createIndexSQL += " WHERE " + idx.Where } return m.DB.Exec(createIndexSQL, values...).Error } } return fmt.Errorf("failed to create index with name %v", name) }) } func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER INDEX ? RENAME TO ?", clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } } return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error }) } func (m Migrator) GetTables() (tableList []string, err error) { currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error } func (m Migrator) CreateTable(values ...interface{}) (err error) { if err = m.Migrator.CreateTable(values...); err != nil { return } for _, value := range m.ReorderModels(values, false) { if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { for _, fieldName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[fieldName] if field.Comment != "" { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), ).Error; err != nil { return err } } } } return nil }); err != nil { return } } return } func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error }) return count > 0 } func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) tx := m.DB.Session(&gorm.Session{}) for i := len(values) - 1; i >= 0; i-- { if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error }); err != nil { return err } } return nil } func (m Migrator) AddColumn(value interface{}, field string) error { if err := m.Migrator.AddColumn(value, field); err != nil { return err } m.resetPreparedStmts() return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if field := stmt.Schema.LookUpField(field); field != nil { if field.Comment != "" { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), ).Error; err != nil { return err } } } } return nil }) } func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { name := field if stmt.Schema != nil { if field := stmt.Schema.LookUpField(field); field != nil { name = field.DBName } } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) return m.queryRaw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentSchema, curTable, name, ).Scan(&count).Error }) return count > 0 } func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // skip primary field if !field.PrimaryKey { if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil { return err } } return m.RunWithValue(value, func(stmt *gorm.Statement) error { var description string currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) values := []interface{}{currentSchema, curTable, field.DBName, stmt.Table, currentSchema} checkSQL := "SELECT description FROM pg_catalog.pg_description " checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" m.queryRaw(checkSQL, values...).Scan(&description) comment := strings.Trim(field.Comment, "'") comment = strings.Trim(comment, `"`) if field.Comment != "" && comment != description { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), ).Error; err != nil { return err } } return nil }) } // AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { err := m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if field := stmt.Schema.LookUpField(field); field != nil { var ( columnTypes, _ = m.DB.Migrator().ColumnTypes(value) fieldColumnType *migrator.ColumnType ) for _, columnType := range columnTypes { if columnType.Name() == field.DBName { fieldColumnType, _ = columnType.(*migrator.ColumnType) } } fileType := clause.Expr{SQL: m.DataTypeOf(field)} // check for typeName and SQL name isSameType := true if !strings.EqualFold(fieldColumnType.DatabaseTypeName(), fileType.SQL) { isSameType = false // if different, also check for aliases aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName()) for _, alias := range aliases { if strings.HasPrefix(fileType.SQL, alias) { isSameType = true break } } } // not same, migrate if !isSameType { filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement() if field.AutoIncrement && filedColumnAutoIncrement { // update serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType { if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { return err } } } else if field.AutoIncrement && !filedColumnAutoIncrement { // create serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { return err } } else if !field.AutoIncrement && filedColumnAutoIncrement { // delete if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil { return err } } else { if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil { return err } } } if null, _ := fieldColumnType.Nullable(); null == field.NotNull { if field.NotNull { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } else { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } } if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue { if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil { return err } } else if field.DefaultValue != "(-)" { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { return err } } else { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } } else if !field.HasDefaultValue { // case - as-is column has default value and to-be column has no default value // need to drop default if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } } return nil } } return fmt.Errorf("failed to look up field with name: %s", field) }) if err != nil { return err } m.resetPreparedStmts() return nil } func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error { alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?" isUncastableDefaultValue := false if targetType.SQL == "boolean" { switch existingColumn.DatabaseTypeName() { case "int2", "int8", "numeric": alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?" } isUncastableDefaultValue = true } if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil { return err } return nil } func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { name = constraint.GetName() } currentSchema, curTable := m.CurrentSchema(stmt, table) return m.queryRaw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", currentSchema, curTable, name, ).Scan(&count).Error }) return count > 0 } func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { columnTypes = make([]gorm.ColumnType, 0) err = m.RunWithValue(value, func(stmt *gorm.Statement) error { var ( currentDatabase = m.DB.Migrator().CurrentDatabase() currentSchema, table = m.CurrentSchema(stmt, stmt.Table) columns, err = m.queryRaw( "SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", currentDatabase, currentSchema, table).Rows() ) if err != nil { return err } for columns.Next() { var ( column = &migrator.ColumnType{ PrimaryKeyValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, } datetimePrecision sql.NullInt64 radixValue sql.NullInt64 typeLenValue sql.NullInt64 identityIncrement sql.NullString ) err = columns.Scan( &column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue, &radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue, &identityIncrement, ) if err != nil { return err } if typeLenValue.Valid && typeLenValue.Int64 > 0 { column.LengthValue = typeLenValue } autoIncrementValuePattern := regexp.MustCompile(`^nextval\('"?[^']+seq"?'::regclass\)$`) if autoIncrementValuePattern.MatchString(column.DefaultValueValue.String) || (identityIncrement.Valid && identityIncrement.String != "") { column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} column.DefaultValueValue = sql.NullString{} } if column.DefaultValueValue.Valid { column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String) } if datetimePrecision.Valid { column.DecimalSizeValue = datetimePrecision } columnTypes = append(columnTypes, column) } columns.Close() // assign sql column type { rows, rowsErr := m.GetRows(currentSchema, table) if rowsErr != nil { return rowsErr } rawColumnTypes, err := rows.ColumnTypes() if err != nil { return err } for _, columnType := range columnTypes { for _, c := range rawColumnTypes { if c.Name() == columnType.Name() { columnType.(*migrator.ColumnType).SQLColumnType = c break } } } rows.Close() } // check primary, unique field { columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() if err != nil { return err } uniqueContraints := map[string]int{} for columnTypeRows.Next() { var constraintName string columnTypeRows.Scan(&constraintName) uniqueContraints[constraintName]++ } columnTypeRows.Close() columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() if err != nil { return err } for columnTypeRows.Next() { var name, constraintName, columnType string columnTypeRows.Scan(&name, &constraintName, &columnType) for _, c := range columnTypes { mc := c.(*migrator.ColumnType) if mc.NameValue.String == name { switch columnType { case "PRIMARY KEY": mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} case "UNIQUE": if uniqueContraints[constraintName] == 1 { mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} } } break } } } columnTypeRows.Close() } // check column type { dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) WHERE a.attnum > 0 -- hide internal columns AND NOT a.attisdropped -- hide deleted columns AND b.relname = ?`, currentSchema, table).Rows() if err != nil { return err } for dataTypeRows.Next() { var name, dataType string dataTypeRows.Scan(&name, &dataType) for _, c := range columnTypes { mc := c.(*migrator.ColumnType) if mc.NameValue.String == name { mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true} // Handle array type: _text -> text[] , _int4 -> integer[] // Not support array size limits and array size limits because: // https://www.postgresql.org/docs/current/arrays.html#ARRAYS-DECLARATION if strings.HasPrefix(mc.DataTypeValue.String, "_") { mc.DataTypeValue = sql.NullString{String: dataType, Valid: true} } break } } } dataTypeRows.Close() } return err }) return } func (m Migrator) GetRows(currentSchema interface{}, table interface{}) (*sql.Rows, error) { name := table.(string) if _, ok := currentSchema.(string); ok { name = fmt.Sprintf("%v.%v", currentSchema, table) } return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Scopes(func(d *gorm.DB) *gorm.DB { dialector, _ := m.Dialector.(Dialector) // use simple protocol if !m.DB.PrepareStmt && (dialector.Config != nil && (dialector.Config.DriverName == "" || dialector.Config.DriverName == "pgx")) { d.Statement.Vars = append([]interface{}{pgx.QueryExecModeSimpleProtocol}, d.Statement.Vars...) } return d }).Rows() } func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}, interface{}) { if strings.Contains(table, ".") { if tables := strings.Split(table, `.`); len(tables) == 2 { return tables[0], tables[1] } } if stmt.TableExpr != nil { if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 { return strings.TrimPrefix(tables[0], `"`), table } } return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table } func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, serialDatabaseType string) (err error) { _, table := m.CurrentSchema(stmt, stmt.Table) tableName := table.(string) sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_") if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { return err } if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')", clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil { return err } if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?", clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil { return err } return } func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, serialDatabaseType string) (err error) { sequenceName, err := m.getColumnSequenceName(tx, stmt, field) if err != nil { return err } if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { return err } if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { return err } return } func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, fileType clause.Expr) (err error) { sequenceName, err := m.getColumnSequenceName(tx, stmt, field) if err != nil { return err } if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { return err } if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil { return err } if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil { return err } return } func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) ( sequenceName string, err error) { _, table := m.CurrentSchema(stmt, stmt.Table) // DefaultValueValue is reset by ColumnTypes, search again. var columnDefault string err = tx.Raw( `SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`, table, field.DBName).Scan(&columnDefault).Error if err != nil { return } sequenceName = strings.TrimSuffix( strings.TrimPrefix(columnDefault, `nextval('`), `'::regclass)`, ) return } func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { indexes := make([]gorm.Index, 0) err := m.RunWithValue(value, func(stmt *gorm.Statement) error { result := make([]*Index, 0) scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error if scanErr != nil { return scanErr } indexMap := groupByIndexName(result) for _, idx := range indexMap { tempIdx := &migrator.Index{ TableName: idx[0].TableName, NameValue: idx[0].IndexName, PrimaryKeyValue: sql.NullBool{ Bool: idx[0].Primary, Valid: true, }, UniqueValue: sql.NullBool{ Bool: idx[0].NonUnique, Valid: true, }, } for _, x := range idx { tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName) } indexes = append(indexes, tempIdx) } return nil }) return indexes, err } // Index table index info type Index struct { TableName string `gorm:"column:table_name"` ColumnName string `gorm:"column:column_name"` IndexName string `gorm:"column:index_name"` NonUnique bool `gorm:"column:non_unique"` Primary bool `gorm:"column:primary"` } func groupByIndexName(indexList []*Index) map[string][]*Index { columnIndexMap := make(map[string][]*Index, len(indexList)) for _, idx := range indexList { columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx) } return columnIndexMap } func (m Migrator) GetTypeAliases(databaseTypeName string) []string { return typeAliasMap[databaseTypeName] } // should reset prepared stmts when table changed func (m Migrator) resetPreparedStmts() { if m.DB.PrepareStmt { if pdb, ok := m.DB.ConnPool.(*gorm.PreparedStmtDB); ok { pdb.Reset() } } } func (m Migrator) DropColumn(dst interface{}, field string) error { if err := m.Migrator.DropColumn(dst, field); err != nil { return err } m.resetPreparedStmts() return nil } func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error { if err := m.Migrator.RenameColumn(dst, oldName, field); err != nil { return err } m.resetPreparedStmts() return nil } func parseDefaultValueValue(defaultValue string) string { value := regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1") return strings.Trim(value, "'") } postgres-1.6.0/migrator_test.go000066400000000000000000000046671501526631700166210ustar00rootroot00000000000000package postgres import "testing" func Test_parseDefaultValueValue(t *testing.T) { type args struct { defaultValue string } tests := []struct { name string args args want string }{ { name: "it should works with number without colons", args: args{defaultValue: "0"}, want: "0", }, { name: "it should works with number and two colons", args: args{defaultValue: "0::int8"}, want: "0", }, { name: "it should works with number and three colons", args: args{defaultValue: "0:::int8"}, want: "0", }, { name: "it should works with empty string without colons", args: args{defaultValue: "''"}, want: "", }, { name: "it should works with empty string with two colons", args: args{defaultValue: "''::character varying"}, want: "", }, { name: "it should works with empty string with three colons", args: args{defaultValue: "'':::character varying"}, want: "", }, { name: "it should works with string without colons", args: args{defaultValue: "'field'"}, want: "field", }, { name: "it should works with string with two colons", args: args{defaultValue: "'field'::character varying"}, want: "field", }, { name: "it should works with string with three colons", args: args{defaultValue: "'field':::character varying"}, want: "field", }, { name: "it should works with value with two colons", args: args{defaultValue: "field"}, want: "field", }, { name: "it should works with value without colons", args: args{defaultValue: "field::character varying"}, want: "field", }, { name: "it should works with value with three colons", args: args{defaultValue: "field:::character varying"}, want: "field", }, { name: "it should works with function without colons", args: args{defaultValue: "now()"}, want: "now()", }, { name: "it should works with function with two colons", args: args{defaultValue: "now()::timestamp without time zone"}, want: "now()", }, { name: "it should works with json without colons", args: args{defaultValue: "{}"}, want: "{}", }, { name: "it should works with json with two colons", args: args{defaultValue: "{}::jsonb"}, want: "{}", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := parseDefaultValueValue(tt.args.defaultValue); got != tt.want { t.Errorf("parseDefaultValueValue() = %v, want %v", got, tt.want) } }) } } postgres-1.6.0/postgres.go000066400000000000000000000160711501526631700155740ustar00rootroot00000000000000package postgres import ( "context" "database/sql" "fmt" "regexp" "strconv" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" ) type Dialector struct { *Config } type Config struct { DriverName string DSN string WithoutQuotingCheck bool PreferSimpleProtocol bool WithoutReturning bool Conn gorm.ConnPool } var ( timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone|timezone)=(.*?)($|&| )") defaultIdentifierLength = 63 //maximum identifier length for postgres ) func Open(dsn string) gorm.Dialector { return &Dialector{&Config{DSN: dsn}} } func New(config Config) gorm.Dialector { return &Dialector{Config: &config} } func (dialector Dialector) Name() string { return "postgres" } func (dialector Dialector) Apply(config *gorm.Config) error { if config.NamingStrategy == nil { config.NamingStrategy = schema.NamingStrategy{ IdentifierMaxLength: defaultIdentifierLength, } return nil } switch v := config.NamingStrategy.(type) { case *schema.NamingStrategy: if v.IdentifierMaxLength <= 0 { v.IdentifierMaxLength = defaultIdentifierLength } case schema.NamingStrategy: if v.IdentifierMaxLength <= 0 { v.IdentifierMaxLength = defaultIdentifierLength config.NamingStrategy = v } } return nil } func (dialector Dialector) Initialize(db *gorm.DB) (err error) { callbackConfig := &callbacks.Config{ CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"}, DeleteClauses: []string{"DELETE", "FROM", "WHERE"}, } // register callbacks if !dialector.WithoutReturning { callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING") callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING") callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING") } callbacks.RegisterDefaultCallbacks(db, callbackConfig) if dialector.Conn != nil { db.ConnPool = dialector.Conn } else if dialector.DriverName != "" { db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN) } else { var config *pgx.ConnConfig config, err = pgx.ParseConfig(dialector.Config.DSN) if err != nil { return } if dialector.Config.PreferSimpleProtocol { config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol } result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) var options []stdlib.OptionOpenDB if len(result) > 2 { config.RuntimeParams["timezone"] = result[2] options = append(options, stdlib.OptionAfterConnect(func(ctx context.Context, conn *pgx.Conn) error { loc, tzErr := time.LoadLocation(result[2]) if tzErr != nil { return tzErr } conn.TypeMap().RegisterType(&pgtype.Type{ Name: "timestamp", OID: pgtype.TimestampOID, Codec: &pgtype.TimestampCodec{ScanLocation: loc}, }) return nil })) } db.ConnPool = stdlib.OpenDB(*config, options...) } return } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ DB: db, Dialector: dialector, CreateIndexAfterCreateTable: true, }}} } func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { return clause.Expr{SQL: "DEFAULT"} } func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { writer.WriteByte('$') index := 0 varLen := len(stmt.Vars) if varLen > 0 { switch stmt.Vars[0].(type) { case pgx.QueryExecMode: index++ } } writer.WriteString(strconv.Itoa(varLen - index)) } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { if dialector.WithoutQuotingCheck { writer.WriteString(str) return } var ( underQuoted, selfQuoted bool continuousBacktick int8 shiftDelimiter int8 ) for _, v := range []byte(str) { switch v { case '"': continuousBacktick++ if continuousBacktick == 2 { writer.WriteString(`""`) continuousBacktick = 0 } case '.': if continuousBacktick > 0 || !selfQuoted { shiftDelimiter = 0 underQuoted = false continuousBacktick = 0 writer.WriteByte('"') } writer.WriteByte(v) continue default: if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { writer.WriteByte('"') underQuoted = true if selfQuoted = continuousBacktick > 0; selfQuoted { continuousBacktick -= 1 } } for ; continuousBacktick > 0; continuousBacktick -= 1 { writer.WriteString(`""`) } writer.WriteByte(v) } shiftDelimiter++ } if continuousBacktick > 0 && !selfQuoted { writer.WriteString(`""`) } writer.WriteByte('"') } var numericPlaceholder = regexp.MustCompile(`\$(\d+)`) func (dialector Dialector) Explain(sql string, vars ...interface{}) string { return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) } func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: return "boolean" case schema.Int, schema.Uint: size := field.Size if field.DataType == schema.Uint { size++ } if field.AutoIncrement { switch { case size <= 16: return "smallserial" case size <= 32: return "serial" default: return "bigserial" } } else { switch { case size <= 16: return "smallint" case size <= 32: return "integer" default: return "bigint" } } case schema.Float: if field.Precision > 0 { if field.Scale > 0 { return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale) } return fmt.Sprintf("numeric(%d)", field.Precision) } return "decimal" case schema.String: if field.Size > 0 && field.Size <= 10485760 { return fmt.Sprintf("varchar(%d)", field.Size) } return "text" case schema.Time: if field.Precision > 0 { return fmt.Sprintf("timestamptz(%d)", field.Precision) } return "timestamptz" case schema.Bytes: return "bytea" default: return dialector.getSchemaCustomType(field) } } func (dialector Dialector) getSchemaCustomType(field *schema.Field) string { sqlType := string(field.DataType) if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") { size := field.Size if field.GORMDataType == schema.Uint { size++ } switch { case size <= 16: sqlType = "smallserial" case size <= 32: sqlType = "serial" default: sqlType = "bigserial" } } return sqlType } func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error { tx.Exec("SAVEPOINT " + name) return nil } func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { tx.Exec("ROLLBACK TO SAVEPOINT " + name) return nil } func getSerialDatabaseType(s string) (dbType string, ok bool) { switch s { case "smallserial": return "smallint", true case "serial": return "integer", true case "bigserial": return "bigint", true default: return "", false } } postgres-1.6.0/postgres_test.go000066400000000000000000000023551501526631700166330ustar00rootroot00000000000000package postgres import ( "testing" "gorm.io/gorm/schema" ) func Test_DataTypeOf(t *testing.T) { type fields struct { Config *Config } type args struct { field *schema.Field } tests := []struct { name string fields fields args args want string } { { name: "it should return boolean", args: args{field: &schema.Field{DataType: schema.Bool}}, want: "boolean", }, { name: "it should return text -1", args: args{field: &schema.Field{DataType: schema.String, Size: -1}}, want: "text", }, { name: "it should return text > 10485760", args: args{field: &schema.Field{DataType: schema.String, Size: 12345678}}, want: "text", }, { name: "it should return varchar(100)", args: args{field: &schema.Field{DataType: schema.String, Size: 100}}, want: "varchar(100)", }, { name: "it should return varchar(10485760)", args: args{field: &schema.Field{DataType: schema.String, Size: 10485760}}, want: "varchar(10485760)", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dialector := Dialector{ Config: tt.fields.Config, } if got := dialector.DataTypeOf(tt.args.field); got != tt.want { t.Errorf("DataTypeOf() = %v, want %v", got, tt.want) } }) } }