Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,27 @@ RequestMigrations introduces a **type-based migration system**. Instead of defin
package main

import (
"log"

rms "github.com/subomi/requestmigrations/v2"
)

func main() {
rm, _ := rms.NewRequestMigration(&rms.RequestMigrationOptions{
VersionHeader: "X-API-Version",
CurrentVersion: "2024-01-01",
CurrentVersion: "2024-06-01",
VersionFormat: rms.DateFormat,
})

// Register migrations for a specific type
rms.Register[User](rm, "2024-01-01", &UserMigration{})
// Register all migrations, then build.
err := rm.Register(
rms.Migration[User]("2024-01-01", &UserV1Migration{}),
rms.Migration[User]("2024-06-01", &UserV2Migration{}),
rms.Migration[Address]("2024-06-01", &AddressMigration{}),
).Build()
if err != nil {
log.Fatal(err)
}
}
```

Expand Down
10 changes: 7 additions & 3 deletions examples/advanced/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ func main() {
log.Fatal(err)
}

// Register migrations across versions
rms.Register[User](rm, "2023-06-01", &UserMigrationV20230601{})
rms.Register[Workspace](rm, "2024-01-01", &WorkspaceMigrationV20240101{})
err = rm.Register(
rms.Migration[User]("2023-06-01", &UserMigrationV20230601{}),
rms.Migration[Workspace]("2024-01-01", &WorkspaceMigrationV20240101{}),
).Build()
if err != nil {
log.Fatal(err)
}

// --- Scenario: Backward Migration (Marshal) ---
// Current data structure
Expand Down
10 changes: 7 additions & 3 deletions examples/basic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ func main() {
log.Fatal(err)
}

// Register migrations for the User and profile types
rms.Register[User](rm, "2023-05-01", &UserMigration{})
rms.Register[profile](rm, "2023-05-01", &ProfileMigration{})
err = rm.Register(
rms.Migration[User]("2023-05-01", &UserMigration{}),
rms.Migration[profile]("2023-05-01", &ProfileMigration{}),
).Build()
if err != nil {
log.Fatal(err)
}

api := &API{rm: rm, store: userStore}
backend := http.Server{
Expand Down
178 changes: 125 additions & 53 deletions requestmigrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"reflect"
"sort"
Expand All @@ -27,6 +28,8 @@ var (
ErrInvalidVersionFormat = errors.New("invalid version format")
ErrCurrentVersionCannotBeEmpty = errors.New("current version field cannot be empty")
ErrNativeTypeMigration = errors.New("cannot register migration for native Go type; use a custom type alias instead (e.g., 'type MyString string')")
ErrAlreadyBuilt = errors.New("cannot register migrations after Build() has been called")
ErrNotBuilt = errors.New("must call Build() before using RequestMigration")
)

type userVersionKey struct{}
Expand Down Expand Up @@ -73,11 +76,13 @@ type RequestMigration struct {
metric *prometheus.HistogramVec
iv string

mu *sync.RWMutex
migrations map[reflect.Type]map[string]TypeMigration // type -> version -> migration

graphBuilder *typeGraphBuilder
graphCache sync.Map

built bool
err error
}

func NewRequestMigration(opts *RequestMigrationOptions) (*RequestMigration, error) {
Expand Down Expand Up @@ -110,7 +115,6 @@ func NewRequestMigration(opts *RequestMigrationOptions) (*RequestMigration, erro
metric: me,
iv: iv,
versions: versions,
mu: new(sync.RWMutex),
migrations: make(map[reflect.Type]map[string]TypeMigration),
}

Expand All @@ -121,6 +125,10 @@ func NewRequestMigration(opts *RequestMigrationOptions) (*RequestMigration, erro

// For creates a request-scoped Migrator for performing migrations.
func (rm *RequestMigration) For(r *http.Request) (*Migrator, error) {
if !rm.built {
return nil, ErrNotBuilt
}

if r == nil {
return nil, errors.New("request cannot be nil")
}
Expand Down Expand Up @@ -168,9 +176,6 @@ func (rm *RequestMigration) WriteVersionHeader() func(next http.Handler) http.Ha

// FindMigrationsForType returns all migrations applicable to a type from a given version forward.
func (rm *RequestMigration) FindMigrationsForType(t reflect.Type, userVersion *Version) []TypeMigration {
rm.mu.RLock()
defer rm.mu.RUnlock()

var applicableMigrations []TypeMigration

typeHistory, ok := rm.migrations[t]
Expand All @@ -192,6 +197,57 @@ func (rm *RequestMigration) FindMigrationsForType(t reflect.Type, userVersion *V
return applicableMigrations
}

// Register adds one or more type migrations. Returns rm for chaining.
// Errors are accumulated and surfaced when Build is called.
func (rm *RequestMigration) Register(migrations ...VersionedTypeMigration) *RequestMigration {
if rm.err != nil {
return rm
}

if rm.built {
rm.err = ErrAlreadyBuilt
return rm
}

for _, entry := range migrations {
if !isValidMigrationType(entry.t) {
rm.err = ErrNativeTypeMigration
return rm
}
rm.registerTypeMigration(entry.version, entry.t, entry.migration)
}

return rm
}

// Build sorts versions, eagerly builds type graphs, and marks the instance as
// ready for use. Must be called after all Register calls and before For/Bind.
func (rm *RequestMigration) Build() error {
if rm.err != nil {
return rm.err
}

if rm.built {
return ErrAlreadyBuilt
}

switch rm.opts.VersionFormat {
case SemverFormat:
sort.Slice(rm.versions, semVerSorter(rm.versions))
case DateFormat:
sort.Slice(rm.versions, dateVersionSorter(rm.versions))
default:
return ErrInvalidVersionFormat
}

for t := range rm.migrations {
rm.buildAndCacheGraphsForType(t, rm.versions)
}

rm.built = true
return nil
}

func (rm *RequestMigration) getUserVersion(req *http.Request) (*Version, error) {
var vh = req.Header.Get(rm.opts.VersionHeader)

Expand Down Expand Up @@ -243,17 +299,11 @@ func (rm *RequestMigration) observeRequestLatency(from, to *Version, sT time.Tim
h.Observe(latency.Seconds())
}

func (rm *RequestMigration) registerTypeMigration(version string, t reflect.Type, m TypeMigration) error {
// Copy versions for graph building (done outside the lock)
var versionsCopy []*Version

rm.mu.Lock()

func (rm *RequestMigration) registerTypeMigration(version string, t reflect.Type, m TypeMigration) {
if rm.migrations == nil {
rm.migrations = make(map[reflect.Type]map[string]TypeMigration)
}

// Check if this version is already known
versionKnown := false
for _, v := range rm.versions {
if v.Value == version {
Expand All @@ -264,39 +314,16 @@ func (rm *RequestMigration) registerTypeMigration(version string, t reflect.Type

if !versionKnown {
rm.versions = append(rm.versions, &Version{Format: rm.opts.VersionFormat, Value: version})

switch rm.opts.VersionFormat {
case SemverFormat:
sort.Slice(rm.versions, semVerSorter(rm.versions))
case DateFormat:
sort.Slice(rm.versions, dateVersionSorter(rm.versions))
default:
rm.mu.Unlock()
return ErrInvalidVersionFormat
}
}

// Internal Type-Centric Pivot: map[Type]map[Version]Migration
if _, ok := rm.migrations[t]; !ok {
rm.migrations[t] = make(map[string]TypeMigration)
}
rm.migrations[t][version] = m

// Copy versions for graph building outside the lock
versionsCopy = make([]*Version, len(rm.versions))
copy(versionsCopy, rm.versions)

rm.mu.Unlock()

// Eagerly build and cache graphs for this type across all known versions
// This is done outside the write lock since building only needs read access
rm.buildAndCacheGraphsForType(t, versionsCopy)

return nil
}

// buildAndCacheGraphsForType builds and caches type graphs for all known versions.
// Called during registration to eagerly populate the cache.
// Called during Build to eagerly populate the cache.
// Types with interface fields are skipped - they require runtime value inspection
// and will be built lazily via buildFromValue.
func (rm *RequestMigration) buildAndCacheGraphsForType(t reflect.Type, versions []*Version) {
Expand Down Expand Up @@ -344,16 +371,11 @@ func (m *Migrator) Marshal(v interface{}) ([]byte, error) {

currentVersion := m.rm.getCurrentVersion()

data, err := json.Marshal(v)
intermediate, err := readBody(v)
if err != nil {
return nil, err
}

var intermediate any
if err := json.Unmarshal(data, &intermediate); err != nil {
return nil, err
}

if err := graph.MigrateBackward(m.ctx, &intermediate); err != nil {
return nil, err
}
Expand Down Expand Up @@ -408,12 +430,7 @@ func (m *Migrator) Unmarshal(data []byte, v interface{}) error {
return err
}

data, err := json.Marshal(intermediate)
if err != nil {
return err
}

if err := json.Unmarshal(data, v); err != nil {
if err := writeBody(intermediate, v); err != nil {
return err
}

Expand Down Expand Up @@ -702,12 +719,21 @@ func (b *typeGraphBuilder) walkValue(v reflect.Value, userVersion *Version, visi
return graph, nil
}

func Register[T any](rm *RequestMigration, version string, m TypeMigration) error {
t := reflect.TypeOf((*T)(nil)).Elem()
if !isValidMigrationType(t) {
return ErrNativeTypeMigration
// VersionedTypeMigration pairs a type with a version and its migration logic.
// Construct using the Migration generic helper.
type VersionedTypeMigration struct {
version string
t reflect.Type
migration TypeMigration
}

// Migration creates a VersionedTypeMigration entry for type T.
func Migration[T any](version string, m TypeMigration) VersionedTypeMigration {
return VersionedTypeMigration{
version: version,
t: reflect.TypeOf((*T)(nil)).Elem(),
migration: m,
}
return rm.registerTypeMigration(version, t, m)
}

// isValidMigrationType returns true ONLY if the type is a user-defined named type.
Expand All @@ -730,3 +756,49 @@ func isValidMigrationType(t reflect.Type) bool {

return true
}

// readBody converts v to a generic JSON representation (map/slice/primitive)
// by streaming the encoding directly into the decoder via an io.Pipe,
// avoiding a full intermediate []byte allocation.
func readBody(v any) (any, error) {
pr, pw := io.Pipe()

var result any
errCh := make(chan error, 1)
go func() {
errCh <- json.NewDecoder(pr).Decode(&result)
}()

if err := json.NewEncoder(pw).Encode(v); err != nil {
pw.CloseWithError(err)
<-errCh
return nil, err
}
pw.Close()

if err := <-errCh; err != nil {
return nil, err
}

return result, nil
}

// writeBody streams a generic JSON representation into the typed destination v,
// avoiding a full intermediate []byte allocation.
func writeBody(src, dst any) error {
pr, pw := io.Pipe()

errCh := make(chan error, 1)
go func() {
errCh <- json.NewDecoder(pr).Decode(dst)
}()

if err := json.NewEncoder(pw).Encode(src); err != nil {
pw.CloseWithError(err)
<-errCh
return err
}
pw.Close()

return <-errCh
}
Loading
Loading