refactor(updater): 重构 Go 版本更新器

- 更新项目名称为 AUTO_MAA_Go_Updater
- 重构代码结构,优化函数命名和逻辑
- 移除 CDK 相关的冗余代码
- 调整版本号为 git commit hash
- 更新构建配置和脚本
- 优化 API 客户端实现
This commit is contained in:
2025-07-22 21:51:58 +08:00
parent 747ad6387b
commit 6b646378b6
21 changed files with 887 additions and 1673 deletions

View File

@@ -9,7 +9,7 @@ BUILD_DIR := build
DIST_DIR := dist
# Go build flags
LDFLAGS := -s -w -X lightweight-updater/version.Version=$(VERSION) -X lightweight-updater/version.BuildTime=$(BUILD_TIME) -X lightweight-updater/version.GitCommit=$(GIT_COMMIT)
LDFLAGS := -s -w -X AUTO_MAA_Go_Updater/version.Version=$(VERSION) -X AUTO_MAA_Go_Updater/version.BuildTime=$(BUILD_TIME) -X AUTO_MAA_Go_Updater/version.GitCommit=$(GIT_COMMIT)
# Default target
.PHONY: all

View File

@@ -10,204 +10,140 @@ import (
"time"
)
// MirrorResponse represents the response from MirrorChyan API
// MirrorResponse 表示 MirrorChyan API 的响应结构
type MirrorResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data struct {
VersionName string `json:"version_name"`
VersionNumber int `json:"version_number"`
URL string `json:"url,omitempty"` // Only present when using CDK
SHA256 string `json:"sha256,omitempty"` // Only present when using CDK
URL string `json:"url,omitempty"`
SHA256 string `json:"sha256,omitempty"`
Channel string `json:"channel"`
OS string `json:"os"`
Arch string `json:"arch"`
UpdateType string `json:"update_type,omitempty"` // Only present when using CDK
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"` // Only present when using CDK
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` // Only present when using CDK
FileSize int64 `json:"filesize,omitempty"`
} `json:"data"`
}
// UpdateCheckParams represents parameters for update checking
// UpdateCheckParams 表示更新检查的参数
type UpdateCheckParams struct {
ResourceID string
CurrentVersion string
Channel string
CDK string
UserAgent string
}
// MirrorClient interface defines the methods for Mirror API client
// MirrorClient 定义 Mirror API 客户端的接口方法
type MirrorClient interface {
CheckUpdate(params UpdateCheckParams) (*MirrorResponse, error)
CheckUpdateLegacy(resourceID, currentVersion, cdk, userAgent string) (*MirrorResponse, error)
IsUpdateAvailable(response *MirrorResponse, currentVersion string) bool
GetOfficialDownloadURL(versionName string) string
GetDownloadURL(versionName string) string
}
// Client implements MirrorClient interface
// Client 实现 MirrorClient 接口
type Client struct {
httpClient *http.Client
baseURL string
downloadURL string
}
// NewClient creates a new Mirror API client
// NewClient 创建新的 Mirror API 客户端
func NewClient() *Client {
return &Client{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
baseURL: "https://mirrorchyan.com/api/resources",
downloadURL: "http://221.236.27.82:10197/d/AUTO_MAA",
}
}
// CheckUpdate calls MirrorChyan API to check for updates with new parameter structure
// CheckUpdate 调用 MirrorChyan API 检查更新
func (c *Client) CheckUpdate(params UpdateCheckParams) (*MirrorResponse, error) {
// Construct the API URL
// 构建 API URL
apiURL := fmt.Sprintf("%s/%s/latest", c.baseURL, params.ResourceID)
// Parse URL to add query parameters
// 解析 URL 并添加查询参数
u, err := url.Parse(apiURL)
if err != nil {
return nil, fmt.Errorf("failed to parse API URL: %w", err)
return nil, fmt.Errorf("解析 API URL 失败: %w", err)
}
// Add query parameters
// 添加查询参数
q := u.Query()
q.Set("current_version", params.CurrentVersion)
q.Set("channel", params.Channel)
q.Set("os", "") // Empty for cross-platform
q.Set("arch", "") // Empty for cross-platform
if params.CDK != "" {
q.Set("cdk", params.CDK)
}
q.Set("os", "") // 跨平台为空
q.Set("arch", "") // 跨平台为空
u.RawQuery = q.Encode()
// Create HTTP request
// 创建 HTTP 请求
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
}
// Set User-Agent header
// 设置 User-Agent
if params.UserAgent != "" {
req.Header.Set("User-Agent", params.UserAgent)
} else {
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/138.0.0.0 Safari/537.36")
}
// Make HTTP request
// 发送 HTTP 请求
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make HTTP request: %w", err)
return nil, fmt.Errorf("发送 HTTP 请求失败: %w", err)
}
defer resp.Body.Close()
// Check HTTP status code
// 检查 HTTP 状态码
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned non-200 status code: %d", resp.StatusCode)
return nil, fmt.Errorf("API 返回非 200 状态码: %d", resp.StatusCode)
}
// Read response body
// 读取响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
return nil, fmt.Errorf("读取响应体失败: %w", err)
}
// Parse JSON response
// 解析 JSON 响应
var mirrorResp MirrorResponse
if err := json.Unmarshal(body, &mirrorResp); err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
return nil, fmt.Errorf("解析 JSON 响应失败: %w", err)
}
return &mirrorResp, nil
}
// CheckUpdateLegacy calls Mirror API to check for updates (legacy method for backward compatibility)
func (c *Client) CheckUpdateLegacy(resourceID, currentVersion, cdk, userAgent string) (*MirrorResponse, error) {
// Construct the API URL
apiURL := fmt.Sprintf("%s/%s/latest", c.baseURL, resourceID)
// Parse URL to add query parameters
u, err := url.Parse(apiURL)
if err != nil {
return nil, fmt.Errorf("failed to parse API URL: %w", err)
}
// Add query parameters
q := u.Query()
q.Set("current_version", currentVersion)
if cdk != "" {
q.Set("cdk", cdk)
}
u.RawQuery = q.Encode()
// Create HTTP request
req, err := http.NewRequest("GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Set User-Agent header
if userAgent != "" {
req.Header.Set("User-Agent", userAgent)
} else {
req.Header.Set("User-Agent", "LightweightUpdater/1.0")
}
// Make HTTP request
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make HTTP request: %w", err)
}
defer resp.Body.Close()
// Check HTTP status code
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned non-200 status code: %d", resp.StatusCode)
}
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Parse JSON response
var mirrorResp MirrorResponse
if err := json.Unmarshal(body, &mirrorResp); err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return &mirrorResp, nil
}
// IsUpdateAvailable compares current version with the latest version from API response
// IsUpdateAvailable 比较当前版本与 API 响应中的最新版本
func (c *Client) IsUpdateAvailable(response *MirrorResponse, currentVersion string) bool {
// Check if API response is successful
// 检查 API 响应是否成功
if response.Code != 0 {
return false
}
// Get latest version from response
// 从响应中获取最新版本
latestVersion := response.Data.VersionName
if latestVersion == "" {
return false
}
// Convert version formats for comparison
// 转换版本格式以便比较
currentVersionNormalized := c.normalizeVersionForComparison(currentVersion)
latestVersionNormalized := c.normalizeVersionForComparison(latestVersion)
// Compare versions using semantic version comparison
// 使用语义版本比较
return compareVersions(currentVersionNormalized, latestVersionNormalized) < 0
}
// normalizeVersionForComparison converts different version formats to comparable format
// normalizeVersionForComparison 将不同版本格式转换为可比较格式
func (c *Client) normalizeVersionForComparison(version string) string {
// Handle AUTO_MAA version format: "4.4.1.3" -> "v4.4.1-beta3"
// 处理 AUTO_MAA 版本格式: "4.4.1.3" -> "v4.4.1-beta3"
if !strings.HasPrefix(version, "v") && strings.Count(version, ".") == 3 {
parts := strings.Split(version, ".")
if len(parts) == 4 {
@@ -220,22 +156,22 @@ func (c *Client) normalizeVersionForComparison(version string) string {
}
}
// Return as-is if already in standard format
// 如果已经是标准格式则直接返回
return version
}
// compareVersions compares two semantic version strings
// Returns: -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2
// compareVersions 比较两个语义版本字符串
// 返回值: -1 如果 v1 < v2, 0 如果 v1 == v2, 1 如果 v1 > v2
func compareVersions(v1, v2 string) int {
// Normalize versions by removing 'v' prefix if present
// 通过移除 'v' 前缀来标准化版本
v1 = normalizeVersion(v1)
v2 = normalizeVersion(v2)
// Parse version components
// 解析版本组件
parts1 := parseVersionParts(v1)
parts2 := parseVersionParts(v2)
// Compare each component
// 比较每个组件
maxLen := len(parts1)
if len(parts2) > maxLen {
maxLen = len(parts2)
@@ -260,7 +196,7 @@ func compareVersions(v1, v2 string) int {
return 0
}
// normalizeVersion removes 'v' prefix and handles common version formats
// normalizeVersion 移除 'v' 前缀并处理常见版本格式
func normalizeVersion(version string) string {
if len(version) > 0 && (version[0] == 'v' || version[0] == 'V') {
return version[1:]
@@ -268,7 +204,7 @@ func normalizeVersion(version string) string {
return version
}
// parseVersionParts parses version string into numeric components
// parseVersionParts 将版本字符串解析为数字组件
func parseVersionParts(version string) []int {
if version == "" {
return []int{0}
@@ -284,15 +220,15 @@ func parseVersionParts(version string) []int {
parts = append(parts, current)
current = 0
} else {
// Stop parsing at non-numeric, non-dot characters (like pre-release identifiers)
// 在非数字、非点字符处停止解析(如预发布标识符)
break
}
}
// Add the last component
// 添加最后一个组件
parts = append(parts, current)
// Ensure at least 3 components (major.minor.patch)
// 确保至少有 3 个组件 (major.minor.patch)
for len(parts) < 3 {
parts = append(parts, 0)
}
@@ -300,33 +236,17 @@ func parseVersionParts(version string) []int {
return parts
}
// GetOfficialDownloadURL generates the official download URL based on version name
func (c *Client) GetOfficialDownloadURL(versionName string) string {
// Official download site base URL
baseURL := "http://221.236.27.82:10197/d/AUTO_MAA"
// Convert version name to filename format
// e.g., "v4.4.0" -> "AUTO_MAA_v4.4.0.zip"
// e.g., "v4.4.1-beta3" -> "AUTO_MAA_v4.4.1-beta.3.zip"
// GetDownloadURL 根据版本名生成下载站的下载 URL
func (c *Client) GetDownloadURL(versionName string) string {
// 将版本名转换为文件名格式
// 例如: "v4.4.0" -> "AUTO_MAA_v4.4.0.zip"
// 例如: "v4.4.1-beta3" -> "AUTO_MAA_v4.4.1-beta.3.zip"
filename := fmt.Sprintf("AUTO_MAA_%s.zip", versionName)
// Handle beta versions: convert "beta3" to "beta.3"
// 处理 beta 版本: 将 "beta3" 转换为 "beta.3"
if strings.Contains(filename, "-beta") && !strings.Contains(filename, "-beta.") {
filename = strings.Replace(filename, "-beta", "-beta.", 1)
}
return fmt.Sprintf("%s/%s", baseURL, filename)
}
// HasCDKDownloadURL checks if the response contains a CDK download URL
func (c *Client) HasCDKDownloadURL(response *MirrorResponse) bool {
return response != nil && response.Data.URL != ""
}
// GetDownloadURL returns the appropriate download URL based on available options
func (c *Client) GetDownloadURL(response *MirrorResponse) string {
if c.HasCDKDownloadURL(response) {
return response.Data.URL
}
return c.GetOfficialDownloadURL(response.Data.VersionName)
return fmt.Sprintf("%s/%s", c.downloadURL, filename)
}

View File

@@ -10,17 +10,20 @@ import (
func TestNewClient(t *testing.T) {
client := NewClient()
if client == nil {
t.Fatal("NewClient() returned nil")
t.Fatal("NewClient() 返回 nil")
}
if client.httpClient == nil {
t.Fatal("HTTP client is nil")
t.Fatal("HTTP 客户端为 nil")
}
if client.baseURL != "https://mirrorchyan.com/api/resources" {
t.Errorf("Expected base URL 'https://mirrorchyan.com/api/resources', got '%s'", client.baseURL)
t.Errorf("期望基础 URL 'https://mirrorchyan.com/api/resources',得到 '%s'", client.baseURL)
}
if client.downloadURL != "http://221.236.27.82:10197/d/AUTO_MAA" {
t.Errorf("期望下载 URL 'http://221.236.27.82:10197/d/AUTO_MAA',得到 '%s'", client.downloadURL)
}
}
func TestGetOfficialDownloadURL(t *testing.T) {
func TestGetDownloadURL(t *testing.T) {
client := NewClient()
tests := []struct {
@@ -30,51 +33,19 @@ func TestGetOfficialDownloadURL(t *testing.T) {
{"v4.4.0", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.0.zip"},
{"v4.4.1-beta3", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.1-beta.3.zip"},
{"v1.2.3", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v1.2.3.zip"},
{"v1.2.3-beta1", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v1.2.3-beta.1.zip"},
}
for _, test := range tests {
result := client.GetOfficialDownloadURL(test.versionName)
result := client.GetDownloadURL(test.versionName)
if result != test.expected {
t.Errorf("For version %s, expected %s, got %s", test.versionName, test.expected, result)
}
}
}
func TestNormalizeVersionForComparison(t *testing.T) {
client := NewClient()
tests := []struct {
input string
expected string
}{
{"4.4.0.0", "v4.4.0"},
{"4.4.1.3", "v4.4.1-beta3"},
{"v4.4.0", "v4.4.0"},
{"v4.4.1-beta3", "v4.4.1-beta3"},
{"1.2.3", "1.2.3"}, // Not 4-part version, return as-is
}
for _, test := range tests {
result := client.normalizeVersionForComparison(test.input)
if result != test.expected {
t.Errorf("For input %s, expected %s, got %s", test.input, test.expected, result)
t.Errorf("版本 %s期望 %s得到 %s", test.versionName, test.expected, result)
}
}
}
func TestCheckUpdate(t *testing.T) {
// Create test server
// 创建测试服务器
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request parameters
if r.URL.Query().Get("current_version") != "4.4.0.0" {
t.Errorf("Expected current_version=4.4.0.0, got %s", r.URL.Query().Get("current_version"))
}
if r.URL.Query().Get("channel") != "stable" {
t.Errorf("Expected channel=stable, got %s", r.URL.Query().Get("channel"))
}
// Return mock response
response := MirrorResponse{
Code: 0,
Msg: "success",
@@ -89,125 +60,47 @@ func TestCheckUpdate(t *testing.T) {
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{
VersionName: "v4.4.1",
VersionNumber: 48,
Channel: "stable",
OS: "",
Arch: "",
ReleaseNote: "Test release notes",
ReleaseNote: "测试发布说明",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
err := json.NewEncoder(w).Encode(response)
if err != nil {
return
}
}))
defer server.Close()
// Create client with test server URL
// 使用测试服务器 URL 创建客户端
client := &Client{
httpClient: &http.Client{},
baseURL: server.URL,
downloadURL: "http://221.236.27.82:10197/d/AUTO_MAA",
}
// Test update check
// 测试更新检查
params := UpdateCheckParams{
ResourceID: "AUTO_MAA",
CurrentVersion: "4.4.0.0",
Channel: "stable",
CDK: "",
UserAgent: "TestAgent/1.0",
}
response, err := client.CheckUpdate(params)
if err != nil {
t.Fatalf("CheckUpdate failed: %v", err)
t.Fatalf("CheckUpdate 失败: %v", err)
}
if response.Code != 0 {
t.Errorf("Expected code 0, got %d", response.Code)
t.Errorf("期望代码 0得到 %d", response.Code)
}
if response.Data.VersionName != "v4.4.1" {
t.Errorf("Expected version v4.4.1, got %s", response.Data.VersionName)
}
if response.Data.Channel != "stable" {
t.Errorf("Expected channel stable, got %s", response.Data.Channel)
}
}
func TestCheckUpdateWithCDK(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify CDK parameter
if r.URL.Query().Get("cdk") != "test_cdk_123" {
t.Errorf("Expected cdk=test_cdk_123, got %s", r.URL.Query().Get("cdk"))
}
// Return mock response with CDK download URL
response := MirrorResponse{
Code: 0,
Msg: "success",
Data: struct {
VersionName string `json:"version_name"`
VersionNumber int `json:"version_number"`
URL string `json:"url,omitempty"`
SHA256 string `json:"sha256,omitempty"`
Channel string `json:"channel"`
OS string `json:"os"`
Arch string `json:"arch"`
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{
VersionName: "v4.4.1",
VersionNumber: 48,
URL: "https://mirrorchyan.com/api/resources/download/test123",
SHA256: "abcd1234",
Channel: "stable",
OS: "",
Arch: "",
UpdateType: "full",
ReleaseNote: "Test release notes",
FileSize: 12345678,
CDKExpiredTime: 1776013593,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
// Create client with test server URL
client := &Client{
httpClient: &http.Client{},
baseURL: server.URL,
}
// Test update check with CDK
params := UpdateCheckParams{
ResourceID: "AUTO_MAA",
CurrentVersion: "4.4.0.0",
Channel: "stable",
CDK: "test_cdk_123",
UserAgent: "TestAgent/1.0",
}
response, err := client.CheckUpdate(params)
if err != nil {
t.Fatalf("CheckUpdate with CDK failed: %v", err)
}
if response.Data.URL == "" {
t.Error("Expected CDK download URL, but got empty")
}
if response.Data.SHA256 == "" {
t.Error("Expected SHA256 hash, but got empty")
}
if response.Data.FileSize == 0 {
t.Error("Expected file size, but got 0")
t.Errorf("期望版本 v4.4.1,得到 %s", response.Data.VersionName)
}
}
@@ -221,7 +114,7 @@ func TestIsUpdateAvailable(t *testing.T) {
expected bool
}{
{
name: "Update available - stable",
name: "有可用更新",
response: &MirrorResponse{
Code: 0,
Data: struct {
@@ -235,14 +128,13 @@ func TestIsUpdateAvailable(t *testing.T) {
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{VersionName: "v4.4.1"},
},
currentVersion: "4.4.0.0",
expected: true,
},
{
name: "No update available - same version",
name: "无可用更新",
response: &MirrorResponse{
Code: 0,
Data: struct {
@@ -256,167 +148,18 @@ func TestIsUpdateAvailable(t *testing.T) {
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{VersionName: "v4.4.0"},
},
currentVersion: "4.4.0.0",
expected: false,
},
{
name: "API error",
response: &MirrorResponse{
Code: 1,
Data: struct {
VersionName string `json:"version_name"`
VersionNumber int `json:"version_number"`
URL string `json:"url,omitempty"`
SHA256 string `json:"sha256,omitempty"`
Channel string `json:"channel"`
OS string `json:"os"`
Arch string `json:"arch"`
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{VersionName: "v4.4.1"},
},
currentVersion: "4.4.0.0",
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := client.IsUpdateAvailable(test.response, test.currentVersion)
if result != test.expected {
t.Errorf("Expected %t, got %t", test.expected, result)
}
})
}
}
func TestHasCDKDownloadURL(t *testing.T) {
client := NewClient()
tests := []struct {
name string
response *MirrorResponse
expected bool
}{
{
name: "Has CDK URL",
response: &MirrorResponse{
Data: struct {
VersionName string `json:"version_name"`
VersionNumber int `json:"version_number"`
URL string `json:"url,omitempty"`
SHA256 string `json:"sha256,omitempty"`
Channel string `json:"channel"`
OS string `json:"os"`
Arch string `json:"arch"`
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{URL: "https://mirrorchyan.com/download/test"},
},
expected: true,
},
{
name: "No CDK URL",
response: &MirrorResponse{
Data: struct {
VersionName string `json:"version_name"`
VersionNumber int `json:"version_number"`
URL string `json:"url,omitempty"`
SHA256 string `json:"sha256,omitempty"`
Channel string `json:"channel"`
OS string `json:"os"`
Arch string `json:"arch"`
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{URL: ""},
},
expected: false,
},
{
name: "Nil response",
response: nil,
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := client.HasCDKDownloadURL(test.response)
if result != test.expected {
t.Errorf("Expected %t, got %t", test.expected, result)
}
})
}
}
func TestGetDownloadURL(t *testing.T) {
client := NewClient()
tests := []struct {
name string
response *MirrorResponse
expected string
}{
{
name: "CDK URL available",
response: &MirrorResponse{
Data: struct {
VersionName string `json:"version_name"`
VersionNumber int `json:"version_number"`
URL string `json:"url,omitempty"`
SHA256 string `json:"sha256,omitempty"`
Channel string `json:"channel"`
OS string `json:"os"`
Arch string `json:"arch"`
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{
VersionName: "v4.4.1",
URL: "https://mirrorchyan.com/download/test",
},
},
expected: "https://mirrorchyan.com/download/test",
},
{
name: "Official URL fallback",
response: &MirrorResponse{
Data: struct {
VersionName string `json:"version_name"`
VersionNumber int `json:"version_number"`
URL string `json:"url,omitempty"`
SHA256 string `json:"sha256,omitempty"`
Channel string `json:"channel"`
OS string `json:"os"`
Arch string `json:"arch"`
UpdateType string `json:"update_type,omitempty"`
ReleaseNote string `json:"release_note"`
FileSize int64 `json:"filesize,omitempty"`
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
}{
VersionName: "v4.4.1",
URL: "",
},
},
expected: "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.1.zip",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := client.GetDownloadURL(test.response)
if result != test.expected {
t.Errorf("Expected %s, got %s", test.expected, result)
t.Errorf("期望 %t得到 %t", test.expected, result)
}
})
}

View File

@@ -8,17 +8,17 @@ import (
//go:embed config_template.yaml
var EmbeddedAssets embed.FS
// GetConfigTemplate returns the embedded config template
// GetConfigTemplate 返回嵌入的配置模板
func GetConfigTemplate() ([]byte, error) {
return EmbeddedAssets.ReadFile("config_template.yaml")
}
// GetAssetFS returns the embedded filesystem
// GetAssetFS 返回嵌入的文件系统
func GetAssetFS() fs.FS {
return EmbeddedAssets
}
// ListAssets returns a list of all embedded assets
// ListAssets 返回所有嵌入资源的列表
func ListAssets() ([]string, error) {
var assets []string
err := fs.WalkDir(EmbeddedAssets, ".", func(path string, d fs.DirEntry, err error) error {

View File

@@ -1,6 +1,5 @@
resource_id: "AUTO_MAA"
current_version: "v1.0.0"
cdk: "" # Will be encrypted when saved
user_agent: "AUTO_MAA_Go_Updater/1.0"
backup_url: "https://backup-download-site.com/releases"
log_level: "info"

View File

@@ -1,9 +1,9 @@
# Build Configuration for Lightweight Updater
# Build Configuration for AUTO_MAA_Go_Updater
project:
name: "Lightweight Updater"
module: "lightweight-updater"
description: "轻量级自动更新器"
name: "AUTO_MAA_Go_Updater"
module: "AUTO_MAA_Go_Updater"
description: "AUTO_MAA_Go版本更新器"
version:
default: "1.0.0"
@@ -14,7 +14,7 @@ targets:
goos: "windows"
goarch: "amd64"
cgo_enabled: true
output: "lightweight-updater.exe"
output: "AUTO_MAA_Go_Updater.exe"
build:
flags:
@@ -40,7 +40,7 @@ directories:
temp: "temp"
version_injection:
package: "lightweight-updater/version"
package: "AUTO_MAA_Go_Updater/version"
variables:
- name: "Version"
source: "version"

View File

@@ -6,14 +6,13 @@ echo AUTO_MAA_Go_Updater Build Script
echo ========================================
:: Set build variables
set VERSION=1.0.0
set OUTPUT_NAME=AUTO_MAA_Go_Updater.exe
set BUILD_DIR=build
set DIST_DIR=dist
:: Get current timestamp
:: Get current datetime for build time
for /f "tokens=2 delims==" %%a in ('wmic OS Get localdatetime /value') do set "dt=%%a"
set "YY=%dt:~2,2%" & set "YYYY=%dt:~0,4%" & set "MM=%dt:~4,2%" & set "DD=%dt:~6,2%"
set "YYYY=%dt:~0,4%" & set "MM=%dt:~4,2%" & set "DD=%dt:~6,2%"
set "HH=%dt:~8,2%" & set "Min=%dt:~10,2%" & set "Sec=%dt:~12,2%"
set "BUILD_TIME=%YYYY%-%MM%-%DD%T%HH%:%Min%:%Sec%Z"
@@ -26,6 +25,9 @@ if exist temp_commit.txt (
set GIT_COMMIT=unknown
)
:: Use commit hash as version
set VERSION=%GIT_COMMIT%
echo Build Information:
echo - Version: %VERSION%
echo - Build Time: %BUILD_TIME%
@@ -38,7 +40,7 @@ if not exist %BUILD_DIR% mkdir %BUILD_DIR%
if not exist %DIST_DIR% mkdir %DIST_DIR%
:: Set build flags
set LDFLAGS=-s -w -X lightweight-updater/version.Version=%VERSION% -X lightweight-updater/version.BuildTime=%BUILD_TIME% -X lightweight-updater/version.GitCommit=%GIT_COMMIT%
set LDFLAGS=-s -w -X AUTO_MAA_Go_Updater/version.Version=%VERSION% -X AUTO_MAA_Go_Updater/version.BuildTime=%BUILD_TIME% -X AUTO_MAA_Go_Updater/version.GitCommit=%GIT_COMMIT%
echo Building application...
@@ -58,6 +60,7 @@ if not exist app.syso (
)
)
:: Set environment variables for Go build
set GOOS=windows
set GOARCH=amd64
set CGO_ENABLED=1
@@ -74,8 +77,6 @@ echo Build completed successfully!
:: Get file size
for %%A in (%BUILD_DIR%\%OUTPUT_NAME%) do set FILE_SIZE=%%~zA
:: Convert bytes to MB
set /a FILE_SIZE_MB=%FILE_SIZE%/1024/1024
echo.
@@ -83,13 +84,6 @@ echo Build Results:
echo - Output: %BUILD_DIR%\%OUTPUT_NAME%
echo - Size: %FILE_SIZE% bytes (~%FILE_SIZE_MB% MB)
:: Check if file size is within requirements (<10MB)
if %FILE_SIZE_MB% gtr 10 (
echo WARNING: File size exceeds 10MB requirement!
) else (
echo File size meets requirements (^<10MB)
)
:: Copy to dist directory
copy %BUILD_DIR%\%OUTPUT_NAME% %DIST_DIR%\%OUTPUT_NAME% >nul
echo - Copied to: %DIST_DIR%\%OUTPUT_NAME%

View File

@@ -1,6 +1,5 @@
# Lightweight Updater Build Script (PowerShell)
# AUTO_MAA_Go_Updater Build Script (PowerShell)
param(
[string]$Version = "1.0.0",
[string]$OutputName = "AUTO_MAA_Go_Updater.exe",
[switch]$Compress = $false
)
@@ -14,6 +13,7 @@ $BuildDir = "build"
$DistDir = "dist"
$BuildTime = (Get-Date).ToString("yyyy-MM-ddTHH:mm:ssZ")
# Get git commit hash
try {
$GitCommit = (git rev-parse --short HEAD 2>$null).Trim()
@@ -23,7 +23,7 @@ try {
}
Write-Host "Build Information:" -ForegroundColor Yellow
Write-Host "- Version: $Version"
Write-Host "- Version: $GitCommit"
Write-Host "- Build Time: $BuildTime"
Write-Host "- Git Commit: $GitCommit"
Write-Host "- Target: Windows 64-bit"
@@ -39,7 +39,7 @@ $env:GOARCH = "amd64"
$env:CGO_ENABLED = "1"
# Set build flags
$LdFlags = "-s -w -X lightweight-updater/version.Version=$Version -X lightweight-updater/version.BuildTime=$BuildTime -X lightweight-updater/version.GitCommit=$GitCommit"
$LdFlags = "-s -w -X AUTO_MAA_Go_Updater/version.Version=$Version -X AUTO_MAA_Go_Updater/version.BuildTime=$BuildTime -X AUTO_MAA_Go_Updater/version.GitCommit=$GitCommit"
Write-Host "Building application..." -ForegroundColor Green
@@ -78,12 +78,6 @@ Write-Host "Build Results:" -ForegroundColor Yellow
Write-Host "- Output: $($OutputFile.FullName)"
Write-Host "- Size: $($OutputFile.Length) bytes (~$FileSizeMB MB)"
# Check file size requirement
if ($FileSizeMB -gt 10) {
Write-Host "WARNING: File size exceeds 10MB requirement!" -ForegroundColor Red
} else {
Write-Host "File size meets requirements (<10MB)" -ForegroundColor Green
}
# Optional UPX compression
if ($Compress) {

View File

@@ -1,40 +1,38 @@
package config
import (
"encoding/base64"
"fmt"
"os"
"path/filepath"
"AUTO_MAA_Go_Updater/assets"
"gopkg.in/yaml.v3"
"lightweight-updater/assets"
)
// Config represents the application configuration
// Config 表示应用程序配置
type Config struct {
ResourceID string `yaml:"resource_id"`
CurrentVersion string `yaml:"current_version"`
CDK string `yaml:"cdk,omitempty"`
UserAgent string `yaml:"user_agent"`
BackupURL string `yaml:"backup_url"`
LogLevel string `yaml:"log_level"`
AutoCheck bool `yaml:"auto_check"`
CheckInterval int `yaml:"check_interval"` // seconds
CheckInterval int `yaml:"check_interval"` //
}
// ConfigManager interface defines methods for configuration management
// ConfigManager 定义配置管理的接口方法
type ConfigManager interface {
Load() (*Config, error)
Save(config *Config) error
GetConfigPath() string
}
// DefaultConfigManager implements ConfigManager interface
// DefaultConfigManager 实现 ConfigManager 接口
type DefaultConfigManager struct {
configPath string
}
// NewConfigManager creates a new configuration manager
// NewConfigManager 创建新的配置管理器
func NewConfigManager() ConfigManager {
configDir := getConfigDir()
configPath := filepath.Join(configDir, "config.yaml")
@@ -43,77 +41,77 @@ func NewConfigManager() ConfigManager {
}
}
// GetConfigPath returns the path to the configuration file
// GetConfigPath 返回配置文件的路径
func (cm *DefaultConfigManager) GetConfigPath() string {
return cm.configPath
}
// Load reads and parses the configuration file
// Load 读取并解析配置文件
func (cm *DefaultConfigManager) Load() (*Config, error) {
// Create config directory if it doesn't exist
// 如果配置目录不存在则创建
configDir := filepath.Dir(cm.configPath)
if err := os.MkdirAll(configDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create config directory: %w", err)
return nil, fmt.Errorf("创建配置目录失败: %w", err)
}
// If config file doesn't exist, create default config
// 如果配置文件不存在,创建默认配置
if _, err := os.Stat(cm.configPath); os.IsNotExist(err) {
defaultConfig := getDefaultConfig()
if err := cm.Save(defaultConfig); err != nil {
return nil, fmt.Errorf("failed to create default config: %w", err)
return nil, fmt.Errorf("创建默认配置失败: %w", err)
}
return defaultConfig, nil
}
// Read existing config file
// 读取现有配置文件
data, err := os.ReadFile(cm.configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
// Validate and apply defaults for missing fields
// 验证并应用缺失字段的默认值
if err := validateAndApplyDefaults(&config); err != nil {
return nil, fmt.Errorf("config validation failed: %w", err)
return nil, fmt.Errorf("配置验证失败: %w", err)
}
return &config, nil
}
// Save writes the configuration to file
// Save 将配置写入文件
func (cm *DefaultConfigManager) Save(config *Config) error {
// Validate config before saving
// 保存前验证配置
if err := validateConfig(config); err != nil {
return fmt.Errorf("config validation failed: %w", err)
return fmt.Errorf("配置验证失败: %w", err)
}
// Create config directory if it doesn't exist
// 如果配置目录不存在则创建
configDir := filepath.Dir(cm.configPath)
if err := os.MkdirAll(configDir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
return fmt.Errorf("创建配置目录失败: %w", err)
}
// Marshal config to YAML
// 将配置序列化为 YAML
data, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
return fmt.Errorf("序列化配置失败: %w", err)
}
// Write to file
// 写入文件
if err := os.WriteFile(cm.configPath, data, 0644); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
return fmt.Errorf("写入配置文件失败: %w", err)
}
return nil
}
// getDefaultConfig returns a configuration with default values
// getDefaultConfig 返回带有默认值的配置
func getDefaultConfig() *Config {
// Try to load from embedded template first
// 首先尝试从嵌入模板加载
if templateData, err := assets.GetConfigTemplate(); err == nil {
var config Config
if err := yaml.Unmarshal(templateData, &config); err == nil {
@@ -121,35 +119,34 @@ func getDefaultConfig() *Config {
}
}
// Fallback to hardcoded defaults if template loading fails
// 如果模板加载失败则回退到硬编码默认值
return &Config{
ResourceID: "M9A", // Default resource ID
ResourceID: "M9A", // 默认资源 ID
CurrentVersion: "v1.0.0",
CDK: "",
UserAgent: "LightweightUpdater/1.0",
UserAgent: "AUTO_MAA_Go_Updater/1.0",
BackupURL: "",
LogLevel: "info",
AutoCheck: true,
CheckInterval: 3600, // 1 hour
CheckInterval: 3600, // 1 小时
}
}
// validateConfig validates the configuration values
// validateConfig 验证配置值
func validateConfig(config *Config) error {
if config == nil {
return fmt.Errorf("config cannot be nil")
return fmt.Errorf("配置不能为空")
}
if config.ResourceID == "" {
return fmt.Errorf("resource_id cannot be empty")
return fmt.Errorf("resource_id 不能为空")
}
if config.CurrentVersion == "" {
return fmt.Errorf("current_version cannot be empty")
return fmt.Errorf("current_version 不能为空")
}
if config.UserAgent == "" {
return fmt.Errorf("user_agent cannot be empty")
return fmt.Errorf("user_agent 不能为空")
}
validLogLevels := map[string]bool{
@@ -159,21 +156,21 @@ func validateConfig(config *Config) error {
"error": true,
}
if !validLogLevels[config.LogLevel] {
return fmt.Errorf("invalid log_level: %s (must be debug, info, warn, or error)", config.LogLevel)
return fmt.Errorf("无效的 log_level: %s (必须是 debug, info, warn error)", config.LogLevel)
}
if config.CheckInterval < 60 {
return fmt.Errorf("check_interval must be at least 60 seconds")
return fmt.Errorf("check_interval 必须至少为 60 秒")
}
return nil
}
// validateAndApplyDefaults validates config and applies defaults for missing fields
// validateAndApplyDefaults 验证配置并为缺失字段应用默认值
func validateAndApplyDefaults(config *Config) error {
defaults := getDefaultConfig()
// Apply defaults for empty fields
// 为空字段应用默认值
if config.UserAgent == "" {
config.UserAgent = defaults.UserAgent
}
@@ -187,62 +184,15 @@ func validateAndApplyDefaults(config *Config) error {
config.CurrentVersion = defaults.CurrentVersion
}
// Validate after applying defaults
// 应用默认值后进行验证
return validateConfig(config)
}
// getConfigDir returns the configuration directory path
// getConfigDir 返回配置目录路径
func getConfigDir() string {
// Use APPDATA on Windows, fallback to current directory
// 在 Windows 上使用 APPDATA回退到当前目录
if appData := os.Getenv("APPDATA"); appData != "" {
return filepath.Join(appData, "LightweightUpdater")
return filepath.Join(appData, "AUTO_MAA_Go_Updater")
}
return "."
}
// encryptCDK encrypts the CDK using XOR encryption with a static key
func encryptCDK(cdk string) string {
if cdk == "" {
return ""
}
key := []byte("updater-key-2024")
encrypted := make([]byte, len(cdk))
for i, b := range []byte(cdk) {
encrypted[i] = b ^ key[i%len(key)]
}
return base64.StdEncoding.EncodeToString(encrypted)
}
// decryptCDK decrypts the CDK using XOR decryption with a static key
func decryptCDK(encryptedCDK string) (string, error) {
if encryptedCDK == "" {
return "", nil
}
encrypted, err := base64.StdEncoding.DecodeString(encryptedCDK)
if err != nil {
return "", fmt.Errorf("failed to decode encrypted CDK: %w", err)
}
key := []byte("updater-key-2024")
decrypted := make([]byte, len(encrypted))
for i, b := range encrypted {
decrypted[i] = b ^ key[i%len(key)]
}
return string(decrypted), nil
}
// SetCDK sets the CDK in the config with encryption
func (c *Config) SetCDK(cdk string) {
c.CDK = encryptCDK(cdk)
}
// GetCDK returns the decrypted CDK from the config
func (c *Config) GetCDK() (string, error) {
return decryptCDK(c.CDK)
}

View File

@@ -44,7 +44,6 @@
},
"Update": {
"IfAutoUpdate": false,
"MirrorChyanCDK": "",
"ProxyUrlList": [],
"ThreadNumb": 8,
"UpdateType": "stable"

View File

@@ -3,163 +3,55 @@ package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
func TestEncryptDecryptCDK(t *testing.T) {
tests := []struct {
name string
original string
}{
{
name: "Empty CDK",
original: "",
},
{
name: "Simple CDK",
original: "test123",
},
{
name: "Complex CDK",
original: "ABC123-DEF456-GHI789",
},
{
name: "CDK with special characters",
original: "test@#$%^&*()_+-={}[]|\\:;\"'<>?,./",
},
{
name: "Long CDK",
original: "this-is-a-very-long-cdk-key-that-should-still-work-properly-with-encryption-and-decryption",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test encryption
encrypted := encryptCDK(tt.original)
// Empty string should remain empty
if tt.original == "" {
if encrypted != "" {
t.Errorf("Expected empty string for empty input, got %s", encrypted)
}
return
}
// Encrypted should be different from original (unless original is empty)
if encrypted == tt.original {
t.Errorf("Encrypted CDK should be different from original")
}
// Test decryption
decrypted, err := decryptCDK(encrypted)
if err != nil {
t.Errorf("Decryption failed: %v", err)
}
// Decrypted should match original
if decrypted != tt.original {
t.Errorf("Expected %s, got %s", tt.original, decrypted)
}
})
}
}
func TestConfigSetGetCDK(t *testing.T) {
config := &Config{}
testCDK := "test-cdk-123"
// Set CDK (should encrypt)
config.SetCDK(testCDK)
// CDK field should be encrypted (different from original)
if config.CDK == testCDK {
t.Errorf("CDK should be encrypted in config")
}
// Get CDK (should decrypt)
retrievedCDK, err := config.GetCDK()
if err != nil {
t.Errorf("Failed to get CDK: %v", err)
}
if retrievedCDK != testCDK {
t.Errorf("Expected %s, got %s", testCDK, retrievedCDK)
}
}
func TestDecryptInvalidCDK(t *testing.T) {
// Test with invalid base64
_, err := decryptCDK("invalid-base64!")
if err == nil {
t.Errorf("Expected error for invalid base64")
}
}
func TestConfigManagerLoadSave(t *testing.T) {
// Create temporary directory for test
// 为测试创建临时目录
tempDir := t.TempDir()
// Create config manager with temp path
// 使用临时路径创建配置管理器
cm := &DefaultConfigManager{
configPath: filepath.Join(tempDir, "test-config.yaml"),
}
// Test loading non-existent config (should create default)
// 测试加载不存在的配置(应创建默认配置)
config, err := cm.Load()
if err != nil {
t.Errorf("Failed to load config: %v", err)
t.Errorf("加载配置失败: %v", err)
}
if config == nil {
t.Errorf("Config should not be nil")
t.Errorf("配置不应为 nil")
}
// Verify default values
// 验证默认值
if config.CurrentVersion != "v1.0.0" {
t.Errorf("Expected default version v1.0.0, got %s", config.CurrentVersion)
t.Errorf("期望默认版本 v1.0.0,得到 %s", config.CurrentVersion)
}
if config.UserAgent != "LightweightUpdater/1.0" {
t.Errorf("Expected default user agent, got %s", config.UserAgent)
if config.UserAgent != "AUTO_MAA_Go_Updater/1.0" {
t.Errorf("期望默认用户代理,得到 %s", config.UserAgent)
}
// Set some values including CDK
// 设置一些值
config.ResourceID = "TEST123"
config.SetCDK("secret-cdk-key")
// Save config
// 保存配置
err = cm.Save(config)
if err != nil {
t.Errorf("Failed to save config: %v", err)
t.Errorf("保存配置失败: %v", err)
}
// Load config again
// 再次加载配置
loadedConfig, err := cm.Load()
if err != nil {
t.Errorf("Failed to load saved config: %v", err)
t.Errorf("加载已保存配置失败: %v", err)
}
// Verify values
// 验证值
if loadedConfig.ResourceID != "TEST123" {
t.Errorf("Expected ResourceID TEST123, got %s", loadedConfig.ResourceID)
}
// Verify CDK is properly encrypted/decrypted
retrievedCDK, err := loadedConfig.GetCDK()
if err != nil {
t.Errorf("Failed to get CDK from loaded config: %v", err)
}
if retrievedCDK != "secret-cdk-key" {
t.Errorf("Expected CDK secret-cdk-key, got %s", retrievedCDK)
}
// Verify CDK is encrypted in the config struct
if loadedConfig.CDK == "secret-cdk-key" {
t.Errorf("CDK should be encrypted in config file")
t.Errorf("期望 ResourceID TEST123,得到 %s", loadedConfig.ResourceID)
}
}
@@ -170,12 +62,12 @@ func TestConfigValidation(t *testing.T) {
expectError bool
}{
{
name: "Nil config",
name: "空配置",
config: nil,
expectError: true,
},
{
name: "Empty ResourceID",
name: " ResourceID",
config: &Config{
ResourceID: "",
CurrentVersion: "v1.0.0",
@@ -186,40 +78,7 @@ func TestConfigValidation(t *testing.T) {
expectError: true,
},
{
name: "Empty CurrentVersion",
config: &Config{
ResourceID: "TEST",
CurrentVersion: "",
UserAgent: "Test/1.0",
LogLevel: "info",
CheckInterval: 3600,
},
expectError: true,
},
{
name: "Invalid LogLevel",
config: &Config{
ResourceID: "TEST",
CurrentVersion: "v1.0.0",
UserAgent: "Test/1.0",
LogLevel: "invalid",
CheckInterval: 3600,
},
expectError: true,
},
{
name: "Invalid CheckInterval",
config: &Config{
ResourceID: "TEST",
CurrentVersion: "v1.0.0",
UserAgent: "Test/1.0",
LogLevel: "info",
CheckInterval: 30, // Less than 60
},
expectError: true,
},
{
name: "Valid config",
name: "有效配置",
config: &Config{
ResourceID: "TEST",
CurrentVersion: "v1.0.0",
@@ -235,112 +94,10 @@ func TestConfigValidation(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
err := validateConfig(tt.config)
if tt.expectError && err == nil {
t.Errorf("Expected error but got none")
t.Errorf("期望错误但没有得到")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error but got: %v", err)
}
})
}
}
func TestGetConfigDir(t *testing.T) {
// Save original APPDATA
originalAppData := os.Getenv("APPDATA")
defer os.Setenv("APPDATA", originalAppData)
// Test with APPDATA set
os.Setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming")
dir := getConfigDir()
expected := "C:\\Users\\Test\\AppData\\Roaming\\LightweightUpdater"
if dir != expected {
t.Errorf("Expected %s, got %s", expected, dir)
}
// Test without APPDATA
os.Unsetenv("APPDATA")
dir = getConfigDir()
if dir != "." {
t.Errorf("Expected current directory, got %s", dir)
}
}
func TestValidateAndApplyDefaults(t *testing.T) {
tests := []struct {
name string
input *Config
expected *Config
hasError bool
}{
{
name: "Apply defaults to empty config",
input: &Config{
ResourceID: "TEST",
},
expected: &Config{
ResourceID: "TEST",
CurrentVersion: "v1.0.0",
UserAgent: "LightweightUpdater/1.0",
LogLevel: "info",
CheckInterval: 3600,
},
hasError: false,
},
{
name: "Partial config with some defaults needed",
input: &Config{
ResourceID: "TEST",
CurrentVersion: "v2.0.0",
LogLevel: "debug",
},
expected: &Config{
ResourceID: "TEST",
CurrentVersion: "v2.0.0",
UserAgent: "LightweightUpdater/1.0",
LogLevel: "debug",
CheckInterval: 3600,
},
hasError: false,
},
{
name: "Config with invalid values after defaults",
input: &Config{
ResourceID: "", // Invalid - empty
CheckInterval: 30, // Invalid - too small
},
expected: nil,
hasError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateAndApplyDefaults(tt.input)
if tt.hasError {
if err == nil {
t.Errorf("Expected error but got none")
}
return
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
return
}
// Check that defaults were applied correctly
if tt.input.CurrentVersion != tt.expected.CurrentVersion {
t.Errorf("CurrentVersion: expected %s, got %s", tt.expected.CurrentVersion, tt.input.CurrentVersion)
}
if tt.input.UserAgent != tt.expected.UserAgent {
t.Errorf("UserAgent: expected %s, got %s", tt.expected.UserAgent, tt.input.UserAgent)
}
if tt.input.LogLevel != tt.expected.LogLevel {
t.Errorf("LogLevel: expected %s, got %s", tt.expected.LogLevel, tt.input.LogLevel)
}
if tt.input.CheckInterval != tt.expected.CheckInterval {
t.Errorf("CheckInterval: expected %d, got %d", tt.expected.CheckInterval, tt.input.CheckInterval)
t.Errorf("期望无错误但得到: %v", err)
}
})
}
@@ -350,123 +107,47 @@ func TestGetDefaultConfig(t *testing.T) {
config := getDefaultConfig()
if config == nil {
t.Fatal("getDefaultConfig() returned nil")
t.Fatal("getDefaultConfig() 返回 nil")
}
// Verify default values
if config.ResourceID != "PLACEHOLDER" {
t.Errorf("Expected ResourceID 'PLACEHOLDER', got %s", config.ResourceID)
// 验证默认值
if config.ResourceID != "AUTO_MAA" {
t.Errorf("期望 ResourceID 'AUTO_MAA',得到 %s", config.ResourceID)
}
if config.CurrentVersion != "v1.0.0" {
t.Errorf("Expected CurrentVersion 'v1.0.0', got %s", config.CurrentVersion)
t.Errorf("期望 CurrentVersion 'v1.0.0',得到 %s", config.CurrentVersion)
}
if config.UserAgent != "LightweightUpdater/1.0" {
t.Errorf("Expected UserAgent 'LightweightUpdater/1.0', got %s", config.UserAgent)
if config.UserAgent != "AUTO_MAA_Go_Updater/1.0" {
t.Errorf("期望 UserAgent 'AUTO_MAA_Go_Updater/1.0',得到 %s", config.UserAgent)
}
if config.LogLevel != "info" {
t.Errorf("Expected LogLevel 'info', got %s", config.LogLevel)
t.Errorf("期望 LogLevel 'info',得到 %s", config.LogLevel)
}
if config.CheckInterval != 3600 {
t.Errorf("Expected CheckInterval 3600, got %d", config.CheckInterval)
t.Errorf("期望 CheckInterval 3600,得到 %d", config.CheckInterval)
}
if !config.AutoCheck {
t.Errorf("Expected AutoCheck true, got %v", config.AutoCheck)
t.Errorf("期望 AutoCheck true,得到 %v", config.AutoCheck)
}
}
func TestConfigManagerWithCustomPath(t *testing.T) {
tempDir := t.TempDir()
customPath := filepath.Join(tempDir, "custom-config.yaml")
func TestGetConfigDir(t *testing.T) {
// 保存原始 APPDATA
originalAppData := os.Getenv("APPDATA")
defer os.Setenv("APPDATA", originalAppData)
cm := &DefaultConfigManager{
configPath: customPath,
// 测试设置了 APPDATA
os.Setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming")
dir := getConfigDir()
expected := "C:\\Users\\Test\\AppData\\Roaming\\AUTO_MAA_Go_Updater"
if dir != expected {
t.Errorf("期望 %s得到 %s", expected, dir)
}
// Test GetConfigPath
if cm.GetConfigPath() != customPath {
t.Errorf("Expected config path %s, got %s", customPath, cm.GetConfigPath())
}
// Test Save and Load with custom path
testConfig := &Config{
ResourceID: "CUSTOM",
CurrentVersion: "v1.5.0",
UserAgent: "CustomUpdater/1.0",
LogLevel: "debug",
CheckInterval: 7200,
AutoCheck: false,
}
// Save config
err := cm.Save(testConfig)
if err != nil {
t.Fatalf("Failed to save config: %v", err)
}
// Load config
loadedConfig, err := cm.Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Verify loaded config matches saved config
if loadedConfig.ResourceID != testConfig.ResourceID {
t.Errorf("ResourceID mismatch: expected %s, got %s", testConfig.ResourceID, loadedConfig.ResourceID)
}
if loadedConfig.CurrentVersion != testConfig.CurrentVersion {
t.Errorf("CurrentVersion mismatch: expected %s, got %s", testConfig.CurrentVersion, loadedConfig.CurrentVersion)
}
if loadedConfig.AutoCheck != testConfig.AutoCheck {
t.Errorf("AutoCheck mismatch: expected %v, got %v", testConfig.AutoCheck, loadedConfig.AutoCheck)
}
}
func TestConfigManagerErrorHandling(t *testing.T) {
// Test with invalid directory path
invalidPath := string([]byte{0}) + "/invalid/config.yaml"
cm := &DefaultConfigManager{
configPath: invalidPath,
}
// Load should fail with invalid path
_, err := cm.Load()
if err == nil {
t.Error("Expected error when loading from invalid path")
}
// Save should fail with invalid path
testConfig := getDefaultConfig()
testConfig.ResourceID = "TEST"
err = cm.Save(testConfig)
if err == nil {
t.Error("Expected error when saving to invalid path")
}
}
func TestEncryptDecryptEdgeCases(t *testing.T) {
tests := []struct {
name string
input string
}{
{"Unicode characters", "测试CDK密钥🔑"},
{"Very long string", strings.Repeat("A", 1000)},
{"Binary-like data", string([]byte{0, 1, 2, 3, 255, 254, 253})},
{"Only spaces", " "},
{"Newlines and tabs", "line1\nline2\tindented"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
encrypted := encryptCDK(tt.input)
decrypted, err := decryptCDK(encrypted)
if err != nil {
t.Errorf("Decryption failed: %v", err)
}
if decrypted != tt.input {
t.Errorf("Encryption/decryption mismatch: expected %q, got %q", tt.input, decrypted)
}
})
// 测试没有 APPDATA
os.Unsetenv("APPDATA")
dir = getConfigDir()
if dir != "." {
t.Errorf("期望当前目录,得到 %s", dir)
}
}

View File

@@ -13,18 +13,18 @@ import (
"time"
)
// DownloadProgress represents the current download progress
// DownloadProgress 表示当前下载进度
type DownloadProgress struct {
BytesDownloaded int64
TotalBytes int64
Percentage float64
Speed int64 // bytes per second
Speed int64 // 每秒字节数
}
// ProgressCallback is called during download to report progress
// ProgressCallback 在下载过程中调用以报告进度
type ProgressCallback func(DownloadProgress)
// DownloadManager interface defines download operations
// DownloadManager 定义下载操作的接口
type DownloadManager interface {
Download(url, destination string, progressCallback ProgressCallback) error
DownloadWithResume(url, destination string, progressCallback ProgressCallback) error
@@ -32,13 +32,13 @@ type DownloadManager interface {
SetTimeout(timeout time.Duration)
}
// Manager implements DownloadManager interface
// Manager 实现 DownloadManager 接口
type Manager struct {
client *http.Client
timeout time.Duration
}
// NewManager creates a new download manager
// NewManager 创建新的下载管理器
func NewManager() *Manager {
return &Manager{
client: &http.Client{
@@ -48,24 +48,24 @@ func NewManager() *Manager {
}
}
// Download downloads a file from the given URL to the destination path
// Download 从给定 URL 下载文件到目标路径
func (m *Manager) Download(url, destination string, progressCallback ProgressCallback) error {
return m.downloadWithContext(context.Background(), url, destination, progressCallback, false)
}
// DownloadWithResume downloads a file with resume capability
// DownloadWithResume 下载文件并支持断点续传
func (m *Manager) DownloadWithResume(url, destination string, progressCallback ProgressCallback) error {
return m.downloadWithContext(context.Background(), url, destination, progressCallback, true)
}
// downloadWithContext performs the actual download with context support
// downloadWithContext 执行实际的下载并支持上下文
func (m *Manager) downloadWithContext(ctx context.Context, url, destination string, progressCallback ProgressCallback, resume bool) error {
// Create destination directory if it doesn't exist
// 如果目标目录不存在则创建
if err := os.MkdirAll(filepath.Dir(destination), 0755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
return fmt.Errorf("创建目标目录失败: %w", err)
}
// Check if file exists for resume
// 检查文件是否存在以支持断点续传
var existingSize int64
if resume {
if stat, err := os.Stat(destination); err == nil {
@@ -73,30 +73,30 @@ func (m *Manager) downloadWithContext(ctx context.Context, url, destination stri
}
}
// Create HTTP request
// 创建 HTTP 请求
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
return fmt.Errorf("创建请求失败: %w", err)
}
// Add range header for resume
// 为断点续传添加范围头
if resume && existingSize > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existingSize))
}
// Execute request
// 执行请求
resp, err := m.client.Do(req)
if err != nil {
return fmt.Errorf("failed to execute request: %w", err)
return fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
// Check response status
// 检查响应状态
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
return fmt.Errorf("意外的状态码: %d", resp.StatusCode)
}
// Get total size
// 获取总大小
totalSize := existingSize
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil {
@@ -104,7 +104,7 @@ func (m *Manager) downloadWithContext(ctx context.Context, url, destination stri
}
}
// Open destination file
// 打开目标文件
var file *os.File
if resume && existingSize > 0 {
file, err = os.OpenFile(destination, os.O_WRONLY|os.O_APPEND, 0644)
@@ -113,17 +113,17 @@ func (m *Manager) downloadWithContext(ctx context.Context, url, destination stri
existingSize = 0
}
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
return fmt.Errorf("创建目标文件失败: %w", err)
}
defer file.Close()
// Download with progress tracking
// 下载并跟踪进度
return m.copyWithProgress(resp.Body, file, existingSize, totalSize, progressCallback)
}
// copyWithProgress copies data while tracking progress
// copyWithProgress 复制数据并跟踪进度
func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, totalBytes int64, progressCallback ProgressCallback) error {
buffer := make([]byte, 32*1024) // 32KB buffer
buffer := make([]byte, 32*1024) // 32KB 缓冲区
downloaded := startBytes
startTime := time.Now()
lastUpdate := startTime
@@ -132,11 +132,11 @@ func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, tot
n, err := src.Read(buffer)
if n > 0 {
if _, writeErr := dst.Write(buffer[:n]); writeErr != nil {
return fmt.Errorf("failed to write to destination: %w", writeErr)
return fmt.Errorf("写入目标失败: %w", writeErr)
}
downloaded += int64(n)
// Update progress every 100ms
// 每 100ms 更新一次进度
now := time.Now()
if progressCallback != nil && now.Sub(lastUpdate) >= 100*time.Millisecond {
elapsed := now.Sub(startTime).Seconds()
@@ -164,11 +164,11 @@ func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, tot
break
}
if err != nil {
return fmt.Errorf("failed to read from source: %w", err)
return fmt.Errorf("从源读取失败: %w", err)
}
}
// Final progress update
// 最终进度更新
if progressCallback != nil {
elapsed := time.Since(startTime).Seconds()
speed := int64(0)
@@ -192,32 +192,32 @@ func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, tot
return nil
}
// ValidateChecksum validates the SHA256 checksum of a file
// ValidateChecksum 验证文件的 SHA256 校验和
func (m *Manager) ValidateChecksum(filePath, expectedChecksum string) error {
if expectedChecksum == "" {
return nil // No checksum to validate
return nil // 没有校验和需要验证
}
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file for checksum validation: %w", err)
return fmt.Errorf("打开文件进行校验和验证失败: %w", err)
}
defer file.Close()
hash := sha256.New()
if _, err := io.Copy(hash, file); err != nil {
return fmt.Errorf("failed to calculate checksum: %w", err)
return fmt.Errorf("计算校验和失败: %w", err)
}
actualChecksum := hex.EncodeToString(hash.Sum(nil))
if actualChecksum != expectedChecksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
return fmt.Errorf("校验和不匹配: 期望 %s得到 %s", expectedChecksum, actualChecksum)
}
return nil
}
// SetTimeout sets the timeout for download operations
// SetTimeout 设置下载操作的超时时间
func (m *Manager) SetTimeout(timeout time.Duration) {
m.timeout = timeout
m.client.Timeout = timeout

View File

@@ -1,4 +1,4 @@
module lightweight-updater
module AUTO_MAA_Go_Updater
go 1.24.5

View File

@@ -2,15 +2,15 @@ package gui
import (
"fmt"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/widget"
"fyne.io/fyne/v2/theme"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/widget"
)
// UpdateStatus represents the current status of the update process
// UpdateStatus 表示更新过程的当前状态
type UpdateStatus int
const (
@@ -22,16 +22,15 @@ const (
StatusError
)
// Config represents the configuration structure for the GUI
// Config 表示 GUI 的配置结构
type Config struct {
ResourceID string
CurrentVersion string
CDK string
UserAgent string
BackupURL string
}
// GUIManager interface defines the methods for GUI management
// GUIManager 定义 GUI 管理的接口方法
type GUIManager interface {
ShowMainWindow()
UpdateStatus(status UpdateStatus, message string)
@@ -41,7 +40,7 @@ type GUIManager interface {
Close()
}
// Manager implements the GUIManager interface
// Manager 实现 GUIManager 接口
type Manager struct {
app fyne.App
window fyne.Window
@@ -60,7 +59,7 @@ func NewManager() *Manager {
a := app.New()
a.SetIcon(theme.ComputerIcon())
w := a.NewWindow("轻量级更新器")
w := a.NewWindow("AUTO_MAA_Go_Updater")
w.Resize(fyne.NewSize(500, 400))
w.SetFixedSize(false)
w.CenterOnScreen()
@@ -117,11 +116,11 @@ func (m *Manager) createUIComponents() {
}
// createMainLayout creates the main window layout
func (m *Manager) createMainLayout() *container.VBox {
func (m *Manager) createMainLayout() *fyne.Container {
// Header section
header := container.NewVBox(
widget.NewCard("", "", container.NewVBox(
widget.NewLabelWithStyle("轻量级更新器", fyne.TextAlignCenter, fyne.TextStyle{Bold: true}),
widget.NewLabelWithStyle("AUTO_MAA_Go_Updater", fyne.TextAlignCenter, fyne.TextStyle{Bold: true}),
m.versionLabel,
)),
)
@@ -226,11 +225,8 @@ func (m *Manager) showConfigDialog() (*Config, error) {
versionEntry := widget.NewEntry()
versionEntry.SetPlaceHolder("例如: v1.0.0")
cdkEntry := widget.NewPasswordEntry()
cdkEntry.SetPlaceHolder("输入您的CDK可选")
userAgentEntry := widget.NewEntry()
userAgentEntry.SetText("LightweightUpdater/1.0")
userAgentEntry.SetText("AUTO_MAA_Go_Updater/1.0")
backupURLEntry := widget.NewEntry()
backupURLEntry.SetPlaceHolder("备用下载地址(可选)")
@@ -240,7 +236,6 @@ func (m *Manager) showConfigDialog() (*Config, error) {
Items: []*widget.FormItem{
{Text: "资源ID:", Widget: resourceIDEntry},
{Text: "当前版本:", Widget: versionEntry},
{Text: "CDK:", Widget: cdkEntry},
{Text: "用户代理:", Widget: userAgentEntry},
{Text: "备用下载地址:", Widget: backupURLEntry},
},
@@ -261,7 +256,6 @@ func (m *Manager) showConfigDialog() (*Config, error) {
config := &Config{
ResourceID: resourceIDEntry.Text,
CurrentVersion: versionEntry.Text,
CDK: cdkEntry.Text,
UserAgent: userAgentEntry.Text,
BackupURL: backupURLEntry.Text,
}
@@ -289,11 +283,8 @@ func (m *Manager) showConfigDialog() (*Config, error) {
**配置说明:**
- **资源ID**: Mirror酱服务中的资源标识符
- **当前版本**: 当前软件的版本号
- **CDK**: Mirror酱服务的访问密钥可选提供更好的下载体验
- **用户代理**: HTTP请求的用户代理字符串
- **备用下载地址**: 当Mirror酱服务不可用时的备用下载地址
如需获取CDK请访问 [Mirror酱官网](https://mirrorchyan.com)
`)
// Create container with help text

