307 lines
7.7 KiB
Go
307 lines
7.7 KiB
Go
package cli
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/spf13/cobra"
|
|
|
|
"github.com/user/aevs/internal/archiver"
|
|
"github.com/user/aevs/internal/config"
|
|
"github.com/user/aevs/internal/storage"
|
|
"github.com/user/aevs/internal/types"
|
|
)
|
|
|
|
var (
|
|
pullConfig string
|
|
pullForce bool
|
|
pullDryRun bool
|
|
)
|
|
|
|
var pullCmd = &cobra.Command{
|
|
Use: "pull [project-name]",
|
|
Short: "Pull env files from storage",
|
|
Long: `Download .env files from cloud storage.`,
|
|
Args: cobra.MaximumNArgs(1),
|
|
RunE: runPull,
|
|
}
|
|
|
|
func init() {
|
|
pullCmd.Flags().StringVarP(&pullConfig, "config", "c", config.DefaultProjectConfigFile, "path to project config")
|
|
pullCmd.Flags().BoolVarP(&pullForce, "force", "f", false, "overwrite files without confirmation")
|
|
pullCmd.Flags().BoolVar(&pullDryRun, "dry-run", false, "show what would be downloaded without downloading")
|
|
}
|
|
|
|
func runPull(cmd *cobra.Command, args []string) error {
|
|
// Load global config
|
|
globalCfg, err := config.LoadGlobalConfig()
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return fmt.Errorf("no storage configured; run 'aevs config' first")
|
|
}
|
|
return fmt.Errorf("failed to load global config: %w", err)
|
|
}
|
|
|
|
// Determine project name
|
|
projectName := ""
|
|
if len(args) > 0 {
|
|
// Project name from argument
|
|
projectName = args[0]
|
|
} else {
|
|
// Try to load from local config
|
|
projectCfg, err := config.LoadProjectConfig(pullConfig)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return fmt.Errorf("project name required; usage: aevs pull <project-name>")
|
|
}
|
|
return fmt.Errorf("failed to load project config: %w", err)
|
|
}
|
|
projectName = projectCfg.Project
|
|
}
|
|
|
|
// Create storage client
|
|
s3Storage, err := storage.NewS3Storage(&globalCfg.Storage)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create storage client: %w", err)
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
// Download metadata
|
|
metadataKey := fmt.Sprintf("%s/%s", projectName, config.MetadataFileName)
|
|
metadataReader, err := s3Storage.Download(ctx, metadataKey)
|
|
if err != nil {
|
|
return fmt.Errorf("project %q not found in storage", projectName)
|
|
}
|
|
defer metadataReader.Close()
|
|
|
|
metadataBytes, err := io.ReadAll(metadataReader)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read metadata: %w", err)
|
|
}
|
|
|
|
var metadata types.Metadata
|
|
if err := json.Unmarshal(metadataBytes, &metadata); err != nil {
|
|
return fmt.Errorf("failed to parse metadata: %w", err)
|
|
}
|
|
|
|
// Download archive
|
|
archiveKey := fmt.Sprintf("%s/%s", projectName, config.ArchiveFileName)
|
|
archiveReader, err := s3Storage.Download(ctx, archiveKey)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to download archive: %w", err)
|
|
}
|
|
defer archiveReader.Close()
|
|
|
|
archiveBytes, err := io.ReadAll(archiveReader)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read archive: %w", err)
|
|
}
|
|
|
|
if pullDryRun {
|
|
fmt.Println()
|
|
fmt.Println("Dry run - no changes will be made")
|
|
fmt.Println()
|
|
fmt.Println("Would download:")
|
|
for _, file := range metadata.Files {
|
|
fmt.Printf(" %s\n", file)
|
|
}
|
|
fmt.Println()
|
|
return nil
|
|
}
|
|
|
|
fmt.Println()
|
|
fmt.Printf("Pulling %q from storage...\n", projectName)
|
|
fmt.Println()
|
|
|
|
currentDir, err := os.Getwd()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get current directory: %w", err)
|
|
}
|
|
|
|
// Extract to temp directory first to handle conflicts
|
|
tempDir, err := os.MkdirTemp("", "aevs-pull-*")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create temp directory: %w", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
extractedFiles, err := archiver.Extract(bytes.NewReader(archiveBytes), tempDir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to extract archive: %w", err)
|
|
}
|
|
|
|
// Process each file
|
|
overwriteAll := false
|
|
skipAll := false
|
|
created := 0
|
|
updated := 0
|
|
unchanged := 0
|
|
skipped := 0
|
|
|
|
for _, file := range extractedFiles {
|
|
tempPath := filepath.Join(tempDir, file)
|
|
targetPath := filepath.Join(currentDir, file)
|
|
|
|
// Check if file exists
|
|
_, err := os.Stat(targetPath)
|
|
fileExists := err == nil
|
|
|
|
if !fileExists {
|
|
// File doesn't exist - create it
|
|
dir := filepath.Dir(targetPath)
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
|
}
|
|
|
|
if err := copyFile(tempPath, targetPath); err != nil {
|
|
return fmt.Errorf("failed to create file %s: %w", file, err)
|
|
}
|
|
|
|
fmt.Printf(" ✓ %s (created)\n", file)
|
|
created++
|
|
continue
|
|
}
|
|
|
|
// File exists - check if different
|
|
same, err := filesAreSame(tempPath, targetPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to compare files: %w", err)
|
|
}
|
|
|
|
if same {
|
|
fmt.Printf(" - %s (unchanged)\n", file)
|
|
unchanged++
|
|
continue
|
|
}
|
|
|
|
// Files differ - handle conflict
|
|
if pullForce || overwriteAll {
|
|
if err := copyFile(tempPath, targetPath); err != nil {
|
|
return fmt.Errorf("failed to overwrite file %s: %w", file, err)
|
|
}
|
|
fmt.Printf(" ✓ %s (overwritten)\n", file)
|
|
updated++
|
|
continue
|
|
}
|
|
|
|
if skipAll {
|
|
fmt.Printf(" - %s (skipped)\n", file)
|
|
skipped++
|
|
continue
|
|
}
|
|
|
|
// Ask user what to do
|
|
fmt.Println()
|
|
fmt.Printf("File %s already exists and differs from remote.\n", file)
|
|
fmt.Print("[o]verwrite / [s]kip / [d]iff / [O]verwrite all / [S]kip all: ")
|
|
|
|
reader := bufio.NewReader(os.Stdin)
|
|
choice, _ := reader.ReadString('\n')
|
|
choice = strings.TrimSpace(strings.ToLower(choice))
|
|
|
|
switch choice {
|
|
case "o":
|
|
if err := copyFile(tempPath, targetPath); err != nil {
|
|
return fmt.Errorf("failed to overwrite file %s: %w", file, err)
|
|
}
|
|
fmt.Printf(" ✓ %s (overwritten)\n", file)
|
|
updated++
|
|
case "s":
|
|
fmt.Printf(" - %s (skipped)\n", file)
|
|
skipped++
|
|
case "d":
|
|
// Show diff (simple version - just show both)
|
|
fmt.Println("\nLocal version:")
|
|
localContent, _ := os.ReadFile(targetPath)
|
|
fmt.Println(string(localContent))
|
|
fmt.Println("\nRemote version:")
|
|
remoteContent, _ := os.ReadFile(tempPath)
|
|
fmt.Println(string(remoteContent))
|
|
fmt.Println()
|
|
// Ask again after showing diff
|
|
fmt.Print("[o]verwrite / [s]kip: ")
|
|
choice2, _ := reader.ReadString('\n')
|
|
choice2 = strings.TrimSpace(strings.ToLower(choice2))
|
|
if choice2 == "o" {
|
|
if err := copyFile(tempPath, targetPath); err != nil {
|
|
return fmt.Errorf("failed to overwrite file %s: %w", file, err)
|
|
}
|
|
fmt.Printf(" ✓ %s (overwritten)\n", file)
|
|
updated++
|
|
} else {
|
|
fmt.Printf(" - %s (skipped)\n", file)
|
|
skipped++
|
|
}
|
|
case "shift+o", "O":
|
|
overwriteAll = true
|
|
if err := copyFile(tempPath, targetPath); err != nil {
|
|
return fmt.Errorf("failed to overwrite file %s: %w", file, err)
|
|
}
|
|
fmt.Printf(" ✓ %s (overwritten)\n", file)
|
|
updated++
|
|
case "shift+s", "S":
|
|
skipAll = true
|
|
fmt.Printf(" - %s (skipped)\n", file)
|
|
skipped++
|
|
default:
|
|
fmt.Printf(" - %s (skipped)\n", file)
|
|
skipped++
|
|
}
|
|
}
|
|
|
|
fmt.Println()
|
|
var summary []string
|
|
if created > 0 {
|
|
summary = append(summary, fmt.Sprintf("%d created", created))
|
|
}
|
|
if updated > 0 {
|
|
summary = append(summary, fmt.Sprintf("%d updated", updated))
|
|
}
|
|
if unchanged > 0 {
|
|
summary = append(summary, fmt.Sprintf("%d unchanged", unchanged))
|
|
}
|
|
if skipped > 0 {
|
|
summary = append(summary, fmt.Sprintf("%d skipped", skipped))
|
|
}
|
|
|
|
if len(summary) > 0 {
|
|
fmt.Printf("Done. %s.\n", strings.Join(summary, ", "))
|
|
} else {
|
|
fmt.Println("Done.")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// copyFile copies a file from src to dst
|
|
func copyFile(src, dst string) error {
|
|
data, err := os.ReadFile(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(dst, data, 0644)
|
|
}
|
|
|
|
// filesAreSame checks if two files have the same content
|
|
func filesAreSame(path1, path2 string) (bool, error) {
|
|
data1, err := os.ReadFile(path1)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
data2, err := os.ReadFile(path2)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return bytes.Equal(data1, data2), nil
|
|
}
|