Skip to content

Commit 04e12b8

Browse files
authored
feat: Add goose provider (#635)
1 parent 8503d4e commit 04e12b8

32 files changed

Lines changed: 1524 additions & 1689 deletions

database/dialect.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package database
22

33
import (
44
"context"
5+
"database/sql"
56
"errors"
67
"fmt"
78

@@ -100,6 +101,9 @@ func (s *store) GetMigration(
100101
&result.Timestamp,
101102
&result.IsApplied,
102103
); err != nil {
104+
if errors.Is(err, sql.ErrNoRows) {
105+
return nil, fmt.Errorf("%w: %d", ErrVersionNotFound, version)
106+
}
103107
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
104108
}
105109
return &result, nil

database/store.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@ package database
22

33
import (
44
"context"
5+
"errors"
56
"time"
67
)
78

9+
var (
10+
// ErrVersionNotFound must be returned by [GetMigration] when a migration version is not found.
11+
ErrVersionNotFound = errors.New("version not found")
12+
)
13+
814
// Store is an interface that defines methods for managing database migrations and versioning. By
915
// defining a Store interface, we can support multiple databases with consistent functionality.
1016
//
@@ -24,8 +30,8 @@ type Store interface {
2430
// Delete deletes a version id from the version table.
2531
Delete(ctx context.Context, db DBTxConn, version int64) error
2632

27-
// GetMigration retrieves a single migration by version id. This method may return the raw sql
28-
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
33+
// GetMigration retrieves a single migration by version id. If the query succeeds, but the
34+
// version is not found, this method must return [ErrVersionNotFound].
2935
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)
3036

3137
// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If

database/store_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ func testStore(
205205
err = runConn(ctx, db, func(conn *sql.Conn) error {
206206
_, err := store.GetMigration(ctx, conn, 0)
207207
check.HasError(t, err)
208-
check.Bool(t, errors.Is(err, sql.ErrNoRows), true)
208+
check.Bool(t, errors.Is(err, database.ErrVersionNotFound), true)
209209
return nil
210210
})
211211
check.NoError(t, err)

globals.go

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,20 @@ func ResetGlobalMigrations() {
2222
// [NewGoMigration] function.
2323
//
2424
// Not safe for concurrent use.
25-
func SetGlobalMigrations(migrations ...Migration) error {
26-
for _, migration := range migrations {
27-
m := &migration
25+
func SetGlobalMigrations(migrations ...*Migration) error {
26+
for _, m := range migrations {
2827
if _, ok := registeredGoMigrations[m.Version]; ok {
2928
return fmt.Errorf("go migration with version %d already registered", m.Version)
3029
}
31-
if err := checkMigration(m); err != nil {
30+
if err := checkGoMigration(m); err != nil {
3231
return fmt.Errorf("invalid go migration: %w", err)
3332
}
3433
registeredGoMigrations[m.Version] = m
3534
}
3635
return nil
3736
}
3837

39-
func checkMigration(m *Migration) error {
38+
func checkGoMigration(m *Migration) error {
4039
if !m.construct {
4140
return errors.New("must use NewGoMigration to construct migrations")
4241
}
@@ -63,10 +62,10 @@ func checkMigration(m *Migration) error {
6362
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
6463
}
6564
}
66-
if err := setGoFunc(m.goUp); err != nil {
65+
if err := checkGoFunc(m.goUp); err != nil {
6766
return fmt.Errorf("up function: %w", err)
6867
}
69-
if err := setGoFunc(m.goDown); err != nil {
68+
if err := checkGoFunc(m.goDown); err != nil {
7069
return fmt.Errorf("down function: %w", err)
7170
}
7271
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
@@ -84,47 +83,22 @@ func checkMigration(m *Migration) error {
8483
return nil
8584
}
8685

87-
func setGoFunc(f *GoFunc) error {
88-
if f == nil {
89-
f = &GoFunc{Mode: TransactionEnabled}
90-
return nil
91-
}
86+
func checkGoFunc(f *GoFunc) error {
9287
if f.RunTx != nil && f.RunDB != nil {
9388
return errors.New("must specify exactly one of RunTx or RunDB")
9489
}
95-
if f.RunTx == nil && f.RunDB == nil {
96-
switch f.Mode {
97-
case 0:
98-
// Default to TransactionEnabled ONLY if mode is not set explicitly.
99-
f.Mode = TransactionEnabled
100-
case TransactionEnabled, TransactionDisabled:
101-
// No functions but mode is set. This is not an error. It means the user wants to record
102-
// a version with the given mode but not run any functions.
103-
default:
104-
return fmt.Errorf("invalid mode: %d", f.Mode)
105-
}
106-
return nil
107-
}
108-
if f.RunDB != nil {
109-
switch f.Mode {
110-
case 0, TransactionDisabled:
111-
f.Mode = TransactionDisabled
112-
default:
113-
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
114-
}
90+
switch f.Mode {
91+
case TransactionEnabled, TransactionDisabled:
92+
// No functions, but mode is set. This is not an error. It means the user wants to
93+
// record a version with the given mode but not run any functions.
94+
default:
95+
return fmt.Errorf("invalid mode: %d", f.Mode)
11596
}
116-
if f.RunTx != nil {
117-
switch f.Mode {
118-
case 0, TransactionEnabled:
119-
f.Mode = TransactionEnabled
120-
default:
121-
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
122-
}
97+
if f.RunDB != nil && f.Mode != TransactionDisabled {
98+
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
12399
}
124-
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
125-
// the functions or return an error. This should never happen.
126-
if f.Mode == 0 {
127-
return errors.New("failed to infer transaction mode")
100+
if f.RunTx != nil && f.Mode != TransactionEnabled {
101+
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
128102
}
129103
return nil
130104
}

globals_test.go

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,15 @@ func TestTransactionMode(t *testing.T) {
104104
// reset so we can check the default is set
105105
migration2.goUp.Mode, migration2.goDown.Mode = 0, 0
106106
err = SetGlobalMigrations(migration2)
107-
check.NoError(t, err)
108-
check.Number(t, len(registeredGoMigrations), 2)
109-
registered = registeredGoMigrations[2]
110-
check.Bool(t, registered.goUp != nil, true)
111-
check.Bool(t, registered.goDown != nil, true)
112-
check.Equal(t, registered.goUp.Mode, TransactionEnabled)
113-
check.Equal(t, registered.goDown.Mode, TransactionEnabled)
107+
check.HasError(t, err)
108+
check.Contains(t, err.Error(), "invalid go migration: up function: invalid mode: 0")
109+
110+
migration3 := NewGoMigration(3, nil, nil)
111+
// reset so we can check the default is set
112+
migration3.goDown.Mode = 0
113+
err = SetGlobalMigrations(migration3)
114+
check.HasError(t, err)
115+
check.Contains(t, err.Error(), "invalid go migration: down function: invalid mode: 0")
114116
})
115117
t.Run("unknown_mode", func(t *testing.T) {
116118
m := NewGoMigration(1, nil, nil)
@@ -192,7 +194,7 @@ func TestGlobalRegister(t *testing.T) {
192194
runTx := func(context.Context, *sql.Tx) error { return nil }
193195

194196
// Success.
195-
err := SetGlobalMigrations([]Migration{}...)
197+
err := SetGlobalMigrations([]*Migration{}...)
196198
check.NoError(t, err)
197199
err = SetGlobalMigrations(
198200
NewGoMigration(1, &GoFunc{RunTx: runTx}, nil),
@@ -204,62 +206,79 @@ func TestGlobalRegister(t *testing.T) {
204206
)
205207
check.HasError(t, err)
206208
check.Contains(t, err.Error(), "go migration with version 1 already registered")
207-
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
209+
err = SetGlobalMigrations(&Migration{Registered: true, Version: 2, Type: TypeGo})
208210
check.HasError(t, err)
209211
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
210212
}
211213

212214
func TestCheckMigration(t *testing.T) {
215+
// Success.
216+
err := checkGoMigration(NewGoMigration(1, nil, nil))
217+
check.NoError(t, err)
213218
// Failures.
214-
err := checkMigration(&Migration{})
219+
err = checkGoMigration(&Migration{})
215220
check.HasError(t, err)
216221
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
217-
err = checkMigration(&Migration{construct: true})
222+
err = checkGoMigration(&Migration{construct: true})
218223
check.HasError(t, err)
219224
check.Contains(t, err.Error(), "must be registered")
220-
err = checkMigration(&Migration{construct: true, Registered: true})
225+
err = checkGoMigration(&Migration{construct: true, Registered: true})
221226
check.HasError(t, err)
222227
check.Contains(t, err.Error(), `type must be "go"`)
223-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
228+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
224229
check.HasError(t, err)
225230
check.Contains(t, err.Error(), "version must be greater than zero")
231+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{}, goDown: &GoFunc{}})
232+
check.HasError(t, err)
233+
check.Contains(t, err.Error(), "up function: invalid mode: 0")
234+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{}})
235+
check.HasError(t, err)
236+
check.Contains(t, err.Error(), "down function: invalid mode: 0")
226237
// Success.
227-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
238+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, goUp: &GoFunc{Mode: TransactionEnabled}, goDown: &GoFunc{Mode: TransactionEnabled}})
228239
check.NoError(t, err)
229240
// Failures.
230-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
241+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
231242
check.HasError(t, err)
232243
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
233-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
244+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
234245
check.HasError(t, err)
235246
check.Contains(t, err.Error(), `no filename separator '_' found`)
236-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
247+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
237248
check.HasError(t, err)
238249
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
239-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
250+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
240251
check.HasError(t, err)
241252
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
242-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
253+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
243254
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
244255
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
256+
goUp: &GoFunc{Mode: TransactionEnabled},
257+
goDown: &GoFunc{Mode: TransactionEnabled},
245258
})
246259
check.HasError(t, err)
247260
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
248-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
261+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
249262
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
250263
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
264+
goUp: &GoFunc{Mode: TransactionEnabled},
265+
goDown: &GoFunc{Mode: TransactionEnabled},
251266
})
252267
check.HasError(t, err)
253268
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
254-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
269+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
255270
UpFn: func(*sql.Tx) error { return nil },
256271
UpFnNoTx: func(*sql.DB) error { return nil },
272+
goUp: &GoFunc{Mode: TransactionEnabled},
273+
goDown: &GoFunc{Mode: TransactionEnabled},
257274
})
258275
check.HasError(t, err)
259276
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
260-
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
277+
err = checkGoMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
261278
DownFn: func(*sql.Tx) error { return nil },
262279
DownFnNoTx: func(*sql.DB) error { return nil },
280+
goUp: &GoFunc{Mode: TransactionEnabled},
281+
goDown: &GoFunc{Mode: TransactionEnabled},
263282
})
264283
check.HasError(t, err)
265284
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")

0 commit comments

Comments
 (0)