aenvs/internal/cli/pull.go

307 lines
7.7 KiB
Go

package cli
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/spf13/cobra"
"github.com/user/aevs/internal/archiver"
"github.com/user/aevs/internal/config"
"github.com/user/aevs/internal/storage"
"github.com/user/aevs/internal/types"
)
var (
pullConfig string
pullForce bool
pullDryRun bool
)
var pullCmd = &cobra.Command{
Use: "pull [project-name]",
Short: "Pull env files from storage",
Long: `Download .env files from cloud storage.`,
Args: cobra.MaximumNArgs(1),
RunE: runPull,
}
func init() {
pullCmd.Flags().StringVarP(&pullConfig, "config", "c", config.DefaultProjectConfigFile, "path to project config")
pullCmd.Flags().BoolVarP(&pullForce, "force", "f", false, "overwrite files without confirmation")
pullCmd.Flags().BoolVar(&pullDryRun, "dry-run", false, "show what would be downloaded without downloading")
}
func runPull(cmd *cobra.Command, args []string) error {
// Load global config
globalCfg, err := config.LoadGlobalConfig()
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("no storage configured; run 'aevs config' first")
}
return fmt.Errorf("failed to load global config: %w", err)
}
// Determine project name
projectName := ""
if len(args) > 0 {
// Project name from argument
projectName = args[0]
} else {
// Try to load from local config
projectCfg, err := config.LoadProjectConfig(pullConfig)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("project name required; usage: aevs pull <project-name>")
}
return fmt.Errorf("failed to load project config: %w", err)
}
projectName = projectCfg.Project
}
// Create storage client
s3Storage, err := storage.NewS3Storage(&globalCfg.Storage)
if err != nil {
return fmt.Errorf("failed to create storage client: %w", err)
}
ctx := context.Background()
// Download metadata
metadataKey := fmt.Sprintf("%s/%s", projectName, config.MetadataFileName)
metadataReader, err := s3Storage.Download(ctx, metadataKey)
if err != nil {
return fmt.Errorf("project %q not found in storage", projectName)
}
defer metadataReader.Close()
metadataBytes, err := io.ReadAll(metadataReader)
if err != nil {
return fmt.Errorf("failed to read metadata: %w", err)
}
var metadata types.Metadata
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
return fmt.Errorf("failed to parse metadata: %w", err)
}
// Download archive
archiveKey := fmt.Sprintf("%s/%s", projectName, config.ArchiveFileName)
archiveReader, err := s3Storage.Download(ctx, archiveKey)
if err != nil {
return fmt.Errorf("failed to download archive: %w", err)
}
defer archiveReader.Close()
archiveBytes, err := io.ReadAll(archiveReader)
if err != nil {
return fmt.Errorf("failed to read archive: %w", err)
}
if pullDryRun {
fmt.Println()
fmt.Println("Dry run - no changes will be made")
fmt.Println()
fmt.Println("Would download:")
for _, file := range metadata.Files {
fmt.Printf(" %s\n", file)
}
fmt.Println()
return nil
}
fmt.Println()
fmt.Printf("Pulling %q from storage...\n", projectName)
fmt.Println()
currentDir, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get current directory: %w", err)
}
// Extract to temp directory first to handle conflicts
tempDir, err := os.MkdirTemp("", "aevs-pull-*")
if err != nil {
return fmt.Errorf("failed to create temp directory: %w", err)
}
defer os.RemoveAll(tempDir)
extractedFiles, err := archiver.Extract(bytes.NewReader(archiveBytes), tempDir)
if err != nil {
return fmt.Errorf("failed to extract archive: %w", err)
}
// Process each file
overwriteAll := false
skipAll := false
created := 0
updated := 0
unchanged := 0
skipped := 0
for _, file := range extractedFiles {
tempPath := filepath.Join(tempDir, file)
targetPath := filepath.Join(currentDir, file)
// Check if file exists
_, err := os.Stat(targetPath)
fileExists := err == nil
if !fileExists {
// File doesn't exist - create it
dir := filepath.Dir(targetPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
if err := copyFile(tempPath, targetPath); err != nil {
return fmt.Errorf("failed to create file %s: %w", file, err)
}
fmt.Printf(" ✓ %s (created)\n", file)
created++
continue
}
// File exists - check if different
same, err := filesAreSame(tempPath, targetPath)
if err != nil {
return fmt.Errorf("failed to compare files: %w", err)
}
if same {
fmt.Printf(" - %s (unchanged)\n", file)
unchanged++
continue
}
// Files differ - handle conflict
if pullForce || overwriteAll {
if err := copyFile(tempPath, targetPath); err != nil {
return fmt.Errorf("failed to overwrite file %s: %w", file, err)
}
fmt.Printf(" ✓ %s (overwritten)\n", file)
updated++
continue
}
if skipAll {
fmt.Printf(" - %s (skipped)\n", file)
skipped++
continue
}
// Ask user what to do
fmt.Println()
fmt.Printf("File %s already exists and differs from remote.\n", file)
fmt.Print("[o]verwrite / [s]kip / [d]iff / [O]verwrite all / [S]kip all: ")
reader := bufio.NewReader(os.Stdin)
choice, _ := reader.ReadString('\n')
choice = strings.TrimSpace(strings.ToLower(choice))
switch choice {
case "o":
if err := copyFile(tempPath, targetPath); err != nil {
return fmt.Errorf("failed to overwrite file %s: %w", file, err)
}
fmt.Printf(" ✓ %s (overwritten)\n", file)
updated++
case "s":
fmt.Printf(" - %s (skipped)\n", file)
skipped++
case "d":
// Show diff (simple version - just show both)
fmt.Println("\nLocal version:")
localContent, _ := os.ReadFile(targetPath)
fmt.Println(string(localContent))
fmt.Println("\nRemote version:")
remoteContent, _ := os.ReadFile(tempPath)
fmt.Println(string(remoteContent))
fmt.Println()
// Ask again after showing diff
fmt.Print("[o]verwrite / [s]kip: ")
choice2, _ := reader.ReadString('\n')
choice2 = strings.TrimSpace(strings.ToLower(choice2))
if choice2 == "o" {
if err := copyFile(tempPath, targetPath); err != nil {
return fmt.Errorf("failed to overwrite file %s: %w", file, err)
}
fmt.Printf(" ✓ %s (overwritten)\n", file)
updated++
} else {
fmt.Printf(" - %s (skipped)\n", file)
skipped++
}
case "shift+o", "O":
overwriteAll = true
if err := copyFile(tempPath, targetPath); err != nil {
return fmt.Errorf("failed to overwrite file %s: %w", file, err)
}
fmt.Printf(" ✓ %s (overwritten)\n", file)
updated++
case "shift+s", "S":
skipAll = true
fmt.Printf(" - %s (skipped)\n", file)
skipped++
default:
fmt.Printf(" - %s (skipped)\n", file)
skipped++
}
}
fmt.Println()
var summary []string
if created > 0 {
summary = append(summary, fmt.Sprintf("%d created", created))
}
if updated > 0 {
summary = append(summary, fmt.Sprintf("%d updated", updated))
}
if unchanged > 0 {
summary = append(summary, fmt.Sprintf("%d unchanged", unchanged))
}
if skipped > 0 {
summary = append(summary, fmt.Sprintf("%d skipped", skipped))
}
if len(summary) > 0 {
fmt.Printf("Done. %s.\n", strings.Join(summary, ", "))
} else {
fmt.Println("Done.")
}
return nil
}
// copyFile copies a file from src to dst
func copyFile(src, dst string) error {
data, err := os.ReadFile(src)
if err != nil {
return err
}
return os.WriteFile(dst, data, 0644)
}
// filesAreSame checks if two files have the same content
func filesAreSame(path1, path2 string) (bool, error) {
data1, err := os.ReadFile(path1)
if err != nil {
return false, err
}
data2, err := os.ReadFile(path2)
if err != nil {
return false, err
}
return bytes.Equal(data1, data2), nil
}