package main import ( "bufio" "bytes" "crypto/sha1" "errors" "fmt" "io" "os" "os/exec" "path/filepath" "runtime" "strings" "unicode" "github.com/pelletier/go-toml/v2" ) const ( defaultServiceDirTemplate = "/root/pb/{service}" defaultEnvFileTemplate = "/root/pb/{service}/.env" totalSetupSteps = 5 ) type pbToml struct { Server serverConfig `toml:"server"` PocketBase pocketBaseConfig `toml:"pocketbase"` } type serverConfig struct { IP string `toml:"ip"` Port int `toml:"port"` Domain string `toml:"domain"` } type pocketBaseConfig struct { Version string `toml:"version"` ServiceName string `toml:"service_name"` } type deploymentContext struct { serverIP string domain string port int serviceName string version string serviceDir string envFile string unitServiceDir string unitEnvFile string } func buildDeploymentContext() (*deploymentContext, error) { cwd, err := os.Getwd() if err != nil { return nil, err } configPath := filepath.Join(cwd, "pb.toml") cfg, err := loadPBConfig(configPath) if err != nil { return nil, err } serviceName := cfg.PocketBase.ServiceName if serviceName == "" { return nil, fmt.Errorf("pb.toml missing [pocketbase].service_name") } serverIP := cfg.Server.IP if serverIP == "" { return nil, fmt.Errorf("pb.toml missing [server].ip") } domain := cfg.Server.Domain if domain == "" { return nil, fmt.Errorf("pb.toml missing [server].domain") } port := cfg.Server.Port if port <= 0 { return nil, fmt.Errorf("pb.toml server.port must be greater than zero") } version := cfg.PocketBase.Version if version == "" { version = defaultPocketbaseVersion } serviceDir := renderServiceTemplate(defaultServiceDirTemplate, serviceName) envFile := renderServiceTemplate(defaultEnvFileTemplate, serviceName) unitServiceDir := renderServiceTemplate(defaultServiceDirTemplate, "%i") unitEnvFile := renderServiceTemplate(defaultEnvFileTemplate, "%i") return &deploymentContext{ serverIP: serverIP, domain: domain, port: port, serviceName: serviceName, version: version, serviceDir: serviceDir, envFile: envFile, unitServiceDir: unitServiceDir, unitEnvFile: unitEnvFile, }, nil } func runSetup() error { ctx, err := buildDeploymentContext() if err != nil { return err } defer closeSSHControlMaster(ctx.serverIP) if err := performSetup(ctx); err != nil { return err } fmt.Printf("\nSetup complete; PocketBase should be reachable at https://%s\n", ctx.domain) return nil } func performSetup(ctx *deploymentContext) error { step := 1 printStep(step, totalSetupSteps, "validating configuration") remoteOS, err := runSSHOutput(ctx.serverIP, "uname -s") if err != nil { return fmt.Errorf("failed to determine remote OS: %w", err) } if !strings.EqualFold(remoteOS, "linux") { return fmt.Errorf("unsupported remote OS %q", remoteOS) } arch, err := detectRemoteArch(ctx.serverIP) if err != nil { return err } assetName := pocketbaseAsset(ctx.version, "linux", arch) assetURL := fmt.Sprintf("https://github.com/pocketbase/pocketbase/releases/download/v%s/%s", ctx.version, assetName) step++ printStep(step, totalSetupSteps, "configuring firewall") if err := runSSHCommand(ctx.serverIP, firewallScript(ctx.port)); err != nil { return fmt.Errorf("firewall setup failed: %w", err) } step++ printStep(step, totalSetupSteps, "installing caddy") if err := runSSHCommand(ctx.serverIP, caddyScript(ctx.domain, ctx.port, ctx.serviceName)); err != nil { return fmt.Errorf("caddy setup failed: %w", err) } step++ printStep(step, totalSetupSteps, "deploying PocketBase binary") if err := runSSHCommand(ctx.serverIP, pocketbaseSetupScript(ctx.serviceDir, ctx.envFile, ctx.version, assetURL, ctx.port)); err != nil { return fmt.Errorf("PocketBase setup failed: %w", err) } step++ printStep(step, totalSetupSteps, "configuring systemd service") if err := runSSHCommand(ctx.serverIP, systemdScript(ctx.unitServiceDir, ctx.unitEnvFile, ctx.serviceName)); err != nil { return fmt.Errorf("systemd setup failed: %w", err) } return nil } func runDeploy() error { ctx, err := buildDeploymentContext() if err != nil { return err } defer closeSSHControlMaster(ctx.serverIP) binaryPath := filepath.Join(ctx.serviceDir, "pocketbase") exists, err := remoteBinaryExists(ctx.serverIP, binaryPath) if err != nil { return err } if !exists { fmt.Println("PocketBase binary missing on remote; running setup") if err := performSetup(ctx); err != nil { return err } } if err := syncLocalDirectories(ctx.serverIP, ctx.serviceDir, []string{"pb_migrations", "pb_hooks", "pb_public"}); err != nil { return fmt.Errorf("asset sync failed: %w", err) } if err := runSSHCommand(ctx.serverIP, systemdScript(ctx.unitServiceDir, ctx.unitEnvFile, ctx.serviceName)); err != nil { return fmt.Errorf("systemd restart failed: %w", err) } fmt.Printf("\nDeployment complete; PocketBase should be reachable at https://%s\n", ctx.domain) return nil } func remoteBinaryExists(server, path string) (bool, error) { script := fmt.Sprintf(`if [ -f %q ]; then printf yes; else printf no; fi`, path) output, err := runSSHOutput(server, script) if err != nil { return false, err } return strings.TrimSpace(output) == "yes", nil } func loadPBConfig(path string) (*pbToml, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } var cfg pbToml if err := toml.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("failed to parse pb.toml: %w", err) } return &cfg, nil } func printStep(idx, total int, message string) { fmt.Printf("Step %d/%d: %s\n", idx, total, message) } func renderServiceTemplate(tpl, serviceName string) string { return strings.ReplaceAll(tpl, "{service}", serviceName) } func translateMachineArch(value string) (string, error) { machine := strings.TrimSpace(strings.ToLower(value)) switch machine { case "x86_64", "amd64": return "amd64", nil case "i386", "i486", "i586", "i686": return "386", nil case "armv7l": return "armv7", nil case "armv6l": return "arm", nil case "aarch64", "arm64": return "arm64", nil } return "", fmt.Errorf("unsupported remote architecture %q", value) } func detectRemoteArch(server string) (string, error) { probes := []string{ "uname -m", "arch", } var lastErr error for _, probe := range probes { output, err := runSSHOutput(server, probe) if err != nil { lastErr = fmt.Errorf("%s failed: %w", probe, err) continue } arch, err := translateMachineArch(output) if err != nil { lastErr = fmt.Errorf("%s -> %w", probe, err) continue } return arch, nil } if lastErr != nil { return "", fmt.Errorf("failed to determine remote architecture: %w", lastErr) } return "", fmt.Errorf("failed to determine remote architecture") } func pocketbaseAsset(version, osName, arch string) string { return fmt.Sprintf("pocketbase_%s_%s_%s.zip", version, osName, arch) } func firewallScript(port int) string { return fmt.Sprintf(`set -euo pipefail if ! command -v ufw >/dev/null; then apt-get update -y apt-get install -y ufw fi ufw allow OpenSSH ufw allow %d/tcp ufw allow 80/tcp ufw allow 443/tcp ufw --force enable `, port) } func caddyScript(domain string, port int, serviceName string) string { return fmt.Sprintf(`set -euo pipefail if ! command -v caddy >/dev/null; then apt-get update -y apt-get install -y curl gnupg2 curl -1sLf 'https://dl.cloudsmith.io/public/caddy/stable/gpg.key' | gpg --dearmor -o /usr/share/keyrings/caddy-stable-archive-keyring.gpg curl -1sLf 'https://dl.cloudsmith.io/public/caddy/stable/debian.deb.txt' > /etc/apt/sources.list.d/caddy-stable.list apt-get update -y apt-get install -y caddy fi mkdir -p /etc/caddy/sites cat <<'EOF' > /etc/caddy/Caddyfile import /etc/caddy/sites/*.caddy EOF cat <<'EOF' > /etc/caddy/sites/pb-%s.caddy %s { request_body { max_size 10MB } encode gzip reverse_proxy 127.0.0.1:%d { transport http { read_timeout 360s } } } EOF systemctl daemon-reload systemctl enable --now caddy.service systemctl reload-or-restart caddy.service `, serviceName, domain, port) } func pocketbaseSetupScript(serviceDir, envFile, version, assetURL string, port int) string { return fmt.Sprintf(`set -euo pipefail service_dir="%s" mkdir -p "$service_dir" apt-get update -y apt-get install -y curl unzip binary="%s/pocketbase" if [ ! -x "$binary" ]; then tmp=$(mktemp) curl -fsSL -o "$tmp" "%s" unzip -p "$tmp" pocketbase > "$binary" chmod +x "$binary" rm -f "$tmp" fi env_file="%s" current_port="" if [ -f "$env_file" ]; then current_port=$(grep '^PORT=' "$env_file" | head -n 1 | cut -d= -f2) fi if [ "$current_port" != "%d" ]; then cat <<'EOF' > "$env_file" PORT=%d EOF fi `, serviceDir, serviceDir, assetURL, envFile, port, port) } func systemdScript(serviceDir, envFile, serviceName string) string { return fmt.Sprintf(`set -euo pipefail cat <<'EOF' > /etc/systemd/system/pb@.service [Unit] Description=PocketBase instance %%i After=network.target [Service] User=root Group=root WorkingDirectory=%s EnvironmentFile=%s ExecStart=%s/pocketbase serve --http="127.0.0.1:${PORT}" Restart=on-failure LimitNOFILE=65535 [Install] WantedBy=multi-user.target EOF systemctl daemon-reload systemctl enable --now pb@%s systemctl restart pb@%s `, serviceDir, envFile, serviceDir, serviceName, serviceName) } func runSSHCommand(server, script string) error { cmd := exec.Command("ssh", append(sshArgs(server), "bash", "--noprofile", "--norc", "-c", script)...) stdoutPipe, err := cmd.StdoutPipe() if err != nil { return err } cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { return err } done := make(chan error, 1) go func() { done <- filterEnvStream(stdoutPipe, os.Stdout) }() waitErr := cmd.Wait() pipeErr := <-done if waitErr != nil { return waitErr } return pipeErr } func runSSHOutput(server, script string) (string, error) { cmd := exec.Command("ssh", append(sshArgs(server), "bash", "--noprofile", "--norc", "-c", script)...) var out bytes.Buffer cmd.Stdout = &out cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { return "", err } filtered := stripLeadingEnvLines(out.String()) return strings.TrimSpace(filtered), nil } func syncLocalDirectories(server, remoteBase string, dirs []string) error { for _, dir := range dirs { localPath := filepath.Join(".", dir) info, err := os.Stat(localPath) if err != nil { if errors.Is(err, os.ErrNotExist) { fmt.Printf("- %s does not exist locally, skipping\n", dir) continue } return err } if !info.IsDir() { fmt.Printf("- %s exists but is not a directory, skipping\n", dir) continue } remotePath := fmt.Sprintf("root@%s:%s/%s", server, remoteBase, dir) rsyncCmd := rsyncSSHCommand(server) cmd := exec.Command("rsync", "-e", rsyncCmd, "-az", "--delete", localPath+"/", remotePath) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { return err } } return nil } func rsyncSSHCommand(server string) string { args := sshSharedArgs(server) return fmt.Sprintf("ssh %s", strings.Join(args, " ")) } func sshArgs(server string) []string { args := append([]string(nil), sshSharedArgs(server)...) return append(args, fmt.Sprintf("root@%s", server)) } func sshSharedArgs(server string) []string { return []string{ "-o", "BatchMode=yes", "-o", "StrictHostKeyChecking=accept-new", "-o", "ControlMaster=auto", "-o", fmt.Sprintf("ControlPath=%s", sshControlPath(server)), "-o", "ControlPersist=5m", } } func sshControlPath(server string) string { sum := sha1.Sum([]byte(server)) return filepath.Join(sshControlDir(), fmt.Sprintf("pb-ssh-%x.sock", sum)) } func sshControlDir() string { if dir := os.Getenv("PB_SSH_CONTROL_DIR"); dir != "" { return dir } if runtime.GOOS == "windows" { return os.TempDir() } return "/tmp" } func closeSSHControlMaster(server string) { args := append([]string(nil), sshSharedArgs(server)...) args = append(args, "-O", "exit", fmt.Sprintf("root@%s", server)) _ = exec.Command("ssh", args...).Run() } func filterEnvStream(r io.Reader, w io.Writer) error { scanner := bufio.NewScanner(r) scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) skipping := true for scanner.Scan() { line := scanner.Text() if skipping && isEnvLine(line) { continue } if skipping { skipping = false } if _, err := fmt.Fprintln(w, line); err != nil { return err } } if err := scanner.Err(); err != nil { return err } return nil } func stripLeadingEnvLines(input string) string { scanner := bufio.NewScanner(strings.NewReader(input)) scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) skipping := true var builder strings.Builder addedLine := false for scanner.Scan() { line := scanner.Text() if skipping && isEnvLine(line) { continue } if skipping { skipping = false } if addedLine { builder.WriteByte('\n') } builder.WriteString(line) addedLine = true } if err := scanner.Err(); err != nil { return input } return builder.String() } func isEnvLine(line string) bool { trimmed := strings.TrimSpace(line) if trimmed == "" { return true } idx := strings.IndexRune(trimmed, '=') if idx <= 0 { return false } key := trimmed[:idx] for i, r := range key { if i == 0 { if r != '_' && !unicode.IsLetter(r) { return false } continue } if r != '_' && !unicode.IsLetter(r) && !unicode.IsDigit(r) { return false } } return true }