diff --git a/main.go b/main.go index 223dbbc..608c26a 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "path/filepath" "runtime" "sort" + "strconv" "strings" "time" "unicode" @@ -333,6 +334,103 @@ func resolveServiceName(defaultName string) (string, error) { return promptServiceName(defaultName) } +func shouldConfirmServerConfig(ctx *deploymentContext) (bool, error) { + exists, err := remoteDirExists(ctx.serverIP, ctx.serviceDir) + if err != nil { + fmt.Fprintf(os.Stderr, "warning: unable to verify remote service at %s: %v\n", ctx.serviceDir, err) + return true, nil + } + return !exists, nil +} + +func confirmServerConfig(pbPath string) error { + cfg, err := loadPBConfig(pbPath) + if err != nil { + return err + } + + fmt.Println("Confirm remote server settings (press Enter to keep the current value):") + reader := bufio.NewReader(os.Stdin) + + ipDefault := cfg.Server.IP + if ipDefault == "" { + ipDefault = "127.0.0.1" + } + ip, err := promptWithDefault(reader, "Server IP", ipDefault) + if err != nil { + return err + } + + portDefault := cfg.Server.Port + if portDefault <= 0 { + portDefault = 8090 + } + port, err := promptPort(reader, "Server port", portDefault) + if err != nil { + return err + } + + domainDefault := cfg.Server.Domain + if domainDefault == "" { + domainDefault = "example.com" + } + domain, err := promptWithDefault(reader, "Server domain", domainDefault) + if err != nil { + return err + } + + cfg.Server.IP = ip + cfg.Server.Port = port + cfg.Server.Domain = domain + if err := savePBConfig(pbPath, cfg); err != nil { + return err + } + + return nil +} + +func promptWithDefault(reader *bufio.Reader, label, defaultValue string) (string, error) { + if defaultValue != "" { + fmt.Printf("%s [%s]: ", label, defaultValue) + } else { + fmt.Printf("%s: ", label) + } + + input, err := reader.ReadString('\n') + if err != nil { + return "", err + } + input = strings.TrimSpace(input) + if input == "" { + return defaultValue, nil + } + return input, nil +} + +func promptPort(reader *bufio.Reader, label string, defaultValue int) (int, error) { + if defaultValue <= 0 { + defaultValue = 8090 + } + + for { + fmt.Printf("%s [%d]: ", label, defaultValue) + input, err := reader.ReadString('\n') + if err != nil { + return 0, err + } + value := strings.TrimSpace(input) + if value == "" { + return defaultValue, nil + } + port, err := strconv.Atoi(value) + if err != nil || port <= 0 || port > 65535 { + fmt.Println(" please enter a valid port between 1 and 65535") + continue + } + return port, nil + } +} + func pocketbaseBinaryName(goos string) string { if goos == "windows" { return "pocketbase.exe" @@ -630,6 +728,7 @@ type deploymentContext struct { envFile string unitServiceDir string unitEnvFile string + configPath string } func buildDeploymentContext() (*deploymentContext, error) { @@ -684,6 +783,7 @@ func buildDeploymentContext() (*deploymentContext, error) { envFile: envFile, unitServiceDir: unitServiceDir, unitEnvFile: unitEnvFile, + configPath: configPath, }, nil } @@ -692,6 +792,21 @@ func runSetup() error { if err != nil { return err } + + needsConfirm, err := shouldConfirmServerConfig(ctx) + if err != nil { + return err + } + if needsConfirm { + if err := confirmServerConfig(ctx.configPath); err != nil { + return err + } + ctx, err = buildDeploymentContext() + if err != nil { + return err + } + } + defer closeSSHControlMaster(ctx.serverIP) start := time.Now() @@ -1056,6 +1171,15 @@ func remoteBinaryExists(server, path string) (bool, error) { return strings.TrimSpace(output) == "yes", nil } +func remoteDirExists(server, path string) (bool, error) { + script := fmt.Sprintf(`if [ -d %q ]; then printf yes; else printf no; fi`, path) + output, err := runSSHOutput(server, script) + if err != nil { + return false, err + } + return strings.TrimSpace(output) == "yes", nil +} + func loadPBConfig(path string) (*pbToml, error) { data, err := os.ReadFile(path) if err != nil { @@ -1069,6 +1193,14 @@ func loadPBConfig(path string) (*pbToml, error) { return &cfg, nil } +func savePBConfig(path string, cfg *pbToml) error { + data, err := toml.Marshal(cfg) + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + func printStep(idx, total int, message string) { fmt.Printf("Step %d/%d: %s\n", idx, total, message) }