View File

@@ -11,14 +11,14 @@ import (
"syscall"
)
// ChangesInfo represents the structure of changes.json file
// ChangesInfo 表示 changes.json 文件的结构
type ChangesInfo struct {
Deleted []string `json:"deleted"`
Added []string `json:"added"`
Modified []string `json:"modified"`
}
// InstallManager interface defines the contract for installation operations
// InstallManager 定义安装操作的接口契约
type InstallManager interface {
ExtractZip(zipPath, destPath string) error
ProcessChanges(changesPath string) (*ChangesInfo, error)
@@ -28,31 +28,31 @@ type InstallManager interface {
CleanupTempDir(tempDir string) error
}
// Manager implements the InstallManager interface
// Manager 实现 InstallManager 接口
type Manager struct {
tempDirs []string // Track temporary directories for cleanup
tempDirs []string // 跟踪临时目录以便清理
}
// NewManager creates a new install manager instance
// NewManager 创建新的安装管理器实例
func NewManager() *Manager {
return &Manager{
tempDirs: make([]string, 0),
}
}
// CreateTempDir creates a temporary directory for extraction
// CreateTempDir 为解压创建临时目录
func (m *Manager) CreateTempDir() (string, error) {
tempDir, err := os.MkdirTemp("", "updater_*")
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
return "", fmt.Errorf("创建临时目录失败: %w", err)
}
// Track temp directory for cleanup
// 跟踪临时目录以便清理
m.tempDirs = append(m.tempDirs, tempDir)
return tempDir, nil
}
// CleanupTempDir removes a temporary directory and its contents
// CleanupTempDir 删除临时目录及其内容
func (m *Manager) CleanupTempDir(tempDir string) error {
if tempDir == "" {
return nil
@@ -60,10 +60,10 @@ func (m *Manager) CleanupTempDir(tempDir string) error {
err := os.RemoveAll(tempDir)
if err != nil {
return fmt.Errorf("failed to cleanup temp directory %s: %w", tempDir, err)
return fmt.Errorf("清理临时目录 %s 失败: %w", tempDir, err)
}
// Remove from tracking list
// 从跟踪列表中删除
for i, dir := range m.tempDirs {
if dir == tempDir {
m.tempDirs = append(m.tempDirs[:i], m.tempDirs[i+1:]...)
@@ -74,98 +74,98 @@ func (m *Manager) CleanupTempDir(tempDir string) error {
return nil
}
// CleanupAllTempDirs removes all tracked temporary directories
// CleanupAllTempDirs 删除所有跟踪的临时目录
func (m *Manager) CleanupAllTempDirs() error {
var errors []string
for _, tempDir := range m.tempDirs {
if err := os.RemoveAll(tempDir); err != nil {
errors = append(errors, fmt.Sprintf("failed to cleanup %s: %v", tempDir, err))
errors = append(errors, fmt.Sprintf("清理 %s 失败: %v", tempDir, err))
}
}
m.tempDirs = m.tempDirs[:0] // Clear the slice
m.tempDirs = m.tempDirs[:0] // 清空切片
if len(errors) > 0 {
return fmt.Errorf("cleanup errors: %s", strings.Join(errors, "; "))
return fmt.Errorf("清理错误: %s", strings.Join(errors, "; "))
}
return nil
}
// ExtractZip extracts a ZIP file to the specified destination directory
// ExtractZip 将 ZIP 文件解压到指定的目标目录
func (m *Manager) ExtractZip(zipPath, destPath string) error {
// Open ZIP file for reading
// 打开 ZIP 文件进行读取
reader, err := zip.OpenReader(zipPath)
if err != nil {
return fmt.Errorf("failed to open ZIP file %s: %w", zipPath, err)
return fmt.Errorf("打开 ZIP 文件 %s 失败: %w", zipPath, err)
}
defer reader.Close()
// Create destination directory if it doesn't exist
// 如果目标目录不存在则创建
if err := os.MkdirAll(destPath, 0755); err != nil {
return fmt.Errorf("failed to create destination directory %s: %w", destPath, err)
return fmt.Errorf("创建目标目录 %s 失败: %w", destPath, err)
}
// Extract files
// 解压文件
for _, file := range reader.File {
if err := m.extractFile(file, destPath); err != nil {
return fmt.Errorf("failed to extract file %s: %w", file.Name, err)
return fmt.Errorf("解压文件 %s 失败: %w", file.Name, err)
}
}
return nil
}
// extractFile extracts a single file from the ZIP archive
// extractFile 从 ZIP 归档中解压单个文件
func (m *Manager) extractFile(file *zip.File, destPath string) error {
// Clean the file path to prevent directory traversal attacks
// 清理文件路径以防止目录遍历攻击
cleanPath := filepath.Clean(file.Name)
if strings.Contains(cleanPath, "..") {
return fmt.Errorf("invalid file path: %s", file.Name)
return fmt.Errorf("无效的文件路径: %s", file.Name)
}
// Create full destination path
// 创建完整的目标路径
destFile := filepath.Join(destPath, cleanPath)
// Create directory structure if needed
// 如果需要则创建目录结构
if file.FileInfo().IsDir() {
return os.MkdirAll(destFile, file.FileInfo().Mode())
}
// Create parent directories
// 创建父目录
if err := os.MkdirAll(filepath.Dir(destFile), 0755); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
return fmt.Errorf("创建父目录失败: %w", err)
}
// Open file in ZIP archive
// 打开 ZIP 归档中的文件
rc, err := file.Open()
if err != nil {
return fmt.Errorf("failed to open file in archive: %w", err)
return fmt.Errorf("打开归档中的文件失败: %w", err)
}
defer rc.Close()
// Create destination file
// 创建目标文件
outFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.FileInfo().Mode())
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
return fmt.Errorf("创建目标文件失败: %w", err)
}
defer outFile.Close()
// Copy file contents
// 复制文件内容
_, err = io.Copy(outFile, rc)
if err != nil {
return fmt.Errorf("failed to copy file contents: %w", err)
return fmt.Errorf("复制文件内容失败: %w", err)
}
return nil
}
// ProcessChanges reads and parses the changes.json file
// ProcessChanges 读取并解析 changes.json 文件
func (m *Manager) ProcessChanges(changesPath string) (*ChangesInfo, error) {
// Check if changes.json exists
// 检查 changes.json 是否存在
if _, err := os.Stat(changesPath); os.IsNotExist(err) {
// If changes.json doesn't exist, return empty changes info
// 如果 changes.json 不存在,返回空的变更信息
return &ChangesInfo{
Deleted: []string{},
Added: []string{},
@@ -173,72 +173,72 @@ func (m *Manager) ProcessChanges(changesPath string) (*ChangesInfo, error) {
}, nil
}
// Read the changes.json file
// 读取 changes.json 文件
data, err := os.ReadFile(changesPath)
if err != nil {
return nil, fmt.Errorf("failed to read changes file %s: %w", changesPath, err)
return nil, fmt.Errorf("读取变更文件 %s 失败: %w", changesPath, err)
}
// Parse JSON
// 解析 JSON
var changes ChangesInfo
if err := json.Unmarshal(data, &changes); err != nil {
return nil, fmt.Errorf("failed to parse changes JSON: %w", err)
return nil, fmt.Errorf("解析变更 JSON 失败: %w", err)
}
return &changes, nil
}
// HandleRunningProcess handles running processes by renaming files that are in use
// HandleRunningProcess 通过重命名正在使用的文件来处理正在运行的进程
func (m *Manager) HandleRunningProcess(processName string) error {
// Get the current executable path
// 获取当前可执行文件路径
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
return fmt.Errorf("获取可执行文件路径失败: %w", err)
}
exeDir := filepath.Dir(exePath)
targetFile := filepath.Join(exeDir, processName)
// Check if the target file exists
// 检查目标文件是否存在
if _, err := os.Stat(targetFile); os.IsNotExist(err) {
// File doesn't exist, nothing to handle
// 文件不存在,无需处理
return nil
}
// Try to rename the file to indicate it should be deleted on next startup
// 尝试重命名文件以指示应在下次启动时删除
oldFile := targetFile + ".old"
// Remove existing .old file if it exists
// 如果存在现有的 .old 文件则删除
if _, err := os.Stat(oldFile); err == nil {
if err := os.Remove(oldFile); err != nil {
return fmt.Errorf("failed to remove existing old file %s: %w", oldFile, err)
return fmt.Errorf("删除现有旧文件 %s 失败: %w", oldFile, err)
}
}
// Rename the current file to .old
// 将当前文件重命名为 .old
if err := os.Rename(targetFile, oldFile); err != nil {
// If rename fails, the process might be running
// On Windows, we can't rename a running executable
// 如果重命名失败,进程可能正在运行
// Windows 上,我们无法重命名正在运行的可执行文件
if isFileInUse(err) {
// Mark the file for deletion on next reboot (Windows specific)
// 标记文件在下次重启时删除Windows 特定)
return m.markFileForDeletion(targetFile)
}
return fmt.Errorf("failed to rename running process file %s: %w", targetFile, err)
return fmt.Errorf("重命名正在运行的进程文件 %s 失败: %w", targetFile, err)
}
return nil
}
// isFileInUse checks if the error indicates the file is in use
// isFileInUse 检查错误是否表示文件正在使用中
func isFileInUse(err error) bool {
if err == nil {
return false
}
// Check for Windows-specific "file in use" errors
// 检查 Windows 特定的"文件正在使用"错误
if pathErr, ok := err.(*os.PathError); ok {
if errno, ok := pathErr.Err.(syscall.Errno); ok {
// ERROR_SHARING_VIOLATION (32) or ERROR_ACCESS_DENIED (5)
// ERROR_SHARING_VIOLATION (32) ERROR_ACCESS_DENIED (5)
return errno == syscall.Errno(32) || errno == syscall.Errno(5)
}
}
@@ -247,226 +247,226 @@ func isFileInUse(err error) bool {
strings.Contains(err.Error(), "access is denied")
}
// markFileForDeletion marks a file for deletion on next system reboot (Windows specific)
// markFileForDeletion 标记文件在下次系统重启时删除Windows 特定)
func (m *Manager) markFileForDeletion(filePath string) error {
// This is a Windows-specific implementation
// For now, we'll create a marker file that can be handled by the main application
// 这是 Windows 特定的实现
// 目前,我们将创建一个可由主应用程序处理的标记文件
markerFile := filePath + ".delete_on_restart"
// Create a marker file
// 创建标记文件
file, err := os.Create(markerFile)
if err != nil {
return fmt.Errorf("failed to create deletion marker file: %w", err)
return fmt.Errorf("创建删除标记文件失败: %w", err)
}
defer file.Close()
// Write the target file path to the marker
// 将目标文件路径写入标记文件
_, err = file.WriteString(filePath)
if err != nil {
return fmt.Errorf("failed to write to marker file: %w", err)
return fmt.Errorf("写入标记文件失败: %w", err)
}
return nil
}
// DeleteMarkedFiles removes files that were marked for deletion
// DeleteMarkedFiles 删除标记为删除的文件
func (m *Manager) DeleteMarkedFiles(directory string) error {
// Find all .delete_on_restart files
// 查找所有 .delete_on_restart 文件
pattern := filepath.Join(directory, "*.delete_on_restart")
matches, err := filepath.Glob(pattern)
if err != nil {
return fmt.Errorf("failed to find marker files: %w", err)
return fmt.Errorf("查找标记文件失败: %w", err)
}
var errors []string
for _, markerFile := range matches {
// Read the target file path
// 读取目标文件路径
data, err := os.ReadFile(markerFile)
if err != nil {
errors = append(errors, fmt.Sprintf("failed to read marker file %s: %v", markerFile, err))
errors = append(errors, fmt.Sprintf("读取标记文件 %s 失败: %v", markerFile, err))
continue
}
targetFile := strings.TrimSpace(string(data))
// Try to delete the target file
// 尝试删除目标文件
if err := os.Remove(targetFile); err != nil && !os.IsNotExist(err) {
errors = append(errors, fmt.Sprintf("failed to delete marked file %s: %v", targetFile, err))
errors = append(errors, fmt.Sprintf("删除标记文件 %s 失败: %v", targetFile, err))
}
// Remove the marker file
// 删除标记文件
if err := os.Remove(markerFile); err != nil {
errors = append(errors, fmt.Sprintf("failed to remove marker file %s: %v", markerFile, err))
errors = append(errors, fmt.Sprintf("删除标记文件 %s 失败: %v", markerFile, err))
}
}
if len(errors) > 0 {
return fmt.Errorf("deletion errors: %s", strings.Join(errors, "; "))
return fmt.Errorf("删除错误: %s", strings.Join(errors, "; "))
}
return nil
}
// ApplyUpdate applies the update by copying files from source to target directory
// ApplyUpdate 通过从源目录复制文件到目标目录来应用更新
func (m *Manager) ApplyUpdate(sourcePath, targetPath string, changes *ChangesInfo) error {
// Create backup directory
// 创建备份目录
backupDir, err := m.createBackupDir(targetPath)
if err != nil {
return fmt.Errorf("failed to create backup directory: %w", err)
return fmt.Errorf("创建备份目录失败: %w", err)
}
// Backup existing files before applying update
// 在应用更新前备份现有文件
if err := m.backupFiles(targetPath, backupDir, changes); err != nil {
return fmt.Errorf("failed to backup files: %w", err)
return fmt.Errorf("备份文件失败: %w", err)
}
// Apply the update
// 应用更新
if err := m.applyUpdateFiles(sourcePath, targetPath, changes); err != nil {
// Rollback on failure
// 失败时回滚
if rollbackErr := m.rollbackUpdate(targetPath, backupDir); rollbackErr != nil {
return fmt.Errorf("update failed and rollback failed: update error: %w, rollback error: %v", err, rollbackErr)
return fmt.Errorf("更新失败且回滚失败: 更新错误: %w, 回滚错误: %v", err, rollbackErr)
}
return fmt.Errorf("update failed and was rolled back: %w", err)
return fmt.Errorf("更新失败已回滚: %w", err)
}
// Clean up backup directory after successful update
// 成功更新后清理备份目录
if err := os.RemoveAll(backupDir); err != nil {
// Log warning but don't fail the update
fmt.Printf("Warning: failed to cleanup backup directory %s: %v\n", backupDir, err)
// 记录警告但不让更新失败
fmt.Printf("警告: 清理备份目录 %s 失败: %v\n", backupDir, err)
}
return nil
}
// createBackupDir creates a backup directory for the update
// createBackupDir 为更新创建备份目录
func (m *Manager) createBackupDir(targetPath string) (string, error) {
backupDir := filepath.Join(targetPath, ".backup_"+fmt.Sprintf("%d", os.Getpid()))
if err := os.MkdirAll(backupDir, 0755); err != nil {
return "", fmt.Errorf("failed to create backup directory: %w", err)
return "", fmt.Errorf("创建备份目录失败: %w", err)
}
return backupDir, nil
}
// backupFiles creates backups of files that will be modified or deleted
// backupFiles 创建将被修改或删除的文件的备份
func (m *Manager) backupFiles(targetPath, backupDir string, changes *ChangesInfo) error {
// Backup files that will be modified
// 备份将被修改的文件
for _, file := range changes.Modified {
srcFile := filepath.Join(targetPath, file)
if _, err := os.Stat(srcFile); os.IsNotExist(err) {
continue // File doesn't exist, skip backup
continue // 文件不存在,跳过备份
}
backupFile := filepath.Join(backupDir, file)
if err := m.copyFileWithDirs(srcFile, backupFile); err != nil {
return fmt.Errorf("failed to backup modified file %s: %w", file, err)
return fmt.Errorf("备份修改文件 %s 失败: %w", file, err)
}
}
// Backup files that will be deleted
// 备份将被删除的文件
for _, file := range changes.Deleted {
srcFile := filepath.Join(targetPath, file)
if _, err := os.Stat(srcFile); os.IsNotExist(err) {
continue // File doesn't exist, skip backup
continue // 文件不存在,跳过备份
}
backupFile := filepath.Join(backupDir, file)
if err := m.copyFileWithDirs(srcFile, backupFile); err != nil {
return fmt.Errorf("failed to backup deleted file %s: %w", file, err)
return fmt.Errorf("备份删除文件 %s 失败: %w", file, err)
}
}
return nil
}
// applyUpdateFiles applies the actual file changes
// applyUpdateFiles 应用实际的文件更改
func (m *Manager) applyUpdateFiles(sourcePath, targetPath string, changes *ChangesInfo) error {
// Delete files marked for deletion
// 删除标记为删除的文件
for _, file := range changes.Deleted {
targetFile := filepath.Join(targetPath, file)
if err := os.Remove(targetFile); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to delete file %s: %w", file, err)
return fmt.Errorf("删除文件 %s 失败: %w", file, err)
}
}
// Copy new and modified files
// 复制新文件和修改的文件
filesToCopy := append(changes.Added, changes.Modified...)
for _, file := range filesToCopy {
srcFile := filepath.Join(sourcePath, file)
targetFile := filepath.Join(targetPath, file)
// Check if source file exists
// 检查源文件是否存在
if _, err := os.Stat(srcFile); os.IsNotExist(err) {
continue // Source file doesn't exist, skip
continue // 源文件不存在,跳过
}
if err := m.copyFileWithDirs(srcFile, targetFile); err != nil {
return fmt.Errorf("failed to copy file %s: %w", file, err)
return fmt.Errorf("复制文件 %s 失败: %w", file, err)
}
}
return nil
}
// copyFileWithDirs copies a file and creates necessary directories
// copyFileWithDirs 复制文件并创建必要的目录
func (m *Manager) copyFileWithDirs(src, dst string) error {
// Create parent directories
// 创建父目录
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
return fmt.Errorf("failed to create parent directories: %w", err)
return fmt.Errorf("创建父目录失败: %w", err)
}
// Open source file
// 打开源文件
srcFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source file: %w", err)
return fmt.Errorf("打开源文件失败: %w", err)
}
defer srcFile.Close()
// Get source file info
// 获取源文件信息
srcInfo, err := srcFile.Stat()
if err != nil {
return fmt.Errorf("failed to get source file info: %w", err)
return fmt.Errorf("获取源文件信息失败: %w", err)
}
// Create destination file
// 创建目标文件
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode())
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
return fmt.Errorf("创建目标文件失败: %w", err)
}
defer dstFile.Close()
// Copy file contents
// 复制文件内容
_, err = io.Copy(dstFile, srcFile)
if err != nil {
return fmt.Errorf("failed to copy file contents: %w", err)
return fmt.Errorf("复制文件内容失败: %w", err)
}
return nil
}
// rollbackUpdate restores files from backup in case of update failure
// rollbackUpdate 在更新失败时从备份恢复文件
func (m *Manager) rollbackUpdate(targetPath, backupDir string) error {
// Walk through backup directory and restore files
// 遍历备份目录并恢复文件
return filepath.Walk(backupDir, func(backupFile string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil // Skip directories
return nil // 跳过目录
}
// Calculate relative path
// 计算相对路径
relPath, err := filepath.Rel(backupDir, backupFile)
if err != nil {
return fmt.Errorf("failed to calculate relative path: %w", err)
return fmt.Errorf("计算相对路径失败: %w", err)
}
// Restore file to target location
// 将文件恢复到目标位置
targetFile := filepath.Join(targetPath, relPath)
if err := m.copyFileWithDirs(backupFile, targetFile); err != nil {
return fmt.Errorf("failed to restore file %s: %w", relPath, err)
return fmt.Errorf("恢复文件 %s 失败: %w", relPath, err)
}
return nil

View File

@@ -4,9 +4,9 @@ import (
"testing"
)
// Integration tests will be implemented here
// This file is currently a placeholder
// 集成测试将在此处实现
// 此文件目前是占位符
func TestIntegrationPlaceholder(t *testing.T) {
t.Skip("Integration tests not yet implemented")
t.Skip("集成测试尚未实现")
}

View File

@@ -64,8 +64,8 @@ type LoggerConfig struct {
Level LogLevel
MaxSize int64 // 最大文件大小字节默认10MB
MaxBackups int // 最大备份文件数默认5
LogDir string // 日志目录,默认%APPDATA%/LightweightUpdater/logs
Filename string // 日志文件名默认updater.log
LogDir string // 日志目录
Filename string // 日志文件名
}
// DefaultLoggerConfig 默认日志配置

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +0,0 @@
package utils
// Package utils provides utility functions for the updater

View File

@@ -8,16 +8,16 @@ import (
"strconv"
"strings"
"lightweight-updater/logger"
"AUTO_MAA_Go_Updater/logger"
)
// VersionInfo represents the version information from version.json
// VersionInfo 表示来自 version.json 的版本信息
type VersionInfo struct {
MainVersion string `json:"main_version"`
VersionInfo map[string]map[string][]string `json:"version_info"`
}
// ParsedVersion represents a parsed version with major, minor, patch, and beta components
// ParsedVersion 表示解析后的版本,包含主版本号、次版本号、补丁版本号和测试版本号组件
type ParsedVersion struct {
Major int
Minor int
@@ -25,13 +25,13 @@ type ParsedVersion struct {
Beta int
}
// VersionManager handles version-related operations
// VersionManager 处理版本相关操作
type VersionManager struct {
executableDir string
logger logger.Logger
}
// NewVersionManager creates a new version manager
// NewVersionManager 创建新的版本管理器
func NewVersionManager() *VersionManager {
execPath, _ := os.Executable()
execDir := filepath.Dir(execPath)
@@ -41,103 +41,93 @@ func NewVersionManager() *VersionManager {
}
}
// NewVersionManagerWithLogger creates a new version manager with a custom logger
func NewVersionManagerWithLogger(customLogger logger.Logger) *VersionManager {
execPath, _ := os.Executable()
execDir := filepath.Dir(execPath)
return &VersionManager{
executableDir: execDir,
logger: customLogger,
}
}
// createDefaultVersion creates a default version structure with v0.0.0
// createDefaultVersion 创建默认版本结构 v0.0.0
func (vm *VersionManager) createDefaultVersion() *VersionInfo {
return &VersionInfo{
MainVersion: "0.0.0.0", // Corresponds to v0.0.0
MainVersion: "0.0.0.0", // 对应 v0.0.0
VersionInfo: make(map[string]map[string][]string),
}
}
// LoadVersionFromFile loads version information from resources/version.json with fallback handling
// LoadVersionFromFile resources/version.json 加载版本信息并处理回退
func (vm *VersionManager) LoadVersionFromFile() (*VersionInfo, error) {
versionPath := filepath.Join(vm.executableDir, "resources", "version.json")
data, err := os.ReadFile(versionPath)
if err != nil {
if os.IsNotExist(err) {
vm.logger.Info("Version file not found at %s, will use default version", versionPath)
fmt.Println("未读取到版本信息,使用默认版本进行更新。")
return vm.createDefaultVersion(), nil
}
vm.logger.Warn("Failed to read version file at %s: %v, will use default version", versionPath, err)
vm.logger.Warn("读取版本文件 %s 失败: %v将使用默认版本", versionPath, err)
return vm.createDefaultVersion(), nil
}
var versionInfo VersionInfo
if err := json.Unmarshal(data, &versionInfo); err != nil {
vm.logger.Warn("Failed to parse version file at %s: %v, will use default version", versionPath, err)
vm.logger.Warn("解析版本文件 %s 失败: %v将使用默认版本", versionPath, err)
return vm.createDefaultVersion(), nil
}
vm.logger.Debug("Successfully loaded version information from %s", versionPath)
vm.logger.Debug("成功从 %s 加载版本信息", versionPath)
return &versionInfo, nil
}
// LoadVersionWithDefault loads version information with guaranteed fallback to default
// LoadVersionWithDefault 加载版本信息并保证回退到默认版本
func (vm *VersionManager) LoadVersionWithDefault() *VersionInfo {
versionInfo, err := vm.LoadVersionFromFile()
if err != nil {
// This should not happen with the updated LoadVersionFromFile, but adding as extra safety
vm.logger.Error("Unexpected error loading version file: %v, using default version", err)
// 这在更新的 LoadVersionFromFile 中不应该发生,但添加作为额外安全措施
vm.logger.Error("加载版本文件时出现意外错误: %v使用默认版本", err)
return vm.createDefaultVersion()
}
// Validate that we have a valid version structure
// 验证我们有一个有效的版本结构
if versionInfo == nil {
vm.logger.Warn("Version info is nil, using default version")
vm.logger.Warn("版本信息为空,使用默认版本")
return vm.createDefaultVersion()
}
if versionInfo.MainVersion == "" {
vm.logger.Warn("Version info has empty main version, using default version")
vm.logger.Warn("版本信息主版本为空,使用默认版本")
return vm.createDefaultVersion()
}
if versionInfo.VersionInfo == nil {
vm.logger.Debug("Version info map is nil, initializing empty map")
vm.logger.Debug("版本信息映射为空,初始化空映射")
versionInfo.VersionInfo = make(map[string]map[string][]string)
}
return versionInfo
}
// ParseVersion parses a version string like "4.4.1.3" into components
// ParseVersion 解析版本字符串如 "4.4.1.3" 为组件
func ParseVersion(versionStr string) (*ParsedVersion, error) {
parts := strings.Split(versionStr, ".")
if len(parts) < 3 || len(parts) > 4 {
return nil, fmt.Errorf("invalid version format: %s", versionStr)
return nil, fmt.Errorf("无效的版本格式: %s", versionStr)
}
major, err := strconv.Atoi(parts[0])
if err != nil {
return nil, fmt.Errorf("invalid major version: %s", parts[0])
return nil, fmt.Errorf("无效的主版本号: %s", parts[0])
}
minor, err := strconv.Atoi(parts[1])
if err != nil {
return nil, fmt.Errorf("invalid minor version: %s", parts[1])
return nil, fmt.Errorf("无效的次版本号: %s", parts[1])
}
patch, err := strconv.Atoi(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid patch version: %s", parts[2])
return nil, fmt.Errorf("无效的补丁版本号: %s", parts[2])
}
beta := 0
if len(parts) == 4 {
beta, err = strconv.Atoi(parts[3])
if err != nil {
return nil, fmt.Errorf("invalid beta version: %s", parts[3])
return nil, fmt.Errorf("无效的测试版本号: %s", parts[3])
}
}
@@ -149,7 +139,7 @@ func ParseVersion(versionStr string) (*ParsedVersion, error) {
}, nil
}
// ToVersionString converts a ParsedVersion back to version string format
// ToVersionString ParsedVersion 转换回版本字符串格式
func (pv *ParsedVersion) ToVersionString() string {
if pv.Beta == 0 {
return fmt.Sprintf("%d.%d.%d.0", pv.Major, pv.Minor, pv.Patch)
@@ -157,7 +147,7 @@ func (pv *ParsedVersion) ToVersionString() string {
return fmt.Sprintf("%d.%d.%d.%d", pv.Major, pv.Minor, pv.Patch, pv.Beta)
}
// ToDisplayVersion converts version to display format (v4.4.0 or v4.4.1-beta3)
// ToDisplayVersion 将版本转换为显示格式 (v4.4.0 v4.4.1-beta3)
func (pv *ParsedVersion) ToDisplayVersion() string {
if pv.Beta == 0 {
return fmt.Sprintf("v%d.%d.%d", pv.Major, pv.Minor, pv.Patch)
@@ -165,7 +155,7 @@ func (pv *ParsedVersion) ToDisplayVersion() string {
return fmt.Sprintf("v%d.%d.%d-beta%d", pv.Major, pv.Minor, pv.Patch, pv.Beta)
}
// GetChannel returns the channel (stable or beta) based on version
// GetChannel 根据版本返回渠道 (stable beta)
func (pv *ParsedVersion) GetChannel() string {
if pv.Beta == 0 {
return "stable"
@@ -173,12 +163,7 @@ func (pv *ParsedVersion) GetChannel() string {
return "beta"
}
// GetDefaultChannel returns the default channel
func GetDefaultChannel() string {
return "stable"
}
// IsNewer checks if this version is newer than the other version
// IsNewer 检查此版本是否比其他版本更新
func (pv *ParsedVersion) IsNewer(other *ParsedVersion) bool {
if pv.Major != other.Major {
return pv.Major > other.Major

View File

@@ -1,41 +1,19 @@
package version
import (
"fmt"
"runtime"
)
var (
// Version is the current version of the application
// Version 应用程序的当前版本
Version = "1.0.0"
// BuildTime is set during build time
// BuildTime 在构建时设置
BuildTime = "unknown"
// GitCommit is set during build time
// GitCommit 在构建时设置
GitCommit = "unknown"
// GoVersion is the Go version used to build
// GoVersion 用于构建的 Go 版本
GoVersion = runtime.Version()
)
// GetVersionInfo returns formatted version information
func GetVersionInfo() string {
return fmt.Sprintf("Version: %s\nBuild Time: %s\nGit Commit: %s\nGo Version: %s",
Version, BuildTime, GitCommit, GoVersion)
}
// GetShortVersion returns just the version number
func GetShortVersion() string {
return Version
}
// GetBuildInfo returns build-specific information
func GetBuildInfo() map[string]string {
return map[string]string{
"version": Version,
"build_time": BuildTime,
"git_commit": GitCommit,
"go_version": GoVersion,
}
}