package cli import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "os" "path/filepath" "strings" "github.com/spf13/cobra" "github.com/user/aevs/internal/archiver" "github.com/user/aevs/internal/config" "github.com/user/aevs/internal/storage" "github.com/user/aevs/internal/types" ) var ( pullConfig string pullForce bool pullDryRun bool ) var pullCmd = &cobra.Command{ Use: "pull [project-name]", Short: "Pull env files from storage", Long: `Download .env files from cloud storage.`, Args: cobra.MaximumNArgs(1), RunE: runPull, } func init() { pullCmd.Flags().StringVarP(&pullConfig, "config", "c", config.DefaultProjectConfigFile, "path to project config") pullCmd.Flags().BoolVarP(&pullForce, "force", "f", false, "overwrite files without confirmation") pullCmd.Flags().BoolVar(&pullDryRun, "dry-run", false, "show what would be downloaded without downloading") } func runPull(cmd *cobra.Command, args []string) error { // Load global config globalCfg, err := config.LoadGlobalConfig() if err != nil { if os.IsNotExist(err) { return fmt.Errorf("no storage configured; run 'aevs config' first") } return fmt.Errorf("failed to load global config: %w", err) } // Determine project name projectName := "" if len(args) > 0 { // Project name from argument projectName = args[0] } else { // Try to load from local config projectCfg, err := config.LoadProjectConfig(pullConfig) if err != nil { if os.IsNotExist(err) { return fmt.Errorf("project name required; usage: aevs pull ") } return fmt.Errorf("failed to load project config: %w", err) } projectName = projectCfg.Project } // Create storage client s3Storage, err := storage.NewS3Storage(&globalCfg.Storage) if err != nil { return fmt.Errorf("failed to create storage client: %w", err) } ctx := context.Background() // Download metadata metadataKey := fmt.Sprintf("%s/%s", projectName, config.MetadataFileName) metadataReader, err := s3Storage.Download(ctx, metadataKey) if err != nil { return fmt.Errorf("project %q not found in storage", projectName) } defer metadataReader.Close() metadataBytes, err := io.ReadAll(metadataReader) if err != nil { return fmt.Errorf("failed to read metadata: %w", err) } var metadata types.Metadata if err := json.Unmarshal(metadataBytes, &metadata); err != nil { return fmt.Errorf("failed to parse metadata: %w", err) } // Download archive archiveKey := fmt.Sprintf("%s/%s", projectName, config.ArchiveFileName) archiveReader, err := s3Storage.Download(ctx, archiveKey) if err != nil { return fmt.Errorf("failed to download archive: %w", err) } defer archiveReader.Close() archiveBytes, err := io.ReadAll(archiveReader) if err != nil { return fmt.Errorf("failed to read archive: %w", err) } if pullDryRun { fmt.Println() fmt.Println("Dry run - no changes will be made") fmt.Println() fmt.Println("Would download:") for _, file := range metadata.Files { fmt.Printf(" %s\n", file) } fmt.Println() return nil } fmt.Println() fmt.Printf("Pulling %q from storage...\n", projectName) fmt.Println() currentDir, err := os.Getwd() if err != nil { return fmt.Errorf("failed to get current directory: %w", err) } // Extract to temp directory first to handle conflicts tempDir, err := os.MkdirTemp("", "aevs-pull-*") if err != nil { return fmt.Errorf("failed to create temp directory: %w", err) } defer os.RemoveAll(tempDir) extractedFiles, err := archiver.Extract(bytes.NewReader(archiveBytes), tempDir) if err != nil { return fmt.Errorf("failed to extract archive: %w", err) } // Process each file overwriteAll := false skipAll := false created := 0 updated := 0 unchanged := 0 skipped := 0 for _, file := range extractedFiles { tempPath := filepath.Join(tempDir, file) targetPath := filepath.Join(currentDir, file) // Check if file exists _, err := os.Stat(targetPath) fileExists := err == nil if !fileExists { // File doesn't exist - create it dir := filepath.Dir(targetPath) if err := os.MkdirAll(dir, 0755); err != nil { return fmt.Errorf("failed to create directory %s: %w", dir, err) } if err := copyFile(tempPath, targetPath); err != nil { return fmt.Errorf("failed to create file %s: %w", file, err) } fmt.Printf(" ✓ %s (created)\n", file) created++ continue } // File exists - check if different same, err := filesAreSame(tempPath, targetPath) if err != nil { return fmt.Errorf("failed to compare files: %w", err) } if same { fmt.Printf(" - %s (unchanged)\n", file) unchanged++ continue } // Files differ - handle conflict if pullForce || overwriteAll { if err := copyFile(tempPath, targetPath); err != nil { return fmt.Errorf("failed to overwrite file %s: %w", file, err) } fmt.Printf(" ✓ %s (overwritten)\n", file) updated++ continue } if skipAll { fmt.Printf(" - %s (skipped)\n", file) skipped++ continue } // Ask user what to do fmt.Println() fmt.Printf("File %s already exists and differs from remote.\n", file) fmt.Print("[o]verwrite / [s]kip / [d]iff / [O]verwrite all / [S]kip all: ") reader := bufio.NewReader(os.Stdin) choice, _ := reader.ReadString('\n') choice = strings.TrimSpace(strings.ToLower(choice)) switch choice { case "o": if err := copyFile(tempPath, targetPath); err != nil { return fmt.Errorf("failed to overwrite file %s: %w", file, err) } fmt.Printf(" ✓ %s (overwritten)\n", file) updated++ case "s": fmt.Printf(" - %s (skipped)\n", file) skipped++ case "d": // Show diff (simple version - just show both) fmt.Println("\nLocal version:") localContent, _ := os.ReadFile(targetPath) fmt.Println(string(localContent)) fmt.Println("\nRemote version:") remoteContent, _ := os.ReadFile(tempPath) fmt.Println(string(remoteContent)) fmt.Println() // Ask again after showing diff fmt.Print("[o]verwrite / [s]kip: ") choice2, _ := reader.ReadString('\n') choice2 = strings.TrimSpace(strings.ToLower(choice2)) if choice2 == "o" { if err := copyFile(tempPath, targetPath); err != nil { return fmt.Errorf("failed to overwrite file %s: %w", file, err) } fmt.Printf(" ✓ %s (overwritten)\n", file) updated++ } else { fmt.Printf(" - %s (skipped)\n", file) skipped++ } case "shift+o", "O": overwriteAll = true if err := copyFile(tempPath, targetPath); err != nil { return fmt.Errorf("failed to overwrite file %s: %w", file, err) } fmt.Printf(" ✓ %s (overwritten)\n", file) updated++ case "shift+s", "S": skipAll = true fmt.Printf(" - %s (skipped)\n", file) skipped++ default: fmt.Printf(" - %s (skipped)\n", file) skipped++ } } fmt.Println() var summary []string if created > 0 { summary = append(summary, fmt.Sprintf("%d created", created)) } if updated > 0 { summary = append(summary, fmt.Sprintf("%d updated", updated)) } if unchanged > 0 { summary = append(summary, fmt.Sprintf("%d unchanged", unchanged)) } if skipped > 0 { summary = append(summary, fmt.Sprintf("%d skipped", skipped)) } if len(summary) > 0 { fmt.Printf("Done. %s.\n", strings.Join(summary, ", ")) } else { fmt.Println("Done.") } return nil } // copyFile copies a file from src to dst func copyFile(src, dst string) error { data, err := os.ReadFile(src) if err != nil { return err } return os.WriteFile(dst, data, 0644) } // filesAreSame checks if two files have the same content func filesAreSame(path1, path2 string) (bool, error) { data1, err := os.ReadFile(path1) if err != nil { return false, err } data2, err := os.ReadFile(path2) if err != nil { return false, err } return bytes.Equal(data1, data2), nil }