diff --git a/cmd/root.go b/cmd/root.go index 94e6250..0728c48 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -43,7 +43,7 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) { // Execute executes the root command. func Execute() error { // register subcommands - rootCmd.AddCommand(startCmd) + rootCmd.AddCommand(startCmd()) rootCmd.AddCommand(completeCmd) rootCmd.AddCommand(rollbackCmd) rootCmd.AddCommand(analyzeCmd) diff --git a/cmd/start.go b/cmd/start.go index c0f8d26..32fff1c 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -10,32 +10,46 @@ import ( "github.com/spf13/cobra" ) -var startCmd = &cobra.Command{ - Use: "start ", - Short: "Start a migration for the operations present in the given file", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - fileName := args[0] +func startCmd() *cobra.Command { + var complete bool - m, err := NewRoll(cmd.Context()) - if err != nil { - return err - } - defer m.Close() + startCmd := &cobra.Command{ + Use: "start ", + Short: "Start a migration for the operations present in the given file", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + fileName := args[0] - migration, err := migrations.ReadMigrationFile(args[0]) - if err != nil { - return fmt.Errorf("reading migration file: %w", err) - } + m, err := NewRoll(cmd.Context()) + if err != nil { + return err + } + defer m.Close() - version := strings.TrimSuffix(filepath.Base(fileName), filepath.Ext(fileName)) + migration, err := migrations.ReadMigrationFile(args[0]) + if err != nil { + return fmt.Errorf("reading migration file: %w", err) + } - err = m.Start(cmd.Context(), migration) - if err != nil { - return err - } + version := strings.TrimSuffix(filepath.Base(fileName), filepath.Ext(fileName)) - fmt.Printf("Migration successful!, new version of the schema available under postgres '%s' schema\n", version) - return nil - }, + err = m.Start(cmd.Context(), migration) + if err != nil { + return err + } + + if complete { + if err = m.Complete(cmd.Context()); err != nil { + return err + } + } + + fmt.Printf("Migration successful!, new version of the schema available under postgres '%s' schema\n", version) + return nil + }, + } + + startCmd.Flags().BoolVarP(&complete, "complete", "c", false, "Mark the migration as complete") + + return startCmd }