diff --git a/Go_Updater/Makefile b/Go_Updater/Makefile new file mode 100644 index 0000000..ecb5902 --- /dev/null +++ b/Go_Updater/Makefile @@ -0,0 +1,116 @@ +# AUTO_MAA_Go_Updater Makefile + +# Build variables +VERSION ?= 1.0.0 +BUILD_TIME := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") +GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown") +OUTPUT_NAME := AUTO_MAA_Go_Updater +BUILD_DIR := build +DIST_DIR := dist + +# Go build flags +LDFLAGS := -s -w -X lightweight-updater/version.Version=$(VERSION) -X lightweight-updater/version.BuildTime=$(BUILD_TIME) -X lightweight-updater/version.GitCommit=$(GIT_COMMIT) + +# Default target +.PHONY: all +all: clean build + +# Clean build artifacts +.PHONY: clean +clean: + @echo "Cleaning build artifacts..." + @rm -rf $(BUILD_DIR) $(DIST_DIR) + @mkdir -p $(BUILD_DIR) $(DIST_DIR) + +# Build for Windows 64-bit +.PHONY: build +build: clean + @echo "=========================================" + @echo "Building AUTO_MAA_Go_Updater" + @echo "=========================================" + @echo "Version: $(VERSION)" + @echo "Build Time: $(BUILD_TIME)" + @echo "Git Commit: $(GIT_COMMIT)" + @echo "Target: Windows 64-bit" + @echo "" + @echo "Building application..." + @GOOS=windows GOARCH=amd64 CGO_ENABLED=1 go build -ldflags="$(LDFLAGS)" -o $(BUILD_DIR)/$(OUTPUT_NAME).exe . + @echo "Build completed successfully!" + @echo "" + @echo "Build Results:" + @ls -lh $(BUILD_DIR)/$(OUTPUT_NAME).exe + @cp $(BUILD_DIR)/$(OUTPUT_NAME).exe $(DIST_DIR)/$(OUTPUT_NAME).exe + @echo "Copied to: $(DIST_DIR)/$(OUTPUT_NAME).exe" + +# Build with UPX compression +.PHONY: build-compressed +build-compressed: build + @echo "" + @echo "Compressing with UPX..." + @if command -v upx >/dev/null 2>&1; then \ + upx --best $(BUILD_DIR)/$(OUTPUT_NAME).exe; \ + echo "Compression completed!"; \ + ls -lh $(BUILD_DIR)/$(OUTPUT_NAME).exe; \ + cp $(BUILD_DIR)/$(OUTPUT_NAME).exe $(DIST_DIR)/$(OUTPUT_NAME).exe; \ + else \ + echo "UPX not found. Skipping compression."; \ + fi + +# Run tests +.PHONY: test +test: + @echo "Running tests..." + @go test -v ./... + +# Run with version flag +.PHONY: version +version: build + @echo "" + @echo "Testing version information:" + @$(BUILD_DIR)/$(OUTPUT_NAME).exe -version + +# Install dependencies +.PHONY: deps +deps: + @echo "Installing dependencies..." + @go mod tidy + @go mod download + +# Format code +.PHONY: fmt +fmt: + @echo "Formatting code..." + @go fmt ./... + +# Lint code +.PHONY: lint +lint: + @echo "Linting code..." + @if command -v golangci-lint >/dev/null 2>&1; then \ + golangci-lint run; \ + else \ + echo "golangci-lint not found. Install it with: go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest"; \ + fi + +# Development build (faster, no optimizations) +.PHONY: dev +dev: + @echo "Building development version..." + @go build -o $(BUILD_DIR)/$(OUTPUT_NAME)-dev.exe . + @echo "Development build completed: $(BUILD_DIR)/$(OUTPUT_NAME)-dev.exe" + +# Help +.PHONY: help +help: + @echo "Available targets:" + @echo " all - Clean and build (default)" + @echo " build - Build for Windows 64-bit" + @echo " build-compressed - Build and compress with UPX" + @echo " clean - Clean build artifacts" + @echo " test - Run tests" + @echo " version - Build and show version" + @echo " deps - Install dependencies" + @echo " fmt - Format code" + @echo " lint - Lint code" + @echo " dev - Development build" + @echo " help - Show this help" \ No newline at end of file diff --git a/Go_Updater/README.MD b/Go_Updater/README.MD new file mode 100644 index 0000000..b5832d6 --- /dev/null +++ b/Go_Updater/README.MD @@ -0,0 +1,15 @@ +# 用Go语言实现的一个AUTO_MAA下载器 +用于直接下载AUTO_MAA软件本体,在Python版本出现问题时使用。 + +## 使用方法 +1. 下载并安装Go语言环境(需要配置环境变量) +2. 运行 `go mod tidy` 命令,安装依赖包。 +3. 运行 `go run main.go` 命令,程序会自动下载并安装AUTO_MAA软件。 + +## 构建 +运行 `.\build.ps1` 脚本即可完成构建。 + +参数说明: +-Version:指定要构建的版本号 + +运行命令: `.\build.ps1 -Version "1.0.8"` \ No newline at end of file diff --git a/Go_Updater/api/client.go b/Go_Updater/api/client.go new file mode 100644 index 0000000..d04f1b6 --- /dev/null +++ b/Go_Updater/api/client.go @@ -0,0 +1,332 @@ +package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +// MirrorResponse represents the response from MirrorChyan API +type MirrorResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` // Only present when using CDK + SHA256 string `json:"sha256,omitempty"` // Only present when using CDK + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` // Only present when using CDK + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` // Only present when using CDK + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` // Only present when using CDK + } `json:"data"` +} + +// UpdateCheckParams represents parameters for update checking +type UpdateCheckParams struct { + ResourceID string + CurrentVersion string + Channel string + CDK string + UserAgent string +} + +// MirrorClient interface defines the methods for Mirror API client +type MirrorClient interface { + CheckUpdate(params UpdateCheckParams) (*MirrorResponse, error) + CheckUpdateLegacy(resourceID, currentVersion, cdk, userAgent string) (*MirrorResponse, error) + IsUpdateAvailable(response *MirrorResponse, currentVersion string) bool + GetOfficialDownloadURL(versionName string) string +} + +// Client implements MirrorClient interface +type Client struct { + httpClient *http.Client + baseURL string +} + +// NewClient creates a new Mirror API client +func NewClient() *Client { + return &Client{ + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + baseURL: "https://mirrorchyan.com/api/resources", + } +} + +// CheckUpdate calls MirrorChyan API to check for updates with new parameter structure +func (c *Client) CheckUpdate(params UpdateCheckParams) (*MirrorResponse, error) { + // Construct the API URL + apiURL := fmt.Sprintf("%s/%s/latest", c.baseURL, params.ResourceID) + + // Parse URL to add query parameters + u, err := url.Parse(apiURL) + if err != nil { + return nil, fmt.Errorf("failed to parse API URL: %w", err) + } + + // Add query parameters + q := u.Query() + q.Set("current_version", params.CurrentVersion) + q.Set("channel", params.Channel) + q.Set("os", "") // Empty for cross-platform + q.Set("arch", "") // Empty for cross-platform + + if params.CDK != "" { + q.Set("cdk", params.CDK) + } + u.RawQuery = q.Encode() + + // Create HTTP request + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set User-Agent header + if params.UserAgent != "" { + req.Header.Set("User-Agent", params.UserAgent) + } else { + req.Header.Set("User-Agent", "LightweightUpdater/1.0") + } + + // Make HTTP request + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to make HTTP request: %w", err) + } + defer resp.Body.Close() + + // Check HTTP status code + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API returned non-200 status code: %d", resp.StatusCode) + } + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Parse JSON response + var mirrorResp MirrorResponse + if err := json.Unmarshal(body, &mirrorResp); err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + + return &mirrorResp, nil +} + +// CheckUpdateLegacy calls Mirror API to check for updates (legacy method for backward compatibility) +func (c *Client) CheckUpdateLegacy(resourceID, currentVersion, cdk, userAgent string) (*MirrorResponse, error) { + // Construct the API URL + apiURL := fmt.Sprintf("%s/%s/latest", c.baseURL, resourceID) + + // Parse URL to add query parameters + u, err := url.Parse(apiURL) + if err != nil { + return nil, fmt.Errorf("failed to parse API URL: %w", err) + } + + // Add query parameters + q := u.Query() + q.Set("current_version", currentVersion) + if cdk != "" { + q.Set("cdk", cdk) + } + u.RawQuery = q.Encode() + + // Create HTTP request + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + + // Set User-Agent header + if userAgent != "" { + req.Header.Set("User-Agent", userAgent) + } else { + req.Header.Set("User-Agent", "LightweightUpdater/1.0") + } + + // Make HTTP request + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to make HTTP request: %w", err) + } + defer resp.Body.Close() + + // Check HTTP status code + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API returned non-200 status code: %d", resp.StatusCode) + } + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Parse JSON response + var mirrorResp MirrorResponse + if err := json.Unmarshal(body, &mirrorResp); err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + + return &mirrorResp, nil +} + +// IsUpdateAvailable compares current version with the latest version from API response +func (c *Client) IsUpdateAvailable(response *MirrorResponse, currentVersion string) bool { + // Check if API response is successful + if response.Code != 0 { + return false + } + + // Get latest version from response + latestVersion := response.Data.VersionName + if latestVersion == "" { + return false + } + + // Convert version formats for comparison + currentVersionNormalized := c.normalizeVersionForComparison(currentVersion) + latestVersionNormalized := c.normalizeVersionForComparison(latestVersion) + + // Compare versions using semantic version comparison + return compareVersions(currentVersionNormalized, latestVersionNormalized) < 0 +} + +// normalizeVersionForComparison converts different version formats to comparable format +func (c *Client) normalizeVersionForComparison(version string) string { + // Handle AUTO_MAA version format: "4.4.1.3" -> "v4.4.1-beta3" + if !strings.HasPrefix(version, "v") && strings.Count(version, ".") == 3 { + parts := strings.Split(version, ".") + if len(parts) == 4 { + major, minor, patch, beta := parts[0], parts[1], parts[2], parts[3] + if beta == "0" { + return fmt.Sprintf("v%s.%s.%s", major, minor, patch) + } else { + return fmt.Sprintf("v%s.%s.%s-beta%s", major, minor, patch, beta) + } + } + } + + // Return as-is if already in standard format + return version +} + +// compareVersions compares two semantic version strings +// Returns: -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2 +func compareVersions(v1, v2 string) int { + // Normalize versions by removing 'v' prefix if present + v1 = normalizeVersion(v1) + v2 = normalizeVersion(v2) + + // Parse version components + parts1 := parseVersionParts(v1) + parts2 := parseVersionParts(v2) + + // Compare each component + maxLen := len(parts1) + if len(parts2) > maxLen { + maxLen = len(parts2) + } + + for i := 0; i < maxLen; i++ { + var p1, p2 int + if i < len(parts1) { + p1 = parts1[i] + } + if i < len(parts2) { + p2 = parts2[i] + } + + if p1 < p2 { + return -1 + } else if p1 > p2 { + return 1 + } + } + + return 0 +} + +// normalizeVersion removes 'v' prefix and handles common version formats +func normalizeVersion(version string) string { + if len(version) > 0 && (version[0] == 'v' || version[0] == 'V') { + return version[1:] + } + return version +} + +// parseVersionParts parses version string into numeric components +func parseVersionParts(version string) []int { + if version == "" { + return []int{0} + } + + parts := make([]int, 0, 3) + current := 0 + + for _, char := range version { + if char >= '0' && char <= '9' { + current = current*10 + int(char-'0') + } else if char == '.' { + parts = append(parts, current) + current = 0 + } else { + // Stop parsing at non-numeric, non-dot characters (like pre-release identifiers) + break + } + } + + // Add the last component + parts = append(parts, current) + + // Ensure at least 3 components (major.minor.patch) + for len(parts) < 3 { + parts = append(parts, 0) + } + + return parts +} + +// GetOfficialDownloadURL generates the official download URL based on version name +func (c *Client) GetOfficialDownloadURL(versionName string) string { + // Official download site base URL + baseURL := "http://221.236.27.82:10197/d/AUTO_MAA" + + // Convert version name to filename format + // e.g., "v4.4.0" -> "AUTO_MAA_v4.4.0.zip" + // e.g., "v4.4.1-beta3" -> "AUTO_MAA_v4.4.1-beta.3.zip" + filename := fmt.Sprintf("AUTO_MAA_%s.zip", versionName) + + // Handle beta versions: convert "beta3" to "beta.3" + if strings.Contains(filename, "-beta") && !strings.Contains(filename, "-beta.") { + filename = strings.Replace(filename, "-beta", "-beta.", 1) + } + + return fmt.Sprintf("%s/%s", baseURL, filename) +} + +// HasCDKDownloadURL checks if the response contains a CDK download URL +func (c *Client) HasCDKDownloadURL(response *MirrorResponse) bool { + return response != nil && response.Data.URL != "" +} + +// GetDownloadURL returns the appropriate download URL based on available options +func (c *Client) GetDownloadURL(response *MirrorResponse) string { + if c.HasCDKDownloadURL(response) { + return response.Data.URL + } + return c.GetOfficialDownloadURL(response.Data.VersionName) +} \ No newline at end of file diff --git a/Go_Updater/api/client_test.go b/Go_Updater/api/client_test.go new file mode 100644 index 0000000..d7a94e7 --- /dev/null +++ b/Go_Updater/api/client_test.go @@ -0,0 +1,423 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewClient(t *testing.T) { + client := NewClient() + if client == nil { + t.Fatal("NewClient() returned nil") + } + if client.httpClient == nil { + t.Fatal("HTTP client is nil") + } + if client.baseURL != "https://mirrorchyan.com/api/resources" { + t.Errorf("Expected base URL 'https://mirrorchyan.com/api/resources', got '%s'", client.baseURL) + } +} + +func TestGetOfficialDownloadURL(t *testing.T) { + client := NewClient() + + tests := []struct { + versionName string + expected string + }{ + {"v4.4.0", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.0.zip"}, + {"v4.4.1-beta3", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.1-beta.3.zip"}, + {"v1.2.3", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v1.2.3.zip"}, + {"v1.2.3-beta1", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v1.2.3-beta.1.zip"}, + } + + for _, test := range tests { + result := client.GetOfficialDownloadURL(test.versionName) + if result != test.expected { + t.Errorf("For version %s, expected %s, got %s", test.versionName, test.expected, result) + } + } +} + +func TestNormalizeVersionForComparison(t *testing.T) { + client := NewClient() + + tests := []struct { + input string + expected string + }{ + {"4.4.0.0", "v4.4.0"}, + {"4.4.1.3", "v4.4.1-beta3"}, + {"v4.4.0", "v4.4.0"}, + {"v4.4.1-beta3", "v4.4.1-beta3"}, + {"1.2.3", "1.2.3"}, // Not 4-part version, return as-is + } + + for _, test := range tests { + result := client.normalizeVersionForComparison(test.input) + if result != test.expected { + t.Errorf("For input %s, expected %s, got %s", test.input, test.expected, result) + } + } +} + +func TestCheckUpdate(t *testing.T) { + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request parameters + if r.URL.Query().Get("current_version") != "4.4.0.0" { + t.Errorf("Expected current_version=4.4.0.0, got %s", r.URL.Query().Get("current_version")) + } + if r.URL.Query().Get("channel") != "stable" { + t.Errorf("Expected channel=stable, got %s", r.URL.Query().Get("channel")) + } + + // Return mock response + response := MirrorResponse{ + Code: 0, + Msg: "success", + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{ + VersionName: "v4.4.1", + VersionNumber: 48, + Channel: "stable", + OS: "", + Arch: "", + ReleaseNote: "Test release notes", + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // Create client with test server URL + client := &Client{ + httpClient: &http.Client{}, + baseURL: server.URL, + } + + // Test update check + params := UpdateCheckParams{ + ResourceID: "AUTO_MAA", + CurrentVersion: "4.4.0.0", + Channel: "stable", + CDK: "", + UserAgent: "TestAgent/1.0", + } + + response, err := client.CheckUpdate(params) + if err != nil { + t.Fatalf("CheckUpdate failed: %v", err) + } + + if response.Code != 0 { + t.Errorf("Expected code 0, got %d", response.Code) + } + if response.Data.VersionName != "v4.4.1" { + t.Errorf("Expected version v4.4.1, got %s", response.Data.VersionName) + } + if response.Data.Channel != "stable" { + t.Errorf("Expected channel stable, got %s", response.Data.Channel) + } +} + +func TestCheckUpdateWithCDK(t *testing.T) { + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify CDK parameter + if r.URL.Query().Get("cdk") != "test_cdk_123" { + t.Errorf("Expected cdk=test_cdk_123, got %s", r.URL.Query().Get("cdk")) + } + + // Return mock response with CDK download URL + response := MirrorResponse{ + Code: 0, + Msg: "success", + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{ + VersionName: "v4.4.1", + VersionNumber: 48, + URL: "https://mirrorchyan.com/api/resources/download/test123", + SHA256: "abcd1234", + Channel: "stable", + OS: "", + Arch: "", + UpdateType: "full", + ReleaseNote: "Test release notes", + FileSize: 12345678, + CDKExpiredTime: 1776013593, + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + // Create client with test server URL + client := &Client{ + httpClient: &http.Client{}, + baseURL: server.URL, + } + + // Test update check with CDK + params := UpdateCheckParams{ + ResourceID: "AUTO_MAA", + CurrentVersion: "4.4.0.0", + Channel: "stable", + CDK: "test_cdk_123", + UserAgent: "TestAgent/1.0", + } + + response, err := client.CheckUpdate(params) + if err != nil { + t.Fatalf("CheckUpdate with CDK failed: %v", err) + } + + if response.Data.URL == "" { + t.Error("Expected CDK download URL, but got empty") + } + if response.Data.SHA256 == "" { + t.Error("Expected SHA256 hash, but got empty") + } + if response.Data.FileSize == 0 { + t.Error("Expected file size, but got 0") + } +} + +func TestIsUpdateAvailable(t *testing.T) { + client := NewClient() + + tests := []struct { + name string + response *MirrorResponse + currentVersion string + expected bool + }{ + { + name: "Update available - stable", + response: &MirrorResponse{ + Code: 0, + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{VersionName: "v4.4.1"}, + }, + currentVersion: "4.4.0.0", + expected: true, + }, + { + name: "No update available - same version", + response: &MirrorResponse{ + Code: 0, + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{VersionName: "v4.4.0"}, + }, + currentVersion: "4.4.0.0", + expected: false, + }, + { + name: "API error", + response: &MirrorResponse{ + Code: 1, + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{VersionName: "v4.4.1"}, + }, + currentVersion: "4.4.0.0", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := client.IsUpdateAvailable(test.response, test.currentVersion) + if result != test.expected { + t.Errorf("Expected %t, got %t", test.expected, result) + } + }) + } +} + +func TestHasCDKDownloadURL(t *testing.T) { + client := NewClient() + + tests := []struct { + name string + response *MirrorResponse + expected bool + }{ + { + name: "Has CDK URL", + response: &MirrorResponse{ + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{URL: "https://mirrorchyan.com/download/test"}, + }, + expected: true, + }, + { + name: "No CDK URL", + response: &MirrorResponse{ + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{URL: ""}, + }, + expected: false, + }, + { + name: "Nil response", + response: nil, + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := client.HasCDKDownloadURL(test.response) + if result != test.expected { + t.Errorf("Expected %t, got %t", test.expected, result) + } + }) + } +} + +func TestGetDownloadURL(t *testing.T) { + client := NewClient() + + tests := []struct { + name string + response *MirrorResponse + expected string + }{ + { + name: "CDK URL available", + response: &MirrorResponse{ + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{ + VersionName: "v4.4.1", + URL: "https://mirrorchyan.com/download/test", + }, + }, + expected: "https://mirrorchyan.com/download/test", + }, + { + name: "Official URL fallback", + response: &MirrorResponse{ + Data: struct { + VersionName string `json:"version_name"` + VersionNumber int `json:"version_number"` + URL string `json:"url,omitempty"` + SHA256 string `json:"sha256,omitempty"` + Channel string `json:"channel"` + OS string `json:"os"` + Arch string `json:"arch"` + UpdateType string `json:"update_type,omitempty"` + ReleaseNote string `json:"release_note"` + FileSize int64 `json:"filesize,omitempty"` + CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` + }{ + VersionName: "v4.4.1", + URL: "", + }, + }, + expected: "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.1.zip", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := client.GetDownloadURL(test.response) + if result != test.expected { + t.Errorf("Expected %s, got %s", test.expected, result) + } + }) + } +} \ No newline at end of file diff --git a/Go_Updater/app.rc b/Go_Updater/app.rc new file mode 100644 index 0000000..7cc1f3b --- /dev/null +++ b/Go_Updater/app.rc @@ -0,0 +1,34 @@ +#include + +// Application icon +IDI_ICON1 ICON "icon/AUTO_MAA_Go_Updater.ico" + +// Version information +VS_VERSION_INFO VERSIONINFO +FILEVERSION 1,0,0,0 +PRODUCTVERSION 1,0,0,0 +FILEFLAGSMASK VS_FFI_FILEFLAGSMASK +FILEFLAGS 0x0L +FILEOS VOS__WINDOWS32 +FILETYPE VFT_APP +FILESUBTYPE VFT2_UNKNOWN +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "040904B0" + BEGIN + VALUE "CompanyName", "AUTO MAA Team" + VALUE "FileDescription", "AUTO MAA Go Updater" + VALUE "FileVersion", "1.0.0.0" + VALUE "InternalName", "AUTO_MAA_Go_Updater" + VALUE "LegalCopyright", "Copyright (C) 2025" + VALUE "OriginalFilename", "AUTO_MAA_Go_Updater.exe" + VALUE "ProductName", "AUTO MAA Go Updater" + VALUE "ProductVersion", "1.0.0.0" + END + END + BLOCK "VarFileInfo" + BEGIN + VALUE "Translation", 0x409, 1200 + END +END \ No newline at end of file diff --git a/Go_Updater/app.syso b/Go_Updater/app.syso new file mode 100644 index 0000000..6bf5237 Binary files /dev/null and b/Go_Updater/app.syso differ diff --git a/Go_Updater/assets/assets.go b/Go_Updater/assets/assets.go new file mode 100644 index 0000000..8b4f1b6 --- /dev/null +++ b/Go_Updater/assets/assets.go @@ -0,0 +1,34 @@ +package assets + +import ( + "embed" + "io/fs" +) + +//go:embed config_template.yaml +var EmbeddedAssets embed.FS + +// GetConfigTemplate returns the embedded config template +func GetConfigTemplate() ([]byte, error) { + return EmbeddedAssets.ReadFile("config_template.yaml") +} + +// GetAssetFS returns the embedded filesystem +func GetAssetFS() fs.FS { + return EmbeddedAssets +} + +// ListAssets returns a list of all embedded assets +func ListAssets() ([]string, error) { + var assets []string + err := fs.WalkDir(EmbeddedAssets, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() { + assets = append(assets, path) + } + return nil + }) + return assets, err +} diff --git a/Go_Updater/assets/assets_test.go b/Go_Updater/assets/assets_test.go new file mode 100644 index 0000000..0bcab5a --- /dev/null +++ b/Go_Updater/assets/assets_test.go @@ -0,0 +1,100 @@ +package assets + +import ( + "testing" +) + +func TestGetConfigTemplate(t *testing.T) { + data, err := GetConfigTemplate() + if err != nil { + t.Fatalf("Failed to get config template: %v", err) + } + + if len(data) == 0 { + t.Fatal("Config template is empty") + } + + // Check that it contains expected content + content := string(data) + if !contains(content, "resource_id") { + t.Error("Config template should contain 'resource_id'") + } + + if !contains(content, "current_version") { + t.Error("Config template should contain 'current_version'") + } + + if !contains(content, "user_agent") { + t.Error("Config template should contain 'user_agent'") + } +} + +func TestListAssets(t *testing.T) { + assets, err := ListAssets() + if err != nil { + t.Fatalf("Failed to list assets: %v", err) + } + + if len(assets) == 0 { + t.Fatal("No assets found") + } + + // Check that config template is in the list + found := false + for _, asset := range assets { + if asset == "config_template.yaml" { + found = true + break + } + } + + if !found { + t.Error("config_template.yaml should be in the assets list") + } +} + +func TestGetAssetFS(t *testing.T) { + fs := GetAssetFS() + if fs == nil { + t.Fatal("Asset filesystem should not be nil") + } + + // Try to open the config template + file, err := fs.Open("config_template.yaml") + if err != nil { + t.Fatalf("Failed to open config template from filesystem: %v", err) + } + defer file.Close() + + // Check that we can read from it + buffer := make([]byte, 100) + n, err := file.Read(buffer) + if err != nil && err.Error() != "EOF" { + t.Fatalf("Failed to read from config template: %v", err) + } + + if n == 0 { + t.Fatal("Config template appears to be empty") + } +} + +// Helper function to check if string contains substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > len(substr) && (s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + containsAt(s, substr, 1)))) +} + +func containsAt(s, substr string, start int) bool { + if start >= len(s) { + return false + } + if start+len(substr) > len(s) { + return containsAt(s, substr, start+1) + } + if s[start:start+len(substr)] == substr { + return true + } + return containsAt(s, substr, start+1) +} diff --git a/Go_Updater/assets/config_template.yaml b/Go_Updater/assets/config_template.yaml new file mode 100644 index 0000000..ec297b5 --- /dev/null +++ b/Go_Updater/assets/config_template.yaml @@ -0,0 +1,8 @@ +resource_id: "AUTO_MAA" +current_version: "v1.0.0" +cdk: "" # Will be encrypted when saved +user_agent: "AUTO_MAA_Go_Updater/1.0" +backup_url: "https://backup-download-site.com/releases" +log_level: "info" +auto_check: true +check_interval: 3600 # seconds \ No newline at end of file diff --git a/Go_Updater/build-config.yaml b/Go_Updater/build-config.yaml new file mode 100644 index 0000000..3b8d59c --- /dev/null +++ b/Go_Updater/build-config.yaml @@ -0,0 +1,55 @@ +# Build Configuration for Lightweight Updater + +project: + name: "Lightweight Updater" + module: "lightweight-updater" + description: "轻量级自动更新器" + +version: + default: "1.0.0" + build_time_format: "2006-01-02T15:04:05Z" + +targets: + - name: "windows-amd64" + goos: "windows" + goarch: "amd64" + cgo_enabled: true + output: "lightweight-updater.exe" + +build: + flags: + ldflags: "-s -w" + tags: [] + + optimization: + strip_debug: true + strip_symbols: true + upx_compression: false # Optional, requires UPX + + size_requirements: + max_size_mb: 10 + warn_size_mb: 8 + +assets: + embed: + - "assets/config_template.yaml" + +directories: + build: "build" + dist: "dist" + temp: "temp" + +version_injection: + package: "lightweight-updater/version" + variables: + - name: "Version" + source: "version" + - name: "BuildTime" + source: "build_time" + - name: "GitCommit" + source: "git_commit" + +quality: + run_tests: true + run_lint: false # Optional + format_code: true \ No newline at end of file diff --git a/Go_Updater/build.bat b/Go_Updater/build.bat new file mode 100644 index 0000000..3ca8d38 --- /dev/null +++ b/Go_Updater/build.bat @@ -0,0 +1,99 @@ +@echo off +setlocal enabledelayedexpansion + +echo ======================================== +echo AUTO_MAA_Go_Updater Build Script +echo ======================================== + +:: Set build variables +set VERSION=1.0.0 +set OUTPUT_NAME=AUTO_MAA_Go_Updater.exe +set BUILD_DIR=build +set DIST_DIR=dist + +:: Get current timestamp +for /f "tokens=2 delims==" %%a in ('wmic OS Get localdatetime /value') do set "dt=%%a" +set "YY=%dt:~2,2%" & set "YYYY=%dt:~0,4%" & set "MM=%dt:~4,2%" & set "DD=%dt:~6,2%" +set "HH=%dt:~8,2%" & set "Min=%dt:~10,2%" & set "Sec=%dt:~12,2%" +set "BUILD_TIME=%YYYY%-%MM%-%DD%T%HH%:%Min%:%Sec%Z" + +:: Get git commit hash (if available) +git rev-parse --short HEAD > temp_commit.txt 2>nul +if exist temp_commit.txt ( + set /p GIT_COMMIT=nul 2>&1 + if !ERRORLEVEL! equ 0 ( + rsrc -ico icon\AUTO_MAA_Go_Updater.ico -o app.syso + if !ERRORLEVEL! equ 0 ( + echo Icon resource compiled successfully + ) else ( + echo Warning: Failed to compile icon resource + ) + ) else ( + echo Warning: rsrc not found. Install with: go install github.com/akavel/rsrc@latest + ) +) + +set GOOS=windows +set GOARCH=amd64 +set CGO_ENABLED=1 + +:: Build the application +go build -ldflags="%LDFLAGS%" -o %BUILD_DIR%\%OUTPUT_NAME% . + +if %ERRORLEVEL% neq 0 ( + echo Build failed! + exit /b 1 +) + +echo Build completed successfully! + +:: Get file size +for %%A in (%BUILD_DIR%\%OUTPUT_NAME%) do set FILE_SIZE=%%~zA + +:: Convert bytes to MB +set /a FILE_SIZE_MB=%FILE_SIZE%/1024/1024 + +echo. +echo Build Results: +echo - Output: %BUILD_DIR%\%OUTPUT_NAME% +echo - Size: %FILE_SIZE% bytes (~%FILE_SIZE_MB% MB) + +:: Check if file size is within requirements (<10MB) +if %FILE_SIZE_MB% gtr 10 ( + echo WARNING: File size exceeds 10MB requirement! +) else ( + echo File size meets requirements (^<10MB) +) + +:: Copy to dist directory +copy %BUILD_DIR%\%OUTPUT_NAME% %DIST_DIR%\%OUTPUT_NAME% >nul +echo - Copied to: %DIST_DIR%\%OUTPUT_NAME% + +echo. +echo Build script completed successfully! +echo ======================================== \ No newline at end of file diff --git a/Go_Updater/build.ps1 b/Go_Updater/build.ps1 new file mode 100644 index 0000000..fef8b34 --- /dev/null +++ b/Go_Updater/build.ps1 @@ -0,0 +1,111 @@ +# Lightweight Updater Build Script (PowerShell) +param( + [string]$Version = "1.0.0", + [string]$OutputName = "AUTO_MAA_Go_Updater.exe", + [switch]$Compress = $false +) + +Write-Host "========================================" -ForegroundColor Cyan +Write-Host "AUTO_MAA_Go_Updater Build Script" -ForegroundColor Cyan +Write-Host "========================================" -ForegroundColor Cyan + +# Set build variables +$BuildDir = "build" +$DistDir = "dist" +$BuildTime = (Get-Date).ToString("yyyy-MM-ddTHH:mm:ssZ") + +# Get git commit hash +try { + $GitCommit = (git rev-parse --short HEAD 2>$null).Trim() + if (-not $GitCommit) { $GitCommit = "unknown" } +} catch { + $GitCommit = "unknown" +} + +Write-Host "Build Information:" -ForegroundColor Yellow +Write-Host "- Version: $Version" +Write-Host "- Build Time: $BuildTime" +Write-Host "- Git Commit: $GitCommit" +Write-Host "- Target: Windows 64-bit" +Write-Host "" + +# Create build directories +if (-not (Test-Path $BuildDir)) { New-Item -ItemType Directory -Path $BuildDir | Out-Null } +if (-not (Test-Path $DistDir)) { New-Item -ItemType Directory -Path $DistDir | Out-Null } + +# Set environment variables +$env:GOOS = "windows" +$env:GOARCH = "amd64" +$env:CGO_ENABLED = "1" + +# Set build flags +$LdFlags = "-s -w -X lightweight-updater/version.Version=$Version -X lightweight-updater/version.BuildTime=$BuildTime -X lightweight-updater/version.GitCommit=$GitCommit" + +Write-Host "Building application..." -ForegroundColor Green + +# Ensure icon resource is compiled +if (-not (Test-Path "app.syso")) { + Write-Host "Compiling icon resource..." -ForegroundColor Yellow + if (Get-Command rsrc -ErrorAction SilentlyContinue) { + rsrc -ico icon/AUTO_MAA_Go_Updater.ico -o app.syso + if ($LASTEXITCODE -ne 0) { + Write-Host "Warning: Failed to compile icon resource" -ForegroundColor Yellow + } else { + Write-Host "Icon resource compiled successfully" -ForegroundColor Green + } + } else { + Write-Host "Warning: rsrc not found. Install with: go install github.com/akavel/rsrc@latest" -ForegroundColor Yellow + } +} + +# Build the application +$BuildCommand = "go build -ldflags=`"$LdFlags`" -o $BuildDir\$OutputName ." +Invoke-Expression $BuildCommand + +if ($LASTEXITCODE -ne 0) { + Write-Host "Build failed!" -ForegroundColor Red + exit 1 +} + +Write-Host "Build completed successfully!" -ForegroundColor Green + +# Get file information +$OutputFile = Get-Item "$BuildDir\$OutputName" +$FileSizeMB = [math]::Round($OutputFile.Length / 1MB, 2) + +Write-Host "" +Write-Host "Build Results:" -ForegroundColor Yellow +Write-Host "- Output: $($OutputFile.FullName)" +Write-Host "- Size: $($OutputFile.Length) bytes (~$FileSizeMB MB)" + +# Check file size requirement +if ($FileSizeMB -gt 10) { + Write-Host "WARNING: File size exceeds 10MB requirement!" -ForegroundColor Red +} else { + Write-Host "File size meets requirements (<10MB)" -ForegroundColor Green +} + +# Optional UPX compression +if ($Compress) { + Write-Host "" + Write-Host "Compressing with UPX..." -ForegroundColor Yellow + + if (Get-Command upx -ErrorAction SilentlyContinue) { + upx --best "$BuildDir\$OutputName" + + $CompressedFile = Get-Item "$BuildDir\$OutputName" + $CompressedSizeMB = [math]::Round($CompressedFile.Length / 1MB, 2) + + Write-Host "- Compressed Size: $($CompressedFile.Length) bytes (~$CompressedSizeMB MB)" -ForegroundColor Green + } else { + Write-Host "UPX not found. Skipping compression." -ForegroundColor Yellow + } +} + +# Copy to dist directory +Copy-Item "$BuildDir\$OutputName" "$DistDir\$OutputName" -Force +Write-Host "- Copied to: $DistDir\$OutputName" + +Write-Host "" +Write-Host "Build script completed successfully!" -ForegroundColor Cyan +Write-Host "========================================" -ForegroundColor Cyan \ No newline at end of file diff --git a/Go_Updater/download/manager.go b/Go_Updater/download/manager.go new file mode 100644 index 0000000..f04df08 --- /dev/null +++ b/Go_Updater/download/manager.go @@ -0,0 +1,224 @@ +package download + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "time" +) + +// DownloadProgress represents the current download progress +type DownloadProgress struct { + BytesDownloaded int64 + TotalBytes int64 + Percentage float64 + Speed int64 // bytes per second +} + +// ProgressCallback is called during download to report progress +type ProgressCallback func(DownloadProgress) + +// DownloadManager interface defines download operations +type DownloadManager interface { + Download(url, destination string, progressCallback ProgressCallback) error + DownloadWithResume(url, destination string, progressCallback ProgressCallback) error + ValidateChecksum(filePath, expectedChecksum string) error + SetTimeout(timeout time.Duration) +} + +// Manager implements DownloadManager interface +type Manager struct { + client *http.Client + timeout time.Duration +} + +// NewManager creates a new download manager +func NewManager() *Manager { + return &Manager{ + client: &http.Client{ + Timeout: 30 * time.Second, + }, + timeout: 30 * time.Second, + } +} + +// Download downloads a file from the given URL to the destination path +func (m *Manager) Download(url, destination string, progressCallback ProgressCallback) error { + return m.downloadWithContext(context.Background(), url, destination, progressCallback, false) +} + +// DownloadWithResume downloads a file with resume capability +func (m *Manager) DownloadWithResume(url, destination string, progressCallback ProgressCallback) error { + return m.downloadWithContext(context.Background(), url, destination, progressCallback, true) +} + +// downloadWithContext performs the actual download with context support +func (m *Manager) downloadWithContext(ctx context.Context, url, destination string, progressCallback ProgressCallback, resume bool) error { + // Create destination directory if it doesn't exist + if err := os.MkdirAll(filepath.Dir(destination), 0755); err != nil { + return fmt.Errorf("failed to create destination directory: %w", err) + } + + // Check if file exists for resume + var existingSize int64 + if resume { + if stat, err := os.Stat(destination); err == nil { + existingSize = stat.Size() + } + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + // Add range header for resume + if resume && existingSize > 0 { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existingSize)) + } + + // Execute request + resp, err := m.client.Do(req) + if err != nil { + return fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + // Check response status + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + // Get total size + totalSize := existingSize + if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { + if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil { + totalSize += size + } + } + + // Open destination file + var file *os.File + if resume && existingSize > 0 { + file, err = os.OpenFile(destination, os.O_WRONLY|os.O_APPEND, 0644) + } else { + file, err = os.Create(destination) + existingSize = 0 + } + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer file.Close() + + // Download with progress tracking + return m.copyWithProgress(resp.Body, file, existingSize, totalSize, progressCallback) +} + +// copyWithProgress copies data while tracking progress +func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, totalBytes int64, progressCallback ProgressCallback) error { + buffer := make([]byte, 32*1024) // 32KB buffer + downloaded := startBytes + startTime := time.Now() + lastUpdate := startTime + + for { + n, err := src.Read(buffer) + if n > 0 { + if _, writeErr := dst.Write(buffer[:n]); writeErr != nil { + return fmt.Errorf("failed to write to destination: %w", writeErr) + } + downloaded += int64(n) + + // Update progress every 100ms + now := time.Now() + if progressCallback != nil && now.Sub(lastUpdate) >= 100*time.Millisecond { + elapsed := now.Sub(startTime).Seconds() + speed := int64(0) + if elapsed > 0 { + speed = int64(float64(downloaded-startBytes) / elapsed) + } + + percentage := float64(0) + if totalBytes > 0 { + percentage = float64(downloaded) / float64(totalBytes) * 100 + } + + progressCallback(DownloadProgress{ + BytesDownloaded: downloaded, + TotalBytes: totalBytes, + Percentage: percentage, + Speed: speed, + }) + lastUpdate = now + } + } + + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read from source: %w", err) + } + } + + // Final progress update + if progressCallback != nil { + elapsed := time.Since(startTime).Seconds() + speed := int64(0) + if elapsed > 0 { + speed = int64(float64(downloaded-startBytes) / elapsed) + } + + percentage := float64(100) + if totalBytes > 0 { + percentage = float64(downloaded) / float64(totalBytes) * 100 + } + + progressCallback(DownloadProgress{ + BytesDownloaded: downloaded, + TotalBytes: totalBytes, + Percentage: percentage, + Speed: speed, + }) + } + + return nil +} + +// ValidateChecksum validates the SHA256 checksum of a file +func (m *Manager) ValidateChecksum(filePath, expectedChecksum string) error { + if expectedChecksum == "" { + return nil // No checksum to validate + } + + file, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("failed to open file for checksum validation: %w", err) + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return fmt.Errorf("failed to calculate checksum: %w", err) + } + + actualChecksum := hex.EncodeToString(hash.Sum(nil)) + if actualChecksum != expectedChecksum { + return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum) + } + + return nil +} + +// SetTimeout sets the timeout for download operations +func (m *Manager) SetTimeout(timeout time.Duration) { + m.timeout = timeout + m.client.Timeout = timeout +} \ No newline at end of file diff --git a/Go_Updater/download/manager_test.go b/Go_Updater/download/manager_test.go new file mode 100644 index 0000000..f406e10 --- /dev/null +++ b/Go_Updater/download/manager_test.go @@ -0,0 +1,1392 @@ +package download + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestNewManager(t *testing.T) { + manager := NewManager() + if manager == nil { + t.Fatal("NewManager() returned nil") + } + if manager.client == nil { + t.Fatal("Manager client is nil") + } + if manager.timeout != 30*time.Second { + t.Errorf("Expected timeout 30s, got %v", manager.timeout) + } +} + +func TestDownload(t *testing.T) { + // Create test content + testContent := "This is test content for download" + + // Create test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer server.Close() + + // Create temporary directory + tempDir, err := os.MkdirTemp("", "download_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + // Test download + manager := NewManager() + destPath := filepath.Join(tempDir, "test_file.txt") + + var progressUpdates []DownloadProgress + progressCallback := func(progress DownloadProgress) { + progressUpdates = append(progressUpdates, progress) + } + + err = manager.Download(server.URL, destPath, progressCallback) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Verify file exists and content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if string(content) != testContent { + t.Errorf("Expected content %q, got %q", testContent, string(content)) + } + + // Verify progress updates + if len(progressUpdates) == 0 { + t.Error("No progress updates received") + } + + // Check final progress + finalProgress := progressUpdates[len(progressUpdates)-1] + if finalProgress.Percentage != 100 { + t.Errorf("Expected final percentage 100, got %f", finalProgress.Percentage) + } + if finalProgress.BytesDownloaded != int64(len(testContent)) { + t.Errorf("Expected bytes downloaded %d, got %d", len(testContent), finalProgress.BytesDownloaded) + } +} + +func TestDownloadWithResume(t *testing.T) { + testContent := "This is a longer test content for resume functionality testing" + partialContent := testContent[:20] // First 20 bytes + remainingContent := testContent[20:] // Remaining bytes + + // Create test server that supports range requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rangeHeader := r.Header.Get("Range") + + if rangeHeader != "" { + // Handle range request + if strings.HasPrefix(rangeHeader, "bytes=20-") { + w.Header().Set("Content-Range", fmt.Sprintf("bytes 20-%d/%d", len(testContent)-1, len(testContent))) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(remainingContent))) + w.WriteHeader(http.StatusPartialContent) + w.Write([]byte(remainingContent)) + return + } + } + + // Handle normal request + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer server.Close() + + // Create temporary directory + tempDir, err := os.MkdirTemp("", "download_resume_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + destPath := filepath.Join(tempDir, "test_resume_file.txt") + + // Create partial file + err = os.WriteFile(destPath, []byte(partialContent), 0644) + if err != nil { + t.Fatal(err) + } + + // Test resume download + manager := NewManager() + + var progressUpdates []DownloadProgress + progressCallback := func(progress DownloadProgress) { + progressUpdates = append(progressUpdates, progress) + } + + err = manager.DownloadWithResume(server.URL, destPath, progressCallback) + if err != nil { + t.Fatalf("Resume download failed: %v", err) + } + + // Verify complete file content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read resumed file: %v", err) + } + + if string(content) != testContent { + t.Errorf("Expected content %q, got %q", testContent, string(content)) + } +} + +func TestValidateChecksum(t *testing.T) { + // Create test content and calculate its checksum + testContent := "Test content for checksum validation" + hash := sha256.Sum256([]byte(testContent)) + expectedChecksum := hex.EncodeToString(hash[:]) + + // Create temporary file + tempDir, err := os.MkdirTemp("", "checksum_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + testFile := filepath.Join(tempDir, "test_checksum.txt") + err = os.WriteFile(testFile, []byte(testContent), 0644) + if err != nil { + t.Fatal(err) + } + + manager := NewManager() + + // Test valid checksum + err = manager.ValidateChecksum(testFile, expectedChecksum) + if err != nil { + t.Errorf("Valid checksum validation failed: %v", err) + } + + // Test invalid checksum + invalidChecksum := "invalid_checksum_value" + err = manager.ValidateChecksum(testFile, invalidChecksum) + if err == nil { + t.Error("Invalid checksum validation should have failed") + } + + // Test empty checksum (should pass) + err = manager.ValidateChecksum(testFile, "") + if err != nil { + t.Errorf("Empty checksum validation failed: %v", err) + } + + // Test non-existent file + err = manager.ValidateChecksum("non_existent_file.txt", expectedChecksum) + if err == nil { + t.Error("Non-existent file validation should have failed") + } +} + +func TestDownloadError(t *testing.T) { + manager := NewManager() + + // Test invalid URL + err := manager.Download("invalid-url", "/tmp/test", nil) + if err == nil { + t.Error("Download with invalid URL should have failed") + } + + // Test server error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "download_error_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + destPath := filepath.Join(tempDir, "error_test.txt") + err = manager.Download(server.URL, destPath, nil) + if err == nil { + t.Error("Download with server error should have failed") + } +} + +func TestProgressCallback(t *testing.T) { + testContent := strings.Repeat("A", 1024*100) // 100KB content for more progress updates + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + + // Write content in smaller chunks to trigger multiple progress updates + writer := w.(http.Flusher) + for i := 0; i < len(testContent); i += 1024 { + end := i + 1024 + if end > len(testContent) { + end = len(testContent) + } + w.Write([]byte(testContent[i:end])) + writer.Flush() + time.Sleep(50 * time.Millisecond) // Longer delay to ensure progress updates + } + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "progress_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "progress_test.txt") + + var progressUpdates []DownloadProgress + progressCallback := func(progress DownloadProgress) { + progressUpdates = append(progressUpdates, progress) + + // Validate progress values + if progress.BytesDownloaded < 0 { + t.Errorf("Negative bytes downloaded: %d", progress.BytesDownloaded) + } + if progress.Percentage < 0 || progress.Percentage > 100 { + t.Errorf("Invalid percentage: %f", progress.Percentage) + } + if progress.Speed < 0 { + t.Errorf("Negative speed: %d", progress.Speed) + } + } + + err = manager.Download(server.URL, destPath, progressCallback) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Should have received at least one progress update (final one is guaranteed) + if len(progressUpdates) < 1 { + t.Errorf("Expected at least one progress update, got %d", len(progressUpdates)) + } + + // Final progress should be 100% + finalProgress := progressUpdates[len(progressUpdates)-1] + if finalProgress.Percentage != 100 { + t.Errorf("Expected final percentage 100, got %f", finalProgress.Percentage) + } + + // Verify that we got the correct total bytes + if finalProgress.BytesDownloaded != int64(len(testContent)) { + t.Errorf("Expected bytes downloaded %d, got %d", len(testContent), finalProgress.BytesDownloaded) + } +} + +func TestDownloadWithSources(t *testing.T) { + testContent := "Test content for multi-source download" + + // Create primary server (Mirror酱 - higher priority) + primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer primaryServer.Close() + + // Create backup server (regular download site - lower priority) + backupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer backupServer.Close() + + tempDir, err := os.MkdirTemp("", "multi_source_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "multi_source_test.txt") + + // Test with multiple sources - should use primary (lower priority number) + sources := []DownloadSource{ + {URL: backupServer.URL, Priority: 2, Name: "Backup Server"}, + {URL: primaryServer.URL, Priority: 1, Name: "Mirror Server"}, // Higher priority + } + + var progressUpdates []DownloadProgress + progressCallback := func(progress DownloadProgress) { + progressUpdates = append(progressUpdates, progress) + } + + err = manager.DownloadWithSources(sources, destPath, progressCallback) + if err != nil { + t.Fatalf("Multi-source download failed: %v", err) + } + + // Verify file content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if string(content) != testContent { + t.Errorf("Expected content %q, got %q", testContent, string(content)) + } + + // Verify progress updates + if len(progressUpdates) == 0 { + t.Error("No progress updates received") + } +} + +func TestDownloadWithSourcesFallback(t *testing.T) { + testContent := "Test content for fallback download" + + // Create failing primary server + failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer failingServer.Close() + + // Create working backup server + workingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer workingServer.Close() + + tempDir, err := os.MkdirTemp("", "fallback_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "fallback_test.txt") + + // Test fallback - primary fails, backup succeeds + sources := []DownloadSource{ + {URL: failingServer.URL, Priority: 1, Name: "Failing Server"}, + {URL: workingServer.URL, Priority: 2, Name: "Working Server"}, + } + + err = manager.DownloadWithSources(sources, destPath, nil) + if err != nil { + t.Fatalf("Fallback download failed: %v", err) + } + + // Verify file content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if string(content) != testContent { + t.Errorf("Expected content %q, got %q", testContent, string(content)) + } +} + +func TestDownloadWithSourcesAllFail(t *testing.T) { + // Create two failing servers + failingServer1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer failingServer1.Close() + + failingServer2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer failingServer2.Close() + + tempDir, err := os.MkdirTemp("", "all_fail_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "all_fail_test.txt") + + // Test when all sources fail + sources := []DownloadSource{ + {URL: failingServer1.URL, Priority: 1, Name: "Failing Server 1"}, + {URL: failingServer2.URL, Priority: 2, Name: "Failing Server 2"}, + } + + err = manager.DownloadWithSources(sources, destPath, nil) + if err == nil { + t.Error("Expected download to fail when all sources fail") + } + + // Verify error message contains information about all sources failing + if !strings.Contains(err.Error(), "all download sources failed") { + t.Errorf("Expected error message about all sources failing, got: %v", err) + } +} + +func TestDownloadSourcePriority(t *testing.T) { + testContent1 := "Content from server 1" + testContent2 := "Content from server 2" + + // Create two working servers with different content + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent1))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent1)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent2))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent2)) + })) + defer server2.Close() + + tempDir, err := os.MkdirTemp("", "priority_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "priority_test.txt") + + // Test priority ordering - server2 has higher priority (lower number) + sources := []DownloadSource{ + {URL: server1.URL, Priority: 5, Name: "Server 1"}, + {URL: server2.URL, Priority: 1, Name: "Server 2"}, // Higher priority + } + + err = manager.DownloadWithSources(sources, destPath, nil) + if err != nil { + t.Fatalf("Priority download failed: %v", err) + } + + // Should have downloaded from server2 (higher priority) + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if string(content) != testContent2 { + t.Errorf("Expected content from server 2 %q, got %q", testContent2, string(content)) + } +} + +func TestSetTimeout(t *testing.T) { + manager := NewManager() + + // Test default timeout + if manager.timeout != 30*time.Second { + t.Errorf("Expected default timeout 30s, got %v", manager.timeout) + } + + // Test setting custom timeout + customTimeout := 60 * time.Second + manager.SetTimeout(customTimeout) + + if manager.timeout != customTimeout { + t.Errorf("Expected timeout %v, got %v", customTimeout, manager.timeout) + } + + if manager.client.Timeout != customTimeout { + t.Errorf("Expected client timeout %v, got %v", customTimeout, manager.client.Timeout) + } +} + +func TestDownloadWithSourcesEmptyList(t *testing.T) { + manager := NewManager() + tempDir, err := os.MkdirTemp("", "empty_sources_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + destPath := filepath.Join(tempDir, "empty_test.txt") + + // Test with empty sources list + var sources []DownloadSource + err = manager.DownloadWithSources(sources, destPath, nil) + + if err == nil { + t.Error("Expected error when no download sources provided") + } + + if !strings.Contains(err.Error(), "no download sources provided") { + t.Errorf("Expected error about no sources, got: %v", err) + } +} + +func TestDownloadWithInvalidDestination(t *testing.T) { + testContent := "Test content" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer server.Close() + + manager := NewManager() + + // Test with invalid destination path (directory that can't be created) + invalidPath := string([]byte{0}) + "/invalid/path/file.txt" + + err := manager.Download(server.URL, invalidPath, nil) + if err == nil { + t.Error("Expected error with invalid destination path") + } +} + +func TestDownloadWithTimeout(t *testing.T) { + // Create a server that delays response + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(2 * time.Second) // Delay longer than timeout + w.WriteHeader(http.StatusOK) + w.Write([]byte("delayed content")) + })) + defer server.Close() + + manager := NewManager() + manager.SetTimeout(500 * time.Millisecond) // Short timeout + + tempDir, err := os.MkdirTemp("", "timeout_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + destPath := filepath.Join(tempDir, "timeout_test.txt") + + err = manager.Download(server.URL, destPath, nil) + if err == nil { + t.Error("Expected timeout error") + } + + // Check that it's a timeout-related error + if !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("Expected timeout error, got: %v", err) + } +} + +func TestDownloadWithLargeFile(t *testing.T) { + // Create large test content (1MB) + largeContent := strings.Repeat("A", 1024*1024) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(largeContent))) + w.WriteHeader(http.StatusOK) + + // Write in chunks to simulate real download + chunkSize := 8192 + for i := 0; i < len(largeContent); i += chunkSize { + end := i + chunkSize + if end > len(largeContent) { + end = len(largeContent) + } + w.Write([]byte(largeContent[i:end])) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + time.Sleep(1 * time.Millisecond) // Small delay to allow progress updates + } + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "large_file_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "large_file.txt") + + var progressUpdates []DownloadProgress + progressCallback := func(progress DownloadProgress) { + progressUpdates = append(progressUpdates, progress) + } + + err = manager.Download(server.URL, destPath, progressCallback) + if err != nil { + t.Fatalf("Large file download failed: %v", err) + } + + // Verify file size + stat, err := os.Stat(destPath) + if err != nil { + t.Fatalf("Failed to stat downloaded file: %v", err) + } + + if stat.Size() != int64(len(largeContent)) { + t.Errorf("Expected file size %d, got %d", len(largeContent), stat.Size()) + } + + // Verify we got multiple progress updates + if len(progressUpdates) < 2 { + t.Errorf("Expected multiple progress updates for large file, got %d", len(progressUpdates)) + } + + // Verify final progress is 100% + if len(progressUpdates) > 0 { + finalProgress := progressUpdates[len(progressUpdates)-1] + if finalProgress.Percentage != 100 { + t.Errorf("Expected final percentage 100, got %f", finalProgress.Percentage) + } + } +} + +func TestDownloadResumeWithExistingFile(t *testing.T) { + fullContent := "This is the complete file content for resume testing" + partialContent := fullContent[:20] // First 20 bytes + remainingContent := fullContent[20:] // Remaining bytes + + // Create test server that supports range requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rangeHeader := r.Header.Get("Range") + + if rangeHeader != "" { + // Handle range request + if strings.HasPrefix(rangeHeader, "bytes=20-") { + w.Header().Set("Content-Range", fmt.Sprintf("bytes 20-%d/%d", len(fullContent)-1, len(fullContent))) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(remainingContent))) + w.WriteHeader(http.StatusPartialContent) + w.Write([]byte(remainingContent)) + return + } + } + + // Handle normal request (shouldn't happen in resume test) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(fullContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(fullContent)) + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "resume_existing_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + destPath := filepath.Join(tempDir, "resume_existing.txt") + + // Create partial file first + err = os.WriteFile(destPath, []byte(partialContent), 0644) + if err != nil { + t.Fatal(err) + } + + manager := NewManager() + + // Test resume download + err = manager.DownloadWithResume(server.URL, destPath, nil) + if err != nil { + t.Fatalf("Resume download failed: %v", err) + } + + // Verify complete file content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read resumed file: %v", err) + } + + if string(content) != fullContent { + t.Errorf("Expected complete content %q, got %q", fullContent, string(content)) + } +} + +func TestDownloadWithInvalidChecksum(t *testing.T) { + testContent := "Test content for checksum validation" + + tempDir, err := os.MkdirTemp("", "checksum_invalid_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + testFile := filepath.Join(tempDir, "checksum_test.txt") + err = os.WriteFile(testFile, []byte(testContent), 0644) + if err != nil { + t.Fatal(err) + } + + manager := NewManager() + + // Test with completely wrong checksum + wrongChecksum := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + err = manager.ValidateChecksum(testFile, wrongChecksum) + if err == nil { + t.Error("Expected checksum validation to fail with wrong checksum") + } + + if !strings.Contains(err.Error(), "checksum mismatch") { + t.Errorf("Expected checksum mismatch error, got: %v", err) + } +} + +func TestDownloadSourcesSorting(t *testing.T) { + testContent := "Test content for source sorting" + + // Create multiple servers with different priorities + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fullContent := testContent + " from server1" + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(fullContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(fullContent)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fullContent := testContent + " from server2" + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(fullContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(fullContent)) + })) + defer server2.Close() + + server3 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fullContent := testContent + " from server3" + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(fullContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(fullContent)) + })) + defer server3.Close() + + tempDir, err := os.MkdirTemp("", "sorting_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "sorting_test.txt") + + // Test with sources in random order - should use highest priority (lowest number) + sources := []DownloadSource{ + {URL: server1.URL, Priority: 10, Name: "Server 1"}, // Lowest priority + {URL: server2.URL, Priority: 1, Name: "Server 2"}, // Highest priority + {URL: server3.URL, Priority: 5, Name: "Server 3"}, // Medium priority + } + + err = manager.DownloadWithSources(sources, destPath, nil) + if err != nil { + t.Fatalf("Download with sources failed: %v", err) + } + + // Should have downloaded from server2 (highest priority) + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if !strings.Contains(string(content), "from server2") { + t.Errorf("Expected content from server2, got: %s", string(content)) + } +} + +func TestDownloadProgressAccuracy(t *testing.T) { + // Create content with known size + contentSize := 50000 // 50KB + testContent := strings.Repeat("X", contentSize) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + + // Write in small chunks to get more progress updates + chunkSize := 1024 + for i := 0; i < len(testContent); i += chunkSize { + end := i + chunkSize + if end > len(testContent) { + end = len(testContent) + } + w.Write([]byte(testContent[i:end])) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + time.Sleep(10 * time.Millisecond) // Small delay for progress updates + } + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "progress_accuracy_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "progress_test.txt") + + var progressUpdates []DownloadProgress + progressCallback := func(progress DownloadProgress) { + progressUpdates = append(progressUpdates, progress) + + // Validate progress values are reasonable + if progress.BytesDownloaded < 0 { + t.Errorf("Negative bytes downloaded: %d", progress.BytesDownloaded) + } + if progress.Percentage < 0 || progress.Percentage > 100 { + t.Errorf("Invalid percentage: %f", progress.Percentage) + } + if progress.TotalBytes > 0 && progress.BytesDownloaded > progress.TotalBytes { + t.Errorf("Downloaded bytes (%d) exceed total bytes (%d)", progress.BytesDownloaded, progress.TotalBytes) + } + if progress.Speed < 0 { + t.Errorf("Negative speed: %d", progress.Speed) + } + } + + err = manager.Download(server.URL, destPath, progressCallback) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Verify we got progress updates + if len(progressUpdates) == 0 { + t.Error("Expected at least one progress update") + } + + // Verify progress is monotonically increasing + for i := 1; i < len(progressUpdates); i++ { + if progressUpdates[i].BytesDownloaded < progressUpdates[i-1].BytesDownloaded { + t.Errorf("Progress went backwards: %d -> %d", + progressUpdates[i-1].BytesDownloaded, + progressUpdates[i].BytesDownloaded) + } + } + + // Verify final progress + if len(progressUpdates) > 0 { + final := progressUpdates[len(progressUpdates)-1] + if final.Percentage != 100 { + t.Errorf("Expected final percentage 100, got %f", final.Percentage) + } + if final.BytesDownloaded != int64(contentSize) { + t.Errorf("Expected final bytes %d, got %d", contentSize, final.BytesDownloaded) + } + } +} + +func TestTestSpeeds(t *testing.T) { + testContent := strings.Repeat("A", 64*1024) // 64KB content for speed testing (smaller size) + + // Create fast server + fastServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.Header().Set("Connection", "close") // Ensure connection is closed + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer fastServer.Close() + + // Create slow server + slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.Header().Set("Connection", "close") // Ensure connection is closed + w.WriteHeader(http.StatusOK) + + // Write slowly + chunkSize := 1024 + for i := 0; i < len(testContent); i += chunkSize { + end := i + chunkSize + if end > len(testContent) { + end = len(testContent) + } + w.Write([]byte(testContent[i:end])) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + time.Sleep(10 * time.Millisecond) // Reduced delay + } + })) + defer slowServer.Close() + + // Create failing server + failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Connection", "close") // Ensure connection is closed + w.WriteHeader(http.StatusInternalServerError) + })) + defer failingServer.Close() + + manager := NewManager() + + sources := []DownloadSource{ + {URL: fastServer.URL, Priority: 1, Name: "Fast Server"}, + {URL: slowServer.URL, Priority: 2, Name: "Slow Server"}, + {URL: failingServer.URL, Priority: 3, Name: "Failing Server"}, + } + + testSize := int64(32 * 1024) // 32KB test size (smaller) + timeout := 5 * time.Second // Shorter timeout + + results, err := manager.TestSpeeds(sources, testSize, timeout) + if err != nil { + t.Fatalf("Speed test failed: %v", err) + } + + if len(results) != len(sources) { + t.Errorf("Expected %d results, got %d", len(sources), len(results)) + } + + // Results should be sorted by speed (descending) + for i := 1; i < len(results); i++ { + if results[i-1].Error == nil && results[i].Error == nil { + if results[i-1].Speed < results[i].Speed { + t.Errorf("Results not sorted by speed: %f < %f", results[i-1].Speed, results[i].Speed) + } + } + } + + // Fast server should have higher speed than slow server (if both succeed) + var fastResult, slowResult *SpeedTestResult + for _, result := range results { + if result.Source.Name == "Fast Server" { + fastResult = &result + } else if result.Source.Name == "Slow Server" { + slowResult = &result + } + } + + if fastResult != nil && slowResult != nil { + if fastResult.Error == nil && slowResult.Error == nil { + if fastResult.Speed <= slowResult.Speed { + t.Logf("Fast server speed: %f MB/s, Slow server speed: %f MB/s", + fastResult.Speed, slowResult.Speed) + // Note: Due to the small test size, speeds might be similar, so we'll just log instead of failing + } + } + } + + // Failing server should have an error + var failingResult *SpeedTestResult + for _, result := range results { + if result.Source.Name == "Failing Server" { + failingResult = &result + } + } + + if failingResult != nil && failingResult.Error == nil { + t.Error("Failing server should have an error") + } + + // Give servers time to close connections properly + time.Sleep(100 * time.Millisecond) +} + +func TestDownloadMultiThreaded(t *testing.T) { + // Create large test content (1MB) + contentSize := 1024 * 1024 + testContent := strings.Repeat("B", contentSize) + + // Create server that supports range requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rangeHeader := r.Header.Get("Range") + + if rangeHeader != "" { + // Parse range header (simplified for testing) + var start, end int64 + if n, err := fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end); n == 2 && err == nil { + if start >= 0 && end < int64(len(testContent)) && start <= end { + content := testContent[start:end+1] + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(testContent))) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content))) + w.WriteHeader(http.StatusPartialContent) + w.Write([]byte(content)) + return + } + } + } + + // Handle HEAD request + if r.Method == "HEAD" { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + + // Handle normal GET request + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "multithread_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "multithread_test.txt") + + config := MultiThreadConfig{ + ThreadCount: 4, + ChunkSize: 0, // Use default chunk size + } + + var progressUpdates []DownloadProgress + progressCallback := func(progress DownloadProgress) { + progressUpdates = append(progressUpdates, progress) + } + + err = manager.DownloadMultiThreaded(server.URL, destPath, config, progressCallback) + if err != nil { + t.Fatalf("Multi-threaded download failed: %v", err) + } + + // Verify file content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if len(content) != contentSize { + t.Errorf("Expected content size %d, got %d", contentSize, len(content)) + } + + if string(content) != testContent { + t.Error("Downloaded content doesn't match original") + } + + // Verify progress updates + if len(progressUpdates) == 0 { + t.Error("No progress updates received") + } + + // Final progress should be 100% + if len(progressUpdates) > 0 { + finalProgress := progressUpdates[len(progressUpdates)-1] + if finalProgress.Percentage != 100 { + t.Errorf("Expected final percentage 100, got %f", finalProgress.Percentage) + } + } +} + +func TestDownloadMultiThreadedFallback(t *testing.T) { + testContent := "Test content for fallback" + + // Create server that doesn't support range requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + // Don't set Accept-Ranges header + w.WriteHeader(http.StatusOK) + return + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "multithread_fallback_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "fallback_test.txt") + + config := MultiThreadConfig{ + ThreadCount: 4, + } + + // Should fallback to single-threaded download + err = manager.DownloadMultiThreaded(server.URL, destPath, config, nil) + if err != nil { + t.Fatalf("Fallback download failed: %v", err) + } + + // Verify file content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if string(content) != testContent { + t.Errorf("Expected content %q, got %q", testContent, string(content)) + } +} + +func TestDownloadMultiThreadedNoContentLength(t *testing.T) { + testContent := "Test content without content length" + + // Create server that doesn't provide content length + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + // Don't set Content-Length header + w.WriteHeader(http.StatusOK) + return + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "multithread_no_length_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "no_length_test.txt") + + config := MultiThreadConfig{ + ThreadCount: 4, + } + + // Should fallback to single-threaded download + err = manager.DownloadMultiThreaded(server.URL, destPath, config, nil) + if err != nil { + t.Fatalf("No content length download failed: %v", err) + } + + // Verify file content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if string(content) != testContent { + t.Errorf("Expected content %q, got %q", testContent, string(content)) + } +} + +func TestSpeedTestTimeout(t *testing.T) { + // Create slow server that will timeout + slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Connection", "close") // Ensure connection is closed + time.Sleep(2 * time.Second) // Longer than timeout + w.WriteHeader(http.StatusOK) + w.Write([]byte("slow content")) + })) + defer slowServer.Close() + + manager := NewManager() + + sources := []DownloadSource{ + {URL: slowServer.URL, Priority: 1, Name: "Slow Server"}, + } + + testSize := int64(1024) + timeout := 500 * time.Millisecond // Short timeout + + results, err := manager.TestSpeeds(sources, testSize, timeout) + if err != nil { + t.Fatalf("Speed test failed: %v", err) + } + + if len(results) != 1 { + t.Errorf("Expected 1 result, got %d", len(results)) + } + + // Should have timed out + if results[0].Error == nil { + t.Error("Expected timeout error") + } + + // Give server time to close connections properly + time.Sleep(100 * time.Millisecond) +} + +func TestDownloadMultiThreadedChunkMerging(t *testing.T) { + // Create content with distinct patterns for each chunk + chunk1 := strings.Repeat("1", 1024) + chunk2 := strings.Repeat("2", 1024) + chunk3 := strings.Repeat("3", 1024) + chunk4 := strings.Repeat("4", 1024) + testContent := chunk1 + chunk2 + chunk3 + chunk4 + + // Create server that supports range requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rangeHeader := r.Header.Get("Range") + + if rangeHeader != "" { + var start, end int64 + if n, err := fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end); n == 2 && err == nil { + if start >= 0 && end < int64(len(testContent)) && start <= end { + content := testContent[start:end+1] + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(testContent))) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content))) + w.WriteHeader(http.StatusPartialContent) + w.Write([]byte(content)) + return + } + } + } + + if r.Method == "HEAD" { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "chunk_merge_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "chunk_merge_test.txt") + + config := MultiThreadConfig{ + ThreadCount: 4, + ChunkSize: 1024, // Each chunk is exactly 1024 bytes + } + + err = manager.DownloadMultiThreaded(server.URL, destPath, config, nil) + if err != nil { + t.Fatalf("Multi-threaded download failed: %v", err) + } + + // Verify file content is correctly merged + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if string(content) != testContent { + t.Error("Chunks were not merged correctly") + + // Debug: check each chunk + if len(content) >= 1024 && string(content[0:1024]) != chunk1 { + t.Error("Chunk 1 incorrect") + } + if len(content) >= 2048 && string(content[1024:2048]) != chunk2 { + t.Error("Chunk 2 incorrect") + } + if len(content) >= 3072 && string(content[2048:3072]) != chunk3 { + t.Error("Chunk 3 incorrect") + } + if len(content) >= 4096 && string(content[3072:4096]) != chunk4 { + t.Error("Chunk 4 incorrect") + } + } + + // Verify no temporary chunk files remain + for i := 0; i < 4; i++ { + chunkFile := fmt.Sprintf("%s.part%d", destPath, i) + if _, err := os.Stat(chunkFile); !os.IsNotExist(err) { + t.Errorf("Temporary chunk file %s should have been removed", chunkFile) + } + } +} + +func TestSpeedTestEmptySources(t *testing.T) { + manager := NewManager() + + var sources []DownloadSource + testSize := int64(1024) + timeout := 10 * time.Second + + results, err := manager.TestSpeeds(sources, testSize, timeout) + if err == nil { + t.Error("Expected error for empty sources") + } + + if results != nil { + t.Error("Expected nil results for empty sources") + } + + if !strings.Contains(err.Error(), "no sources provided") { + t.Errorf("Expected 'no sources provided' error, got: %v", err) + } +} + +func TestDownloadMultiThreadedDefaultConfig(t *testing.T) { + testContent := strings.Repeat("C", 8192) // 8KB content + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rangeHeader := r.Header.Get("Range") + + if rangeHeader != "" { + var start, end int64 + if n, err := fmt.Sscanf(rangeHeader, "bytes=%d-%d", &start, &end); n == 2 && err == nil { + if start >= 0 && end < int64(len(testContent)) && start <= end { + content := testContent[start:end+1] + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(testContent))) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(content))) + w.WriteHeader(http.StatusPartialContent) + w.Write([]byte(content)) + return + } + } + } + + if r.Method == "HEAD" { + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + return + } + + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(testContent))) + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusOK) + w.Write([]byte(testContent)) + })) + defer server.Close() + + tempDir, err := os.MkdirTemp("", "default_config_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + manager := NewManager() + destPath := filepath.Join(tempDir, "default_config_test.txt") + + // Test with zero thread count (should default to 4) + config := MultiThreadConfig{ + ThreadCount: 0, + } + + err = manager.DownloadMultiThreaded(server.URL, destPath, config, nil) + if err != nil { + t.Fatalf("Default config download failed: %v", err) + } + + // Verify file content + content, err := os.ReadFile(destPath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + if string(content) != testContent { + t.Errorf("Expected content length %d, got %d", len(testContent), len(content)) + } +} \ No newline at end of file diff --git a/Go_Updater/errors/errors.go b/Go_Updater/errors/errors.go new file mode 100644 index 0000000..f6a7742 --- /dev/null +++ b/Go_Updater/errors/errors.go @@ -0,0 +1,219 @@ +package errors + +import ( + "fmt" + "time" +) + +// ErrorType 定义错误类型枚举 +type ErrorType int + +const ( + NetworkError ErrorType = iota + APIError + FileError + ConfigError + InstallError +) + +// String 返回错误类型的字符串表示 +func (et ErrorType) String() string { + switch et { + case NetworkError: + return "NetworkError" + case APIError: + return "APIError" + case FileError: + return "FileError" + case ConfigError: + return "ConfigError" + case InstallError: + return "InstallError" + default: + return "UnknownError" + } +} + +// UpdaterError 统一的错误结构体 +type UpdaterError struct { + Type ErrorType + Message string + Cause error + Timestamp time.Time + Context map[string]interface{} +} + +// Error 实现error接口 +func (ue *UpdaterError) Error() string { + if ue.Cause != nil { + return fmt.Sprintf("[%s] %s: %v", ue.Type, ue.Message, ue.Cause) + } + return fmt.Sprintf("[%s] %s", ue.Type, ue.Message) +} + +// Unwrap 支持错误链 +func (ue *UpdaterError) Unwrap() error { + return ue.Cause +} + +// NewUpdaterError 创建新的UpdaterError +func NewUpdaterError(errorType ErrorType, message string, cause error) *UpdaterError { + return &UpdaterError{ + Type: errorType, + Message: message, + Cause: cause, + Timestamp: time.Now(), + Context: make(map[string]interface{}), + } +} + +// WithContext 添加上下文信息 +func (ue *UpdaterError) WithContext(key string, value interface{}) *UpdaterError { + ue.Context[key] = value + return ue +} + +// GetUserFriendlyMessage 获取用户友好的错误消息 +func (ue *UpdaterError) GetUserFriendlyMessage() string { + switch ue.Type { + case NetworkError: + return "网络连接失败,请检查网络连接后重试" + case APIError: + return "服务器响应异常,请稍后重试或联系技术支持" + case FileError: + return "文件操作失败,请检查文件权限和磁盘空间" + case ConfigError: + return "配置文件错误,程序将使用默认配置" + case InstallError: + return "安装过程中出现错误,程序将尝试回滚更改" + default: + return "发生未知错误,请联系技术支持" + } +} + +// RetryConfig 重试配置 +type RetryConfig struct { + MaxRetries int + InitialDelay time.Duration + MaxDelay time.Duration + BackoffFactor float64 + RetryableErrors []ErrorType +} + +// DefaultRetryConfig 默认重试配置 +func DefaultRetryConfig() *RetryConfig { + return &RetryConfig{ + MaxRetries: 3, + InitialDelay: time.Second, + MaxDelay: 30 * time.Second, + BackoffFactor: 2.0, + RetryableErrors: []ErrorType{NetworkError, APIError}, + } +} + +// IsRetryable 检查错误是否可重试 +func (rc *RetryConfig) IsRetryable(err error) bool { + if ue, ok := err.(*UpdaterError); ok { + for _, retryableType := range rc.RetryableErrors { + if ue.Type == retryableType { + return true + } + } + } + return false +} + +// CalculateDelay 计算重试延迟时间 +func (rc *RetryConfig) CalculateDelay(attempt int) time.Duration { + delay := time.Duration(float64(rc.InitialDelay) * pow(rc.BackoffFactor, float64(attempt))) + if delay > rc.MaxDelay { + delay = rc.MaxDelay + } + return delay +} + +// pow 简单的幂运算实现 +func pow(base, exp float64) float64 { + result := 1.0 + for i := 0; i < int(exp); i++ { + result *= base + } + return result +} + +// RetryableOperation 可重试的操作函数类型 +type RetryableOperation func() error + +// ExecuteWithRetry 执行带重试的操作 +func ExecuteWithRetry(operation RetryableOperation, config *RetryConfig) error { + var lastErr error + + for attempt := 0; attempt <= config.MaxRetries; attempt++ { + err := operation() + if err == nil { + return nil + } + + lastErr = err + + // 如果不是可重试的错误,直接返回 + if !config.IsRetryable(err) { + return err + } + + // 如果已经是最后一次尝试,不再等待 + if attempt == config.MaxRetries { + break + } + + // 计算延迟时间并等待 + delay := config.CalculateDelay(attempt) + time.Sleep(delay) + } + + return lastErr +} + +// ErrorHandler 错误处理器接口 +type ErrorHandler interface { + HandleError(err error) error + ShouldRetry(err error) bool + GetUserMessage(err error) string +} + +// DefaultErrorHandler 默认错误处理器 +type DefaultErrorHandler struct { + retryConfig *RetryConfig +} + +// NewDefaultErrorHandler 创建默认错误处理器 +func NewDefaultErrorHandler() *DefaultErrorHandler { + return &DefaultErrorHandler{ + retryConfig: DefaultRetryConfig(), + } +} + +// HandleError 处理错误 +func (h *DefaultErrorHandler) HandleError(err error) error { + if ue, ok := err.(*UpdaterError); ok { + // 记录错误上下文 + ue.WithContext("handled_at", time.Now()) + return ue + } + + // 将普通错误包装为UpdaterError + return NewUpdaterError(NetworkError, "未分类错误", err) +} + +// ShouldRetry 判断是否应该重试 +func (h *DefaultErrorHandler) ShouldRetry(err error) bool { + return h.retryConfig.IsRetryable(err) +} + +// GetUserMessage 获取用户友好的错误消息 +func (h *DefaultErrorHandler) GetUserMessage(err error) string { + if ue, ok := err.(*UpdaterError); ok { + return ue.GetUserFriendlyMessage() + } + return "发生未知错误,请联系技术支持" +} \ No newline at end of file diff --git a/Go_Updater/errors/errors_test.go b/Go_Updater/errors/errors_test.go new file mode 100644 index 0000000..bb8dd3f --- /dev/null +++ b/Go_Updater/errors/errors_test.go @@ -0,0 +1,287 @@ +package errors + +import ( + "fmt" + "testing" + "time" +) + +func TestUpdaterError_Error(t *testing.T) { + tests := []struct { + name string + err *UpdaterError + expected string + }{ + { + name: "error with cause", + err: &UpdaterError{ + Type: NetworkError, + Message: "connection failed", + Cause: fmt.Errorf("timeout"), + }, + expected: "[NetworkError] connection failed: timeout", + }, + { + name: "error without cause", + err: &UpdaterError{ + Type: APIError, + Message: "invalid response", + Cause: nil, + }, + expected: "[APIError] invalid response", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.err.Error(); got != tt.expected { + t.Errorf("UpdaterError.Error() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestNewUpdaterError(t *testing.T) { + cause := fmt.Errorf("original error") + err := NewUpdaterError(FileError, "test message", cause) + + if err.Type != FileError { + t.Errorf("Expected type %v, got %v", FileError, err.Type) + } + if err.Message != "test message" { + t.Errorf("Expected message 'test message', got '%v'", err.Message) + } + if err.Cause != cause { + t.Errorf("Expected cause %v, got %v", cause, err.Cause) + } + if err.Context == nil { + t.Error("Expected context to be initialized") + } +} + +func TestUpdaterError_WithContext(t *testing.T) { + err := NewUpdaterError(ConfigError, "test", nil) + err.WithContext("key1", "value1").WithContext("key2", 42) + + if err.Context["key1"] != "value1" { + t.Errorf("Expected context key1 to be 'value1', got %v", err.Context["key1"]) + } + if err.Context["key2"] != 42 { + t.Errorf("Expected context key2 to be 42, got %v", err.Context["key2"]) + } +} + +func TestUpdaterError_GetUserFriendlyMessage(t *testing.T) { + tests := []struct { + errorType ErrorType + expected string + }{ + {NetworkError, "网络连接失败,请检查网络连接后重试"}, + {APIError, "服务器响应异常,请稍后重试或联系技术支持"}, + {FileError, "文件操作失败,请检查文件权限和磁盘空间"}, + {ConfigError, "配置文件错误,程序将使用默认配置"}, + {InstallError, "安装过程中出现错误,程序将尝试回滚更改"}, + } + + for _, tt := range tests { + t.Run(tt.errorType.String(), func(t *testing.T) { + err := NewUpdaterError(tt.errorType, "test", nil) + if got := err.GetUserFriendlyMessage(); got != tt.expected { + t.Errorf("GetUserFriendlyMessage() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestRetryConfig_IsRetryable(t *testing.T) { + config := DefaultRetryConfig() + + tests := []struct { + name string + err error + expected bool + }{ + { + name: "retryable network error", + err: NewUpdaterError(NetworkError, "test", nil), + expected: true, + }, + { + name: "retryable api error", + err: NewUpdaterError(APIError, "test", nil), + expected: true, + }, + { + name: "non-retryable file error", + err: NewUpdaterError(FileError, "test", nil), + expected: false, + }, + { + name: "regular error", + err: fmt.Errorf("regular error"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := config.IsRetryable(tt.err); got != tt.expected { + t.Errorf("IsRetryable() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestRetryConfig_CalculateDelay(t *testing.T) { + config := DefaultRetryConfig() + + tests := []struct { + attempt int + expected time.Duration + }{ + {0, time.Second}, + {1, 2 * time.Second}, + {2, 4 * time.Second}, + {10, 30 * time.Second}, // should be capped at MaxDelay + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("attempt_%d", tt.attempt), func(t *testing.T) { + if got := config.CalculateDelay(tt.attempt); got != tt.expected { + t.Errorf("CalculateDelay(%d) = %v, want %v", tt.attempt, got, tt.expected) + } + }) + } +} + +func TestExecuteWithRetry(t *testing.T) { + config := DefaultRetryConfig() + config.InitialDelay = time.Millisecond // 加快测试速度 + + t.Run("success on first try", func(t *testing.T) { + attempts := 0 + operation := func() error { + attempts++ + return nil + } + + err := ExecuteWithRetry(operation, config) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if attempts != 1 { + t.Errorf("Expected 1 attempt, got %d", attempts) + } + }) + + t.Run("success after retries", func(t *testing.T) { + attempts := 0 + operation := func() error { + attempts++ + if attempts < 3 { + return NewUpdaterError(NetworkError, "temporary failure", nil) + } + return nil + } + + err := ExecuteWithRetry(operation, config) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if attempts != 3 { + t.Errorf("Expected 3 attempts, got %d", attempts) + } + }) + + t.Run("non-retryable error", func(t *testing.T) { + attempts := 0 + operation := func() error { + attempts++ + return NewUpdaterError(FileError, "file not found", nil) + } + + err := ExecuteWithRetry(operation, config) + if err == nil { + t.Error("Expected error, got nil") + } + if attempts != 1 { + t.Errorf("Expected 1 attempt, got %d", attempts) + } + }) + + t.Run("max retries exceeded", func(t *testing.T) { + attempts := 0 + operation := func() error { + attempts++ + return NewUpdaterError(NetworkError, "persistent failure", nil) + } + + err := ExecuteWithRetry(operation, config) + if err == nil { + t.Error("Expected error, got nil") + } + expectedAttempts := config.MaxRetries + 1 + if attempts != expectedAttempts { + t.Errorf("Expected %d attempts, got %d", expectedAttempts, attempts) + } + }) +} + +func TestDefaultErrorHandler(t *testing.T) { + handler := NewDefaultErrorHandler() + + t.Run("handle updater error", func(t *testing.T) { + originalErr := NewUpdaterError(NetworkError, "test", nil) + handledErr := handler.HandleError(originalErr) + + if handledErr != originalErr { + t.Error("Expected same error instance") + } + if originalErr.Context["handled_at"] == nil { + t.Error("Expected handled_at context to be set") + } + }) + + t.Run("handle regular error", func(t *testing.T) { + originalErr := fmt.Errorf("regular error") + handledErr := handler.HandleError(originalErr) + + if ue, ok := handledErr.(*UpdaterError); ok { + if ue.Type != NetworkError { + t.Errorf("Expected NetworkError, got %v", ue.Type) + } + if ue.Cause != originalErr { + t.Error("Expected original error as cause") + } + } else { + t.Error("Expected UpdaterError") + } + }) + + t.Run("should retry", func(t *testing.T) { + retryableErr := NewUpdaterError(NetworkError, "test", nil) + nonRetryableErr := NewUpdaterError(FileError, "test", nil) + + if !handler.ShouldRetry(retryableErr) { + t.Error("Expected network error to be retryable") + } + if handler.ShouldRetry(nonRetryableErr) { + t.Error("Expected file error to not be retryable") + } + }) + + t.Run("get user message", func(t *testing.T) { + updaterErr := NewUpdaterError(NetworkError, "test", nil) + regularErr := fmt.Errorf("regular error") + + userMsg1 := handler.GetUserMessage(updaterErr) + userMsg2 := handler.GetUserMessage(regularErr) + + if userMsg1 != "网络连接失败,请检查网络连接后重试" { + t.Errorf("Unexpected user message: %s", userMsg1) + } + if userMsg2 != "发生未知错误,请联系技术支持" { + t.Errorf("Unexpected user message: %s", userMsg2) + } + }) +} \ No newline at end of file diff --git a/Go_Updater/go.mod b/Go_Updater/go.mod new file mode 100644 index 0000000..466c327 --- /dev/null +++ b/Go_Updater/go.mod @@ -0,0 +1,42 @@ +module lightweight-updater + +go 1.24.5 + +require ( + fyne.io/fyne/v2 v2.6.1 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + fyne.io/systray v1.11.0 // indirect + github.com/BurntSushi/toml v1.4.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fredbi/uri v1.1.0 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fyne-io/gl-js v0.1.0 // indirect + github.com/fyne-io/glfw-js v0.2.0 // indirect + github.com/fyne-io/image v0.1.1 // indirect + github.com/fyne-io/oksvg v0.1.0 // indirect + github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect + github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect + github.com/go-text/render v0.2.0 // indirect + github.com/go-text/typesetting v0.2.1 // indirect + github.com/godbus/dbus/v5 v5.1.0 // indirect + github.com/hack-pad/go-indexeddb v0.3.2 // indirect + github.com/hack-pad/safejs v0.1.0 // indirect + github.com/jeandeaual/go-locale v0.0.0-20241217141322-fcc2cadd6f08 // indirect + github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect + github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rymdport/portal v0.4.1 // indirect + github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect + github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect + github.com/stretchr/testify v1.10.0 // indirect + github.com/yuin/goldmark v1.7.8 // indirect + golang.org/x/image v0.24.0 // indirect + golang.org/x/net v0.35.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect +) diff --git a/Go_Updater/go.sum b/Go_Updater/go.sum new file mode 100644 index 0000000..83677a5 --- /dev/null +++ b/Go_Updater/go.sum @@ -0,0 +1,80 @@ +fyne.io/fyne/v2 v2.6.1 h1:kjPJD4/rBS9m2nHJp+npPSuaK79yj6ObMTuzR6VQ1Is= +fyne.io/fyne/v2 v2.6.1/go.mod h1:YZt7SksjvrSNJCwbWFV32WON3mE1Sr7L41D29qMZ/lU= +fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= +fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= +github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= +github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= +github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= +github.com/fredbi/uri v1.1.0 h1:OqLpTXtyRg9ABReqvDGdJPqZUxs8cyBDOMXBbskCaB8= +github.com/fredbi/uri v1.1.0/go.mod h1:aYTUoAXBOq7BLfVJ8GnKmfcuURosB1xyHDIfWeC/iW4= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fyne-io/gl-js v0.1.0 h1:8luJzNs0ntEAJo+8x8kfUOXujUlP8gB3QMOxO2mUdpM= +github.com/fyne-io/gl-js v0.1.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI= +github.com/fyne-io/glfw-js v0.2.0 h1:8GUZtN2aCoTPNqgRDxK5+kn9OURINhBEBc7M4O1KrmM= +github.com/fyne-io/glfw-js v0.2.0/go.mod h1:Ri6te7rdZtBgBpxLW19uBpp3Dl6K9K/bRaYdJ22G8Jk= +github.com/fyne-io/image v0.1.1 h1:WH0z4H7qfvNUw5l4p3bC1q70sa5+YWVt6HCj7y4VNyA= +github.com/fyne-io/image v0.1.1/go.mod h1:xrfYBh6yspc+KjkgdZU/ifUC9sPA5Iv7WYUBzQKK7JM= +github.com/fyne-io/oksvg v0.1.0 h1:7EUKk3HV3Y2E+qypp3nWqMXD7mum0hCw2KEGhI1fnBw= +github.com/fyne-io/oksvg v0.1.0/go.mod h1:dJ9oEkPiWhnTFNCmRgEze+YNprJF7YRbpjgpWS4kzoI= +github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw2H/zl9nrd3eCZfYV+NfQA= +github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc= +github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU= +github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8= +github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M= +github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0= +github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= +github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= +github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= +github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= +github.com/hack-pad/go-indexeddb v0.3.2 h1:DTqeJJYc1usa45Q5r52t01KhvlSN02+Oq+tQbSBI91A= +github.com/hack-pad/go-indexeddb v0.3.2/go.mod h1:QvfTevpDVlkfomY498LhstjwbPW6QC4VC/lxYb0Kom0= +github.com/hack-pad/safejs v0.1.0 h1:qPS6vjreAqh2amUqj4WNG1zIw7qlRQJ9K10eDKMCnE8= +github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xNCnb+Tio= +github.com/jeandeaual/go-locale v0.0.0-20241217141322-fcc2cadd6f08 h1:wMeVzrPO3mfHIWLZtDcSaGAe2I4PW9B/P5nMkRSwCAc= +github.com/jeandeaual/go-locale v0.0.0-20241217141322-fcc2cadd6f08/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o= +github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M= +github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= +github.com/nicksnyder/go-i18n/v2 v2.5.1/go.mod h1:DrhgsSDZxoAfvVrBVLXoxZn/pN5TXqaDbq7ju94viiQ= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= +github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rymdport/portal v0.4.1 h1:2dnZhjf5uEaeDjeF/yBIeeRo6pNI2QAKm7kq1w/kbnA= +github.com/rymdport/portal v0.4.1/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= +github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE= +github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q= +github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ= +github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef/go.mod h1:nXTWP6+gD5+LUJ8krVhhoeHjvHTutPxMYl5SvkcnJNE= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= +github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ= +golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/Go_Updater/gui/manager.go b/Go_Updater/gui/manager.go new file mode 100644 index 0000000..5e6c9ca --- /dev/null +++ b/Go_Updater/gui/manager.go @@ -0,0 +1,522 @@ +package gui + +import ( + "fmt" + "fyne.io/fyne/v2/app" + "fyne.io/fyne/v2/container" + "fyne.io/fyne/v2/dialog" + "fyne.io/fyne/v2/widget" + "fyne.io/fyne/v2/theme" + "fyne.io/fyne/v2" +) + +// UpdateStatus represents the current status of the update process +type UpdateStatus int + +const ( + StatusChecking UpdateStatus = iota + StatusUpdateAvailable + StatusDownloading + StatusInstalling + StatusCompleted + StatusError +) + +// Config represents the configuration structure for the GUI +type Config struct { + ResourceID string + CurrentVersion string + CDK string + UserAgent string + BackupURL string +} + +// GUIManager interface defines the methods for GUI management +type GUIManager interface { + ShowMainWindow() + UpdateStatus(status UpdateStatus, message string) + ShowProgress(percentage float64) + ShowError(errorMsg string) + ShowConfigDialog() (*Config, error) + Close() +} + +// Manager implements the GUIManager interface +type Manager struct { + app fyne.App + window fyne.Window + statusLabel *widget.Label + progressBar *widget.ProgressBar + actionButton *widget.Button + versionLabel *widget.Label + releaseNotes *widget.RichText + currentStatus UpdateStatus + onCheckUpdate func() + onCancel func() +} + +// NewManager creates a new GUI manager instance +func NewManager() *Manager { + a := app.New() + a.SetIcon(theme.ComputerIcon()) + + w := a.NewWindow("轻量级更新器") + w.Resize(fyne.NewSize(500, 400)) + w.SetFixedSize(false) + w.CenterOnScreen() + + return &Manager{ + app: a, + window: w, + } +} + +// SetCallbacks sets the callback functions for user actions +func (m *Manager) SetCallbacks(onCheckUpdate, onCancel func()) { + m.onCheckUpdate = onCheckUpdate + m.onCancel = onCancel +} + +// ShowMainWindow displays the main application window +func (m *Manager) ShowMainWindow() { + // Create UI components + m.createUIComponents() + + // Create main layout + content := m.createMainLayout() + + m.window.SetContent(content) + m.window.ShowAndRun() +} + +// createUIComponents initializes all UI components +func (m *Manager) createUIComponents() { + // Status label + m.statusLabel = widget.NewLabel("准备检查更新...") + m.statusLabel.Alignment = fyne.TextAlignCenter + + // Progress bar + m.progressBar = widget.NewProgressBar() + m.progressBar.Hide() + + // Version label + m.versionLabel = widget.NewLabel("当前版本: 未知") + m.versionLabel.TextStyle = fyne.TextStyle{Italic: true} + + // Release notes + m.releaseNotes = widget.NewRichText() + m.releaseNotes.Hide() + + // Action button + m.actionButton = widget.NewButton("检查更新", func() { + if m.onCheckUpdate != nil { + m.onCheckUpdate() + } + }) + m.actionButton.Importance = widget.HighImportance +} + +// createMainLayout creates the main window layout +func (m *Manager) createMainLayout() *container.VBox { + // Header section + header := container.NewVBox( + widget.NewCard("", "", container.NewVBox( + widget.NewLabelWithStyle("轻量级更新器", fyne.TextAlignCenter, fyne.TextStyle{Bold: true}), + m.versionLabel, + )), + ) + + // Status section + statusSection := container.NewVBox( + m.statusLabel, + m.progressBar, + ) + + // Release notes section + releaseNotesCard := widget.NewCard("更新日志", "", container.NewScroll(m.releaseNotes)) + releaseNotesCard.Hide() + + // Button section + buttonSection := container.NewHBox( + widget.NewButton("配置", func() { + m.showConfigDialog() + }), + widget.NewSpacer(), + m.actionButton, + ) + + // Main layout + return container.NewVBox( + header, + widget.NewSeparator(), + statusSection, + releaseNotesCard, + widget.NewSeparator(), + buttonSection, + ) +} + +// UpdateStatus updates the current status and UI accordingly +func (m *Manager) UpdateStatus(status UpdateStatus, message string) { + m.currentStatus = status + m.statusLabel.SetText(message) + + switch status { + case StatusChecking: + m.actionButton.SetText("检查中...") + m.actionButton.Disable() + m.progressBar.Hide() + + case StatusUpdateAvailable: + m.actionButton.SetText("开始更新") + m.actionButton.Enable() + m.progressBar.Hide() + + case StatusDownloading: + m.actionButton.SetText("下载中...") + m.actionButton.Disable() + m.progressBar.Show() + + case StatusInstalling: + m.actionButton.SetText("安装中...") + m.actionButton.Disable() + m.progressBar.Show() + + case StatusCompleted: + m.actionButton.SetText("完成") + m.actionButton.Enable() + m.progressBar.Hide() + + case StatusError: + m.actionButton.SetText("重试") + m.actionButton.Enable() + m.progressBar.Hide() + } +} + +// ShowProgress updates the progress bar +func (m *Manager) ShowProgress(percentage float64) { + if percentage < 0 { + percentage = 0 + } + if percentage > 100 { + percentage = 100 + } + + m.progressBar.SetValue(percentage / 100.0) + m.progressBar.Show() +} + +// ShowError displays an error dialog +func (m *Manager) ShowError(errorMsg string) { + dialog.ShowError(fmt.Errorf(errorMsg), m.window) +} + +// ShowConfigDialog displays the configuration dialog +func (m *Manager) ShowConfigDialog() (*Config, error) { + return m.showConfigDialog() +} + +// showConfigDialog creates and shows the configuration dialog +func (m *Manager) showConfigDialog() (*Config, error) { + // Create form entries + resourceIDEntry := widget.NewEntry() + resourceIDEntry.SetPlaceHolder("例如: M9A") + + versionEntry := widget.NewEntry() + versionEntry.SetPlaceHolder("例如: v1.0.0") + + cdkEntry := widget.NewPasswordEntry() + cdkEntry.SetPlaceHolder("输入您的CDK(可选)") + + userAgentEntry := widget.NewEntry() + userAgentEntry.SetText("LightweightUpdater/1.0") + + backupURLEntry := widget.NewEntry() + backupURLEntry.SetPlaceHolder("备用下载地址(可选)") + + // Create form + form := &widget.Form{ + Items: []*widget.FormItem{ + {Text: "资源ID:", Widget: resourceIDEntry}, + {Text: "当前版本:", Widget: versionEntry}, + {Text: "CDK:", Widget: cdkEntry}, + {Text: "用户代理:", Widget: userAgentEntry}, + {Text: "备用下载地址:", Widget: backupURLEntry}, + }, + } + + // Create result channel + resultChan := make(chan *Config, 1) + errorChan := make(chan error, 1) + + // Create dialog + configDialog := dialog.NewCustomConfirm( + "配置设置", + "保存", + "取消", + form, + func(confirmed bool) { + if confirmed { + config := &Config{ + ResourceID: resourceIDEntry.Text, + CurrentVersion: versionEntry.Text, + CDK: cdkEntry.Text, + UserAgent: userAgentEntry.Text, + BackupURL: backupURLEntry.Text, + } + + // Basic validation + if config.ResourceID == "" { + errorChan <- fmt.Errorf("资源ID不能为空") + return + } + if config.CurrentVersion == "" { + errorChan <- fmt.Errorf("当前版本不能为空") + return + } + + resultChan <- config + } else { + errorChan <- fmt.Errorf("用户取消了配置") + } + }, + m.window, + ) + + // Add help text + helpText := widget.NewRichTextFromMarkdown(` +**配置说明:** +- **资源ID**: Mirror酱服务中的资源标识符 +- **当前版本**: 当前软件的版本号 +- **CDK**: Mirror酱服务的访问密钥(可选,提供更好的下载体验) +- **用户代理**: HTTP请求的用户代理字符串 +- **备用下载地址**: 当Mirror酱服务不可用时的备用下载地址 + +如需获取CDK,请访问 [Mirror酱官网](https://mirrorchyan.com) +`) + + // Create container with help text + dialogContent := container.NewVBox( + form, + widget.NewSeparator(), + helpText, + ) + + configDialog.SetContent(dialogContent) + configDialog.Resize(fyne.NewSize(600, 500)) + configDialog.Show() + + // Wait for result + select { + case config := <-resultChan: + return config, nil + case err := <-errorChan: + return nil, err + } +} + +// SetVersionInfo updates the version display +func (m *Manager) SetVersionInfo(version string) { + m.versionLabel.SetText(fmt.Sprintf("当前版本: %s", version)) +} + +// ShowReleaseNotes displays the release notes +func (m *Manager) ShowReleaseNotes(notes string) { + if notes != "" { + m.releaseNotes.ParseMarkdown(notes) + // Find the release notes card and show it + if parent := m.window.Content().(*container.VBox); parent != nil { + for _, obj := range parent.Objects { + if card, ok := obj.(*widget.Card); ok && card.Title == "更新日志" { + card.Show() + break + } + } + } + } +} + +// UpdateStatusWithDetails updates status with detailed information +func (m *Manager) UpdateStatusWithDetails(status UpdateStatus, message string, details map[string]string) { + m.UpdateStatus(status, message) + + // Update version info if provided + if version, ok := details["version"]; ok { + m.SetVersionInfo(version) + } + + // Show release notes if provided + if notes, ok := details["release_notes"]; ok { + m.ShowReleaseNotes(notes) + } + + // Update progress if provided + if progress, ok := details["progress"]; ok { + if p, err := fmt.Sscanf(progress, "%f", new(float64)); err == nil && p == 1 { + var progressValue float64 + fmt.Sscanf(progress, "%f", &progressValue) + m.ShowProgress(progressValue) + } + } +} + +// ShowProgressWithSpeed shows progress with download speed information +func (m *Manager) ShowProgressWithSpeed(percentage float64, speed int64, eta string) { + m.ShowProgress(percentage) + + // Update status with speed and ETA information + speedText := m.formatSpeed(speed) + statusText := fmt.Sprintf("下载中... %.1f%% (%s)", percentage, speedText) + if eta != "" { + statusText += fmt.Sprintf(" - 剩余时间: %s", eta) + } + + m.statusLabel.SetText(statusText) +} + +// formatSpeed formats the download speed for display +func (m *Manager) formatSpeed(bytesPerSecond int64) string { + if bytesPerSecond < 1024 { + return fmt.Sprintf("%d B/s", bytesPerSecond) + } else if bytesPerSecond < 1024*1024 { + return fmt.Sprintf("%.1f KB/s", float64(bytesPerSecond)/1024) + } else { + return fmt.Sprintf("%.1f MB/s", float64(bytesPerSecond)/(1024*1024)) + } +} + +// ShowConfirmDialog shows a confirmation dialog +func (m *Manager) ShowConfirmDialog(title, message string, callback func(bool)) { + dialog.ShowConfirm(title, message, callback, m.window) +} + +// ShowInfoDialog shows an information dialog +func (m *Manager) ShowInfoDialog(title, message string) { + dialog.ShowInformation(title, message, m.window) +} + +// ShowUpdateAvailableDialog shows a dialog when update is available +func (m *Manager) ShowUpdateAvailableDialog(currentVersion, newVersion, releaseNotes string, onConfirm func()) { + content := container.NewVBox( + widget.NewLabel(fmt.Sprintf("发现新版本: %s", newVersion)), + widget.NewLabel(fmt.Sprintf("当前版本: %s", currentVersion)), + widget.NewSeparator(), + ) + + if releaseNotes != "" { + notesWidget := widget.NewRichText() + notesWidget.ParseMarkdown(releaseNotes) + + notesScroll := container.NewScroll(notesWidget) + notesScroll.SetMinSize(fyne.NewSize(400, 200)) + + content.Add(widget.NewLabel("更新内容:")) + content.Add(notesScroll) + } + + dialog.ShowCustomConfirm( + "发现新版本", + "立即更新", + "稍后提醒", + content, + func(confirmed bool) { + if confirmed && onConfirm != nil { + onConfirm() + } + }, + m.window, + ) +} + +// SetActionButtonCallback sets the callback for the main action button +func (m *Manager) SetActionButtonCallback(callback func()) { + if m.actionButton != nil { + m.actionButton.OnTapped = callback + } +} + +// EnableActionButton enables or disables the action button +func (m *Manager) EnableActionButton(enabled bool) { + if m.actionButton != nil { + if enabled { + m.actionButton.Enable() + } else { + m.actionButton.Disable() + } + } +} + +// SetActionButtonText sets the text of the action button +func (m *Manager) SetActionButtonText(text string) { + if m.actionButton != nil { + m.actionButton.SetText(text) + } +} + +// ShowErrorWithRetry shows an error with retry option +func (m *Manager) ShowErrorWithRetry(errorMsg string, onRetry func()) { + dialog.ShowCustomConfirm( + "错误", + "重试", + "取消", + widget.NewLabel(errorMsg), + func(retry bool) { + if retry && onRetry != nil { + onRetry() + } + }, + m.window, + ) +} + +// UpdateProgressBar updates the progress bar with custom styling +func (m *Manager) UpdateProgressBar(percentage float64, color string) { + m.ShowProgress(percentage) + // Note: Fyne doesn't support custom colors easily, but we keep the interface for future enhancement +} + +// HideProgressBar hides the progress bar +func (m *Manager) HideProgressBar() { + if m.progressBar != nil { + m.progressBar.Hide() + } +} + +// ShowProgressBar shows the progress bar +func (m *Manager) ShowProgressBar() { + if m.progressBar != nil { + m.progressBar.Show() + } +} + +// SetWindowTitle sets the window title +func (m *Manager) SetWindowTitle(title string) { + if m.window != nil { + m.window.SetTitle(title) + } +} + +// GetCurrentStatus returns the current update status +func (m *Manager) GetCurrentStatus() UpdateStatus { + return m.currentStatus +} + +// IsWindowVisible returns whether the window is currently visible +func (m *Manager) IsWindowVisible() bool { + return m.window != nil && m.window.Content() != nil +} + +// RefreshUI refreshes the user interface +func (m *Manager) RefreshUI() { + if m.window != nil && m.window.Content() != nil { + m.window.Content().Refresh() + } +} + +// Close closes the application +func (m *Manager) Close() { + if m.window != nil { + m.window.Close() + } +} \ No newline at end of file diff --git a/Go_Updater/gui/manager_test.go b/Go_Updater/gui/manager_test.go new file mode 100644 index 0000000..c03be1b --- /dev/null +++ b/Go_Updater/gui/manager_test.go @@ -0,0 +1,227 @@ +package gui + +import ( + "testing" + "time" +) + +func TestNewManager(t *testing.T) { + manager := NewManager() + if manager == nil { + t.Fatal("NewManager() returned nil") + } + + if manager.app == nil { + t.Error("Manager app is nil") + } + + if manager.window == nil { + t.Error("Manager window is nil") + } +} + +func TestUpdateStatus(t *testing.T) { + manager := NewManager() + manager.createUIComponents() + + // Test different status updates + testCases := []struct { + status UpdateStatus + message string + }{ + {StatusChecking, "检查更新中..."}, + {StatusUpdateAvailable, "发现新版本"}, + {StatusDownloading, "下载中..."}, + {StatusInstalling, "安装中..."}, + {StatusCompleted, "更新完成"}, + {StatusError, "更新失败"}, + } + + for _, tc := range testCases { + manager.UpdateStatus(tc.status, tc.message) + + if manager.GetCurrentStatus() != tc.status { + t.Errorf("Expected status %v, got %v", tc.status, manager.GetCurrentStatus()) + } + + if manager.statusLabel.Text != tc.message { + t.Errorf("Expected message '%s', got '%s'", tc.message, manager.statusLabel.Text) + } + } +} + +func TestShowProgress(t *testing.T) { + manager := NewManager() + manager.createUIComponents() + + // Test progress values + testValues := []float64{0, 25.5, 50, 75.8, 100, 150, -10} + expectedValues := []float64{0, 25.5, 50, 75.8, 100, 100, 0} + + for i, value := range testValues { + manager.ShowProgress(value) + expected := expectedValues[i] / 100.0 + + if manager.progressBar.Value != expected { + t.Errorf("Expected progress %.2f, got %.2f", expected, manager.progressBar.Value) + } + } +} + +func TestSetVersionInfo(t *testing.T) { + manager := NewManager() + manager.createUIComponents() + + version := "v1.2.3" + manager.SetVersionInfo(version) + + expectedText := "当前版本: v1.2.3" + if manager.versionLabel.Text != expectedText { + t.Errorf("Expected version text '%s', got '%s'", expectedText, manager.versionLabel.Text) + } +} + +func TestFormatSpeed(t *testing.T) { + manager := NewManager() + + testCases := []struct { + speed int64 + expected string + }{ + {512, "512 B/s"}, + {1536, "1.5 KB/s"}, + {1048576, "1.0 MB/s"}, + {2621440, "2.5 MB/s"}, + } + + for _, tc := range testCases { + result := manager.formatSpeed(tc.speed) + if result != tc.expected { + t.Errorf("Expected speed format '%s', got '%s'", tc.expected, result) + } + } +} + +func TestShowProgressWithSpeed(t *testing.T) { + manager := NewManager() + manager.createUIComponents() + + percentage := 45.5 + speed := int64(1048576) // 1 MB/s + eta := "2分钟" + + manager.ShowProgressWithSpeed(percentage, speed, eta) + + expectedProgress := percentage / 100.0 + if manager.progressBar.Value != expectedProgress { + t.Errorf("Expected progress %.2f, got %.2f", expectedProgress, manager.progressBar.Value) + } + + expectedStatus := "下载中... 45.5% (1.0 MB/s) - 剩余时间: 2分钟" + if manager.statusLabel.Text != expectedStatus { + t.Errorf("Expected status '%s', got '%s'", expectedStatus, manager.statusLabel.Text) + } +} + +func TestActionButtonStates(t *testing.T) { + manager := NewManager() + manager.createUIComponents() + + // Test enabling/disabling + manager.EnableActionButton(false) + if !manager.actionButton.Disabled() { + t.Error("Action button should be disabled") + } + + manager.EnableActionButton(true) + if manager.actionButton.Disabled() { + t.Error("Action button should be enabled") + } + + // Test text setting + testText := "测试按钮" + manager.SetActionButtonText(testText) + if manager.actionButton.Text != testText { + t.Errorf("Expected button text '%s', got '%s'", testText, manager.actionButton.Text) + } +} + +func TestProgressBarVisibility(t *testing.T) { + manager := NewManager() + manager.createUIComponents() + + // Initially hidden + if manager.progressBar.Visible() { + t.Error("Progress bar should be initially hidden") + } + + // Show progress bar + manager.ShowProgressBar() + if !manager.progressBar.Visible() { + t.Error("Progress bar should be visible after ShowProgressBar()") + } + + // Hide progress bar + manager.HideProgressBar() + if manager.progressBar.Visible() { + t.Error("Progress bar should be hidden after HideProgressBar()") + } +} + +func TestSetCallbacks(t *testing.T) { + manager := NewManager() + + checkUpdateCalled := false + cancelCalled := false + + onCheckUpdate := func() { + checkUpdateCalled = true + } + + onCancel := func() { + cancelCalled = true + } + + manager.SetCallbacks(onCheckUpdate, onCancel) + + // Verify callbacks are set + if manager.onCheckUpdate == nil { + t.Error("onCheckUpdate callback not set") + } + + if manager.onCancel == nil { + t.Error("onCancel callback not set") + } + + // Test callback execution + manager.onCheckUpdate() + if !checkUpdateCalled { + t.Error("onCheckUpdate callback was not called") + } + + manager.onCancel() + if !cancelCalled { + t.Error("onCancel callback was not called") + } +} + +// Benchmark tests for performance +func BenchmarkUpdateStatus(b *testing.B) { + manager := NewManager() + manager.createUIComponents() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.UpdateStatus(StatusDownloading, "下载中...") + } +} + +func BenchmarkShowProgress(b *testing.B) { + manager := NewManager() + manager.createUIComponents() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.ShowProgress(float64(i % 100)) + } +} \ No newline at end of file diff --git a/Go_Updater/icon/AUTO_MAA_Go_Updater.ico b/Go_Updater/icon/AUTO_MAA_Go_Updater.ico new file mode 100644 index 0000000..5520beb Binary files /dev/null and b/Go_Updater/icon/AUTO_MAA_Go_Updater.ico differ diff --git a/Go_Updater/icon/AUTO_MAA_Go_Updater.png b/Go_Updater/icon/AUTO_MAA_Go_Updater.png new file mode 100644 index 0000000..630a52b Binary files /dev/null and b/Go_Updater/icon/AUTO_MAA_Go_Updater.png differ diff --git a/Go_Updater/install/manager.go b/Go_Updater/install/manager.go new file mode 100644 index 0000000..e62bbaa --- /dev/null +++ b/Go_Updater/install/manager.go @@ -0,0 +1,474 @@ +package install + +import ( + "archive/zip" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "syscall" +) + +// ChangesInfo represents the structure of changes.json file +type ChangesInfo struct { + Deleted []string `json:"deleted"` + Added []string `json:"added"` + Modified []string `json:"modified"` +} + +// InstallManager interface defines the contract for installation operations +type InstallManager interface { + ExtractZip(zipPath, destPath string) error + ProcessChanges(changesPath string) (*ChangesInfo, error) + ApplyUpdate(sourcePath, targetPath string, changes *ChangesInfo) error + HandleRunningProcess(processName string) error + CreateTempDir() (string, error) + CleanupTempDir(tempDir string) error +} + +// Manager implements the InstallManager interface +type Manager struct { + tempDirs []string // Track temporary directories for cleanup +} + +// NewManager creates a new install manager instance +func NewManager() *Manager { + return &Manager{ + tempDirs: make([]string, 0), + } +} + +// CreateTempDir creates a temporary directory for extraction +func (m *Manager) CreateTempDir() (string, error) { + tempDir, err := os.MkdirTemp("", "updater_*") + if err != nil { + return "", fmt.Errorf("failed to create temp directory: %w", err) + } + + // Track temp directory for cleanup + m.tempDirs = append(m.tempDirs, tempDir) + return tempDir, nil +} + +// CleanupTempDir removes a temporary directory and its contents +func (m *Manager) CleanupTempDir(tempDir string) error { + if tempDir == "" { + return nil + } + + err := os.RemoveAll(tempDir) + if err != nil { + return fmt.Errorf("failed to cleanup temp directory %s: %w", tempDir, err) + } + + // Remove from tracking list + for i, dir := range m.tempDirs { + if dir == tempDir { + m.tempDirs = append(m.tempDirs[:i], m.tempDirs[i+1:]...) + break + } + } + + return nil +} + +// CleanupAllTempDirs removes all tracked temporary directories +func (m *Manager) CleanupAllTempDirs() error { + var errors []string + + for _, tempDir := range m.tempDirs { + if err := os.RemoveAll(tempDir); err != nil { + errors = append(errors, fmt.Sprintf("failed to cleanup %s: %v", tempDir, err)) + } + } + + m.tempDirs = m.tempDirs[:0] // Clear the slice + + if len(errors) > 0 { + return fmt.Errorf("cleanup errors: %s", strings.Join(errors, "; ")) + } + + return nil +} + +// ExtractZip extracts a ZIP file to the specified destination directory +func (m *Manager) ExtractZip(zipPath, destPath string) error { + // Open ZIP file for reading + reader, err := zip.OpenReader(zipPath) + if err != nil { + return fmt.Errorf("failed to open ZIP file %s: %w", zipPath, err) + } + defer reader.Close() + + // Create destination directory if it doesn't exist + if err := os.MkdirAll(destPath, 0755); err != nil { + return fmt.Errorf("failed to create destination directory %s: %w", destPath, err) + } + + // Extract files + for _, file := range reader.File { + if err := m.extractFile(file, destPath); err != nil { + return fmt.Errorf("failed to extract file %s: %w", file.Name, err) + } + } + + return nil +} + +// extractFile extracts a single file from the ZIP archive +func (m *Manager) extractFile(file *zip.File, destPath string) error { + // Clean the file path to prevent directory traversal attacks + cleanPath := filepath.Clean(file.Name) + if strings.Contains(cleanPath, "..") { + return fmt.Errorf("invalid file path: %s", file.Name) + } + + // Create full destination path + destFile := filepath.Join(destPath, cleanPath) + + // Create directory structure if needed + if file.FileInfo().IsDir() { + return os.MkdirAll(destFile, file.FileInfo().Mode()) + } + + // Create parent directories + if err := os.MkdirAll(filepath.Dir(destFile), 0755); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + + // Open file in ZIP archive + rc, err := file.Open() + if err != nil { + return fmt.Errorf("failed to open file in archive: %w", err) + } + defer rc.Close() + + // Create destination file + outFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.FileInfo().Mode()) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer outFile.Close() + + // Copy file contents + _, err = io.Copy(outFile, rc) + if err != nil { + return fmt.Errorf("failed to copy file contents: %w", err) + } + + return nil +} + +// ProcessChanges reads and parses the changes.json file +func (m *Manager) ProcessChanges(changesPath string) (*ChangesInfo, error) { + // Check if changes.json exists + if _, err := os.Stat(changesPath); os.IsNotExist(err) { + // If changes.json doesn't exist, return empty changes info + return &ChangesInfo{ + Deleted: []string{}, + Added: []string{}, + Modified: []string{}, + }, nil + } + + // Read the changes.json file + data, err := os.ReadFile(changesPath) + if err != nil { + return nil, fmt.Errorf("failed to read changes file %s: %w", changesPath, err) + } + + // Parse JSON + var changes ChangesInfo + if err := json.Unmarshal(data, &changes); err != nil { + return nil, fmt.Errorf("failed to parse changes JSON: %w", err) + } + + return &changes, nil +} + +// HandleRunningProcess handles running processes by renaming files that are in use +func (m *Manager) HandleRunningProcess(processName string) error { + // Get the current executable path + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + + exeDir := filepath.Dir(exePath) + targetFile := filepath.Join(exeDir, processName) + + // Check if the target file exists + if _, err := os.Stat(targetFile); os.IsNotExist(err) { + // File doesn't exist, nothing to handle + return nil + } + + // Try to rename the file to indicate it should be deleted on next startup + oldFile := targetFile + ".old" + + // Remove existing .old file if it exists + if _, err := os.Stat(oldFile); err == nil { + if err := os.Remove(oldFile); err != nil { + return fmt.Errorf("failed to remove existing old file %s: %w", oldFile, err) + } + } + + // Rename the current file to .old + if err := os.Rename(targetFile, oldFile); err != nil { + // If rename fails, the process might be running + // On Windows, we can't rename a running executable + if isFileInUse(err) { + // Mark the file for deletion on next reboot (Windows specific) + return m.markFileForDeletion(targetFile) + } + return fmt.Errorf("failed to rename running process file %s: %w", targetFile, err) + } + + return nil +} + +// isFileInUse checks if the error indicates the file is in use +func isFileInUse(err error) bool { + if err == nil { + return false + } + + // Check for Windows-specific "file in use" errors + if pathErr, ok := err.(*os.PathError); ok { + if errno, ok := pathErr.Err.(syscall.Errno); ok { + // ERROR_SHARING_VIOLATION (32) or ERROR_ACCESS_DENIED (5) + return errno == syscall.Errno(32) || errno == syscall.Errno(5) + } + } + + return strings.Contains(err.Error(), "being used by another process") || + strings.Contains(err.Error(), "access is denied") +} + +// markFileForDeletion marks a file for deletion on next system reboot (Windows specific) +func (m *Manager) markFileForDeletion(filePath string) error { + // This is a Windows-specific implementation + // For now, we'll create a marker file that can be handled by the main application + markerFile := filePath + ".delete_on_restart" + + // Create a marker file + file, err := os.Create(markerFile) + if err != nil { + return fmt.Errorf("failed to create deletion marker file: %w", err) + } + defer file.Close() + + // Write the target file path to the marker + _, err = file.WriteString(filePath) + if err != nil { + return fmt.Errorf("failed to write to marker file: %w", err) + } + + return nil +} + +// DeleteMarkedFiles removes files that were marked for deletion +func (m *Manager) DeleteMarkedFiles(directory string) error { + // Find all .delete_on_restart files + pattern := filepath.Join(directory, "*.delete_on_restart") + matches, err := filepath.Glob(pattern) + if err != nil { + return fmt.Errorf("failed to find marker files: %w", err) + } + + var errors []string + for _, markerFile := range matches { + // Read the target file path + data, err := os.ReadFile(markerFile) + if err != nil { + errors = append(errors, fmt.Sprintf("failed to read marker file %s: %v", markerFile, err)) + continue + } + + targetFile := strings.TrimSpace(string(data)) + + // Try to delete the target file + if err := os.Remove(targetFile); err != nil && !os.IsNotExist(err) { + errors = append(errors, fmt.Sprintf("failed to delete marked file %s: %v", targetFile, err)) + } + + // Remove the marker file + if err := os.Remove(markerFile); err != nil { + errors = append(errors, fmt.Sprintf("failed to remove marker file %s: %v", markerFile, err)) + } + } + + if len(errors) > 0 { + return fmt.Errorf("deletion errors: %s", strings.Join(errors, "; ")) + } + + return nil +} + +// ApplyUpdate applies the update by copying files from source to target directory +func (m *Manager) ApplyUpdate(sourcePath, targetPath string, changes *ChangesInfo) error { + // Create backup directory + backupDir, err := m.createBackupDir(targetPath) + if err != nil { + return fmt.Errorf("failed to create backup directory: %w", err) + } + + // Backup existing files before applying update + if err := m.backupFiles(targetPath, backupDir, changes); err != nil { + return fmt.Errorf("failed to backup files: %w", err) + } + + // Apply the update + if err := m.applyUpdateFiles(sourcePath, targetPath, changes); err != nil { + // Rollback on failure + if rollbackErr := m.rollbackUpdate(targetPath, backupDir); rollbackErr != nil { + return fmt.Errorf("update failed and rollback failed: update error: %w, rollback error: %v", err, rollbackErr) + } + return fmt.Errorf("update failed and was rolled back: %w", err) + } + + // Clean up backup directory after successful update + if err := os.RemoveAll(backupDir); err != nil { + // Log warning but don't fail the update + fmt.Printf("Warning: failed to cleanup backup directory %s: %v\n", backupDir, err) + } + + return nil +} + +// createBackupDir creates a backup directory for the update +func (m *Manager) createBackupDir(targetPath string) (string, error) { + backupDir := filepath.Join(targetPath, ".backup_"+fmt.Sprintf("%d", os.Getpid())) + + if err := os.MkdirAll(backupDir, 0755); err != nil { + return "", fmt.Errorf("failed to create backup directory: %w", err) + } + + return backupDir, nil +} + +// backupFiles creates backups of files that will be modified or deleted +func (m *Manager) backupFiles(targetPath, backupDir string, changes *ChangesInfo) error { + // Backup files that will be modified + for _, file := range changes.Modified { + srcFile := filepath.Join(targetPath, file) + if _, err := os.Stat(srcFile); os.IsNotExist(err) { + continue // File doesn't exist, skip backup + } + + backupFile := filepath.Join(backupDir, file) + if err := m.copyFileWithDirs(srcFile, backupFile); err != nil { + return fmt.Errorf("failed to backup modified file %s: %w", file, err) + } + } + + // Backup files that will be deleted + for _, file := range changes.Deleted { + srcFile := filepath.Join(targetPath, file) + if _, err := os.Stat(srcFile); os.IsNotExist(err) { + continue // File doesn't exist, skip backup + } + + backupFile := filepath.Join(backupDir, file) + if err := m.copyFileWithDirs(srcFile, backupFile); err != nil { + return fmt.Errorf("failed to backup deleted file %s: %w", file, err) + } + } + + return nil +} + +// applyUpdateFiles applies the actual file changes +func (m *Manager) applyUpdateFiles(sourcePath, targetPath string, changes *ChangesInfo) error { + // Delete files marked for deletion + for _, file := range changes.Deleted { + targetFile := filepath.Join(targetPath, file) + if err := os.Remove(targetFile); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to delete file %s: %w", file, err) + } + } + + // Copy new and modified files + filesToCopy := append(changes.Added, changes.Modified...) + for _, file := range filesToCopy { + srcFile := filepath.Join(sourcePath, file) + targetFile := filepath.Join(targetPath, file) + + // Check if source file exists + if _, err := os.Stat(srcFile); os.IsNotExist(err) { + continue // Source file doesn't exist, skip + } + + if err := m.copyFileWithDirs(srcFile, targetFile); err != nil { + return fmt.Errorf("failed to copy file %s: %w", file, err) + } + } + + return nil +} + +// copyFileWithDirs copies a file and creates necessary directories +func (m *Manager) copyFileWithDirs(src, dst string) error { + // Create parent directories + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return fmt.Errorf("failed to create parent directories: %w", err) + } + + // Open source file + srcFile, err := os.Open(src) + if err != nil { + return fmt.Errorf("failed to open source file: %w", err) + } + defer srcFile.Close() + + // Get source file info + srcInfo, err := srcFile.Stat() + if err != nil { + return fmt.Errorf("failed to get source file info: %w", err) + } + + // Create destination file + dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode()) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer dstFile.Close() + + // Copy file contents + _, err = io.Copy(dstFile, srcFile) + if err != nil { + return fmt.Errorf("failed to copy file contents: %w", err) + } + + return nil +} + +// rollbackUpdate restores files from backup in case of update failure +func (m *Manager) rollbackUpdate(targetPath, backupDir string) error { + // Walk through backup directory and restore files + return filepath.Walk(backupDir, func(backupFile string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil // Skip directories + } + + // Calculate relative path + relPath, err := filepath.Rel(backupDir, backupFile) + if err != nil { + return fmt.Errorf("failed to calculate relative path: %w", err) + } + + // Restore file to target location + targetFile := filepath.Join(targetPath, relPath) + if err := m.copyFileWithDirs(backupFile, targetFile); err != nil { + return fmt.Errorf("failed to restore file %s: %w", relPath, err) + } + + return nil + }) +} diff --git a/Go_Updater/install/manager_test.go b/Go_Updater/install/manager_test.go new file mode 100644 index 0000000..5bf9381 --- /dev/null +++ b/Go_Updater/install/manager_test.go @@ -0,0 +1,1033 @@ +package install + +import ( + "archive/zip" + "fmt" + "os" + "path/filepath" + "testing" +) + +func TestNewManager(t *testing.T) { + manager := NewManager() + if manager == nil { + t.Fatal("NewManager() returned nil") + } + if manager.tempDirs == nil { + t.Fatal("tempDirs slice not initialized") + } +} + +func TestCreateTempDir(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + + // Verify directory exists + if _, err := os.Stat(tempDir); os.IsNotExist(err) { + t.Fatalf("Temp directory was not created: %s", tempDir) + } + + // Verify it's tracked + if len(manager.tempDirs) != 1 || manager.tempDirs[0] != tempDir { + t.Fatalf("Temp directory not properly tracked") + } + + // Cleanup + defer manager.CleanupTempDir(tempDir) +} + +func TestCleanupTempDir(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + + // Create a test file in temp directory + testFile := filepath.Join(tempDir, "test.txt") + if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Cleanup + err = manager.CleanupTempDir(tempDir) + if err != nil { + t.Fatalf("CleanupTempDir() failed: %v", err) + } + + // Verify directory is removed + if _, err := os.Stat(tempDir); !os.IsNotExist(err) { + t.Fatalf("Temp directory was not removed: %s", tempDir) + } + + // Verify it's no longer tracked + if len(manager.tempDirs) != 0 { + t.Fatalf("Temp directory still tracked after cleanup") + } +} + +func TestExtractZip(t *testing.T) { + manager := NewManager() + + // Create a temporary ZIP file for testing + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + zipPath := filepath.Join(tempDir, "test.zip") + extractDir := filepath.Join(tempDir, "extract") + + // Create test ZIP file + if err := createTestZip(zipPath); err != nil { + t.Fatalf("Failed to create test ZIP: %v", err) + } + + // Extract ZIP + err = manager.ExtractZip(zipPath, extractDir) + if err != nil { + t.Fatalf("ExtractZip() failed: %v", err) + } + + // Verify extracted files + testFile := filepath.Join(extractDir, "test.txt") + if _, err := os.Stat(testFile); os.IsNotExist(err) { + t.Fatalf("Extracted file not found: %s", testFile) + } + + // Verify file contents + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read extracted file: %v", err) + } + + expected := "Hello, World!" + if string(content) != expected { + t.Fatalf("File content mismatch. Expected: %s, Got: %s", expected, string(content)) + } + + // Verify directory structure + subDir := filepath.Join(extractDir, "subdir") + if _, err := os.Stat(subDir); os.IsNotExist(err) { + t.Fatalf("Extracted subdirectory not found: %s", subDir) + } + + subFile := filepath.Join(subDir, "sub.txt") + if _, err := os.Stat(subFile); os.IsNotExist(err) { + t.Fatalf("Extracted subdirectory file not found: %s", subFile) + } +} + +func TestExtractZipInvalidPath(t *testing.T) { + manager := NewManager() + + // Test with non-existent ZIP file + err := manager.ExtractZip("nonexistent.zip", "dest") + if err == nil { + t.Fatal("ExtractZip() should fail with non-existent ZIP file") + } +} + +func TestExtractZipDirectoryTraversal(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + zipPath := filepath.Join(tempDir, "malicious.zip") + extractDir := filepath.Join(tempDir, "extract") + + // Create ZIP with directory traversal attempt + if err := createMaliciousZip(zipPath); err != nil { + t.Fatalf("Failed to create malicious ZIP: %v", err) + } + + // Extract should fail or sanitize the path + err = manager.ExtractZip(zipPath, extractDir) + if err != nil { + // This is expected behavior - the extraction should fail + return + } + + // If extraction succeeded, verify no files were created outside extract dir + parentDir := filepath.Dir(extractDir) + maliciousFile := filepath.Join(parentDir, "malicious.txt") + if _, err := os.Stat(maliciousFile); !os.IsNotExist(err) { + t.Fatal("Directory traversal attack succeeded - malicious file created outside extract directory") + } +} + +// Helper function to create a test ZIP file +func createTestZip(zipPath string) error { + file, err := os.Create(zipPath) + if err != nil { + return err + } + defer file.Close() + + writer := zip.NewWriter(file) + defer writer.Close() + + // Add a test file + f1, err := writer.Create("test.txt") + if err != nil { + return err + } + _, err = f1.Write([]byte("Hello, World!")) + if err != nil { + return err + } + + // Add a subdirectory and file + f2, err := writer.Create("subdir/sub.txt") + if err != nil { + return err + } + _, err = f2.Write([]byte("Subdirectory file")) + if err != nil { + return err + } + + return nil +} + +func TestProcessChanges(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Test with valid changes.json + changesPath := filepath.Join(tempDir, "changes.json") + changesData := `{ + "deleted": ["old_file.txt", "deprecated/module.dll"], + "added": ["new_file.txt", "features/new_module.dll"], + "modified": ["main.exe", "config.ini"] + }` + + if err := os.WriteFile(changesPath, []byte(changesData), 0644); err != nil { + t.Fatalf("Failed to create test changes.json: %v", err) + } + + changes, err := manager.ProcessChanges(changesPath) + if err != nil { + t.Fatalf("ProcessChanges() failed: %v", err) + } + + // Verify parsed data + if len(changes.Deleted) != 2 { + t.Fatalf("Expected 2 deleted files, got %d", len(changes.Deleted)) + } + if changes.Deleted[0] != "old_file.txt" { + t.Fatalf("Expected first deleted file to be 'old_file.txt', got '%s'", changes.Deleted[0]) + } + + if len(changes.Added) != 2 { + t.Fatalf("Expected 2 added files, got %d", len(changes.Added)) + } + + if len(changes.Modified) != 2 { + t.Fatalf("Expected 2 modified files, got %d", len(changes.Modified)) + } +} + +func TestProcessChangesNonExistent(t *testing.T) { + manager := NewManager() + + // Test with non-existent changes.json + changes, err := manager.ProcessChanges("nonexistent.json") + if err != nil { + t.Fatalf("ProcessChanges() should not fail with non-existent file: %v", err) + } + + // Should return empty changes + if len(changes.Deleted) != 0 || len(changes.Added) != 0 || len(changes.Modified) != 0 { + t.Fatalf("Expected empty changes for non-existent file") + } +} + +func TestProcessChangesInvalidJSON(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Test with invalid JSON + changesPath := filepath.Join(tempDir, "invalid.json") + invalidData := `{"deleted": ["file1.txt", "file2.txt"` // Missing closing bracket + + if err := os.WriteFile(changesPath, []byte(invalidData), 0644); err != nil { + t.Fatalf("Failed to create invalid JSON file: %v", err) + } + + _, err = manager.ProcessChanges(changesPath) + if err == nil { + t.Fatal("ProcessChanges() should fail with invalid JSON") + } +} + +func TestHandleRunningProcess(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Create a test executable file + testExe := filepath.Join(tempDir, "test.exe") + if err := os.WriteFile(testExe, []byte("test executable"), 0755); err != nil { + t.Fatalf("Failed to create test executable: %v", err) + } + + // Test handling non-existent process + err = manager.HandleRunningProcess("nonexistent.exe") + if err != nil { + t.Fatalf("HandleRunningProcess() should not fail with non-existent process: %v", err) + } +} + +func TestDeleteMarkedFiles(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Create test files to be deleted + testFile1 := filepath.Join(tempDir, "file1.txt") + testFile2 := filepath.Join(tempDir, "file2.txt") + + if err := os.WriteFile(testFile1, []byte("test1"), 0644); err != nil { + t.Fatalf("Failed to create test file1: %v", err) + } + if err := os.WriteFile(testFile2, []byte("test2"), 0644); err != nil { + t.Fatalf("Failed to create test file2: %v", err) + } + + // Create marker files + marker1 := testFile1 + ".delete_on_restart" + marker2 := testFile2 + ".delete_on_restart" + + if err := os.WriteFile(marker1, []byte(testFile1), 0644); err != nil { + t.Fatalf("Failed to create marker file1: %v", err) + } + if err := os.WriteFile(marker2, []byte(testFile2), 0644); err != nil { + t.Fatalf("Failed to create marker file2: %v", err) + } + + // Delete marked files + err = manager.DeleteMarkedFiles(tempDir) + if err != nil { + t.Fatalf("DeleteMarkedFiles() failed: %v", err) + } + + // Verify files are deleted + if _, err := os.Stat(testFile1); !os.IsNotExist(err) { + t.Fatalf("Test file1 should be deleted") + } + if _, err := os.Stat(testFile2); !os.IsNotExist(err) { + t.Fatalf("Test file2 should be deleted") + } + + // Verify marker files are deleted + if _, err := os.Stat(marker1); !os.IsNotExist(err) { + t.Fatalf("Marker file1 should be deleted") + } + if _, err := os.Stat(marker2); !os.IsNotExist(err) { + t.Fatalf("Marker file2 should be deleted") + } +} + +func TestApplyUpdate(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Create source and target directories + sourceDir := filepath.Join(tempDir, "source") + targetDir := filepath.Join(tempDir, "target") + + if err := os.MkdirAll(sourceDir, 0755); err != nil { + t.Fatalf("Failed to create source directory: %v", err) + } + if err := os.MkdirAll(targetDir, 0755); err != nil { + t.Fatalf("Failed to create target directory: %v", err) + } + + // Create test files in source directory + newFile := filepath.Join(sourceDir, "new_file.txt") + modifiedFile := filepath.Join(sourceDir, "modified_file.txt") + + if err := os.WriteFile(newFile, []byte("new content"), 0644); err != nil { + t.Fatalf("Failed to create new file: %v", err) + } + if err := os.WriteFile(modifiedFile, []byte("updated content"), 0644); err != nil { + t.Fatalf("Failed to create modified file: %v", err) + } + + // Create existing files in target directory + existingModified := filepath.Join(targetDir, "modified_file.txt") + existingDeleted := filepath.Join(targetDir, "deleted_file.txt") + + if err := os.WriteFile(existingModified, []byte("old content"), 0644); err != nil { + t.Fatalf("Failed to create existing modified file: %v", err) + } + if err := os.WriteFile(existingDeleted, []byte("to be deleted"), 0644); err != nil { + t.Fatalf("Failed to create file to be deleted: %v", err) + } + + // Define changes + changes := &ChangesInfo{ + Added: []string{"new_file.txt"}, + Modified: []string{"modified_file.txt"}, + Deleted: []string{"deleted_file.txt"}, + } + + // Apply update + err = manager.ApplyUpdate(sourceDir, targetDir, changes) + if err != nil { + t.Fatalf("ApplyUpdate() failed: %v", err) + } + + // Verify new file was added + newTargetFile := filepath.Join(targetDir, "new_file.txt") + if _, err := os.Stat(newTargetFile); os.IsNotExist(err) { + t.Fatalf("New file was not added to target directory") + } + + // Verify modified file was updated + content, err := os.ReadFile(existingModified) + if err != nil { + t.Fatalf("Failed to read modified file: %v", err) + } + if string(content) != "updated content" { + t.Fatalf("Modified file content incorrect. Expected: 'updated content', Got: '%s'", string(content)) + } + + // Verify deleted file was removed + if _, err := os.Stat(existingDeleted); !os.IsNotExist(err) { + t.Fatalf("Deleted file still exists") + } +} + +func TestApplyUpdateWithRollback(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Create source and target directories + sourceDir := filepath.Join(tempDir, "source") + targetDir := filepath.Join(tempDir, "target") + + if err := os.MkdirAll(sourceDir, 0755); err != nil { + t.Fatalf("Failed to create source directory: %v", err) + } + if err := os.MkdirAll(targetDir, 0755); err != nil { + t.Fatalf("Failed to create target directory: %v", err) + } + + // Create existing file in target directory + existingFile := filepath.Join(targetDir, "existing_file.txt") + originalContent := "original content" + if err := os.WriteFile(existingFile, []byte(originalContent), 0644); err != nil { + t.Fatalf("Failed to create existing file: %v", err) + } + + // Create a source file that will cause a copy failure by making target read-only + sourceFile := filepath.Join(sourceDir, "existing_file.txt") + if err := os.WriteFile(sourceFile, []byte("new content"), 0644); err != nil { + t.Fatalf("Failed to create source file: %v", err) + } + + // Make target directory read-only to cause copy failure + readOnlyDir := filepath.Join(targetDir, "readonly") + if err := os.MkdirAll(readOnlyDir, 0755); err != nil { + t.Fatalf("Failed to create readonly directory: %v", err) + } + + // Create a file in readonly directory that we'll try to modify + readOnlyFile := filepath.Join(readOnlyDir, "readonly_file.txt") + if err := os.WriteFile(readOnlyFile, []byte("readonly content"), 0644); err != nil { + t.Fatalf("Failed to create readonly file: %v", err) + } + + // Create source file for readonly file + sourceReadOnlyFile := filepath.Join(sourceDir, "readonly", "readonly_file.txt") + if err := os.MkdirAll(filepath.Dir(sourceReadOnlyFile), 0755); err != nil { + t.Fatalf("Failed to create source readonly directory: %v", err) + } + if err := os.WriteFile(sourceReadOnlyFile, []byte("new readonly content"), 0644); err != nil { + t.Fatalf("Failed to create source readonly file: %v", err) + } + + // Make the readonly directory read-only (Windows specific) + if err := os.Chmod(readOnlyDir, 0444); err != nil { + t.Fatalf("Failed to make directory read-only: %v", err) + } + + // Restore permissions after test + defer func() { + os.Chmod(readOnlyDir, 0755) + os.RemoveAll(readOnlyDir) + }() + + // Define changes that will cause failure due to read-only directory + changes := &ChangesInfo{ + Modified: []string{"existing_file.txt", "readonly/readonly_file.txt"}, + } + + // Apply update (should fail and rollback) + err = manager.ApplyUpdate(sourceDir, targetDir, changes) + if err == nil { + // On some systems, the read-only test might not work as expected + // Let's just verify the update completed successfully in this case + t.Log("Update completed successfully (read-only test may not work on this system)") + return + } + + // Verify rollback occurred - original file should be restored + content, err := os.ReadFile(existingFile) + if err != nil { + t.Fatalf("Failed to read file after rollback: %v", err) + } + if string(content) != originalContent { + t.Fatalf("Rollback failed. Expected: '%s', Got: '%s'", originalContent, string(content)) + } +} + +func TestCopyFileWithDirs(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Create source file + srcFile := filepath.Join(tempDir, "source.txt") + content := "test content" + if err := os.WriteFile(srcFile, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create source file: %v", err) + } + + // Copy to destination with nested directories + dstFile := filepath.Join(tempDir, "nested", "dir", "destination.txt") + + err = manager.copyFileWithDirs(srcFile, dstFile) + if err != nil { + t.Fatalf("copyFileWithDirs() failed: %v", err) + } + + // Verify file was copied + if _, err := os.Stat(dstFile); os.IsNotExist(err) { + t.Fatalf("Destination file was not created") + } + + // Verify content + dstContent, err := os.ReadFile(dstFile) + if err != nil { + t.Fatalf("Failed to read destination file: %v", err) + } + if string(dstContent) != content { + t.Fatalf("File content mismatch. Expected: '%s', Got: '%s'", content, string(dstContent)) + } + + // Verify parent directories were created + parentDir := filepath.Join(tempDir, "nested", "dir") + if _, err := os.Stat(parentDir); os.IsNotExist(err) { + t.Fatalf("Parent directories were not created") + } +} + +// Helper function to create a malicious ZIP file with directory traversal +func createMaliciousZip(zipPath string) error { + file, err := os.Create(zipPath) + if err != nil { + return err + } + defer file.Close() + + writer := zip.NewWriter(file) + defer writer.Close() + + // Add a file with directory traversal path + f1, err := writer.Create("../malicious.txt") + if err != nil { + return err + } + _, err = f1.Write([]byte("This should not be extracted outside the target directory")) + if err != nil { + return err + } + + return nil +} + +func TestCleanupAllTempDirs(t *testing.T) { + manager := NewManager() + + // Create multiple temp directories + tempDir1, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("Failed to create temp dir 1: %v", err) + } + + tempDir2, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("Failed to create temp dir 2: %v", err) + } + + tempDir3, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("Failed to create temp dir 3: %v", err) + } + + // Verify all directories exist + for _, dir := range []string{tempDir1, tempDir2, tempDir3} { + if _, err := os.Stat(dir); os.IsNotExist(err) { + t.Fatalf("Temp directory should exist: %s", dir) + } + } + + // Verify manager is tracking all directories + if len(manager.tempDirs) != 3 { + t.Fatalf("Expected 3 tracked temp dirs, got %d", len(manager.tempDirs)) + } + + // Cleanup all temp directories + err = manager.CleanupAllTempDirs() + if err != nil { + t.Fatalf("CleanupAllTempDirs failed: %v", err) + } + + // Verify all directories are removed + for _, dir := range []string{tempDir1, tempDir2, tempDir3} { + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Fatalf("Temp directory should be removed: %s", dir) + } + } + + // Verify manager is no longer tracking directories + if len(manager.tempDirs) != 0 { + t.Fatalf("Expected 0 tracked temp dirs after cleanup, got %d", len(manager.tempDirs)) + } +} + +func TestExtractZipWithNestedDirectories(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + zipPath := filepath.Join(tempDir, "nested.zip") + extractDir := filepath.Join(tempDir, "extract") + + // Create ZIP with nested directory structure + if err := createNestedZip(zipPath); err != nil { + t.Fatalf("Failed to create nested ZIP: %v", err) + } + + // Extract ZIP + err = manager.ExtractZip(zipPath, extractDir) + if err != nil { + t.Fatalf("ExtractZip() failed: %v", err) + } + + // Verify nested structure was created + expectedFiles := []string{ + "level1/file1.txt", + "level1/level2/file2.txt", + "level1/level2/level3/file3.txt", + } + + for _, expectedFile := range expectedFiles { + fullPath := filepath.Join(extractDir, expectedFile) + if _, err := os.Stat(fullPath); os.IsNotExist(err) { + t.Fatalf("Expected nested file not found: %s", expectedFile) + } + + // Verify file content + content, err := os.ReadFile(fullPath) + if err != nil { + t.Fatalf("Failed to read nested file %s: %v", expectedFile, err) + } + + expectedContent := fmt.Sprintf("Content of %s", filepath.Base(expectedFile)) + if string(content) != expectedContent { + t.Fatalf("File content mismatch for %s. Expected: %s, Got: %s", + expectedFile, expectedContent, string(content)) + } + } +} + +func TestProcessChangesWithComplexStructure(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Create complex changes.json with nested paths + changesPath := filepath.Join(tempDir, "complex_changes.json") + changesData := `{ + "deleted": [ + "old/legacy/file1.txt", + "deprecated/module.dll", + "temp/cache.dat" + ], + "added": [ + "new/features/feature1.dll", + "resources/icons/icon.png", + "config/new_settings.json" + ], + "modified": [ + "core/main.exe", + "lib/utils.dll", + "data/database.db" + ] + }` + + if err := os.WriteFile(changesPath, []byte(changesData), 0644); err != nil { + t.Fatalf("Failed to create complex changes.json: %v", err) + } + + changes, err := manager.ProcessChanges(changesPath) + if err != nil { + t.Fatalf("ProcessChanges() failed: %v", err) + } + + // Verify all changes were parsed correctly + expectedDeleted := []string{"old/legacy/file1.txt", "deprecated/module.dll", "temp/cache.dat"} + expectedAdded := []string{"new/features/feature1.dll", "resources/icons/icon.png", "config/new_settings.json"} + expectedModified := []string{"core/main.exe", "lib/utils.dll", "data/database.db"} + + if len(changes.Deleted) != len(expectedDeleted) { + t.Fatalf("Expected %d deleted files, got %d", len(expectedDeleted), len(changes.Deleted)) + } + + for i, expected := range expectedDeleted { + if changes.Deleted[i] != expected { + t.Errorf("Deleted[%d]: expected %s, got %s", i, expected, changes.Deleted[i]) + } + } + + if len(changes.Added) != len(expectedAdded) { + t.Fatalf("Expected %d added files, got %d", len(expectedAdded), len(changes.Added)) + } + + for i, expected := range expectedAdded { + if changes.Added[i] != expected { + t.Errorf("Added[%d]: expected %s, got %s", i, expected, changes.Added[i]) + } + } + + if len(changes.Modified) != len(expectedModified) { + t.Fatalf("Expected %d modified files, got %d", len(expectedModified), len(changes.Modified)) + } + + for i, expected := range expectedModified { + if changes.Modified[i] != expected { + t.Errorf("Modified[%d]: expected %s, got %s", i, expected, changes.Modified[i]) + } + } +} + +func TestApplyUpdateWithNestedPaths(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Create source and target directories + sourceDir := filepath.Join(tempDir, "source") + targetDir := filepath.Join(tempDir, "target") + + if err := os.MkdirAll(sourceDir, 0755); err != nil { + t.Fatalf("Failed to create source directory: %v", err) + } + if err := os.MkdirAll(targetDir, 0755); err != nil { + t.Fatalf("Failed to create target directory: %v", err) + } + + // Create nested source files + nestedFiles := map[string]string{ + "level1/new_file.txt": "New file content", + "level1/level2/modified.txt": "Modified content", + "features/feature1/config.json": `{"enabled": true}`, + } + + for filePath, content := range nestedFiles { + fullPath := filepath.Join(sourceDir, filePath) + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + t.Fatalf("Failed to create source directory for %s: %v", filePath, err) + } + if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create source file %s: %v", filePath, err) + } + } + + // Create existing target files + existingFiles := map[string]string{ + "level1/level2/modified.txt": "Old content", + "old/deprecated.txt": "To be deleted", + } + + for filePath, content := range existingFiles { + fullPath := filepath.Join(targetDir, filePath) + if err := os.MkdirAll(filepath.Dir(fullPath), 0755); err != nil { + t.Fatalf("Failed to create target directory for %s: %v", filePath, err) + } + if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to create target file %s: %v", filePath, err) + } + } + + // Define changes with nested paths + changes := &ChangesInfo{ + Added: []string{"level1/new_file.txt", "features/feature1/config.json"}, + Modified: []string{"level1/level2/modified.txt"}, + Deleted: []string{"old/deprecated.txt"}, + } + + // Apply update + err = manager.ApplyUpdate(sourceDir, targetDir, changes) + if err != nil { + t.Fatalf("ApplyUpdate() failed: %v", err) + } + + // Verify added files + for _, addedFile := range changes.Added { + targetFile := filepath.Join(targetDir, addedFile) + if _, err := os.Stat(targetFile); os.IsNotExist(err) { + t.Fatalf("Added file not found: %s", addedFile) + } + + // Verify content matches source + sourceFile := filepath.Join(sourceDir, addedFile) + sourceContent, _ := os.ReadFile(sourceFile) + targetContent, _ := os.ReadFile(targetFile) + + if string(sourceContent) != string(targetContent) { + t.Fatalf("Content mismatch for added file %s", addedFile) + } + } + + // Verify modified files + modifiedFile := filepath.Join(targetDir, "level1/level2/modified.txt") + content, err := os.ReadFile(modifiedFile) + if err != nil { + t.Fatalf("Failed to read modified file: %v", err) + } + if string(content) != "Modified content" { + t.Fatalf("Modified file content incorrect. Expected: 'Modified content', Got: '%s'", string(content)) + } + + // Verify deleted files + deletedFile := filepath.Join(targetDir, "old/deprecated.txt") + if _, err := os.Stat(deletedFile); !os.IsNotExist(err) { + t.Fatalf("Deleted file still exists: %s", deletedFile) + } +} + +func TestMarkFileForDeletion(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + // Create a test file + testFile := filepath.Join(tempDir, "test.exe") + if err := os.WriteFile(testFile, []byte("test executable"), 0755); err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + // Mark file for deletion + err = manager.markFileForDeletion(testFile) + if err != nil { + t.Fatalf("markFileForDeletion() failed: %v", err) + } + + // Verify marker file was created + markerFile := testFile + ".delete_on_restart" + if _, err := os.Stat(markerFile); os.IsNotExist(err) { + t.Fatalf("Marker file was not created: %s", markerFile) + } + + // Verify marker file contains correct path + content, err := os.ReadFile(markerFile) + if err != nil { + t.Fatalf("Failed to read marker file: %v", err) + } + + if string(content) != testFile { + t.Fatalf("Marker file content incorrect. Expected: %s, Got: %s", testFile, string(content)) + } +} + +func TestIsFileInUse(t *testing.T) { + // Test with nil error + if isFileInUse(nil) { + t.Error("isFileInUse(nil) should return false") + } + + // Test with regular error + regularErr := fmt.Errorf("regular error") + if isFileInUse(regularErr) { + t.Error("isFileInUse with regular error should return false") + } + + // Test with file in use error message + fileInUseErr := fmt.Errorf("file is being used by another process") + if !isFileInUse(fileInUseErr) { + t.Error("isFileInUse with 'being used by another process' should return true") + } + + // Test with access denied error message + accessDeniedErr := fmt.Errorf("access is denied") + if !isFileInUse(accessDeniedErr) { + t.Error("isFileInUse with 'access is denied' should return true") + } +} + +func TestExtractFileEdgeCases(t *testing.T) { + manager := NewManager() + + tempDir, err := manager.CreateTempDir() + if err != nil { + t.Fatalf("CreateTempDir() failed: %v", err) + } + defer manager.CleanupTempDir(tempDir) + + zipPath := filepath.Join(tempDir, "edge_cases.zip") + extractDir := filepath.Join(tempDir, "extract") + + // Create ZIP with edge cases + if err := createEdgeCaseZip(zipPath); err != nil { + t.Fatalf("Failed to create edge case ZIP: %v", err) + } + + // Extract ZIP + err = manager.ExtractZip(zipPath, extractDir) + if err != nil { + t.Fatalf("ExtractZip() failed: %v", err) + } + + // Verify files with special names were extracted + specialFiles := []string{ + "file with spaces.txt", + "file-with-dashes.txt", + "file_with_underscores.txt", + "UPPERCASE.TXT", + } + + for _, fileName := range specialFiles { + filePath := filepath.Join(extractDir, fileName) + if _, err := os.Stat(filePath); os.IsNotExist(err) { + t.Fatalf("Special file not extracted: %s", fileName) + } + } +} + +// Helper function to create a ZIP with nested directories +func createNestedZip(zipPath string) error { + file, err := os.Create(zipPath) + if err != nil { + return err + } + defer file.Close() + + writer := zip.NewWriter(file) + defer writer.Close() + + // Create nested structure + files := map[string]string{ + "level1/file1.txt": "Content of file1.txt", + "level1/level2/file2.txt": "Content of file2.txt", + "level1/level2/level3/file3.txt": "Content of file3.txt", + } + + for filePath, content := range files { + f, err := writer.Create(filePath) + if err != nil { + return err + } + _, err = f.Write([]byte(content)) + if err != nil { + return err + } + } + + return nil +} + +// Helper function to create a ZIP with edge case file names +func createEdgeCaseZip(zipPath string) error { + file, err := os.Create(zipPath) + if err != nil { + return err + } + defer file.Close() + + writer := zip.NewWriter(file) + defer writer.Close() + + // Create files with special names + files := []string{ + "file with spaces.txt", + "file-with-dashes.txt", + "file_with_underscores.txt", + "UPPERCASE.TXT", + } + + for _, fileName := range files { + f, err := writer.Create(fileName) + if err != nil { + return err + } + _, err = f.Write([]byte(fmt.Sprintf("Content of %s", fileName))) + if err != nil { + return err + } + } + + return nil +} \ No newline at end of file diff --git a/Go_Updater/integration_test.go b/Go_Updater/integration_test.go new file mode 100644 index 0000000..9117f08 --- /dev/null +++ b/Go_Updater/integration_test.go @@ -0,0 +1,12 @@ +package main + +import ( + "testing" +) + +// Integration tests will be implemented here +// This file is currently a placeholder + +func TestIntegrationPlaceholder(t *testing.T) { + t.Skip("Integration tests not yet implemented") +} \ No newline at end of file diff --git a/Go_Updater/logger/logger.go b/Go_Updater/logger/logger.go new file mode 100644 index 0000000..7307154 --- /dev/null +++ b/Go_Updater/logger/logger.go @@ -0,0 +1,438 @@ +package logger + +import ( + "fmt" + "io" + "log" + "os" + "path/filepath" + "sync" + "time" +) + +// LogLevel 日志级别 +type LogLevel int + +const ( + DEBUG LogLevel = iota + INFO + WARN + ERROR +) + +// String 返回日志级别的字符串表示 +func (l LogLevel) String() string { + switch l { + case DEBUG: + return "DEBUG" + case INFO: + return "INFO" + case WARN: + return "WARN" + case ERROR: + return "ERROR" + default: + return "UNKNOWN" + } +} + +// Logger 日志记录器接口 +type Logger interface { + Debug(msg string, fields ...interface{}) + Info(msg string, fields ...interface{}) + Warn(msg string, fields ...interface{}) + Error(msg string, fields ...interface{}) + SetLevel(level LogLevel) + Close() error +} + +// FileLogger 文件日志记录器 +type FileLogger struct { + mu sync.RWMutex + file *os.File + logger *log.Logger + level LogLevel + maxSize int64 // 最大文件大小(字节) + maxBackups int // 最大备份文件数 + logDir string // 日志目录 + filename string // 日志文件名 + currentSize int64 // 当前文件大小 +} + +// LoggerConfig 日志配置 +type LoggerConfig struct { + Level LogLevel + MaxSize int64 // 最大文件大小(字节),默认10MB + MaxBackups int // 最大备份文件数,默认5 + LogDir string // 日志目录,默认%APPDATA%/LightweightUpdater/logs + Filename string // 日志文件名,默认updater.log +} + +// DefaultLoggerConfig 默认日志配置 +func DefaultLoggerConfig() *LoggerConfig { + // 获取当前可执行文件目录 + exePath, err := os.Executable() + var logDir string + if err != nil { + logDir = "debug" + } else { + exeDir := filepath.Dir(exePath) + logDir = filepath.Join(exeDir, "debug") + } + + return &LoggerConfig{ + Level: INFO, + MaxSize: 10 * 1024 * 1024, // 10MB + MaxBackups: 5, + LogDir: logDir, + Filename: "AUTO_MAA_Go_Updater.log", + } +} + +// NewFileLogger 创建新的文件日志记录器 +func NewFileLogger(config *LoggerConfig) (*FileLogger, error) { + if config == nil { + config = DefaultLoggerConfig() + } + + // 创建日志目录 + if err := os.MkdirAll(config.LogDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create log directory: %w", err) + } + + logPath := filepath.Join(config.LogDir, config.Filename) + + // 打开或创建日志文件 + file, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil, fmt.Errorf("failed to open log file: %w", err) + } + + // 获取当前文件大小 + stat, err := file.Stat() + if err != nil { + file.Close() + return nil, fmt.Errorf("failed to get file stats: %w", err) + } + + logger := &FileLogger{ + file: file, + logger: log.New(file, "", 0), // 我们自己处理格式 + level: config.Level, + maxSize: config.MaxSize, + maxBackups: config.MaxBackups, + logDir: config.LogDir, + filename: config.Filename, + currentSize: stat.Size(), + } + + return logger, nil +} + +// formatMessage 格式化日志消息 +func (fl *FileLogger) formatMessage(level LogLevel, msg string, fields ...interface{}) string { + timestamp := time.Now().Format("2006-01-02 15:04:05.000") + + if len(fields) > 0 { + msg = fmt.Sprintf(msg, fields...) + } + + return fmt.Sprintf("[%s] %s %s\n", timestamp, level.String(), msg) +} + +// writeLog 写入日志 +func (fl *FileLogger) writeLog(level LogLevel, msg string, fields ...interface{}) { + fl.mu.Lock() + defer fl.mu.Unlock() + + // 检查日志级别 + if level < fl.level { + return + } + + formattedMsg := fl.formatMessage(level, msg, fields...) + + // 检查是否需要轮转 + if fl.currentSize+int64(len(formattedMsg)) > fl.maxSize { + if err := fl.rotate(); err != nil { + // 轮转失败,尝试写入stderr + fmt.Fprintf(os.Stderr, "Failed to rotate log: %v\n", err) + } + } + + // 写入日志 + n, err := fl.file.WriteString(formattedMsg) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to write log: %v\n", err) + return + } + + fl.currentSize += int64(n) + fl.file.Sync() // 确保写入磁盘 +} + +// rotate 轮转日志文件 +func (fl *FileLogger) rotate() error { + // 关闭当前文件 + if err := fl.file.Close(); err != nil { + return fmt.Errorf("failed to close current log file: %w", err) + } + + // 轮转备份文件 + if err := fl.rotateBackups(); err != nil { + return fmt.Errorf("failed to rotate backups: %w", err) + } + + // 创建新的日志文件 + logPath := filepath.Join(fl.logDir, fl.filename) + file, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("failed to create new log file: %w", err) + } + + fl.file = file + fl.logger.SetOutput(file) + fl.currentSize = 0 + + return nil +} + +// rotateBackups 轮转备份文件 +func (fl *FileLogger) rotateBackups() error { + basePath := filepath.Join(fl.logDir, fl.filename) + + // 删除最老的备份文件 + if fl.maxBackups > 0 { + oldestBackup := fmt.Sprintf("%s.%d", basePath, fl.maxBackups) + os.Remove(oldestBackup) // 忽略错误,文件可能不存在 + } + + // 重命名现有备份文件 + for i := fl.maxBackups - 1; i > 0; i-- { + oldName := fmt.Sprintf("%s.%d", basePath, i) + newName := fmt.Sprintf("%s.%d", basePath, i+1) + os.Rename(oldName, newName) // 忽略错误,文件可能不存在 + } + + // 将当前日志文件重命名为第一个备份 + if fl.maxBackups > 0 { + backupName := fmt.Sprintf("%s.1", basePath) + return os.Rename(basePath, backupName) + } + + return nil +} + +// Debug 记录调试级别日志 +func (fl *FileLogger) Debug(msg string, fields ...interface{}) { + fl.writeLog(DEBUG, msg, fields...) +} + +// Info 记录信息级别日志 +func (fl *FileLogger) Info(msg string, fields ...interface{}) { + fl.writeLog(INFO, msg, fields...) +} + +// Warn 记录警告级别日志 +func (fl *FileLogger) Warn(msg string, fields ...interface{}) { + fl.writeLog(WARN, msg, fields...) +} + +// Error 记录错误级别日志 +func (fl *FileLogger) Error(msg string, fields ...interface{}) { + fl.writeLog(ERROR, msg, fields...) +} + +// SetLevel 设置日志级别 +func (fl *FileLogger) SetLevel(level LogLevel) { + fl.mu.Lock() + defer fl.mu.Unlock() + fl.level = level +} + +// Close 关闭日志记录器 +func (fl *FileLogger) Close() error { + fl.mu.Lock() + defer fl.mu.Unlock() + + if fl.file != nil { + return fl.file.Close() + } + return nil +} + +// MultiLogger 多输出日志记录器 +type MultiLogger struct { + loggers []Logger + level LogLevel +} + +// NewMultiLogger 创建多输出日志记录器 +func NewMultiLogger(loggers ...Logger) *MultiLogger { + return &MultiLogger{ + loggers: loggers, + level: INFO, + } +} + +// Debug 记录调试级别日志 +func (ml *MultiLogger) Debug(msg string, fields ...interface{}) { + for _, logger := range ml.loggers { + logger.Debug(msg, fields...) + } +} + +// Info 记录信息级别日志 +func (ml *MultiLogger) Info(msg string, fields ...interface{}) { + for _, logger := range ml.loggers { + logger.Info(msg, fields...) + } +} + +// Warn 记录警告级别日志 +func (ml *MultiLogger) Warn(msg string, fields ...interface{}) { + for _, logger := range ml.loggers { + logger.Warn(msg, fields...) + } +} + +// Error 记录错误级别日志 +func (ml *MultiLogger) Error(msg string, fields ...interface{}) { + for _, logger := range ml.loggers { + logger.Error(msg, fields...) + } +} + +// SetLevel 设置日志级别 +func (ml *MultiLogger) SetLevel(level LogLevel) { + ml.level = level + for _, logger := range ml.loggers { + logger.SetLevel(level) + } +} + +// Close 关闭所有日志记录器 +func (ml *MultiLogger) Close() error { + var lastErr error + for _, logger := range ml.loggers { + if err := logger.Close(); err != nil { + lastErr = err + } + } + return lastErr +} + +// ConsoleLogger 控制台日志记录器 +type ConsoleLogger struct { + writer io.Writer + level LogLevel +} + +// NewConsoleLogger 创建控制台日志记录器 +func NewConsoleLogger(writer io.Writer) *ConsoleLogger { + if writer == nil { + writer = os.Stdout + } + return &ConsoleLogger{ + writer: writer, + level: INFO, + } +} + +// formatMessage 格式化控制台日志消息 +func (cl *ConsoleLogger) formatMessage(level LogLevel, msg string, fields ...interface{}) string { + timestamp := time.Now().Format("15:04:05") + + if len(fields) > 0 { + msg = fmt.Sprintf(msg, fields...) + } + + return fmt.Sprintf("[%s] %s %s\n", timestamp, level.String(), msg) +} + +// writeLog 写入控制台日志 +func (cl *ConsoleLogger) writeLog(level LogLevel, msg string, fields ...interface{}) { + if level < cl.level { + return + } + + formattedMsg := cl.formatMessage(level, msg, fields...) + fmt.Fprint(cl.writer, formattedMsg) +} + +// Debug 记录调试级别日志 +func (cl *ConsoleLogger) Debug(msg string, fields ...interface{}) { + cl.writeLog(DEBUG, msg, fields...) +} + +// Info 记录信息级别日志 +func (cl *ConsoleLogger) Info(msg string, fields ...interface{}) { + cl.writeLog(INFO, msg, fields...) +} + +// Warn 记录警告级别日志 +func (cl *ConsoleLogger) Warn(msg string, fields ...interface{}) { + cl.writeLog(WARN, msg, fields...) +} + +// Error 记录错误级别日志 +func (cl *ConsoleLogger) Error(msg string, fields ...interface{}) { + cl.writeLog(ERROR, msg, fields...) +} + +// SetLevel 设置日志级别 +func (cl *ConsoleLogger) SetLevel(level LogLevel) { + cl.level = level +} + +// Close 关闭控制台日志记录器(无操作) +func (cl *ConsoleLogger) Close() error { + return nil +} + +// 全局日志记录器实例 +var ( + defaultLogger Logger + once sync.Once +) + +// GetDefaultLogger 获取默认日志记录器 +func GetDefaultLogger() Logger { + once.Do(func() { + fileLogger, err := NewFileLogger(DefaultLoggerConfig()) + if err != nil { + // 如果文件日志创建失败,使用控制台日志 + defaultLogger = NewConsoleLogger(os.Stderr) + } else { + // 同时输出到文件和控制台 + consoleLogger := NewConsoleLogger(os.Stdout) + defaultLogger = NewMultiLogger(fileLogger, consoleLogger) + } + }) + return defaultLogger +} + +// 便捷函数 +func Debug(msg string, fields ...interface{}) { + GetDefaultLogger().Debug(msg, fields...) +} + +func Info(msg string, fields ...interface{}) { + GetDefaultLogger().Info(msg, fields...) +} + +func Warn(msg string, fields ...interface{}) { + GetDefaultLogger().Warn(msg, fields...) +} + +func Error(msg string, fields ...interface{}) { + GetDefaultLogger().Error(msg, fields...) +} + +func SetLevel(level LogLevel) { + GetDefaultLogger().SetLevel(level) +} + +func Close() error { + return GetDefaultLogger().Close() +} \ No newline at end of file diff --git a/Go_Updater/logger/logger_test.go b/Go_Updater/logger/logger_test.go new file mode 100644 index 0000000..be99ae9 --- /dev/null +++ b/Go_Updater/logger/logger_test.go @@ -0,0 +1,300 @@ +package logger + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLogLevel_String(t *testing.T) { + tests := []struct { + level LogLevel + expected string + }{ + {DEBUG, "DEBUG"}, + {INFO, "INFO"}, + {WARN, "WARN"}, + {ERROR, "ERROR"}, + {LogLevel(999), "UNKNOWN"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + if got := tt.level.String(); got != tt.expected { + t.Errorf("LogLevel.String() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestDefaultLoggerConfig(t *testing.T) { + config := DefaultLoggerConfig() + + if config.Level != INFO { + t.Errorf("Expected default level INFO, got %v", config.Level) + } + if config.MaxSize != 10*1024*1024 { + t.Errorf("Expected default max size 10MB, got %v", config.MaxSize) + } + if config.MaxBackups != 5 { + t.Errorf("Expected default max backups 5, got %v", config.MaxBackups) + } + if config.Filename != "updater.log" { + t.Errorf("Expected default filename 'updater.log', got %v", config.Filename) + } +} + +func TestConsoleLogger(t *testing.T) { + var buf bytes.Buffer + logger := NewConsoleLogger(&buf) + + t.Run("log levels", func(t *testing.T) { + logger.SetLevel(DEBUG) + + logger.Debug("debug message") + logger.Info("info message") + logger.Warn("warn message") + logger.Error("error message") + + output := buf.String() + if !strings.Contains(output, "DEBUG debug message") { + t.Error("Expected debug message in output") + } + if !strings.Contains(output, "INFO info message") { + t.Error("Expected info message in output") + } + if !strings.Contains(output, "WARN warn message") { + t.Error("Expected warn message in output") + } + if !strings.Contains(output, "ERROR error message") { + t.Error("Expected error message in output") + } + }) + + t.Run("log level filtering", func(t *testing.T) { + buf.Reset() + logger.SetLevel(WARN) + + logger.Debug("debug message") + logger.Info("info message") + logger.Warn("warn message") + logger.Error("error message") + + output := buf.String() + if strings.Contains(output, "DEBUG") { + t.Error("Debug message should be filtered out") + } + if strings.Contains(output, "INFO") { + t.Error("Info message should be filtered out") + } + if !strings.Contains(output, "WARN warn message") { + t.Error("Expected warn message in output") + } + if !strings.Contains(output, "ERROR error message") { + t.Error("Expected error message in output") + } + }) + + t.Run("formatted messages", func(t *testing.T) { + buf.Reset() + logger.SetLevel(DEBUG) + + logger.Info("formatted message: %s %d", "test", 42) + + output := buf.String() + if !strings.Contains(output, "formatted message: test 42") { + t.Error("Expected formatted message in output") + } + }) +} + +func TestFileLogger(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + config := &LoggerConfig{ + Level: DEBUG, + MaxSize: 1024, // 1KB for testing rotation + MaxBackups: 3, + LogDir: tempDir, + Filename: "test.log", + } + + logger, err := NewFileLogger(config) + if err != nil { + t.Fatalf("Failed to create file logger: %v", err) + } + defer logger.Close() + + t.Run("basic logging", func(t *testing.T) { + logger.Info("test message") + logger.Error("error message with %s", "formatting") + + // 读取日志文件 + logPath := filepath.Join(tempDir, "test.log") + content, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + output := string(content) + if !strings.Contains(output, "INFO test message") { + t.Error("Expected info message in log file") + } + if !strings.Contains(output, "ERROR error message with formatting") { + t.Error("Expected formatted error message in log file") + } + }) + + t.Run("log rotation", func(t *testing.T) { + // 写入大量数据触发轮转 + longMessage := strings.Repeat("a", 200) + for i := 0; i < 10; i++ { + logger.Info("Long message %d: %s", i, longMessage) + } + + // 检查是否创建了备份文件 + logPath := filepath.Join(tempDir, "test.log") + backupPath := filepath.Join(tempDir, "test.log.1") + + if _, err := os.Stat(logPath); os.IsNotExist(err) { + t.Error("Main log file should exist") + } + if _, err := os.Stat(backupPath); os.IsNotExist(err) { + t.Error("Backup log file should exist after rotation") + } + }) +} + +func TestMultiLogger(t *testing.T) { + var buf1, buf2 bytes.Buffer + logger1 := NewConsoleLogger(&buf1) + logger2 := NewConsoleLogger(&buf2) + + multiLogger := NewMultiLogger(logger1, logger2) + multiLogger.SetLevel(INFO) + + multiLogger.Info("test message") + multiLogger.Error("error message") + + // 检查两个logger都收到了消息 + output1 := buf1.String() + output2 := buf2.String() + + if !strings.Contains(output1, "INFO test message") { + t.Error("Expected info message in first logger") + } + if !strings.Contains(output1, "ERROR error message") { + t.Error("Expected error message in first logger") + } + if !strings.Contains(output2, "INFO test message") { + t.Error("Expected info message in second logger") + } + if !strings.Contains(output2, "ERROR error message") { + t.Error("Expected error message in second logger") + } +} + +func TestFileLoggerRotation(t *testing.T) { + tempDir := t.TempDir() + + config := &LoggerConfig{ + Level: DEBUG, + MaxSize: 100, // Very small for testing + MaxBackups: 2, + LogDir: tempDir, + Filename: "rotation_test.log", + } + + logger, err := NewFileLogger(config) + if err != nil { + t.Fatalf("Failed to create file logger: %v", err) + } + defer logger.Close() + + // 写入足够的数据触发多次轮转 + for i := 0; i < 20; i++ { + logger.Info("Message %d: %s", i, strings.Repeat("x", 50)) + } + + // 检查文件存在性 + logPath := filepath.Join(tempDir, "rotation_test.log") + backup1Path := filepath.Join(tempDir, "rotation_test.log.1") + backup2Path := filepath.Join(tempDir, "rotation_test.log.2") + backup3Path := filepath.Join(tempDir, "rotation_test.log.3") + + if _, err := os.Stat(logPath); os.IsNotExist(err) { + t.Error("Main log file should exist") + } + if _, err := os.Stat(backup1Path); os.IsNotExist(err) { + t.Error("First backup should exist") + } + if _, err := os.Stat(backup2Path); os.IsNotExist(err) { + t.Error("Second backup should exist") + } + // 第三个备份不应该存在(MaxBackups=2) + if _, err := os.Stat(backup3Path); !os.IsNotExist(err) { + t.Error("Third backup should not exist (exceeds MaxBackups)") + } +} + +func TestGlobalLoggerFunctions(t *testing.T) { + // 这个测试比较简单,主要确保全局函数不会panic + Debug("debug message") + Info("info message") + Warn("warn message") + Error("error message") + + SetLevel(ERROR) + + // 这些调用不应该panic + Debug("filtered debug") + Info("filtered info") + Error("visible error") +} + +func TestFileLoggerErrorHandling(t *testing.T) { + t.Run("invalid directory", func(t *testing.T) { + // 使用一个真正无效的路径 + config := &LoggerConfig{ + Level: INFO, + MaxSize: 1024, + MaxBackups: 3, + LogDir: string([]byte{0}), // 无效的路径字符 + Filename: "test.log", + } + + _, err := NewFileLogger(config) + if err == nil { + t.Error("Expected error when creating logger with invalid directory") + } + }) +} + +func TestLoggerFormatting(t *testing.T) { + var buf bytes.Buffer + logger := NewConsoleLogger(&buf) + logger.SetLevel(DEBUG) + + // 测试时间戳格式 + logger.Info("test message") + + output := buf.String() + lines := strings.Split(strings.TrimSpace(output), "\n") + if len(lines) == 0 { + t.Fatal("Expected at least one log line") + } + + // 检查格式:[HH:MM:SS] LEVEL message + line := lines[0] + if !strings.Contains(line, "INFO test message") { + t.Errorf("Expected 'INFO test message' in output, got: %s", line) + } + + // 检查时间戳格式(简单检查) + if !strings.HasPrefix(line, "[") { + t.Error("Expected log line to start with timestamp in brackets") + } +} \ No newline at end of file diff --git a/Go_Updater/main.go b/Go_Updater/main.go new file mode 100644 index 0000000..473d1e8 --- /dev/null +++ b/Go_Updater/main.go @@ -0,0 +1,1046 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + + "lightweight-updater/api" + "lightweight-updater/config" + "lightweight-updater/download" + "lightweight-updater/errors" + "lightweight-updater/install" + "lightweight-updater/logger" + appversion "lightweight-updater/version" +) + +// UpdateState represents the current state of the update process +type UpdateState int + +const ( + StateIdle UpdateState = iota + StateChecking + StateUpdateAvailable + StateDownloading + StateInstalling + StateCompleted + StateError +) + +// String returns the string representation of the update state +func (s UpdateState) String() string { + switch s { + case StateIdle: + return "Idle" + case StateChecking: + return "Checking" + case StateUpdateAvailable: + return "UpdateAvailable" + case StateDownloading: + return "Downloading" + case StateInstalling: + return "Installing" + case StateCompleted: + return "Completed" + case StateError: + return "Error" + default: + return "Unknown" + } +} + +// GUIManager interface for optional GUI functionality +type GUIManager interface { + ShowMainWindow() + UpdateStatus(status int, message string) + ShowProgress(percentage float64) + ShowError(errorMsg string) + Close() +} + +// UpdateInfo contains information about an available update +type UpdateInfo struct { + CurrentVersion string + NewVersion string + DownloadURL string + ReleaseNotes string + IsAvailable bool +} + +// Application represents the main application instance +type Application struct { + config *config.Config + configManager config.ConfigManager + apiClient api.MirrorClient + downloadManager download.DownloadManager + installManager install.InstallManager + guiManager GUIManager + logger logger.Logger + errorHandler errors.ErrorHandler + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + // Update flow state + currentState UpdateState + stateMutex sync.RWMutex + updateInfo *UpdateInfo + userConfirmed chan bool +} + +// Command line flags +var ( + configPath = flag.String("config", "", "Path to configuration file") + logLevel = flag.String("log-level", "info", "Log level (debug, info, warn, error)") + noGUI = flag.Bool("no-gui", false, "Run without GUI (command line mode)") + version = flag.Bool("version", false, "Show version information") + help = flag.Bool("help", false, "Show help information") + channel = flag.String("channel", "", "Update channel (stable or beta)") + currentVersion = flag.String("current-version", "", "Current version to check against") + cdk = flag.String("cdk", "", "CDK for MirrorChyan download") +) + +// Version information is now handled by the version package + +func main() { + // Parse command line arguments + flag.Parse() + + // Show version information + if *version { + showVersion() + return + } + + // Show help information + if *help { + showHelp() + return + } + + // Check for single instance + if err := ensureSingleInstance(); err != nil { + fmt.Fprintf(os.Stderr, "Another instance is already running: %v\n", err) + os.Exit(1) + } + + // Initialize application + app, err := initializeApplication() + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to initialize application: %v\n", err) + os.Exit(1) + } + defer app.cleanup() + + // Handle cleanup on process marked files on startup + if err := app.handleStartupCleanup(); err != nil { + app.logger.Warn("Failed to cleanup marked files: %v", err) + } + + // Setup signal handling + app.setupSignalHandling() + + // Start the application + if err := app.run(); err != nil { + app.logger.Error("Application error: %v", err) + os.Exit(1) + } +} + +// initializeApplication initializes all application components +func initializeApplication() (*Application, error) { + // Create context for graceful shutdown + ctx, cancel := context.WithCancel(context.Background()) + + // Initialize logger first + loggerConfig := logger.DefaultLoggerConfig() + + // Set log level from command line + switch *logLevel { + case "debug": + loggerConfig.Level = logger.DEBUG + case "info": + loggerConfig.Level = logger.INFO + case "warn": + loggerConfig.Level = logger.WARN + case "error": + loggerConfig.Level = logger.ERROR + } + + var appLogger logger.Logger + fileLogger, err := logger.NewFileLogger(loggerConfig) + if err != nil { + // Fallback to console logger + appLogger = logger.NewConsoleLogger(os.Stdout) + } else { + appLogger = fileLogger + } + + appLogger.Info("Initializing AUTO_MAA_Go_Updater v%s", appversion.Version) + + // Initialize configuration manager + var configManager config.ConfigManager + if *configPath != "" { + // Custom config path not implemented in the config package yet + // For now, use default manager + configManager = config.NewConfigManager() + appLogger.Warn("Custom config path not fully supported yet, using default") + } else { + configManager = config.NewConfigManager() + } + + // Load configuration + cfg, err := configManager.Load() + if err != nil { + appLogger.Error("Failed to load configuration: %v", err) + return nil, fmt.Errorf("failed to load configuration: %w", err) + } + + appLogger.Info("Configuration loaded successfully") + + // Initialize API client + apiClient := api.NewClient() + + // Initialize download manager + downloadManager := download.NewManager() + + // Initialize install manager + installManager := install.NewManager() + + // Initialize error handler + errorHandler := errors.NewDefaultErrorHandler() + + // Initialize GUI manager (if not in no-gui mode) + var guiManager GUIManager + if !*noGUI { + // GUI will be implemented when GUI dependencies are available + appLogger.Info("GUI mode requested but not available in this build") + guiManager = nil + } else { + appLogger.Info("Running in no-GUI mode") + } + + app := &Application{ + config: cfg, + configManager: configManager, + apiClient: apiClient, + downloadManager: downloadManager, + installManager: installManager, + guiManager: guiManager, + logger: appLogger, + errorHandler: errorHandler, + ctx: ctx, + cancel: cancel, + currentState: StateIdle, + userConfirmed: make(chan bool, 1), + } + + appLogger.Info("Application initialized successfully") + return app, nil +} + +// run starts the main application logic +func (app *Application) run() error { + app.logger.Info("Starting application") + + if app.guiManager != nil { + // Run with GUI + return app.runWithGUI() + } else { + // Run in command line mode + return app.runCommandLine() + } +} + +// runWithGUI runs the application with GUI +func (app *Application) runWithGUI() error { + app.logger.Info("Starting GUI mode") + + // Set up GUI callbacks + app.setupGUICallbacks() + + // Show main window (this will block until window is closed) + app.guiManager.ShowMainWindow() + + return nil +} + +// runCommandLine runs the application in command line mode +func (app *Application) runCommandLine() error { + app.logger.Info("Starting command line mode") + + // Start the complete update flow + return app.executeUpdateFlow() +} + +// setupGUICallbacks sets up callbacks for GUI interactions +func (app *Application) setupGUICallbacks() { + if app.guiManager == nil { + return + } + + // GUI callbacks will be implemented when GUI is available + app.logger.Info("GUI callbacks setup requested but GUI not available") + + // For now, we'll set up basic interaction handling + // The actual GUI integration will be completed when GUI dependencies are resolved +} + +// handleStartupCleanup handles cleanup of files marked for deletion on startup +func (app *Application) handleStartupCleanup() error { + app.logger.Info("Performing startup cleanup") + + // Get current executable directory + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + + exeDir := filepath.Dir(exePath) + + // Delete files marked for deletion + if installMgr, ok := app.installManager.(*install.Manager); ok { + if err := installMgr.DeleteMarkedFiles(exeDir); err != nil { + return fmt.Errorf("failed to delete marked files: %w", err) + } + } + + app.logger.Info("Startup cleanup completed") + return nil +} + +// setupSignalHandling sets up graceful shutdown on system signals +func (app *Application) setupSignalHandling() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigChan + app.logger.Info("Received signal: %v", sig) + app.logger.Info("Initiating graceful shutdown...") + app.cancel() + }() +} + +// cleanup performs application cleanup +func (app *Application) cleanup() { + app.logger.Info("Cleaning up application resources") + + // Cancel context to stop all operations + app.cancel() + + // Wait for all goroutines to finish + app.wg.Wait() + + // Cleanup install manager temporary directories + if installMgr, ok := app.installManager.(*install.Manager); ok { + if err := installMgr.CleanupAllTempDirs(); err != nil { + app.logger.Error("Failed to cleanup temp directories: %v", err) + } + } + + app.logger.Info("Application cleanup completed") + + // Close logger last + if err := app.logger.Close(); err != nil { + fmt.Fprintf(os.Stderr, "Failed to close logger: %v\n", err) + } +} + +// ensureSingleInstance ensures only one instance of the application is running +func ensureSingleInstance() error { + // Create a lock file in temp directory + tempDir := os.TempDir() + lockFile := filepath.Join(tempDir, "lightweight-updater.lock") + + // Try to create the lock file exclusively + file, err := os.OpenFile(lockFile, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644) + if err != nil { + if os.IsExist(err) { + // Check if the process is still running + if isProcessRunning(lockFile) { + return fmt.Errorf("another instance is already running") + } + // Remove stale lock file and try again + os.Remove(lockFile) + return ensureSingleInstance() + } + return fmt.Errorf("failed to create lock file: %w", err) + } + + // Write current process ID to lock file + fmt.Fprintf(file, "%d", os.Getpid()) + file.Close() + + // Remove lock file on exit + go func() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + os.Remove(lockFile) + }() + + return nil +} + +// isProcessRunning checks if the process in the lock file is still running +func isProcessRunning(lockFile string) bool { + data, err := os.ReadFile(lockFile) + if err != nil { + return false + } + + var pid int + if _, err := fmt.Sscanf(string(data), "%d", &pid); err != nil { + return false + } + + // Check if process exists (Windows specific) + process, err := os.FindProcess(pid) + if err != nil { + return false + } + + // On Windows, FindProcess always succeeds, so we need to check differently + // Try to send signal 0 to check if process exists + err = process.Signal(syscall.Signal(0)) + return err == nil +} + +// showVersion displays version information +func showVersion() { + fmt.Printf("AUTO_MAA_Go_Updater\n") + fmt.Printf("Version: %s\n", appversion.Version) + fmt.Printf("Build Time: %s\n", appversion.BuildTime) + fmt.Printf("Git Commit: %s\n", appversion.GitCommit) +} + +// showHelp displays help information +func showHelp() { + fmt.Printf("AUTO_MAA_Go_Updater - AUTO_MAA 轻量级更新器\n\n") + fmt.Printf("Usage: %s [options]\n\n", os.Args[0]) + fmt.Printf("Options:\n") + flag.PrintDefaults() + fmt.Printf("\nExamples:\n") + fmt.Printf(" %s # Run with GUI\n", os.Args[0]) + fmt.Printf(" %s -no-gui # Run in command line mode\n", os.Args[0]) + fmt.Printf(" %s -log-level debug # Run with debug logging\n", os.Args[0]) + fmt.Printf(" %s -version # Show version information\n", os.Args[0]) +} + +// executeUpdateFlow executes the complete update flow with state machine management +func (app *Application) executeUpdateFlow() error { + app.logger.Info("Starting update flow execution") + + // Execute the state machine + for { + select { + case <-app.ctx.Done(): + app.logger.Info("Update flow cancelled") + return app.ctx.Err() + default: + } + + // Get current state + state := app.getCurrentState() + app.logger.Debug("Current state: %s", state.String()) + + // Execute state logic + nextState, err := app.executeState(state) + if err != nil { + app.logger.Error("State execution failed: %v", err) + app.setState(StateError) + return err + } + + // Check if we're done + if nextState == StateCompleted || nextState == StateError { + app.setState(nextState) + break + } + + // Transition to next state + app.setState(nextState) + } + + finalState := app.getCurrentState() + app.logger.Info("Update flow completed with state: %s", finalState.String()) + + if finalState == StateError { + return fmt.Errorf("update flow failed") + } + + return nil +} + +// executeState executes the logic for the current state and returns the next state +func (app *Application) executeState(state UpdateState) (UpdateState, error) { + switch state { + case StateIdle: + return app.executeIdleState() + case StateChecking: + return app.executeCheckingState() + case StateUpdateAvailable: + return app.executeUpdateAvailableState() + case StateDownloading: + return app.executeDownloadingState() + case StateInstalling: + return app.executeInstallingState() + case StateCompleted: + return StateCompleted, nil + case StateError: + return StateError, nil + default: + return StateError, fmt.Errorf("unknown state: %s", state.String()) + } +} + +// executeIdleState handles the idle state +func (app *Application) executeIdleState() (UpdateState, error) { + app.logger.Info("Starting update check...") + fmt.Println("正在检查更新...") + return StateChecking, nil +} + +// executeCheckingState handles the checking state +func (app *Application) executeCheckingState() (UpdateState, error) { + app.logger.Info("Checking for updates") + + // Determine version and channel to use + var currentVer, updateChannel, cdkToUse string + var err error + + // Priority: command line args > version file > config + if *currentVersion != "" { + currentVer = *currentVersion + app.logger.Info("Using current version from command line: %s", currentVer) + } else { + // Try to load version from resources/version.json + versionManager := appversion.NewVersionManager() + versionInfo, err := versionManager.LoadVersionFromFile() + if err != nil { + app.logger.Warn("Failed to load version from file: %v, using config version", err) + currentVer = app.config.CurrentVersion + } else { + currentVer = versionInfo.MainVersion + app.logger.Info("Using current version from version file: %s", currentVer) + } + } + + // Determine channel + if *channel != "" { + updateChannel = *channel + app.logger.Info("Using channel from command line: %s", updateChannel) + } else { + // Try to load channel from config.json + updateChannel = app.loadChannelFromConfig() + app.logger.Info("Using channel from config: %s", updateChannel) + } + + // Determine CDK to use + if *cdk != "" { + cdkToUse = *cdk + app.logger.Info("Using CDK from command line") + } else { + // Get CDK from config + cdkToUse, err = app.config.GetCDK() + if err != nil { + app.logger.Warn("Failed to get CDK from config: %v", err) + cdkToUse = "" // Continue without CDK + } + } + + // Prepare API parameters + params := api.UpdateCheckParams{ + ResourceID: "AUTO_MAA", // Fixed resource ID for AUTO_MAA + CurrentVersion: currentVer, + Channel: updateChannel, + CDK: cdkToUse, + UserAgent: app.config.UserAgent, + } + + // Call MirrorChyan API to check for updates + response, err := app.apiClient.CheckUpdate(params) + if err != nil { + app.logger.Error("Failed to check for updates: %v", err) + fmt.Printf("检查更新失败: %v\n", err) + return StateError, fmt.Errorf("failed to check for updates: %w", err) + } + + // Check if update is available + isUpdateAvailable := app.apiClient.IsUpdateAvailable(response, currentVer) + + if !isUpdateAvailable { + app.logger.Info("No update available") + fmt.Println("当前已是最新版本") + return StateCompleted, nil + } + + // Determine download URL + var downloadURL string + if response.Data.URL != "" { + // Use CDK download URL from MirrorChyan + downloadURL = response.Data.URL + app.logger.Info("Using CDK download URL from MirrorChyan") + } else { + // Use official download site + downloadURL = app.apiClient.GetOfficialDownloadURL(response.Data.VersionName) + app.logger.Info("Using official download URL: %s", downloadURL) + } + + // Store update information + app.updateInfo = &UpdateInfo{ + CurrentVersion: currentVer, + NewVersion: response.Data.VersionName, + DownloadURL: downloadURL, + ReleaseNotes: response.Data.ReleaseNote, + IsAvailable: true, + } + + app.logger.Info("Update available: %s -> %s", currentVer, response.Data.VersionName) + fmt.Printf("发现新版本: %s -> %s\n", currentVer, response.Data.VersionName) + + // if response.Data.ReleaseNote != "" { + // fmt.Printf("更新内容: %s\n", response.Data.ReleaseNote) + // } + + return StateUpdateAvailable, nil +} + +// executeUpdateAvailableState handles the update available state +func (app *Application) executeUpdateAvailableState() (UpdateState, error) { + app.logger.Info("Update available, starting download automatically") + + // Automatically start download without user confirmation + fmt.Println("开始下载更新...") + return StateDownloading, nil +} + +// executeDownloadingState handles the downloading state +func (app *Application) executeDownloadingState() (UpdateState, error) { + app.logger.Info("Starting download") + + if app.updateInfo == nil || app.updateInfo.DownloadURL == "" { + return StateError, fmt.Errorf("no download URL available") + } + + // Get current executable directory + exePath, err := os.Executable() + if err != nil { + return StateError, fmt.Errorf("failed to get executable path: %w", err) + } + exeDir := filepath.Dir(exePath) + + // Create AUTOMAA_UPDATE_TEMP directory for download + tempDir := filepath.Join(exeDir, "AUTOMAA_UPDATE_TEMP") + if err := os.MkdirAll(tempDir, 0755); err != nil { + return StateError, fmt.Errorf("failed to create temp directory: %w", err) + } + + // Download file + downloadPath := filepath.Join(tempDir, "update.zip") + + fmt.Println("正在下载更新包...") + + // Create progress callback + progressCallback := func(progress download.DownloadProgress) { + if progress.TotalBytes > 0 { + fmt.Printf("\r下载进度: %.1f%% (%s/s)", + progress.Percentage, + app.formatBytes(progress.Speed)) + } + } + + // Download the update file + downloadErr := app.downloadManager.Download(app.updateInfo.DownloadURL, downloadPath, progressCallback) + + fmt.Println() // New line after progress + + if downloadErr != nil { + app.logger.Error("Download failed: %v", downloadErr) + fmt.Printf("下载失败: %v\n", downloadErr) + return StateError, fmt.Errorf("download failed: %w", downloadErr) + } + + app.logger.Info("Download completed successfully") + fmt.Println("下载完成") + + // Store download path for installation + app.updateInfo.DownloadURL = downloadPath + + return StateInstalling, nil +} + +// executeInstallingState handles the installing state +func (app *Application) executeInstallingState() (UpdateState, error) { + app.logger.Info("Starting installation") + fmt.Println("正在安装更新...") + + if app.updateInfo == nil || app.updateInfo.DownloadURL == "" { + return StateError, fmt.Errorf("no download file available") + } + + downloadPath := app.updateInfo.DownloadURL + + // Create temporary directory for extraction + tempDir, err := app.installManager.CreateTempDir() + if err != nil { + return StateError, fmt.Errorf("failed to create temp directory: %w", err) + } + + // Extract the downloaded zip file + app.logger.Info("Extracting update package") + if err := app.installManager.ExtractZip(downloadPath, tempDir); err != nil { + app.logger.Error("Failed to extract zip: %v", err) + return StateError, fmt.Errorf("failed to extract update package: %w", err) + } + + // Process changes.json if it exists (for future use) + changesPath := filepath.Join(tempDir, "changes.json") + _, err = app.installManager.ProcessChanges(changesPath) + if err != nil { + app.logger.Warn("Failed to process changes (not critical): %v", err) + // This is not critical for AUTO_MAA-Setup.exe installation + } + + // Get current executable directory + exePath, err := os.Executable() + if err != nil { + return StateError, fmt.Errorf("failed to get executable path: %w", err) + } + targetDir := filepath.Dir(exePath) + + // Handle running processes (but skip the updater itself) + updaterName := filepath.Base(exePath) + if err := app.handleRunningProcesses(targetDir, updaterName); err != nil { + app.logger.Warn("Failed to handle running processes: %v", err) + // Continue with installation, this is not critical + } + + // Look for AUTO_MAA-Setup.exe in the extracted files + setupExePath := filepath.Join(tempDir, "AUTO_MAA-Setup.exe") + if _, err := os.Stat(setupExePath); err != nil { + app.logger.Error("AUTO_MAA-Setup.exe not found in update package: %v", err) + return StateError, fmt.Errorf("AUTO_MAA-Setup.exe not found in update package: %w", err) + } + + // Run the setup executable + app.logger.Info("Running AUTO_MAA-Setup.exe") + fmt.Println("正在运行安装程序...") + + if err := app.runSetupExecutable(setupExePath); err != nil { + app.logger.Error("Failed to run setup executable: %v", err) + return StateError, fmt.Errorf("failed to run setup executable: %w", err) + } + + // Update the version.json file with new version + if err := app.updateVersionFile(app.updateInfo.NewVersion); err != nil { + app.logger.Warn("Failed to update version file: %v", err) + // This is not critical, continue + } + + // Clean up AUTOMAA_UPDATE_TEMP directory after installation + if err := os.RemoveAll(tempDir); err != nil { + app.logger.Warn("Failed to cleanup temp directory: %v", err) + // This is not critical, continue + } else { + app.logger.Info("Cleaned up temp directory: %s", tempDir) + } + + app.logger.Info("Installation completed successfully") + fmt.Println("安装完成") + fmt.Printf("已更新到版本: %s\n", app.updateInfo.NewVersion) + + return StateCompleted, nil +} + +// getCurrentState returns the current state thread-safely +func (app *Application) getCurrentState() UpdateState { + app.stateMutex.RLock() + defer app.stateMutex.RUnlock() + return app.currentState +} + +// setState sets the current state thread-safely +func (app *Application) setState(state UpdateState) { + app.stateMutex.Lock() + defer app.stateMutex.Unlock() + + app.logger.Debug("State transition: %s -> %s", app.currentState.String(), state.String()) + app.currentState = state + + // Update GUI if available + if app.guiManager != nil { + app.updateGUIStatus(state) + } +} + +// updateGUIStatus updates the GUI based on the current state +func (app *Application) updateGUIStatus(state UpdateState) { + if app.guiManager == nil { + return + } + + switch state { + case StateIdle: + app.guiManager.UpdateStatus(0, "准备检查更新...") + case StateChecking: + app.guiManager.UpdateStatus(1, "正在检查更新...") + case StateUpdateAvailable: + if app.updateInfo != nil { + message := fmt.Sprintf("发现新版本: %s", app.updateInfo.NewVersion) + app.guiManager.UpdateStatus(2, message) + } + case StateDownloading: + app.guiManager.UpdateStatus(3, "正在下载更新...") + case StateInstalling: + app.guiManager.UpdateStatus(4, "正在安装更新...") + case StateCompleted: + app.guiManager.UpdateStatus(5, "更新完成") + case StateError: + app.guiManager.UpdateStatus(6, "更新失败") + } +} + +// formatBytes formats bytes into human readable format +func (app *Application) formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +// handleUserInteraction handles user interaction for GUI mode +func (app *Application) handleUserInteraction(action string) { + switch action { + case "confirm_update": + select { + case app.userConfirmed <- true: + default: + } + case "cancel_update": + select { + case app.userConfirmed <- false: + default: + } + case "check_update": + // Start update flow in a goroutine + app.wg.Add(1) + go func() { + defer app.wg.Done() + if err := app.executeUpdateFlow(); err != nil { + app.logger.Error("Update flow failed: %v", err) + } + }() + } +} + +// updateVersionFile updates the target software's version.json file with the new version +func (app *Application) updateVersionFile(newVersion string) error { + // Get current executable directory (where the target software is located) + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + targetDir := filepath.Dir(exePath) + + // Path to the target software's version file + versionFilePath := filepath.Join(targetDir, "resources", "version.json") + + // Try to load existing version file + versionManager := appversion.NewVersionManager() + versionInfo, err := versionManager.LoadVersionFromFile() + if err != nil { + app.logger.Warn("Could not load existing version file, creating new one: %v", err) + // Create a basic version info structure + versionInfo = &appversion.VersionInfo{ + MainVersion: newVersion, + VersionInfo: make(map[string]map[string][]string), + } + } + + // Parse the new version to get the proper format + parsedVersion, err := appversion.ParseVersion(newVersion) + if err != nil { + // If we can't parse the version from API response, try to extract from display format + if strings.HasPrefix(newVersion, "v") { + // Convert "v4.4.1-beta3" to "4.4.1.3" format + versionStr := strings.TrimPrefix(newVersion, "v") + if strings.Contains(versionStr, "-beta") { + parts := strings.Split(versionStr, "-beta") + if len(parts) == 2 { + baseVersion := parts[0] + betaNum := parts[1] + versionInfo.MainVersion = fmt.Sprintf("%s.%s", baseVersion, betaNum) + } else { + versionInfo.MainVersion = versionStr + ".0" + } + } else { + versionInfo.MainVersion = versionStr + ".0" + } + } else { + versionInfo.MainVersion = newVersion + } + } else { + // Use the parsed version to create the proper format + versionInfo.MainVersion = parsedVersion.ToVersionString() + } + + // Create resources directory if it doesn't exist + resourcesDir := filepath.Join(targetDir, "resources") + if err := os.MkdirAll(resourcesDir, 0755); err != nil { + return fmt.Errorf("failed to create resources directory: %w", err) + } + + // Write updated version file + data, err := json.MarshalIndent(versionInfo, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal version info: %w", err) + } + + if err := os.WriteFile(versionFilePath, data, 0644); err != nil { + return fmt.Errorf("failed to write version file: %w", err) + } + + app.logger.Info("Updated version file: %s -> %s", versionFilePath, versionInfo.MainVersion) + return nil +} + +// handleRunningProcesses handles running processes but excludes the updater itself +func (app *Application) handleRunningProcesses(targetDir, updaterName string) error { + app.logger.Info("Handling running processes, excluding updater: %s", updaterName) + + // Get list of executable files in the target directory + files, err := os.ReadDir(targetDir) + if err != nil { + return fmt.Errorf("failed to read target directory: %w", err) + } + + for _, file := range files { + if file.IsDir() { + continue + } + + fileName := file.Name() + + // Skip the updater itself + if fileName == updaterName { + app.logger.Info("Skipping updater file: %s", fileName) + continue + } + + // Only handle .exe files + if !strings.HasSuffix(strings.ToLower(fileName), ".exe") { + continue + } + + // Handle this executable + if err := app.installManager.HandleRunningProcess(fileName); err != nil { + app.logger.Warn("Failed to handle running process %s: %v", fileName, err) + // Continue with other files, don't fail the entire process + } + } + + return nil +} + +// runSetupExecutable runs the setup executable with proper parameters +func (app *Application) runSetupExecutable(setupExePath string) error { + app.logger.Info("Executing setup file: %s", setupExePath) + + // Get current executable directory as installation directory + exePath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + installDir := filepath.Dir(exePath) + + // Setup command with parameters matching Python implementation + args := []string{ + "/SP-", // Skip welcome page + "/SILENT", // Silent installation + "/NOCANCEL", // No cancel button + "/FORCECLOSEAPPLICATIONS", // Force close applications + "/LANG=Chinese", // Chinese language + fmt.Sprintf("/DIR=%s", installDir), // Installation directory + } + + app.logger.Info("Running setup with args: %v", args) + + // Create command with arguments + cmd := exec.Command(setupExePath, args...) + + // Set working directory to the setup file's directory + cmd.Dir = filepath.Dir(setupExePath) + + // Run the command and wait for it to complete + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to execute setup: %w", err) + } + + app.logger.Info("Setup executable completed successfully") + return nil +} + +// AutoMAAConfig represents the structure of config/config.json +type AutoMAAConfig struct { + Update struct { + UpdateType string `json:"UpdateType"` + } `json:"Update"` +} + +// loadChannelFromConfig loads the update channel from config/config.json +func (app *Application) loadChannelFromConfig() string { + // Get current executable directory + exePath, err := os.Executable() + if err != nil { + app.logger.Warn("Failed to get executable path: %v", err) + return "stable" + } + + configPath := filepath.Join(filepath.Dir(exePath), "config", "config.json") + + // Check if config file exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + app.logger.Info("Config file not found: %s, using default channel", configPath) + return "stable" + } + + // Read config file + data, err := os.ReadFile(configPath) + if err != nil { + app.logger.Warn("Failed to read config file: %v, using default channel", err) + return "stable" + } + + // Parse JSON + var config AutoMAAConfig + if err := json.Unmarshal(data, &config); err != nil { + app.logger.Warn("Failed to parse config file: %v, using default channel", err) + return "stable" + } + + // Get update channel + updateType := config.Update.UpdateType + if updateType == "" { + app.logger.Info("UpdateType not found in config, using default channel") + return "stable" + } + + app.logger.Info("Loaded update channel from config: %s", updateType) + return updateType +} \ No newline at end of file diff --git a/Go_Updater/utils/utils.go b/Go_Updater/utils/utils.go new file mode 100644 index 0000000..f9f6e8b --- /dev/null +++ b/Go_Updater/utils/utils.go @@ -0,0 +1,3 @@ +package utils + +// Package utils provides utility functions for the updater \ No newline at end of file diff --git a/Go_Updater/version/manager.go b/Go_Updater/version/manager.go new file mode 100644 index 0000000..e1a8895 --- /dev/null +++ b/Go_Updater/version/manager.go @@ -0,0 +1,193 @@ +package version + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + "lightweight-updater/logger" +) + +// VersionInfo represents the version information from version.json +type VersionInfo struct { + MainVersion string `json:"main_version"` + VersionInfo map[string]map[string][]string `json:"version_info"` +} + +// ParsedVersion represents a parsed version with major, minor, patch, and beta components +type ParsedVersion struct { + Major int + Minor int + Patch int + Beta int +} + +// VersionManager handles version-related operations +type VersionManager struct { + executableDir string + logger logger.Logger +} + +// NewVersionManager creates a new version manager +func NewVersionManager() *VersionManager { + execPath, _ := os.Executable() + execDir := filepath.Dir(execPath) + return &VersionManager{ + executableDir: execDir, + logger: logger.GetDefaultLogger(), + } +} + +// NewVersionManagerWithLogger creates a new version manager with a custom logger +func NewVersionManagerWithLogger(customLogger logger.Logger) *VersionManager { + execPath, _ := os.Executable() + execDir := filepath.Dir(execPath) + return &VersionManager{ + executableDir: execDir, + logger: customLogger, + } +} + +// createDefaultVersion creates a default version structure with v0.0.0 +func (vm *VersionManager) createDefaultVersion() *VersionInfo { + return &VersionInfo{ + MainVersion: "0.0.0.0", // Corresponds to v0.0.0 + VersionInfo: make(map[string]map[string][]string), + } +} + +// LoadVersionFromFile loads version information from resources/version.json with fallback handling +func (vm *VersionManager) LoadVersionFromFile() (*VersionInfo, error) { + versionPath := filepath.Join(vm.executableDir, "resources", "version.json") + + data, err := os.ReadFile(versionPath) + if err != nil { + if os.IsNotExist(err) { + vm.logger.Info("Version file not found at %s, will use default version", versionPath) + return vm.createDefaultVersion(), nil + } + vm.logger.Warn("Failed to read version file at %s: %v, will use default version", versionPath, err) + return vm.createDefaultVersion(), nil + } + + var versionInfo VersionInfo + if err := json.Unmarshal(data, &versionInfo); err != nil { + vm.logger.Warn("Failed to parse version file at %s: %v, will use default version", versionPath, err) + return vm.createDefaultVersion(), nil + } + + vm.logger.Debug("Successfully loaded version information from %s", versionPath) + return &versionInfo, nil +} + +// LoadVersionWithDefault loads version information with guaranteed fallback to default +func (vm *VersionManager) LoadVersionWithDefault() *VersionInfo { + versionInfo, err := vm.LoadVersionFromFile() + if err != nil { + // This should not happen with the updated LoadVersionFromFile, but adding as extra safety + vm.logger.Error("Unexpected error loading version file: %v, using default version", err) + return vm.createDefaultVersion() + } + + // Validate that we have a valid version structure + if versionInfo == nil { + vm.logger.Warn("Version info is nil, using default version") + return vm.createDefaultVersion() + } + + if versionInfo.MainVersion == "" { + vm.logger.Warn("Version info has empty main version, using default version") + return vm.createDefaultVersion() + } + + if versionInfo.VersionInfo == nil { + vm.logger.Debug("Version info map is nil, initializing empty map") + versionInfo.VersionInfo = make(map[string]map[string][]string) + } + + return versionInfo +} + +// ParseVersion parses a version string like "4.4.1.3" into components +func ParseVersion(versionStr string) (*ParsedVersion, error) { + parts := strings.Split(versionStr, ".") + if len(parts) < 3 || len(parts) > 4 { + return nil, fmt.Errorf("invalid version format: %s", versionStr) + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return nil, fmt.Errorf("invalid major version: %s", parts[0]) + } + + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, fmt.Errorf("invalid minor version: %s", parts[1]) + } + + patch, err := strconv.Atoi(parts[2]) + if err != nil { + return nil, fmt.Errorf("invalid patch version: %s", parts[2]) + } + + beta := 0 + if len(parts) == 4 { + beta, err = strconv.Atoi(parts[3]) + if err != nil { + return nil, fmt.Errorf("invalid beta version: %s", parts[3]) + } + } + + return &ParsedVersion{ + Major: major, + Minor: minor, + Patch: patch, + Beta: beta, + }, nil +} + +// ToVersionString converts a ParsedVersion back to version string format +func (pv *ParsedVersion) ToVersionString() string { + if pv.Beta == 0 { + return fmt.Sprintf("%d.%d.%d.0", pv.Major, pv.Minor, pv.Patch) + } + return fmt.Sprintf("%d.%d.%d.%d", pv.Major, pv.Minor, pv.Patch, pv.Beta) +} + +// ToDisplayVersion converts version to display format (v4.4.0 or v4.4.1-beta3) +func (pv *ParsedVersion) ToDisplayVersion() string { + if pv.Beta == 0 { + return fmt.Sprintf("v%d.%d.%d", pv.Major, pv.Minor, pv.Patch) + } + return fmt.Sprintf("v%d.%d.%d-beta%d", pv.Major, pv.Minor, pv.Patch, pv.Beta) +} + +// GetChannel returns the channel (stable or beta) based on version +func (pv *ParsedVersion) GetChannel() string { + if pv.Beta == 0 { + return "stable" + } + return "beta" +} + +// GetDefaultChannel returns the default channel +func GetDefaultChannel() string { + return "stable" +} + +// IsNewer checks if this version is newer than the other version +func (pv *ParsedVersion) IsNewer(other *ParsedVersion) bool { + if pv.Major != other.Major { + return pv.Major > other.Major + } + if pv.Minor != other.Minor { + return pv.Minor > other.Minor + } + if pv.Patch != other.Patch { + return pv.Patch > other.Patch + } + return pv.Beta > other.Beta +} diff --git a/Go_Updater/version/manager_test.go b/Go_Updater/version/manager_test.go new file mode 100644 index 0000000..41a3f7c --- /dev/null +++ b/Go_Updater/version/manager_test.go @@ -0,0 +1,366 @@ +package version + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestParseVersion(t *testing.T) { + tests := []struct { + input string + expected *ParsedVersion + hasError bool + }{ + {"4.4.0.0", &ParsedVersion{4, 4, 0, 0}, false}, + {"4.4.1.3", &ParsedVersion{4, 4, 1, 3}, false}, + {"1.2.3", &ParsedVersion{1, 2, 3, 0}, false}, + {"invalid", nil, true}, + {"1.2", nil, true}, + {"1.2.3.4.5", nil, true}, + } + + for _, test := range tests { + result, err := ParseVersion(test.input) + + if test.hasError { + if err == nil { + t.Errorf("Expected error for input %s, but got none", test.input) + } + continue + } + + if err != nil { + t.Errorf("Unexpected error for input %s: %v", test.input, err) + continue + } + + if result.Major != test.expected.Major || + result.Minor != test.expected.Minor || + result.Patch != test.expected.Patch || + result.Beta != test.expected.Beta { + t.Errorf("For input %s, expected %+v, got %+v", test.input, test.expected, result) + } + } +} + +func TestToDisplayVersion(t *testing.T) { + tests := []struct { + version *ParsedVersion + expected string + }{ + {&ParsedVersion{4, 4, 0, 0}, "v4.4.0"}, + {&ParsedVersion{4, 4, 1, 3}, "v4.4.1-beta3"}, + {&ParsedVersion{1, 2, 3, 0}, "v1.2.3"}, + {&ParsedVersion{1, 2, 3, 5}, "v1.2.3-beta5"}, + } + + for _, test := range tests { + result := test.version.ToDisplayVersion() + if result != test.expected { + t.Errorf("For version %+v, expected %s, got %s", test.version, test.expected, result) + } + } +} + +func TestGetChannel(t *testing.T) { + tests := []struct { + version *ParsedVersion + expected string + }{ + {&ParsedVersion{4, 4, 0, 0}, "stable"}, + {&ParsedVersion{4, 4, 1, 3}, "beta"}, + {&ParsedVersion{1, 2, 3, 0}, "stable"}, + {&ParsedVersion{1, 2, 3, 1}, "beta"}, + } + + for _, test := range tests { + result := test.version.GetChannel() + if result != test.expected { + t.Errorf("For version %+v, expected channel %s, got %s", test.version, test.expected, result) + } + } +} + +func TestIsNewer(t *testing.T) { + tests := []struct { + v1 *ParsedVersion + v2 *ParsedVersion + expected bool + }{ + {&ParsedVersion{4, 4, 1, 0}, &ParsedVersion{4, 4, 0, 0}, true}, + {&ParsedVersion{4, 4, 0, 0}, &ParsedVersion{4, 4, 1, 0}, false}, + {&ParsedVersion{4, 4, 1, 3}, &ParsedVersion{4, 4, 1, 2}, true}, + {&ParsedVersion{4, 4, 1, 2}, &ParsedVersion{4, 4, 1, 3}, false}, + {&ParsedVersion{4, 4, 1, 0}, &ParsedVersion{4, 4, 1, 0}, false}, + } + + for _, test := range tests { + result := test.v1.IsNewer(test.v2) + if result != test.expected { + t.Errorf("For %+v.IsNewer(%+v), expected %t, got %t", test.v1, test.v2, test.expected, result) + } + } +} + +func TestLoadVersionFromFile(t *testing.T) { + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "version_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + // Create resources directory + resourcesDir := filepath.Join(tempDir, "resources") + if err := os.MkdirAll(resourcesDir, 0755); err != nil { + t.Fatal(err) + } + + // Create test version file + versionData := VersionInfo{ + MainVersion: "4.4.1.3", + VersionInfo: map[string]map[string][]string{ + "4.4.1.3": { + "修复BUG": {"移除崩溃弹窗机制"}, + }, + }, + } + + data, err := json.Marshal(versionData) + if err != nil { + t.Fatal(err) + } + + versionFile := filepath.Join(resourcesDir, "version.json") + if err := os.WriteFile(versionFile, data, 0644); err != nil { + t.Fatal(err) + } + + // Create version manager with custom executable directory and logger + vm := NewVersionManager() + vm.executableDir = tempDir + + // Test loading version + result, err := vm.LoadVersionFromFile() + if err != nil { + t.Fatalf("Failed to load version: %v", err) + } + + if result.MainVersion != "4.4.1.3" { + t.Errorf("Expected main version 4.4.1.3, got %s", result.MainVersion) + } + + if len(result.VersionInfo) != 1 { + t.Errorf("Expected 1 version info entry, got %d", len(result.VersionInfo)) + } +} + +func TestLoadVersionFromFileNotFound(t *testing.T) { + // Create a temporary directory without version file + tempDir, err := os.MkdirTemp("", "version_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + // Create version manager with custom executable directory and logger + vm := NewVersionManager() + vm.executableDir = tempDir + + // Test loading version (should now return default version instead of error) + result, err := vm.LoadVersionFromFile() + if err != nil { + t.Errorf("Expected no error with fallback mechanism, but got: %v", err) + } + + // Should return default version + if result.MainVersion != "0.0.0.0" { + t.Errorf("Expected default version 0.0.0.0, got %s", result.MainVersion) + } + + if result.VersionInfo == nil { + t.Error("Expected initialized VersionInfo map, got nil") + } +} + +func TestLoadVersionWithDefault(t *testing.T) { + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "version_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + // Create version manager with custom executable directory + vm := NewVersionManager() + vm.executableDir = tempDir + + // Test loading version with default (no file exists) + result := vm.LoadVersionWithDefault() + if result == nil { + t.Fatal("Expected non-nil result from LoadVersionWithDefault") + } + + if result.MainVersion != "0.0.0.0" { + t.Errorf("Expected default version 0.0.0.0, got %s", result.MainVersion) + } + + if result.VersionInfo == nil { + t.Error("Expected initialized VersionInfo map, got nil") + } +} + +func TestLoadVersionWithDefaultValidFile(t *testing.T) { + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "version_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + // Create resources directory + resourcesDir := filepath.Join(tempDir, "resources") + if err := os.MkdirAll(resourcesDir, 0755); err != nil { + t.Fatal(err) + } + + // Create test version file + versionData := VersionInfo{ + MainVersion: "4.4.1.3", + VersionInfo: map[string]map[string][]string{ + "4.4.1.3": { + "修复BUG": {"移除崩溃弹窗机制"}, + }, + }, + } + + data, err := json.Marshal(versionData) + if err != nil { + t.Fatal(err) + } + + versionFile := filepath.Join(resourcesDir, "version.json") + if err := os.WriteFile(versionFile, data, 0644); err != nil { + t.Fatal(err) + } + + // Create version manager with custom executable directory + vm := NewVersionManager() + vm.executableDir = tempDir + + // Test loading version with default (valid file exists) + result := vm.LoadVersionWithDefault() + if result == nil { + t.Fatal("Expected non-nil result from LoadVersionWithDefault") + } + + if result.MainVersion != "4.4.1.3" { + t.Errorf("Expected version 4.4.1.3, got %s", result.MainVersion) + } + + if len(result.VersionInfo) != 1 { + t.Errorf("Expected 1 version info entry, got %d", len(result.VersionInfo)) + } +} + +func TestLoadVersionFromFileCorrupted(t *testing.T) { + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "version_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + // Create resources directory + resourcesDir := filepath.Join(tempDir, "resources") + if err := os.MkdirAll(resourcesDir, 0755); err != nil { + t.Fatal(err) + } + + // Create corrupted version file + versionFile := filepath.Join(resourcesDir, "version.json") + if err := os.WriteFile(versionFile, []byte("invalid json content"), 0644); err != nil { + t.Fatal(err) + } + + // Create version manager with custom executable directory + vm := NewVersionManager() + vm.executableDir = tempDir + + // Test loading version (should return default version for corrupted file) + result, err := vm.LoadVersionFromFile() + if err != nil { + t.Errorf("Expected no error with fallback mechanism for corrupted file, but got: %v", err) + } + + // Should return default version + if result.MainVersion != "0.0.0.0" { + t.Errorf("Expected default version 0.0.0.0 for corrupted file, got %s", result.MainVersion) + } + + if result.VersionInfo == nil { + t.Error("Expected initialized VersionInfo map for corrupted file, got nil") + } +} + +func TestLoadVersionWithDefaultCorrupted(t *testing.T) { + // Create a temporary directory + tempDir, err := os.MkdirTemp("", "version_test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDir) + + // Create resources directory + resourcesDir := filepath.Join(tempDir, "resources") + if err := os.MkdirAll(resourcesDir, 0755); err != nil { + t.Fatal(err) + } + + // Create corrupted version file + versionFile := filepath.Join(resourcesDir, "version.json") + if err := os.WriteFile(versionFile, []byte("invalid json content"), 0644); err != nil { + t.Fatal(err) + } + + // Create version manager with custom executable directory + vm := NewVersionManager() + vm.executableDir = tempDir + + // Test loading version with default (corrupted file) + result := vm.LoadVersionWithDefault() + if result == nil { + t.Fatal("Expected non-nil result from LoadVersionWithDefault for corrupted file") + } + + if result.MainVersion != "0.0.0.0" { + t.Errorf("Expected default version 0.0.0.0 for corrupted file, got %s", result.MainVersion) + } + + if result.VersionInfo == nil { + t.Error("Expected initialized VersionInfo map for corrupted file, got nil") + } +} + +func TestCreateDefaultVersion(t *testing.T) { + vm := NewVersionManager() + + result := vm.createDefaultVersion() + if result == nil { + t.Fatal("Expected non-nil result from createDefaultVersion") + } + + if result.MainVersion != "0.0.0.0" { + t.Errorf("Expected default version 0.0.0.0, got %s", result.MainVersion) + } + + if result.VersionInfo == nil { + t.Error("Expected initialized VersionInfo map, got nil") + } + + if len(result.VersionInfo) != 0 { + t.Errorf("Expected empty VersionInfo map, got %d entries", len(result.VersionInfo)) + } +} \ No newline at end of file diff --git a/Go_Updater/version/version.go b/Go_Updater/version/version.go new file mode 100644 index 0000000..35389ce --- /dev/null +++ b/Go_Updater/version/version.go @@ -0,0 +1,41 @@ +package version + +import ( + "fmt" + "runtime" +) + +var ( + // Version is the current version of the application + Version = "1.0.0" + + // BuildTime is set during build time + BuildTime = "unknown" + + // GitCommit is set during build time + GitCommit = "unknown" + + // GoVersion is the Go version used to build + GoVersion = runtime.Version() +) + +// GetVersionInfo returns formatted version information +func GetVersionInfo() string { + return fmt.Sprintf("Version: %s\nBuild Time: %s\nGit Commit: %s\nGo Version: %s", + Version, BuildTime, GitCommit, GoVersion) +} + +// GetShortVersion returns just the version number +func GetShortVersion() string { + return Version +} + +// GetBuildInfo returns build-specific information +func GetBuildInfo() map[string]string { + return map[string]string{ + "version": Version, + "build_time": BuildTime, + "git_commit": GitCommit, + "go_version": GoVersion, + } +} \ No newline at end of file