refactor(updater): 重构 Go 版本更新器
- 更新项目名称为 AUTO_MAA_Go_Updater - 重构代码结构,优化函数命名和逻辑 - 移除 CDK 相关的冗余代码 - 调整版本号为 git commit hash - 更新构建配置和脚本 - 优化 API 客户端实现
This commit is contained in:
@@ -9,7 +9,7 @@ BUILD_DIR := build
|
||||
DIST_DIR := dist
|
||||
|
||||
# Go build flags
|
||||
LDFLAGS := -s -w -X lightweight-updater/version.Version=$(VERSION) -X lightweight-updater/version.BuildTime=$(BUILD_TIME) -X lightweight-updater/version.GitCommit=$(GIT_COMMIT)
|
||||
LDFLAGS := -s -w -X AUTO_MAA_Go_Updater/version.Version=$(VERSION) -X AUTO_MAA_Go_Updater/version.BuildTime=$(BUILD_TIME) -X AUTO_MAA_Go_Updater/version.GitCommit=$(GIT_COMMIT)
|
||||
|
||||
# Default target
|
||||
.PHONY: all
|
||||
|
||||
@@ -10,204 +10,140 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// MirrorResponse represents the response from MirrorChyan API
|
||||
// MirrorResponse 表示 MirrorChyan API 的响应结构
|
||||
type MirrorResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data struct {
|
||||
VersionName string `json:"version_name"`
|
||||
VersionNumber int `json:"version_number"`
|
||||
URL string `json:"url,omitempty"` // Only present when using CDK
|
||||
SHA256 string `json:"sha256,omitempty"` // Only present when using CDK
|
||||
URL string `json:"url,omitempty"`
|
||||
SHA256 string `json:"sha256,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
UpdateType string `json:"update_type,omitempty"` // Only present when using CDK
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"` // Only present when using CDK
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"` // Only present when using CDK
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// UpdateCheckParams represents parameters for update checking
|
||||
// UpdateCheckParams 表示更新检查的参数
|
||||
type UpdateCheckParams struct {
|
||||
ResourceID string
|
||||
CurrentVersion string
|
||||
Channel string
|
||||
CDK string
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
// MirrorClient interface defines the methods for Mirror API client
|
||||
// MirrorClient 定义 Mirror API 客户端的接口方法
|
||||
type MirrorClient interface {
|
||||
CheckUpdate(params UpdateCheckParams) (*MirrorResponse, error)
|
||||
CheckUpdateLegacy(resourceID, currentVersion, cdk, userAgent string) (*MirrorResponse, error)
|
||||
IsUpdateAvailable(response *MirrorResponse, currentVersion string) bool
|
||||
GetOfficialDownloadURL(versionName string) string
|
||||
GetDownloadURL(versionName string) string
|
||||
}
|
||||
|
||||
// Client implements MirrorClient interface
|
||||
// Client 实现 MirrorClient 接口
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
baseURL string
|
||||
downloadURL string
|
||||
}
|
||||
|
||||
// NewClient creates a new Mirror API client
|
||||
// NewClient 创建新的 Mirror API 客户端
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
baseURL: "https://mirrorchyan.com/api/resources",
|
||||
downloadURL: "http://221.236.27.82:10197/d/AUTO_MAA",
|
||||
}
|
||||
}
|
||||
|
||||
// CheckUpdate calls MirrorChyan API to check for updates with new parameter structure
|
||||
// CheckUpdate 调用 MirrorChyan API 检查更新
|
||||
func (c *Client) CheckUpdate(params UpdateCheckParams) (*MirrorResponse, error) {
|
||||
// Construct the API URL
|
||||
// 构建 API URL
|
||||
apiURL := fmt.Sprintf("%s/%s/latest", c.baseURL, params.ResourceID)
|
||||
|
||||
// Parse URL to add query parameters
|
||||
// 解析 URL 并添加查询参数
|
||||
u, err := url.Parse(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse API URL: %w", err)
|
||||
return nil, fmt.Errorf("解析 API URL 失败: %w", err)
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
// 添加查询参数
|
||||
q := u.Query()
|
||||
q.Set("current_version", params.CurrentVersion)
|
||||
q.Set("channel", params.Channel)
|
||||
q.Set("os", "") // Empty for cross-platform
|
||||
q.Set("arch", "") // Empty for cross-platform
|
||||
|
||||
if params.CDK != "" {
|
||||
q.Set("cdk", params.CDK)
|
||||
}
|
||||
q.Set("os", "") // 跨平台为空
|
||||
q.Set("arch", "") // 跨平台为空
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
// Create HTTP request
|
||||
// 创建 HTTP 请求
|
||||
req, err := http.NewRequest("GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
|
||||
}
|
||||
|
||||
// Set User-Agent header
|
||||
// 设置 User-Agent 头
|
||||
if params.UserAgent != "" {
|
||||
req.Header.Set("User-Agent", params.UserAgent)
|
||||
} else {
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/138.0.0.0 Safari/537.36")
|
||||
}
|
||||
|
||||
// Make HTTP request
|
||||
// 发送 HTTP 请求
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make HTTP request: %w", err)
|
||||
return nil, fmt.Errorf("发送 HTTP 请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check HTTP status code
|
||||
// 检查 HTTP 状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API returned non-200 status code: %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("API 返回非 200 状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read response body
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
return nil, fmt.Errorf("读取响应体失败: %w", err)
|
||||
}
|
||||
|
||||
// Parse JSON response
|
||||
// 解析 JSON 响应
|
||||
var mirrorResp MirrorResponse
|
||||
if err := json.Unmarshal(body, &mirrorResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
return nil, fmt.Errorf("解析 JSON 响应失败: %w", err)
|
||||
}
|
||||
|
||||
return &mirrorResp, nil
|
||||
}
|
||||
|
||||
// CheckUpdateLegacy calls Mirror API to check for updates (legacy method for backward compatibility)
|
||||
func (c *Client) CheckUpdateLegacy(resourceID, currentVersion, cdk, userAgent string) (*MirrorResponse, error) {
|
||||
// Construct the API URL
|
||||
apiURL := fmt.Sprintf("%s/%s/latest", c.baseURL, resourceID)
|
||||
|
||||
// Parse URL to add query parameters
|
||||
u, err := url.Parse(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse API URL: %w", err)
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
q := u.Query()
|
||||
q.Set("current_version", currentVersion)
|
||||
if cdk != "" {
|
||||
q.Set("cdk", cdk)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("GET", u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
// Set User-Agent header
|
||||
if userAgent != "" {
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
} else {
|
||||
req.Header.Set("User-Agent", "LightweightUpdater/1.0")
|
||||
}
|
||||
|
||||
// Make HTTP request
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make HTTP request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check HTTP status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API returned non-200 status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
// Parse JSON response
|
||||
var mirrorResp MirrorResponse
|
||||
if err := json.Unmarshal(body, &mirrorResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return &mirrorResp, nil
|
||||
}
|
||||
|
||||
// IsUpdateAvailable compares current version with the latest version from API response
|
||||
// IsUpdateAvailable 比较当前版本与 API 响应中的最新版本
|
||||
func (c *Client) IsUpdateAvailable(response *MirrorResponse, currentVersion string) bool {
|
||||
// Check if API response is successful
|
||||
// 检查 API 响应是否成功
|
||||
if response.Code != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get latest version from response
|
||||
// 从响应中获取最新版本
|
||||
latestVersion := response.Data.VersionName
|
||||
if latestVersion == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Convert version formats for comparison
|
||||
// 转换版本格式以便比较
|
||||
currentVersionNormalized := c.normalizeVersionForComparison(currentVersion)
|
||||
latestVersionNormalized := c.normalizeVersionForComparison(latestVersion)
|
||||
|
||||
// Compare versions using semantic version comparison
|
||||
// 使用语义版本比较
|
||||
return compareVersions(currentVersionNormalized, latestVersionNormalized) < 0
|
||||
}
|
||||
|
||||
// normalizeVersionForComparison converts different version formats to comparable format
|
||||
// normalizeVersionForComparison 将不同版本格式转换为可比较格式
|
||||
func (c *Client) normalizeVersionForComparison(version string) string {
|
||||
// Handle AUTO_MAA version format: "4.4.1.3" -> "v4.4.1-beta3"
|
||||
// 处理 AUTO_MAA 版本格式: "4.4.1.3" -> "v4.4.1-beta3"
|
||||
if !strings.HasPrefix(version, "v") && strings.Count(version, ".") == 3 {
|
||||
parts := strings.Split(version, ".")
|
||||
if len(parts) == 4 {
|
||||
@@ -220,22 +156,22 @@ func (c *Client) normalizeVersionForComparison(version string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// Return as-is if already in standard format
|
||||
// 如果已经是标准格式则直接返回
|
||||
return version
|
||||
}
|
||||
|
||||
// compareVersions compares two semantic version strings
|
||||
// Returns: -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2
|
||||
// compareVersions 比较两个语义版本字符串
|
||||
// 返回值: -1 如果 v1 < v2, 0 如果 v1 == v2, 1 如果 v1 > v2
|
||||
func compareVersions(v1, v2 string) int {
|
||||
// Normalize versions by removing 'v' prefix if present
|
||||
// 通过移除 'v' 前缀来标准化版本
|
||||
v1 = normalizeVersion(v1)
|
||||
v2 = normalizeVersion(v2)
|
||||
|
||||
// Parse version components
|
||||
// 解析版本组件
|
||||
parts1 := parseVersionParts(v1)
|
||||
parts2 := parseVersionParts(v2)
|
||||
|
||||
// Compare each component
|
||||
// 比较每个组件
|
||||
maxLen := len(parts1)
|
||||
if len(parts2) > maxLen {
|
||||
maxLen = len(parts2)
|
||||
@@ -260,7 +196,7 @@ func compareVersions(v1, v2 string) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// normalizeVersion removes 'v' prefix and handles common version formats
|
||||
// normalizeVersion 移除 'v' 前缀并处理常见版本格式
|
||||
func normalizeVersion(version string) string {
|
||||
if len(version) > 0 && (version[0] == 'v' || version[0] == 'V') {
|
||||
return version[1:]
|
||||
@@ -268,7 +204,7 @@ func normalizeVersion(version string) string {
|
||||
return version
|
||||
}
|
||||
|
||||
// parseVersionParts parses version string into numeric components
|
||||
// parseVersionParts 将版本字符串解析为数字组件
|
||||
func parseVersionParts(version string) []int {
|
||||
if version == "" {
|
||||
return []int{0}
|
||||
@@ -284,15 +220,15 @@ func parseVersionParts(version string) []int {
|
||||
parts = append(parts, current)
|
||||
current = 0
|
||||
} else {
|
||||
// Stop parsing at non-numeric, non-dot characters (like pre-release identifiers)
|
||||
// 在非数字、非点字符处停止解析(如预发布标识符)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Add the last component
|
||||
// 添加最后一个组件
|
||||
parts = append(parts, current)
|
||||
|
||||
// Ensure at least 3 components (major.minor.patch)
|
||||
// 确保至少有 3 个组件 (major.minor.patch)
|
||||
for len(parts) < 3 {
|
||||
parts = append(parts, 0)
|
||||
}
|
||||
@@ -300,33 +236,17 @@ func parseVersionParts(version string) []int {
|
||||
return parts
|
||||
}
|
||||
|
||||
// GetOfficialDownloadURL generates the official download URL based on version name
|
||||
func (c *Client) GetOfficialDownloadURL(versionName string) string {
|
||||
// Official download site base URL
|
||||
baseURL := "http://221.236.27.82:10197/d/AUTO_MAA"
|
||||
|
||||
// Convert version name to filename format
|
||||
// e.g., "v4.4.0" -> "AUTO_MAA_v4.4.0.zip"
|
||||
// e.g., "v4.4.1-beta3" -> "AUTO_MAA_v4.4.1-beta.3.zip"
|
||||
// GetDownloadURL 根据版本名生成下载站的下载 URL
|
||||
func (c *Client) GetDownloadURL(versionName string) string {
|
||||
// 将版本名转换为文件名格式
|
||||
// 例如: "v4.4.0" -> "AUTO_MAA_v4.4.0.zip"
|
||||
// 例如: "v4.4.1-beta3" -> "AUTO_MAA_v4.4.1-beta.3.zip"
|
||||
filename := fmt.Sprintf("AUTO_MAA_%s.zip", versionName)
|
||||
|
||||
// Handle beta versions: convert "beta3" to "beta.3"
|
||||
// 处理 beta 版本: 将 "beta3" 转换为 "beta.3"
|
||||
if strings.Contains(filename, "-beta") && !strings.Contains(filename, "-beta.") {
|
||||
filename = strings.Replace(filename, "-beta", "-beta.", 1)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s", baseURL, filename)
|
||||
}
|
||||
|
||||
// HasCDKDownloadURL checks if the response contains a CDK download URL
|
||||
func (c *Client) HasCDKDownloadURL(response *MirrorResponse) bool {
|
||||
return response != nil && response.Data.URL != ""
|
||||
}
|
||||
|
||||
// GetDownloadURL returns the appropriate download URL based on available options
|
||||
func (c *Client) GetDownloadURL(response *MirrorResponse) string {
|
||||
if c.HasCDKDownloadURL(response) {
|
||||
return response.Data.URL
|
||||
}
|
||||
return c.GetOfficialDownloadURL(response.Data.VersionName)
|
||||
return fmt.Sprintf("%s/%s", c.downloadURL, filename)
|
||||
}
|
||||
|
||||
@@ -10,17 +10,20 @@ import (
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient()
|
||||
if client == nil {
|
||||
t.Fatal("NewClient() returned nil")
|
||||
t.Fatal("NewClient() 返回 nil")
|
||||
}
|
||||
if client.httpClient == nil {
|
||||
t.Fatal("HTTP client is nil")
|
||||
t.Fatal("HTTP 客户端为 nil")
|
||||
}
|
||||
if client.baseURL != "https://mirrorchyan.com/api/resources" {
|
||||
t.Errorf("Expected base URL 'https://mirrorchyan.com/api/resources', got '%s'", client.baseURL)
|
||||
t.Errorf("期望基础 URL 'https://mirrorchyan.com/api/resources',得到 '%s'", client.baseURL)
|
||||
}
|
||||
if client.downloadURL != "http://221.236.27.82:10197/d/AUTO_MAA" {
|
||||
t.Errorf("期望下载 URL 'http://221.236.27.82:10197/d/AUTO_MAA',得到 '%s'", client.downloadURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOfficialDownloadURL(t *testing.T) {
|
||||
func TestGetDownloadURL(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
@@ -30,51 +33,19 @@ func TestGetOfficialDownloadURL(t *testing.T) {
|
||||
{"v4.4.0", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.0.zip"},
|
||||
{"v4.4.1-beta3", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.1-beta.3.zip"},
|
||||
{"v1.2.3", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v1.2.3.zip"},
|
||||
{"v1.2.3-beta1", "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v1.2.3-beta.1.zip"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := client.GetOfficialDownloadURL(test.versionName)
|
||||
result := client.GetDownloadURL(test.versionName)
|
||||
if result != test.expected {
|
||||
t.Errorf("For version %s, expected %s, got %s", test.versionName, test.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeVersionForComparison(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"4.4.0.0", "v4.4.0"},
|
||||
{"4.4.1.3", "v4.4.1-beta3"},
|
||||
{"v4.4.0", "v4.4.0"},
|
||||
{"v4.4.1-beta3", "v4.4.1-beta3"},
|
||||
{"1.2.3", "1.2.3"}, // Not 4-part version, return as-is
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := client.normalizeVersionForComparison(test.input)
|
||||
if result != test.expected {
|
||||
t.Errorf("For input %s, expected %s, got %s", test.input, test.expected, result)
|
||||
t.Errorf("版本 %s,期望 %s,得到 %s", test.versionName, test.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUpdate(t *testing.T) {
|
||||
// Create test server
|
||||
// 创建测试服务器
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request parameters
|
||||
if r.URL.Query().Get("current_version") != "4.4.0.0" {
|
||||
t.Errorf("Expected current_version=4.4.0.0, got %s", r.URL.Query().Get("current_version"))
|
||||
}
|
||||
if r.URL.Query().Get("channel") != "stable" {
|
||||
t.Errorf("Expected channel=stable, got %s", r.URL.Query().Get("channel"))
|
||||
}
|
||||
|
||||
// Return mock response
|
||||
response := MirrorResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
@@ -89,125 +60,47 @@ func TestCheckUpdate(t *testing.T) {
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{
|
||||
VersionName: "v4.4.1",
|
||||
VersionNumber: 48,
|
||||
Channel: "stable",
|
||||
OS: "",
|
||||
Arch: "",
|
||||
ReleaseNote: "Test release notes",
|
||||
ReleaseNote: "测试发布说明",
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
err := json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create client with test server URL
|
||||
// 使用测试服务器 URL 创建客户端
|
||||
client := &Client{
|
||||
httpClient: &http.Client{},
|
||||
baseURL: server.URL,
|
||||
downloadURL: "http://221.236.27.82:10197/d/AUTO_MAA",
|
||||
}
|
||||
|
||||
// Test update check
|
||||
// 测试更新检查
|
||||
params := UpdateCheckParams{
|
||||
ResourceID: "AUTO_MAA",
|
||||
CurrentVersion: "4.4.0.0",
|
||||
Channel: "stable",
|
||||
CDK: "",
|
||||
UserAgent: "TestAgent/1.0",
|
||||
}
|
||||
|
||||
response, err := client.CheckUpdate(params)
|
||||
if err != nil {
|
||||
t.Fatalf("CheckUpdate failed: %v", err)
|
||||
t.Fatalf("CheckUpdate 失败: %v", err)
|
||||
}
|
||||
|
||||
if response.Code != 0 {
|
||||
t.Errorf("Expected code 0, got %d", response.Code)
|
||||
t.Errorf("期望代码 0,得到 %d", response.Code)
|
||||
}
|
||||
if response.Data.VersionName != "v4.4.1" {
|
||||
t.Errorf("Expected version v4.4.1, got %s", response.Data.VersionName)
|
||||
}
|
||||
if response.Data.Channel != "stable" {
|
||||
t.Errorf("Expected channel stable, got %s", response.Data.Channel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckUpdateWithCDK(t *testing.T) {
|
||||
// Create test server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify CDK parameter
|
||||
if r.URL.Query().Get("cdk") != "test_cdk_123" {
|
||||
t.Errorf("Expected cdk=test_cdk_123, got %s", r.URL.Query().Get("cdk"))
|
||||
}
|
||||
|
||||
// Return mock response with CDK download URL
|
||||
response := MirrorResponse{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
Data: struct {
|
||||
VersionName string `json:"version_name"`
|
||||
VersionNumber int `json:"version_number"`
|
||||
URL string `json:"url,omitempty"`
|
||||
SHA256 string `json:"sha256,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{
|
||||
VersionName: "v4.4.1",
|
||||
VersionNumber: 48,
|
||||
URL: "https://mirrorchyan.com/api/resources/download/test123",
|
||||
SHA256: "abcd1234",
|
||||
Channel: "stable",
|
||||
OS: "",
|
||||
Arch: "",
|
||||
UpdateType: "full",
|
||||
ReleaseNote: "Test release notes",
|
||||
FileSize: 12345678,
|
||||
CDKExpiredTime: 1776013593,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create client with test server URL
|
||||
client := &Client{
|
||||
httpClient: &http.Client{},
|
||||
baseURL: server.URL,
|
||||
}
|
||||
|
||||
// Test update check with CDK
|
||||
params := UpdateCheckParams{
|
||||
ResourceID: "AUTO_MAA",
|
||||
CurrentVersion: "4.4.0.0",
|
||||
Channel: "stable",
|
||||
CDK: "test_cdk_123",
|
||||
UserAgent: "TestAgent/1.0",
|
||||
}
|
||||
|
||||
response, err := client.CheckUpdate(params)
|
||||
if err != nil {
|
||||
t.Fatalf("CheckUpdate with CDK failed: %v", err)
|
||||
}
|
||||
|
||||
if response.Data.URL == "" {
|
||||
t.Error("Expected CDK download URL, but got empty")
|
||||
}
|
||||
if response.Data.SHA256 == "" {
|
||||
t.Error("Expected SHA256 hash, but got empty")
|
||||
}
|
||||
if response.Data.FileSize == 0 {
|
||||
t.Error("Expected file size, but got 0")
|
||||
t.Errorf("期望版本 v4.4.1,得到 %s", response.Data.VersionName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,7 +114,7 @@ func TestIsUpdateAvailable(t *testing.T) {
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Update available - stable",
|
||||
name: "有可用更新",
|
||||
response: &MirrorResponse{
|
||||
Code: 0,
|
||||
Data: struct {
|
||||
@@ -235,14 +128,13 @@ func TestIsUpdateAvailable(t *testing.T) {
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{VersionName: "v4.4.1"},
|
||||
},
|
||||
currentVersion: "4.4.0.0",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No update available - same version",
|
||||
name: "无可用更新",
|
||||
response: &MirrorResponse{
|
||||
Code: 0,
|
||||
Data: struct {
|
||||
@@ -256,167 +148,18 @@ func TestIsUpdateAvailable(t *testing.T) {
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{VersionName: "v4.4.0"},
|
||||
},
|
||||
currentVersion: "4.4.0.0",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "API error",
|
||||
response: &MirrorResponse{
|
||||
Code: 1,
|
||||
Data: struct {
|
||||
VersionName string `json:"version_name"`
|
||||
VersionNumber int `json:"version_number"`
|
||||
URL string `json:"url,omitempty"`
|
||||
SHA256 string `json:"sha256,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{VersionName: "v4.4.1"},
|
||||
},
|
||||
currentVersion: "4.4.0.0",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := client.IsUpdateAvailable(test.response, test.currentVersion)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %t, got %t", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasCDKDownloadURL(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response *MirrorResponse
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Has CDK URL",
|
||||
response: &MirrorResponse{
|
||||
Data: struct {
|
||||
VersionName string `json:"version_name"`
|
||||
VersionNumber int `json:"version_number"`
|
||||
URL string `json:"url,omitempty"`
|
||||
SHA256 string `json:"sha256,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{URL: "https://mirrorchyan.com/download/test"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No CDK URL",
|
||||
response: &MirrorResponse{
|
||||
Data: struct {
|
||||
VersionName string `json:"version_name"`
|
||||
VersionNumber int `json:"version_number"`
|
||||
URL string `json:"url,omitempty"`
|
||||
SHA256 string `json:"sha256,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{URL: ""},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil response",
|
||||
response: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := client.HasCDKDownloadURL(test.response)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %t, got %t", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDownloadURL(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response *MirrorResponse
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "CDK URL available",
|
||||
response: &MirrorResponse{
|
||||
Data: struct {
|
||||
VersionName string `json:"version_name"`
|
||||
VersionNumber int `json:"version_number"`
|
||||
URL string `json:"url,omitempty"`
|
||||
SHA256 string `json:"sha256,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{
|
||||
VersionName: "v4.4.1",
|
||||
URL: "https://mirrorchyan.com/download/test",
|
||||
},
|
||||
},
|
||||
expected: "https://mirrorchyan.com/download/test",
|
||||
},
|
||||
{
|
||||
name: "Official URL fallback",
|
||||
response: &MirrorResponse{
|
||||
Data: struct {
|
||||
VersionName string `json:"version_name"`
|
||||
VersionNumber int `json:"version_number"`
|
||||
URL string `json:"url,omitempty"`
|
||||
SHA256 string `json:"sha256,omitempty"`
|
||||
Channel string `json:"channel"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
UpdateType string `json:"update_type,omitempty"`
|
||||
ReleaseNote string `json:"release_note"`
|
||||
FileSize int64 `json:"filesize,omitempty"`
|
||||
CDKExpiredTime int64 `json:"cdk_expired_time,omitempty"`
|
||||
}{
|
||||
VersionName: "v4.4.1",
|
||||
URL: "",
|
||||
},
|
||||
},
|
||||
expected: "http://221.236.27.82:10197/d/AUTO_MAA/AUTO_MAA_v4.4.1.zip",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := client.GetDownloadURL(test.response)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %s, got %s", test.expected, result)
|
||||
t.Errorf("期望 %t,得到 %t", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,17 +8,17 @@ import (
|
||||
//go:embed config_template.yaml
|
||||
var EmbeddedAssets embed.FS
|
||||
|
||||
// GetConfigTemplate returns the embedded config template
|
||||
// GetConfigTemplate 返回嵌入的配置模板
|
||||
func GetConfigTemplate() ([]byte, error) {
|
||||
return EmbeddedAssets.ReadFile("config_template.yaml")
|
||||
}
|
||||
|
||||
// GetAssetFS returns the embedded filesystem
|
||||
// GetAssetFS 返回嵌入的文件系统
|
||||
func GetAssetFS() fs.FS {
|
||||
return EmbeddedAssets
|
||||
}
|
||||
|
||||
// ListAssets returns a list of all embedded assets
|
||||
// ListAssets 返回所有嵌入资源的列表
|
||||
func ListAssets() ([]string, error) {
|
||||
var assets []string
|
||||
err := fs.WalkDir(EmbeddedAssets, ".", func(path string, d fs.DirEntry, err error) error {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
resource_id: "AUTO_MAA"
|
||||
current_version: "v1.0.0"
|
||||
cdk: "" # Will be encrypted when saved
|
||||
user_agent: "AUTO_MAA_Go_Updater/1.0"
|
||||
backup_url: "https://backup-download-site.com/releases"
|
||||
log_level: "info"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Build Configuration for Lightweight Updater
|
||||
# Build Configuration for AUTO_MAA_Go_Updater
|
||||
|
||||
project:
|
||||
name: "Lightweight Updater"
|
||||
module: "lightweight-updater"
|
||||
description: "轻量级自动更新器"
|
||||
name: "AUTO_MAA_Go_Updater"
|
||||
module: "AUTO_MAA_Go_Updater"
|
||||
description: "AUTO_MAA_Go版本更新器"
|
||||
|
||||
version:
|
||||
default: "1.0.0"
|
||||
@@ -14,7 +14,7 @@ targets:
|
||||
goos: "windows"
|
||||
goarch: "amd64"
|
||||
cgo_enabled: true
|
||||
output: "lightweight-updater.exe"
|
||||
output: "AUTO_MAA_Go_Updater.exe"
|
||||
|
||||
build:
|
||||
flags:
|
||||
@@ -40,7 +40,7 @@ directories:
|
||||
temp: "temp"
|
||||
|
||||
version_injection:
|
||||
package: "lightweight-updater/version"
|
||||
package: "AUTO_MAA_Go_Updater/version"
|
||||
variables:
|
||||
- name: "Version"
|
||||
source: "version"
|
||||
|
||||
@@ -6,14 +6,13 @@ echo AUTO_MAA_Go_Updater Build Script
|
||||
echo ========================================
|
||||
|
||||
:: Set build variables
|
||||
set VERSION=1.0.0
|
||||
set OUTPUT_NAME=AUTO_MAA_Go_Updater.exe
|
||||
set BUILD_DIR=build
|
||||
set DIST_DIR=dist
|
||||
|
||||
:: Get current timestamp
|
||||
:: Get current datetime for build time
|
||||
for /f "tokens=2 delims==" %%a in ('wmic OS Get localdatetime /value') do set "dt=%%a"
|
||||
set "YY=%dt:~2,2%" & set "YYYY=%dt:~0,4%" & set "MM=%dt:~4,2%" & set "DD=%dt:~6,2%"
|
||||
set "YYYY=%dt:~0,4%" & set "MM=%dt:~4,2%" & set "DD=%dt:~6,2%"
|
||||
set "HH=%dt:~8,2%" & set "Min=%dt:~10,2%" & set "Sec=%dt:~12,2%"
|
||||
set "BUILD_TIME=%YYYY%-%MM%-%DD%T%HH%:%Min%:%Sec%Z"
|
||||
|
||||
@@ -26,6 +25,9 @@ if exist temp_commit.txt (
|
||||
set GIT_COMMIT=unknown
|
||||
)
|
||||
|
||||
:: Use commit hash as version
|
||||
set VERSION=%GIT_COMMIT%
|
||||
|
||||
echo Build Information:
|
||||
echo - Version: %VERSION%
|
||||
echo - Build Time: %BUILD_TIME%
|
||||
@@ -38,7 +40,7 @@ if not exist %BUILD_DIR% mkdir %BUILD_DIR%
|
||||
if not exist %DIST_DIR% mkdir %DIST_DIR%
|
||||
|
||||
:: Set build flags
|
||||
set LDFLAGS=-s -w -X lightweight-updater/version.Version=%VERSION% -X lightweight-updater/version.BuildTime=%BUILD_TIME% -X lightweight-updater/version.GitCommit=%GIT_COMMIT%
|
||||
set LDFLAGS=-s -w -X AUTO_MAA_Go_Updater/version.Version=%VERSION% -X AUTO_MAA_Go_Updater/version.BuildTime=%BUILD_TIME% -X AUTO_MAA_Go_Updater/version.GitCommit=%GIT_COMMIT%
|
||||
|
||||
echo Building application...
|
||||
|
||||
@@ -58,6 +60,7 @@ if not exist app.syso (
|
||||
)
|
||||
)
|
||||
|
||||
:: Set environment variables for Go build
|
||||
set GOOS=windows
|
||||
set GOARCH=amd64
|
||||
set CGO_ENABLED=1
|
||||
@@ -74,8 +77,6 @@ echo Build completed successfully!
|
||||
|
||||
:: Get file size
|
||||
for %%A in (%BUILD_DIR%\%OUTPUT_NAME%) do set FILE_SIZE=%%~zA
|
||||
|
||||
:: Convert bytes to MB
|
||||
set /a FILE_SIZE_MB=%FILE_SIZE%/1024/1024
|
||||
|
||||
echo.
|
||||
@@ -83,13 +84,6 @@ echo Build Results:
|
||||
echo - Output: %BUILD_DIR%\%OUTPUT_NAME%
|
||||
echo - Size: %FILE_SIZE% bytes (~%FILE_SIZE_MB% MB)
|
||||
|
||||
:: Check if file size is within requirements (<10MB)
|
||||
if %FILE_SIZE_MB% gtr 10 (
|
||||
echo WARNING: File size exceeds 10MB requirement!
|
||||
) else (
|
||||
echo File size meets requirements (^<10MB)
|
||||
)
|
||||
|
||||
:: Copy to dist directory
|
||||
copy %BUILD_DIR%\%OUTPUT_NAME% %DIST_DIR%\%OUTPUT_NAME% >nul
|
||||
echo - Copied to: %DIST_DIR%\%OUTPUT_NAME%
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# Lightweight Updater Build Script (PowerShell)
|
||||
# AUTO_MAA_Go_Updater Build Script (PowerShell)
|
||||
param(
|
||||
[string]$Version = "1.0.0",
|
||||
[string]$OutputName = "AUTO_MAA_Go_Updater.exe",
|
||||
[switch]$Compress = $false
|
||||
)
|
||||
@@ -14,6 +13,7 @@ $BuildDir = "build"
|
||||
$DistDir = "dist"
|
||||
$BuildTime = (Get-Date).ToString("yyyy-MM-ddTHH:mm:ssZ")
|
||||
|
||||
|
||||
# Get git commit hash
|
||||
try {
|
||||
$GitCommit = (git rev-parse --short HEAD 2>$null).Trim()
|
||||
@@ -23,7 +23,7 @@ try {
|
||||
}
|
||||
|
||||
Write-Host "Build Information:" -ForegroundColor Yellow
|
||||
Write-Host "- Version: $Version"
|
||||
Write-Host "- Version: $GitCommit"
|
||||
Write-Host "- Build Time: $BuildTime"
|
||||
Write-Host "- Git Commit: $GitCommit"
|
||||
Write-Host "- Target: Windows 64-bit"
|
||||
@@ -39,7 +39,7 @@ $env:GOARCH = "amd64"
|
||||
$env:CGO_ENABLED = "1"
|
||||
|
||||
# Set build flags
|
||||
$LdFlags = "-s -w -X lightweight-updater/version.Version=$Version -X lightweight-updater/version.BuildTime=$BuildTime -X lightweight-updater/version.GitCommit=$GitCommit"
|
||||
$LdFlags = "-s -w -X AUTO_MAA_Go_Updater/version.Version=$Version -X AUTO_MAA_Go_Updater/version.BuildTime=$BuildTime -X AUTO_MAA_Go_Updater/version.GitCommit=$GitCommit"
|
||||
|
||||
Write-Host "Building application..." -ForegroundColor Green
|
||||
|
||||
@@ -78,12 +78,6 @@ Write-Host "Build Results:" -ForegroundColor Yellow
|
||||
Write-Host "- Output: $($OutputFile.FullName)"
|
||||
Write-Host "- Size: $($OutputFile.Length) bytes (~$FileSizeMB MB)"
|
||||
|
||||
# Check file size requirement
|
||||
if ($FileSizeMB -gt 10) {
|
||||
Write-Host "WARNING: File size exceeds 10MB requirement!" -ForegroundColor Red
|
||||
} else {
|
||||
Write-Host "File size meets requirements (<10MB)" -ForegroundColor Green
|
||||
}
|
||||
|
||||
# Optional UPX compression
|
||||
if ($Compress) {
|
||||
|
||||
@@ -1,40 +1,38 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"AUTO_MAA_Go_Updater/assets"
|
||||
"gopkg.in/yaml.v3"
|
||||
"lightweight-updater/assets"
|
||||
)
|
||||
|
||||
// Config represents the application configuration
|
||||
// Config 表示应用程序配置
|
||||
type Config struct {
|
||||
ResourceID string `yaml:"resource_id"`
|
||||
CurrentVersion string `yaml:"current_version"`
|
||||
CDK string `yaml:"cdk,omitempty"`
|
||||
UserAgent string `yaml:"user_agent"`
|
||||
BackupURL string `yaml:"backup_url"`
|
||||
LogLevel string `yaml:"log_level"`
|
||||
AutoCheck bool `yaml:"auto_check"`
|
||||
CheckInterval int `yaml:"check_interval"` // seconds
|
||||
CheckInterval int `yaml:"check_interval"` // 秒
|
||||
}
|
||||
|
||||
// ConfigManager interface defines methods for configuration management
|
||||
// ConfigManager 定义配置管理的接口方法
|
||||
type ConfigManager interface {
|
||||
Load() (*Config, error)
|
||||
Save(config *Config) error
|
||||
GetConfigPath() string
|
||||
}
|
||||
|
||||
// DefaultConfigManager implements ConfigManager interface
|
||||
// DefaultConfigManager 实现 ConfigManager 接口
|
||||
type DefaultConfigManager struct {
|
||||
configPath string
|
||||
}
|
||||
|
||||
// NewConfigManager creates a new configuration manager
|
||||
// NewConfigManager 创建新的配置管理器
|
||||
func NewConfigManager() ConfigManager {
|
||||
configDir := getConfigDir()
|
||||
configPath := filepath.Join(configDir, "config.yaml")
|
||||
@@ -43,77 +41,77 @@ func NewConfigManager() ConfigManager {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfigPath returns the path to the configuration file
|
||||
// GetConfigPath 返回配置文件的路径
|
||||
func (cm *DefaultConfigManager) GetConfigPath() string {
|
||||
return cm.configPath
|
||||
}
|
||||
|
||||
// Load reads and parses the configuration file
|
||||
// Load 读取并解析配置文件
|
||||
func (cm *DefaultConfigManager) Load() (*Config, error) {
|
||||
// Create config directory if it doesn't exist
|
||||
// 如果配置目录不存在则创建
|
||||
configDir := filepath.Dir(cm.configPath)
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create config directory: %w", err)
|
||||
return nil, fmt.Errorf("创建配置目录失败: %w", err)
|
||||
}
|
||||
|
||||
// If config file doesn't exist, create default config
|
||||
// 如果配置文件不存在,创建默认配置
|
||||
if _, err := os.Stat(cm.configPath); os.IsNotExist(err) {
|
||||
defaultConfig := getDefaultConfig()
|
||||
if err := cm.Save(defaultConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default config: %w", err)
|
||||
return nil, fmt.Errorf("创建默认配置失败: %w", err)
|
||||
}
|
||||
return defaultConfig, nil
|
||||
}
|
||||
|
||||
// Read existing config file
|
||||
// 读取现有配置文件
|
||||
data, err := os.ReadFile(cm.configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// Validate and apply defaults for missing fields
|
||||
// 验证并应用缺失字段的默认值
|
||||
if err := validateAndApplyDefaults(&config); err != nil {
|
||||
return nil, fmt.Errorf("config validation failed: %w", err)
|
||||
return nil, fmt.Errorf("配置验证失败: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// Save writes the configuration to file
|
||||
// Save 将配置写入文件
|
||||
func (cm *DefaultConfigManager) Save(config *Config) error {
|
||||
// Validate config before saving
|
||||
// 保存前验证配置
|
||||
if err := validateConfig(config); err != nil {
|
||||
return fmt.Errorf("config validation failed: %w", err)
|
||||
return fmt.Errorf("配置验证失败: %w", err)
|
||||
}
|
||||
|
||||
// Create config directory if it doesn't exist
|
||||
// 如果配置目录不存在则创建
|
||||
configDir := filepath.Dir(cm.configPath)
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
return fmt.Errorf("创建配置目录失败: %w", err)
|
||||
}
|
||||
|
||||
// Marshal config to YAML
|
||||
// 将配置序列化为 YAML
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
return fmt.Errorf("序列化配置失败: %w", err)
|
||||
}
|
||||
|
||||
// Write to file
|
||||
// 写入文件
|
||||
if err := os.WriteFile(cm.configPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
return fmt.Errorf("写入配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDefaultConfig returns a configuration with default values
|
||||
// getDefaultConfig 返回带有默认值的配置
|
||||
func getDefaultConfig() *Config {
|
||||
// Try to load from embedded template first
|
||||
// 首先尝试从嵌入模板加载
|
||||
if templateData, err := assets.GetConfigTemplate(); err == nil {
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(templateData, &config); err == nil {
|
||||
@@ -121,35 +119,34 @@ func getDefaultConfig() *Config {
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to hardcoded defaults if template loading fails
|
||||
// 如果模板加载失败则回退到硬编码默认值
|
||||
return &Config{
|
||||
ResourceID: "M9A", // Default resource ID
|
||||
ResourceID: "M9A", // 默认资源 ID
|
||||
CurrentVersion: "v1.0.0",
|
||||
CDK: "",
|
||||
UserAgent: "LightweightUpdater/1.0",
|
||||
UserAgent: "AUTO_MAA_Go_Updater/1.0",
|
||||
BackupURL: "",
|
||||
LogLevel: "info",
|
||||
AutoCheck: true,
|
||||
CheckInterval: 3600, // 1 hour
|
||||
CheckInterval: 3600, // 1 小时
|
||||
}
|
||||
}
|
||||
|
||||
// validateConfig validates the configuration values
|
||||
// validateConfig 验证配置值
|
||||
func validateConfig(config *Config) error {
|
||||
if config == nil {
|
||||
return fmt.Errorf("config cannot be nil")
|
||||
return fmt.Errorf("配置不能为空")
|
||||
}
|
||||
|
||||
if config.ResourceID == "" {
|
||||
return fmt.Errorf("resource_id cannot be empty")
|
||||
return fmt.Errorf("resource_id 不能为空")
|
||||
}
|
||||
|
||||
if config.CurrentVersion == "" {
|
||||
return fmt.Errorf("current_version cannot be empty")
|
||||
return fmt.Errorf("current_version 不能为空")
|
||||
}
|
||||
|
||||
if config.UserAgent == "" {
|
||||
return fmt.Errorf("user_agent cannot be empty")
|
||||
return fmt.Errorf("user_agent 不能为空")
|
||||
}
|
||||
|
||||
validLogLevels := map[string]bool{
|
||||
@@ -159,21 +156,21 @@ func validateConfig(config *Config) error {
|
||||
"error": true,
|
||||
}
|
||||
if !validLogLevels[config.LogLevel] {
|
||||
return fmt.Errorf("invalid log_level: %s (must be debug, info, warn, or error)", config.LogLevel)
|
||||
return fmt.Errorf("无效的 log_level: %s (必须是 debug, info, warn 或 error)", config.LogLevel)
|
||||
}
|
||||
|
||||
if config.CheckInterval < 60 {
|
||||
return fmt.Errorf("check_interval must be at least 60 seconds")
|
||||
return fmt.Errorf("check_interval 必须至少为 60 秒")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateAndApplyDefaults validates config and applies defaults for missing fields
|
||||
// validateAndApplyDefaults 验证配置并为缺失字段应用默认值
|
||||
func validateAndApplyDefaults(config *Config) error {
|
||||
defaults := getDefaultConfig()
|
||||
|
||||
// Apply defaults for empty fields
|
||||
// 为空字段应用默认值
|
||||
if config.UserAgent == "" {
|
||||
config.UserAgent = defaults.UserAgent
|
||||
}
|
||||
@@ -187,62 +184,15 @@ func validateAndApplyDefaults(config *Config) error {
|
||||
config.CurrentVersion = defaults.CurrentVersion
|
||||
}
|
||||
|
||||
// Validate after applying defaults
|
||||
// 应用默认值后进行验证
|
||||
return validateConfig(config)
|
||||
}
|
||||
|
||||
// getConfigDir returns the configuration directory path
|
||||
// getConfigDir 返回配置目录路径
|
||||
func getConfigDir() string {
|
||||
// Use APPDATA on Windows, fallback to current directory
|
||||
// 在 Windows 上使用 APPDATA,回退到当前目录
|
||||
if appData := os.Getenv("APPDATA"); appData != "" {
|
||||
return filepath.Join(appData, "LightweightUpdater")
|
||||
return filepath.Join(appData, "AUTO_MAA_Go_Updater")
|
||||
}
|
||||
return "."
|
||||
}
|
||||
|
||||
// encryptCDK encrypts the CDK using XOR encryption with a static key
|
||||
func encryptCDK(cdk string) string {
|
||||
if cdk == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
key := []byte("updater-key-2024")
|
||||
encrypted := make([]byte, len(cdk))
|
||||
|
||||
for i, b := range []byte(cdk) {
|
||||
encrypted[i] = b ^ key[i%len(key)]
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(encrypted)
|
||||
}
|
||||
|
||||
// decryptCDK decrypts the CDK using XOR decryption with a static key
|
||||
func decryptCDK(encryptedCDK string) (string, error) {
|
||||
if encryptedCDK == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
encrypted, err := base64.StdEncoding.DecodeString(encryptedCDK)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode encrypted CDK: %w", err)
|
||||
}
|
||||
|
||||
key := []byte("updater-key-2024")
|
||||
decrypted := make([]byte, len(encrypted))
|
||||
|
||||
for i, b := range encrypted {
|
||||
decrypted[i] = b ^ key[i%len(key)]
|
||||
}
|
||||
|
||||
return string(decrypted), nil
|
||||
}
|
||||
|
||||
// SetCDK sets the CDK in the config with encryption
|
||||
func (c *Config) SetCDK(cdk string) {
|
||||
c.CDK = encryptCDK(cdk)
|
||||
}
|
||||
|
||||
// GetCDK returns the decrypted CDK from the config
|
||||
func (c *Config) GetCDK() (string, error) {
|
||||
return decryptCDK(c.CDK)
|
||||
}
|
||||
|
||||
@@ -44,7 +44,6 @@
|
||||
},
|
||||
"Update": {
|
||||
"IfAutoUpdate": false,
|
||||
"MirrorChyanCDK": "",
|
||||
"ProxyUrlList": [],
|
||||
"ThreadNumb": 8,
|
||||
"UpdateType": "stable"
|
||||
|
||||
@@ -3,163 +3,55 @@ package config
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncryptDecryptCDK(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
original string
|
||||
}{
|
||||
{
|
||||
name: "Empty CDK",
|
||||
original: "",
|
||||
},
|
||||
{
|
||||
name: "Simple CDK",
|
||||
original: "test123",
|
||||
},
|
||||
{
|
||||
name: "Complex CDK",
|
||||
original: "ABC123-DEF456-GHI789",
|
||||
},
|
||||
{
|
||||
name: "CDK with special characters",
|
||||
original: "test@#$%^&*()_+-={}[]|\\:;\"'<>?,./",
|
||||
},
|
||||
{
|
||||
name: "Long CDK",
|
||||
original: "this-is-a-very-long-cdk-key-that-should-still-work-properly-with-encryption-and-decryption",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test encryption
|
||||
encrypted := encryptCDK(tt.original)
|
||||
|
||||
// Empty string should remain empty
|
||||
if tt.original == "" {
|
||||
if encrypted != "" {
|
||||
t.Errorf("Expected empty string for empty input, got %s", encrypted)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Encrypted should be different from original (unless original is empty)
|
||||
if encrypted == tt.original {
|
||||
t.Errorf("Encrypted CDK should be different from original")
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decrypted, err := decryptCDK(encrypted)
|
||||
if err != nil {
|
||||
t.Errorf("Decryption failed: %v", err)
|
||||
}
|
||||
|
||||
// Decrypted should match original
|
||||
if decrypted != tt.original {
|
||||
t.Errorf("Expected %s, got %s", tt.original, decrypted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSetGetCDK(t *testing.T) {
|
||||
config := &Config{}
|
||||
|
||||
testCDK := "test-cdk-123"
|
||||
|
||||
// Set CDK (should encrypt)
|
||||
config.SetCDK(testCDK)
|
||||
|
||||
// CDK field should be encrypted (different from original)
|
||||
if config.CDK == testCDK {
|
||||
t.Errorf("CDK should be encrypted in config")
|
||||
}
|
||||
|
||||
// Get CDK (should decrypt)
|
||||
retrievedCDK, err := config.GetCDK()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get CDK: %v", err)
|
||||
}
|
||||
|
||||
if retrievedCDK != testCDK {
|
||||
t.Errorf("Expected %s, got %s", testCDK, retrievedCDK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidCDK(t *testing.T) {
|
||||
// Test with invalid base64
|
||||
_, err := decryptCDK("invalid-base64!")
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for invalid base64")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigManagerLoadSave(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
// 为测试创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create config manager with temp path
|
||||
// 使用临时路径创建配置管理器
|
||||
cm := &DefaultConfigManager{
|
||||
configPath: filepath.Join(tempDir, "test-config.yaml"),
|
||||
}
|
||||
|
||||
// Test loading non-existent config (should create default)
|
||||
// 测试加载不存在的配置(应创建默认配置)
|
||||
config, err := cm.Load()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to load config: %v", err)
|
||||
t.Errorf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Errorf("Config should not be nil")
|
||||
t.Errorf("配置不应为 nil")
|
||||
}
|
||||
|
||||
// Verify default values
|
||||
// 验证默认值
|
||||
if config.CurrentVersion != "v1.0.0" {
|
||||
t.Errorf("Expected default version v1.0.0, got %s", config.CurrentVersion)
|
||||
t.Errorf("期望默认版本 v1.0.0,得到 %s", config.CurrentVersion)
|
||||
}
|
||||
|
||||
if config.UserAgent != "LightweightUpdater/1.0" {
|
||||
t.Errorf("Expected default user agent, got %s", config.UserAgent)
|
||||
if config.UserAgent != "AUTO_MAA_Go_Updater/1.0" {
|
||||
t.Errorf("期望默认用户代理,得到 %s", config.UserAgent)
|
||||
}
|
||||
|
||||
// Set some values including CDK
|
||||
// 设置一些值
|
||||
config.ResourceID = "TEST123"
|
||||
config.SetCDK("secret-cdk-key")
|
||||
|
||||
// Save config
|
||||
// 保存配置
|
||||
err = cm.Save(config)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to save config: %v", err)
|
||||
t.Errorf("保存配置失败: %v", err)
|
||||
}
|
||||
|
||||
// Load config again
|
||||
// 再次加载配置
|
||||
loadedConfig, err := cm.Load()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to load saved config: %v", err)
|
||||
t.Errorf("加载已保存配置失败: %v", err)
|
||||
}
|
||||
|
||||
// Verify values
|
||||
// 验证值
|
||||
if loadedConfig.ResourceID != "TEST123" {
|
||||
t.Errorf("Expected ResourceID TEST123, got %s", loadedConfig.ResourceID)
|
||||
}
|
||||
|
||||
// Verify CDK is properly encrypted/decrypted
|
||||
retrievedCDK, err := loadedConfig.GetCDK()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to get CDK from loaded config: %v", err)
|
||||
}
|
||||
|
||||
if retrievedCDK != "secret-cdk-key" {
|
||||
t.Errorf("Expected CDK secret-cdk-key, got %s", retrievedCDK)
|
||||
}
|
||||
|
||||
// Verify CDK is encrypted in the config struct
|
||||
if loadedConfig.CDK == "secret-cdk-key" {
|
||||
t.Errorf("CDK should be encrypted in config file")
|
||||
t.Errorf("期望 ResourceID TEST123,得到 %s", loadedConfig.ResourceID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,12 +62,12 @@ func TestConfigValidation(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Nil config",
|
||||
name: "空配置",
|
||||
config: nil,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty ResourceID",
|
||||
name: "空 ResourceID",
|
||||
config: &Config{
|
||||
ResourceID: "",
|
||||
CurrentVersion: "v1.0.0",
|
||||
@@ -186,40 +78,7 @@ func TestConfigValidation(t *testing.T) {
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty CurrentVersion",
|
||||
config: &Config{
|
||||
ResourceID: "TEST",
|
||||
CurrentVersion: "",
|
||||
UserAgent: "Test/1.0",
|
||||
LogLevel: "info",
|
||||
CheckInterval: 3600,
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid LogLevel",
|
||||
config: &Config{
|
||||
ResourceID: "TEST",
|
||||
CurrentVersion: "v1.0.0",
|
||||
UserAgent: "Test/1.0",
|
||||
LogLevel: "invalid",
|
||||
CheckInterval: 3600,
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid CheckInterval",
|
||||
config: &Config{
|
||||
ResourceID: "TEST",
|
||||
CurrentVersion: "v1.0.0",
|
||||
UserAgent: "Test/1.0",
|
||||
LogLevel: "info",
|
||||
CheckInterval: 30, // Less than 60
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid config",
|
||||
name: "有效配置",
|
||||
config: &Config{
|
||||
ResourceID: "TEST",
|
||||
CurrentVersion: "v1.0.0",
|
||||
@@ -235,112 +94,10 @@ func TestConfigValidation(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateConfig(tt.config)
|
||||
if tt.expectError && err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
t.Errorf("期望错误但没有得到")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error but got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConfigDir(t *testing.T) {
|
||||
// Save original APPDATA
|
||||
originalAppData := os.Getenv("APPDATA")
|
||||
defer os.Setenv("APPDATA", originalAppData)
|
||||
|
||||
// Test with APPDATA set
|
||||
os.Setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming")
|
||||
dir := getConfigDir()
|
||||
expected := "C:\\Users\\Test\\AppData\\Roaming\\LightweightUpdater"
|
||||
if dir != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, dir)
|
||||
}
|
||||
|
||||
// Test without APPDATA
|
||||
os.Unsetenv("APPDATA")
|
||||
dir = getConfigDir()
|
||||
if dir != "." {
|
||||
t.Errorf("Expected current directory, got %s", dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAndApplyDefaults(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input *Config
|
||||
expected *Config
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "Apply defaults to empty config",
|
||||
input: &Config{
|
||||
ResourceID: "TEST",
|
||||
},
|
||||
expected: &Config{
|
||||
ResourceID: "TEST",
|
||||
CurrentVersion: "v1.0.0",
|
||||
UserAgent: "LightweightUpdater/1.0",
|
||||
LogLevel: "info",
|
||||
CheckInterval: 3600,
|
||||
},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Partial config with some defaults needed",
|
||||
input: &Config{
|
||||
ResourceID: "TEST",
|
||||
CurrentVersion: "v2.0.0",
|
||||
LogLevel: "debug",
|
||||
},
|
||||
expected: &Config{
|
||||
ResourceID: "TEST",
|
||||
CurrentVersion: "v2.0.0",
|
||||
UserAgent: "LightweightUpdater/1.0",
|
||||
LogLevel: "debug",
|
||||
CheckInterval: 3600,
|
||||
},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Config with invalid values after defaults",
|
||||
input: &Config{
|
||||
ResourceID: "", // Invalid - empty
|
||||
CheckInterval: 30, // Invalid - too small
|
||||
},
|
||||
expected: nil,
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateAndApplyDefaults(tt.input)
|
||||
|
||||
if tt.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check that defaults were applied correctly
|
||||
if tt.input.CurrentVersion != tt.expected.CurrentVersion {
|
||||
t.Errorf("CurrentVersion: expected %s, got %s", tt.expected.CurrentVersion, tt.input.CurrentVersion)
|
||||
}
|
||||
if tt.input.UserAgent != tt.expected.UserAgent {
|
||||
t.Errorf("UserAgent: expected %s, got %s", tt.expected.UserAgent, tt.input.UserAgent)
|
||||
}
|
||||
if tt.input.LogLevel != tt.expected.LogLevel {
|
||||
t.Errorf("LogLevel: expected %s, got %s", tt.expected.LogLevel, tt.input.LogLevel)
|
||||
}
|
||||
if tt.input.CheckInterval != tt.expected.CheckInterval {
|
||||
t.Errorf("CheckInterval: expected %d, got %d", tt.expected.CheckInterval, tt.input.CheckInterval)
|
||||
t.Errorf("期望无错误但得到: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -350,123 +107,47 @@ func TestGetDefaultConfig(t *testing.T) {
|
||||
config := getDefaultConfig()
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("getDefaultConfig() returned nil")
|
||||
t.Fatal("getDefaultConfig() 返回 nil")
|
||||
}
|
||||
|
||||
// Verify default values
|
||||
if config.ResourceID != "PLACEHOLDER" {
|
||||
t.Errorf("Expected ResourceID 'PLACEHOLDER', got %s", config.ResourceID)
|
||||
// 验证默认值
|
||||
if config.ResourceID != "AUTO_MAA" {
|
||||
t.Errorf("期望 ResourceID 'AUTO_MAA',得到 %s", config.ResourceID)
|
||||
}
|
||||
if config.CurrentVersion != "v1.0.0" {
|
||||
t.Errorf("Expected CurrentVersion 'v1.0.0', got %s", config.CurrentVersion)
|
||||
t.Errorf("期望 CurrentVersion 'v1.0.0',得到 %s", config.CurrentVersion)
|
||||
}
|
||||
if config.UserAgent != "LightweightUpdater/1.0" {
|
||||
t.Errorf("Expected UserAgent 'LightweightUpdater/1.0', got %s", config.UserAgent)
|
||||
if config.UserAgent != "AUTO_MAA_Go_Updater/1.0" {
|
||||
t.Errorf("期望 UserAgent 'AUTO_MAA_Go_Updater/1.0',得到 %s", config.UserAgent)
|
||||
}
|
||||
if config.LogLevel != "info" {
|
||||
t.Errorf("Expected LogLevel 'info', got %s", config.LogLevel)
|
||||
t.Errorf("期望 LogLevel 'info',得到 %s", config.LogLevel)
|
||||
}
|
||||
if config.CheckInterval != 3600 {
|
||||
t.Errorf("Expected CheckInterval 3600, got %d", config.CheckInterval)
|
||||
t.Errorf("期望 CheckInterval 3600,得到 %d", config.CheckInterval)
|
||||
}
|
||||
if !config.AutoCheck {
|
||||
t.Errorf("Expected AutoCheck true, got %v", config.AutoCheck)
|
||||
t.Errorf("期望 AutoCheck true,得到 %v", config.AutoCheck)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigManagerWithCustomPath(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
customPath := filepath.Join(tempDir, "custom-config.yaml")
|
||||
func TestGetConfigDir(t *testing.T) {
|
||||
// 保存原始 APPDATA
|
||||
originalAppData := os.Getenv("APPDATA")
|
||||
defer os.Setenv("APPDATA", originalAppData)
|
||||
|
||||
cm := &DefaultConfigManager{
|
||||
configPath: customPath,
|
||||
// 测试设置了 APPDATA
|
||||
os.Setenv("APPDATA", "C:\\Users\\Test\\AppData\\Roaming")
|
||||
dir := getConfigDir()
|
||||
expected := "C:\\Users\\Test\\AppData\\Roaming\\AUTO_MAA_Go_Updater"
|
||||
if dir != expected {
|
||||
t.Errorf("期望 %s,得到 %s", expected, dir)
|
||||
}
|
||||
|
||||
// Test GetConfigPath
|
||||
if cm.GetConfigPath() != customPath {
|
||||
t.Errorf("Expected config path %s, got %s", customPath, cm.GetConfigPath())
|
||||
}
|
||||
|
||||
// Test Save and Load with custom path
|
||||
testConfig := &Config{
|
||||
ResourceID: "CUSTOM",
|
||||
CurrentVersion: "v1.5.0",
|
||||
UserAgent: "CustomUpdater/1.0",
|
||||
LogLevel: "debug",
|
||||
CheckInterval: 7200,
|
||||
AutoCheck: false,
|
||||
}
|
||||
|
||||
// Save config
|
||||
err := cm.Save(testConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save config: %v", err)
|
||||
}
|
||||
|
||||
// Load config
|
||||
loadedConfig, err := cm.Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded config matches saved config
|
||||
if loadedConfig.ResourceID != testConfig.ResourceID {
|
||||
t.Errorf("ResourceID mismatch: expected %s, got %s", testConfig.ResourceID, loadedConfig.ResourceID)
|
||||
}
|
||||
if loadedConfig.CurrentVersion != testConfig.CurrentVersion {
|
||||
t.Errorf("CurrentVersion mismatch: expected %s, got %s", testConfig.CurrentVersion, loadedConfig.CurrentVersion)
|
||||
}
|
||||
if loadedConfig.AutoCheck != testConfig.AutoCheck {
|
||||
t.Errorf("AutoCheck mismatch: expected %v, got %v", testConfig.AutoCheck, loadedConfig.AutoCheck)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigManagerErrorHandling(t *testing.T) {
|
||||
// Test with invalid directory path
|
||||
invalidPath := string([]byte{0}) + "/invalid/config.yaml"
|
||||
cm := &DefaultConfigManager{
|
||||
configPath: invalidPath,
|
||||
}
|
||||
|
||||
// Load should fail with invalid path
|
||||
_, err := cm.Load()
|
||||
if err == nil {
|
||||
t.Error("Expected error when loading from invalid path")
|
||||
}
|
||||
|
||||
// Save should fail with invalid path
|
||||
testConfig := getDefaultConfig()
|
||||
testConfig.ResourceID = "TEST"
|
||||
err = cm.Save(testConfig)
|
||||
if err == nil {
|
||||
t.Error("Expected error when saving to invalid path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"Unicode characters", "测试CDK密钥🔑"},
|
||||
{"Very long string", strings.Repeat("A", 1000)},
|
||||
{"Binary-like data", string([]byte{0, 1, 2, 3, 255, 254, 253})},
|
||||
{"Only spaces", " "},
|
||||
{"Newlines and tabs", "line1\nline2\tindented"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
encrypted := encryptCDK(tt.input)
|
||||
decrypted, err := decryptCDK(encrypted)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Decryption failed: %v", err)
|
||||
}
|
||||
|
||||
if decrypted != tt.input {
|
||||
t.Errorf("Encryption/decryption mismatch: expected %q, got %q", tt.input, decrypted)
|
||||
}
|
||||
})
|
||||
// 测试没有 APPDATA
|
||||
os.Unsetenv("APPDATA")
|
||||
dir = getConfigDir()
|
||||
if dir != "." {
|
||||
t.Errorf("期望当前目录,得到 %s", dir)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,18 +13,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DownloadProgress represents the current download progress
|
||||
// DownloadProgress 表示当前下载进度
|
||||
type DownloadProgress struct {
|
||||
BytesDownloaded int64
|
||||
TotalBytes int64
|
||||
Percentage float64
|
||||
Speed int64 // bytes per second
|
||||
Speed int64 // 每秒字节数
|
||||
}
|
||||
|
||||
// ProgressCallback is called during download to report progress
|
||||
// ProgressCallback 在下载过程中调用以报告进度
|
||||
type ProgressCallback func(DownloadProgress)
|
||||
|
||||
// DownloadManager interface defines download operations
|
||||
// DownloadManager 定义下载操作的接口
|
||||
type DownloadManager interface {
|
||||
Download(url, destination string, progressCallback ProgressCallback) error
|
||||
DownloadWithResume(url, destination string, progressCallback ProgressCallback) error
|
||||
@@ -32,13 +32,13 @@ type DownloadManager interface {
|
||||
SetTimeout(timeout time.Duration)
|
||||
}
|
||||
|
||||
// Manager implements DownloadManager interface
|
||||
// Manager 实现 DownloadManager 接口
|
||||
type Manager struct {
|
||||
client *http.Client
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewManager creates a new download manager
|
||||
// NewManager 创建新的下载管理器
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
client: &http.Client{
|
||||
@@ -48,24 +48,24 @@ func NewManager() *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// Download downloads a file from the given URL to the destination path
|
||||
// Download 从给定 URL 下载文件到目标路径
|
||||
func (m *Manager) Download(url, destination string, progressCallback ProgressCallback) error {
|
||||
return m.downloadWithContext(context.Background(), url, destination, progressCallback, false)
|
||||
}
|
||||
|
||||
// DownloadWithResume downloads a file with resume capability
|
||||
// DownloadWithResume 下载文件并支持断点续传
|
||||
func (m *Manager) DownloadWithResume(url, destination string, progressCallback ProgressCallback) error {
|
||||
return m.downloadWithContext(context.Background(), url, destination, progressCallback, true)
|
||||
}
|
||||
|
||||
// downloadWithContext performs the actual download with context support
|
||||
// downloadWithContext 执行实际的下载并支持上下文
|
||||
func (m *Manager) downloadWithContext(ctx context.Context, url, destination string, progressCallback ProgressCallback, resume bool) error {
|
||||
// Create destination directory if it doesn't exist
|
||||
// 如果目标目录不存在则创建
|
||||
if err := os.MkdirAll(filepath.Dir(destination), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create destination directory: %w", err)
|
||||
return fmt.Errorf("创建目标目录失败: %w", err)
|
||||
}
|
||||
|
||||
// Check if file exists for resume
|
||||
// 检查文件是否存在以支持断点续传
|
||||
var existingSize int64
|
||||
if resume {
|
||||
if stat, err := os.Stat(destination); err == nil {
|
||||
@@ -73,30 +73,30 @@ func (m *Manager) downloadWithContext(ctx context.Context, url, destination stri
|
||||
}
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
// 创建 HTTP 请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
return fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// Add range header for resume
|
||||
// 为断点续传添加范围头
|
||||
if resume && existingSize > 0 {
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existingSize))
|
||||
}
|
||||
|
||||
// Execute request
|
||||
// 执行请求
|
||||
resp, err := m.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute request: %w", err)
|
||||
return fmt.Errorf("执行请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check response status
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
||||
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
return fmt.Errorf("意外的状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Get total size
|
||||
// 获取总大小
|
||||
totalSize := existingSize
|
||||
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
|
||||
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil {
|
||||
@@ -104,7 +104,7 @@ func (m *Manager) downloadWithContext(ctx context.Context, url, destination stri
|
||||
}
|
||||
}
|
||||
|
||||
// Open destination file
|
||||
// 打开目标文件
|
||||
var file *os.File
|
||||
if resume && existingSize > 0 {
|
||||
file, err = os.OpenFile(destination, os.O_WRONLY|os.O_APPEND, 0644)
|
||||
@@ -113,17 +113,17 @@ func (m *Manager) downloadWithContext(ctx context.Context, url, destination stri
|
||||
existingSize = 0
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create destination file: %w", err)
|
||||
return fmt.Errorf("创建目标文件失败: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Download with progress tracking
|
||||
// 下载并跟踪进度
|
||||
return m.copyWithProgress(resp.Body, file, existingSize, totalSize, progressCallback)
|
||||
}
|
||||
|
||||
// copyWithProgress copies data while tracking progress
|
||||
// copyWithProgress 复制数据并跟踪进度
|
||||
func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, totalBytes int64, progressCallback ProgressCallback) error {
|
||||
buffer := make([]byte, 32*1024) // 32KB buffer
|
||||
buffer := make([]byte, 32*1024) // 32KB 缓冲区
|
||||
downloaded := startBytes
|
||||
startTime := time.Now()
|
||||
lastUpdate := startTime
|
||||
@@ -132,11 +132,11 @@ func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, tot
|
||||
n, err := src.Read(buffer)
|
||||
if n > 0 {
|
||||
if _, writeErr := dst.Write(buffer[:n]); writeErr != nil {
|
||||
return fmt.Errorf("failed to write to destination: %w", writeErr)
|
||||
return fmt.Errorf("写入目标失败: %w", writeErr)
|
||||
}
|
||||
downloaded += int64(n)
|
||||
|
||||
// Update progress every 100ms
|
||||
// 每 100ms 更新一次进度
|
||||
now := time.Now()
|
||||
if progressCallback != nil && now.Sub(lastUpdate) >= 100*time.Millisecond {
|
||||
elapsed := now.Sub(startTime).Seconds()
|
||||
@@ -164,11 +164,11 @@ func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, tot
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read from source: %w", err)
|
||||
return fmt.Errorf("从源读取失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Final progress update
|
||||
// 最终进度更新
|
||||
if progressCallback != nil {
|
||||
elapsed := time.Since(startTime).Seconds()
|
||||
speed := int64(0)
|
||||
@@ -192,32 +192,32 @@ func (m *Manager) copyWithProgress(src io.Reader, dst io.Writer, startBytes, tot
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateChecksum validates the SHA256 checksum of a file
|
||||
// ValidateChecksum 验证文件的 SHA256 校验和
|
||||
func (m *Manager) ValidateChecksum(filePath, expectedChecksum string) error {
|
||||
if expectedChecksum == "" {
|
||||
return nil // No checksum to validate
|
||||
return nil // 没有校验和需要验证
|
||||
}
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file for checksum validation: %w", err)
|
||||
return fmt.Errorf("打开文件进行校验和验证失败: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
hash := sha256.New()
|
||||
if _, err := io.Copy(hash, file); err != nil {
|
||||
return fmt.Errorf("failed to calculate checksum: %w", err)
|
||||
return fmt.Errorf("计算校验和失败: %w", err)
|
||||
}
|
||||
|
||||
actualChecksum := hex.EncodeToString(hash.Sum(nil))
|
||||
if actualChecksum != expectedChecksum {
|
||||
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
|
||||
return fmt.Errorf("校验和不匹配: 期望 %s,得到 %s", expectedChecksum, actualChecksum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTimeout sets the timeout for download operations
|
||||
// SetTimeout 设置下载操作的超时时间
|
||||
func (m *Manager) SetTimeout(timeout time.Duration) {
|
||||
m.timeout = timeout
|
||||
m.client.Timeout = timeout
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
module lightweight-updater
|
||||
module AUTO_MAA_Go_Updater
|
||||
|
||||
go 1.24.5
|
||||
|
||||
|
||||
@@ -2,15 +2,15 @@ package gui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/app"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/dialog"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
"fyne.io/fyne/v2/theme"
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
)
|
||||
|
||||
// UpdateStatus represents the current status of the update process
|
||||
// UpdateStatus 表示更新过程的当前状态
|
||||
type UpdateStatus int
|
||||
|
||||
const (
|
||||
@@ -22,16 +22,15 @@ const (
|
||||
StatusError
|
||||
)
|
||||
|
||||
// Config represents the configuration structure for the GUI
|
||||
// Config 表示 GUI 的配置结构
|
||||
type Config struct {
|
||||
ResourceID string
|
||||
CurrentVersion string
|
||||
CDK string
|
||||
UserAgent string
|
||||
BackupURL string
|
||||
}
|
||||
|
||||
// GUIManager interface defines the methods for GUI management
|
||||
// GUIManager 定义 GUI 管理的接口方法
|
||||
type GUIManager interface {
|
||||
ShowMainWindow()
|
||||
UpdateStatus(status UpdateStatus, message string)
|
||||
@@ -41,7 +40,7 @@ type GUIManager interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
// Manager implements the GUIManager interface
|
||||
// Manager 实现 GUIManager 接口
|
||||
type Manager struct {
|
||||
app fyne.App
|
||||
window fyne.Window
|
||||
@@ -60,7 +59,7 @@ func NewManager() *Manager {
|
||||
a := app.New()
|
||||
a.SetIcon(theme.ComputerIcon())
|
||||
|
||||
w := a.NewWindow("轻量级更新器")
|
||||
w := a.NewWindow("AUTO_MAA_Go_Updater")
|
||||
w.Resize(fyne.NewSize(500, 400))
|
||||
w.SetFixedSize(false)
|
||||
w.CenterOnScreen()
|
||||
@@ -117,11 +116,11 @@ func (m *Manager) createUIComponents() {
|
||||
}
|
||||
|
||||
// createMainLayout creates the main window layout
|
||||
func (m *Manager) createMainLayout() *container.VBox {
|
||||
func (m *Manager) createMainLayout() *fyne.Container {
|
||||
// Header section
|
||||
header := container.NewVBox(
|
||||
widget.NewCard("", "", container.NewVBox(
|
||||
widget.NewLabelWithStyle("轻量级更新器", fyne.TextAlignCenter, fyne.TextStyle{Bold: true}),
|
||||
widget.NewLabelWithStyle("AUTO_MAA_Go_Updater", fyne.TextAlignCenter, fyne.TextStyle{Bold: true}),
|
||||
m.versionLabel,
|
||||
)),
|
||||
)
|
||||
@@ -226,11 +225,8 @@ func (m *Manager) showConfigDialog() (*Config, error) {
|
||||
versionEntry := widget.NewEntry()
|
||||
versionEntry.SetPlaceHolder("例如: v1.0.0")
|
||||
|
||||
cdkEntry := widget.NewPasswordEntry()
|
||||
cdkEntry.SetPlaceHolder("输入您的CDK(可选)")
|
||||
|
||||
userAgentEntry := widget.NewEntry()
|
||||
userAgentEntry.SetText("LightweightUpdater/1.0")
|
||||
userAgentEntry.SetText("AUTO_MAA_Go_Updater/1.0")
|
||||
|
||||
backupURLEntry := widget.NewEntry()
|
||||
backupURLEntry.SetPlaceHolder("备用下载地址(可选)")
|
||||
@@ -240,7 +236,6 @@ func (m *Manager) showConfigDialog() (*Config, error) {
|
||||
Items: []*widget.FormItem{
|
||||
{Text: "资源ID:", Widget: resourceIDEntry},
|
||||
{Text: "当前版本:", Widget: versionEntry},
|
||||
{Text: "CDK:", Widget: cdkEntry},
|
||||
{Text: "用户代理:", Widget: userAgentEntry},
|
||||
{Text: "备用下载地址:", Widget: backupURLEntry},
|
||||
},
|
||||
@@ -261,7 +256,6 @@ func (m *Manager) showConfigDialog() (*Config, error) {
|
||||
config := &Config{
|
||||
ResourceID: resourceIDEntry.Text,
|
||||
CurrentVersion: versionEntry.Text,
|
||||
CDK: cdkEntry.Text,
|
||||
UserAgent: userAgentEntry.Text,
|
||||
BackupURL: backupURLEntry.Text,
|
||||
}
|
||||
@@ -289,11 +283,8 @@ func (m *Manager) showConfigDialog() (*Config, error) {
|
||||
**配置说明:**
|
||||
- **资源ID**: Mirror酱服务中的资源标识符
|
||||
- **当前版本**: 当前软件的版本号
|
||||
- **CDK**: Mirror酱服务的访问密钥(可选,提供更好的下载体验)
|
||||
- **用户代理**: HTTP请求的用户代理字符串
|
||||
- **备用下载地址**: 当Mirror酱服务不可用时的备用下载地址
|
||||
|
||||
如需获取CDK,请访问 [Mirror酱官网](https://mirrorchyan.com)
|
||||
`)
|
||||
|
||||
// Create container with help text
|
||||
|
||||
@@ -11,14 +11,14 @@ import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// ChangesInfo represents the structure of changes.json file
|
||||
// ChangesInfo 表示 changes.json 文件的结构
|
||||
type ChangesInfo struct {
|
||||
Deleted []string `json:"deleted"`
|
||||
Added []string `json:"added"`
|
||||
Modified []string `json:"modified"`
|
||||
}
|
||||
|
||||
// InstallManager interface defines the contract for installation operations
|
||||
// InstallManager 定义安装操作的接口契约
|
||||
type InstallManager interface {
|
||||
ExtractZip(zipPath, destPath string) error
|
||||
ProcessChanges(changesPath string) (*ChangesInfo, error)
|
||||
@@ -28,31 +28,31 @@ type InstallManager interface {
|
||||
CleanupTempDir(tempDir string) error
|
||||
}
|
||||
|
||||
// Manager implements the InstallManager interface
|
||||
// Manager 实现 InstallManager 接口
|
||||
type Manager struct {
|
||||
tempDirs []string // Track temporary directories for cleanup
|
||||
tempDirs []string // 跟踪临时目录以便清理
|
||||
}
|
||||
|
||||
// NewManager creates a new install manager instance
|
||||
// NewManager 创建新的安装管理器实例
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
tempDirs: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTempDir creates a temporary directory for extraction
|
||||
// CreateTempDir 为解压创建临时目录
|
||||
func (m *Manager) CreateTempDir() (string, error) {
|
||||
tempDir, err := os.MkdirTemp("", "updater_*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
return "", fmt.Errorf("创建临时目录失败: %w", err)
|
||||
}
|
||||
|
||||
// Track temp directory for cleanup
|
||||
// 跟踪临时目录以便清理
|
||||
m.tempDirs = append(m.tempDirs, tempDir)
|
||||
return tempDir, nil
|
||||
}
|
||||
|
||||
// CleanupTempDir removes a temporary directory and its contents
|
||||
// CleanupTempDir 删除临时目录及其内容
|
||||
func (m *Manager) CleanupTempDir(tempDir string) error {
|
||||
if tempDir == "" {
|
||||
return nil
|
||||
@@ -60,10 +60,10 @@ func (m *Manager) CleanupTempDir(tempDir string) error {
|
||||
|
||||
err := os.RemoveAll(tempDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to cleanup temp directory %s: %w", tempDir, err)
|
||||
return fmt.Errorf("清理临时目录 %s 失败: %w", tempDir, err)
|
||||
}
|
||||
|
||||
// Remove from tracking list
|
||||
// 从跟踪列表中删除
|
||||
for i, dir := range m.tempDirs {
|
||||
if dir == tempDir {
|
||||
m.tempDirs = append(m.tempDirs[:i], m.tempDirs[i+1:]...)
|
||||
@@ -74,98 +74,98 @@ func (m *Manager) CleanupTempDir(tempDir string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupAllTempDirs removes all tracked temporary directories
|
||||
// CleanupAllTempDirs 删除所有跟踪的临时目录
|
||||
func (m *Manager) CleanupAllTempDirs() error {
|
||||
var errors []string
|
||||
|
||||
for _, tempDir := range m.tempDirs {
|
||||
if err := os.RemoveAll(tempDir); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("failed to cleanup %s: %v", tempDir, err))
|
||||
errors = append(errors, fmt.Sprintf("清理 %s 失败: %v", tempDir, err))
|
||||
}
|
||||
}
|
||||
|
||||
m.tempDirs = m.tempDirs[:0] // Clear the slice
|
||||
m.tempDirs = m.tempDirs[:0] // 清空切片
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("cleanup errors: %s", strings.Join(errors, "; "))
|
||||
return fmt.Errorf("清理错误: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractZip extracts a ZIP file to the specified destination directory
|
||||
// ExtractZip 将 ZIP 文件解压到指定的目标目录
|
||||
func (m *Manager) ExtractZip(zipPath, destPath string) error {
|
||||
// Open ZIP file for reading
|
||||
// 打开 ZIP 文件进行读取
|
||||
reader, err := zip.OpenReader(zipPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open ZIP file %s: %w", zipPath, err)
|
||||
return fmt.Errorf("打开 ZIP 文件 %s 失败: %w", zipPath, err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Create destination directory if it doesn't exist
|
||||
// 如果目标目录不存在则创建
|
||||
if err := os.MkdirAll(destPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create destination directory %s: %w", destPath, err)
|
||||
return fmt.Errorf("创建目标目录 %s 失败: %w", destPath, err)
|
||||
}
|
||||
|
||||
// Extract files
|
||||
// 解压文件
|
||||
for _, file := range reader.File {
|
||||
if err := m.extractFile(file, destPath); err != nil {
|
||||
return fmt.Errorf("failed to extract file %s: %w", file.Name, err)
|
||||
return fmt.Errorf("解压文件 %s 失败: %w", file.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFile extracts a single file from the ZIP archive
|
||||
// extractFile 从 ZIP 归档中解压单个文件
|
||||
func (m *Manager) extractFile(file *zip.File, destPath string) error {
|
||||
// Clean the file path to prevent directory traversal attacks
|
||||
// 清理文件路径以防止目录遍历攻击
|
||||
cleanPath := filepath.Clean(file.Name)
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return fmt.Errorf("invalid file path: %s", file.Name)
|
||||
return fmt.Errorf("无效的文件路径: %s", file.Name)
|
||||
}
|
||||
|
||||
// Create full destination path
|
||||
// 创建完整的目标路径
|
||||
destFile := filepath.Join(destPath, cleanPath)
|
||||
|
||||
// Create directory structure if needed
|
||||
// 如果需要则创建目录结构
|
||||
if file.FileInfo().IsDir() {
|
||||
return os.MkdirAll(destFile, file.FileInfo().Mode())
|
||||
}
|
||||
|
||||
// Create parent directories
|
||||
// 创建父目录
|
||||
if err := os.MkdirAll(filepath.Dir(destFile), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create parent directory: %w", err)
|
||||
return fmt.Errorf("创建父目录失败: %w", err)
|
||||
}
|
||||
|
||||
// Open file in ZIP archive
|
||||
// 打开 ZIP 归档中的文件
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file in archive: %w", err)
|
||||
return fmt.Errorf("打开归档中的文件失败: %w", err)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
// Create destination file
|
||||
// 创建目标文件
|
||||
outFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.FileInfo().Mode())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create destination file: %w", err)
|
||||
return fmt.Errorf("创建目标文件失败: %w", err)
|
||||
}
|
||||
defer outFile.Close()
|
||||
|
||||
// Copy file contents
|
||||
// 复制文件内容
|
||||
_, err = io.Copy(outFile, rc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to copy file contents: %w", err)
|
||||
return fmt.Errorf("复制文件内容失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessChanges reads and parses the changes.json file
|
||||
// ProcessChanges 读取并解析 changes.json 文件
|
||||
func (m *Manager) ProcessChanges(changesPath string) (*ChangesInfo, error) {
|
||||
// Check if changes.json exists
|
||||
// 检查 changes.json 是否存在
|
||||
if _, err := os.Stat(changesPath); os.IsNotExist(err) {
|
||||
// If changes.json doesn't exist, return empty changes info
|
||||
// 如果 changes.json 不存在,返回空的变更信息
|
||||
return &ChangesInfo{
|
||||
Deleted: []string{},
|
||||
Added: []string{},
|
||||
@@ -173,72 +173,72 @@ func (m *Manager) ProcessChanges(changesPath string) (*ChangesInfo, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Read the changes.json file
|
||||
// 读取 changes.json 文件
|
||||
data, err := os.ReadFile(changesPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read changes file %s: %w", changesPath, err)
|
||||
return nil, fmt.Errorf("读取变更文件 %s 失败: %w", changesPath, err)
|
||||
}
|
||||
|
||||
// Parse JSON
|
||||
// 解析 JSON
|
||||
var changes ChangesInfo
|
||||
if err := json.Unmarshal(data, &changes); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse changes JSON: %w", err)
|
||||
return nil, fmt.Errorf("解析变更 JSON 失败: %w", err)
|
||||
}
|
||||
|
||||
return &changes, nil
|
||||
}
|
||||
|
||||
// HandleRunningProcess handles running processes by renaming files that are in use
|
||||
// HandleRunningProcess 通过重命名正在使用的文件来处理正在运行的进程
|
||||
func (m *Manager) HandleRunningProcess(processName string) error {
|
||||
// Get the current executable path
|
||||
// 获取当前可执行文件路径
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get executable path: %w", err)
|
||||
return fmt.Errorf("获取可执行文件路径失败: %w", err)
|
||||
}
|
||||
|
||||
exeDir := filepath.Dir(exePath)
|
||||
targetFile := filepath.Join(exeDir, processName)
|
||||
|
||||
// Check if the target file exists
|
||||
// 检查目标文件是否存在
|
||||
if _, err := os.Stat(targetFile); os.IsNotExist(err) {
|
||||
// File doesn't exist, nothing to handle
|
||||
// 文件不存在,无需处理
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to rename the file to indicate it should be deleted on next startup
|
||||
// 尝试重命名文件以指示应在下次启动时删除
|
||||
oldFile := targetFile + ".old"
|
||||
|
||||
// Remove existing .old file if it exists
|
||||
// 如果存在现有的 .old 文件则删除
|
||||
if _, err := os.Stat(oldFile); err == nil {
|
||||
if err := os.Remove(oldFile); err != nil {
|
||||
return fmt.Errorf("failed to remove existing old file %s: %w", oldFile, err)
|
||||
return fmt.Errorf("删除现有旧文件 %s 失败: %w", oldFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Rename the current file to .old
|
||||
// 将当前文件重命名为 .old
|
||||
if err := os.Rename(targetFile, oldFile); err != nil {
|
||||
// If rename fails, the process might be running
|
||||
// On Windows, we can't rename a running executable
|
||||
// 如果重命名失败,进程可能正在运行
|
||||
// 在 Windows 上,我们无法重命名正在运行的可执行文件
|
||||
if isFileInUse(err) {
|
||||
// Mark the file for deletion on next reboot (Windows specific)
|
||||
// 标记文件在下次重启时删除(Windows 特定)
|
||||
return m.markFileForDeletion(targetFile)
|
||||
}
|
||||
return fmt.Errorf("failed to rename running process file %s: %w", targetFile, err)
|
||||
return fmt.Errorf("重命名正在运行的进程文件 %s 失败: %w", targetFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isFileInUse checks if the error indicates the file is in use
|
||||
// isFileInUse 检查错误是否表示文件正在使用中
|
||||
func isFileInUse(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Windows-specific "file in use" errors
|
||||
// 检查 Windows 特定的"文件正在使用"错误
|
||||
if pathErr, ok := err.(*os.PathError); ok {
|
||||
if errno, ok := pathErr.Err.(syscall.Errno); ok {
|
||||
// ERROR_SHARING_VIOLATION (32) or ERROR_ACCESS_DENIED (5)
|
||||
// ERROR_SHARING_VIOLATION (32) 或 ERROR_ACCESS_DENIED (5)
|
||||
return errno == syscall.Errno(32) || errno == syscall.Errno(5)
|
||||
}
|
||||
}
|
||||
@@ -247,226 +247,226 @@ func isFileInUse(err error) bool {
|
||||
strings.Contains(err.Error(), "access is denied")
|
||||
}
|
||||
|
||||
// markFileForDeletion marks a file for deletion on next system reboot (Windows specific)
|
||||
// markFileForDeletion 标记文件在下次系统重启时删除(Windows 特定)
|
||||
func (m *Manager) markFileForDeletion(filePath string) error {
|
||||
// This is a Windows-specific implementation
|
||||
// For now, we'll create a marker file that can be handled by the main application
|
||||
// 这是 Windows 特定的实现
|
||||
// 目前,我们将创建一个可由主应用程序处理的标记文件
|
||||
markerFile := filePath + ".delete_on_restart"
|
||||
|
||||
// Create a marker file
|
||||
// 创建标记文件
|
||||
file, err := os.Create(markerFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create deletion marker file: %w", err)
|
||||
return fmt.Errorf("创建删除标记文件失败: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write the target file path to the marker
|
||||
// 将目标文件路径写入标记文件
|
||||
_, err = file.WriteString(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to marker file: %w", err)
|
||||
return fmt.Errorf("写入标记文件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteMarkedFiles removes files that were marked for deletion
|
||||
// DeleteMarkedFiles 删除标记为删除的文件
|
||||
func (m *Manager) DeleteMarkedFiles(directory string) error {
|
||||
// Find all .delete_on_restart files
|
||||
// 查找所有 .delete_on_restart 文件
|
||||
pattern := filepath.Join(directory, "*.delete_on_restart")
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find marker files: %w", err)
|
||||
return fmt.Errorf("查找标记文件失败: %w", err)
|
||||
}
|
||||
|
||||
var errors []string
|
||||
for _, markerFile := range matches {
|
||||
// Read the target file path
|
||||
// 读取目标文件路径
|
||||
data, err := os.ReadFile(markerFile)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("failed to read marker file %s: %v", markerFile, err))
|
||||
errors = append(errors, fmt.Sprintf("读取标记文件 %s 失败: %v", markerFile, err))
|
||||
continue
|
||||
}
|
||||
|
||||
targetFile := strings.TrimSpace(string(data))
|
||||
|
||||
// Try to delete the target file
|
||||
// 尝试删除目标文件
|
||||
if err := os.Remove(targetFile); err != nil && !os.IsNotExist(err) {
|
||||
errors = append(errors, fmt.Sprintf("failed to delete marked file %s: %v", targetFile, err))
|
||||
errors = append(errors, fmt.Sprintf("删除标记文件 %s 失败: %v", targetFile, err))
|
||||
}
|
||||
|
||||
// Remove the marker file
|
||||
// 删除标记文件
|
||||
if err := os.Remove(markerFile); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("failed to remove marker file %s: %v", markerFile, err))
|
||||
errors = append(errors, fmt.Sprintf("删除标记文件 %s 失败: %v", markerFile, err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("deletion errors: %s", strings.Join(errors, "; "))
|
||||
return fmt.Errorf("删除错误: %s", strings.Join(errors, "; "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyUpdate applies the update by copying files from source to target directory
|
||||
// ApplyUpdate 通过从源目录复制文件到目标目录来应用更新
|
||||
func (m *Manager) ApplyUpdate(sourcePath, targetPath string, changes *ChangesInfo) error {
|
||||
// Create backup directory
|
||||
// 创建备份目录
|
||||
backupDir, err := m.createBackupDir(targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
return fmt.Errorf("创建备份目录失败: %w", err)
|
||||
}
|
||||
|
||||
// Backup existing files before applying update
|
||||
// 在应用更新前备份现有文件
|
||||
if err := m.backupFiles(targetPath, backupDir, changes); err != nil {
|
||||
return fmt.Errorf("failed to backup files: %w", err)
|
||||
return fmt.Errorf("备份文件失败: %w", err)
|
||||
}
|
||||
|
||||
// Apply the update
|
||||
// 应用更新
|
||||
if err := m.applyUpdateFiles(sourcePath, targetPath, changes); err != nil {
|
||||
// Rollback on failure
|
||||
// 失败时回滚
|
||||
if rollbackErr := m.rollbackUpdate(targetPath, backupDir); rollbackErr != nil {
|
||||
return fmt.Errorf("update failed and rollback failed: update error: %w, rollback error: %v", err, rollbackErr)
|
||||
return fmt.Errorf("更新失败且回滚失败: 更新错误: %w, 回滚错误: %v", err, rollbackErr)
|
||||
}
|
||||
return fmt.Errorf("update failed and was rolled back: %w", err)
|
||||
return fmt.Errorf("更新失败已回滚: %w", err)
|
||||
}
|
||||
|
||||
// Clean up backup directory after successful update
|
||||
// 成功更新后清理备份目录
|
||||
if err := os.RemoveAll(backupDir); err != nil {
|
||||
// Log warning but don't fail the update
|
||||
fmt.Printf("Warning: failed to cleanup backup directory %s: %v\n", backupDir, err)
|
||||
// 记录警告但不让更新失败
|
||||
fmt.Printf("警告: 清理备份目录 %s 失败: %v\n", backupDir, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createBackupDir creates a backup directory for the update
|
||||
// createBackupDir 为更新创建备份目录
|
||||
func (m *Manager) createBackupDir(targetPath string) (string, error) {
|
||||
backupDir := filepath.Join(targetPath, ".backup_"+fmt.Sprintf("%d", os.Getpid()))
|
||||
|
||||
if err := os.MkdirAll(backupDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create backup directory: %w", err)
|
||||
return "", fmt.Errorf("创建备份目录失败: %w", err)
|
||||
}
|
||||
|
||||
return backupDir, nil
|
||||
}
|
||||
|
||||
// backupFiles creates backups of files that will be modified or deleted
|
||||
// backupFiles 创建将被修改或删除的文件的备份
|
||||
func (m *Manager) backupFiles(targetPath, backupDir string, changes *ChangesInfo) error {
|
||||
// Backup files that will be modified
|
||||
// 备份将被修改的文件
|
||||
for _, file := range changes.Modified {
|
||||
srcFile := filepath.Join(targetPath, file)
|
||||
if _, err := os.Stat(srcFile); os.IsNotExist(err) {
|
||||
continue // File doesn't exist, skip backup
|
||||
continue // 文件不存在,跳过备份
|
||||
}
|
||||
|
||||
backupFile := filepath.Join(backupDir, file)
|
||||
if err := m.copyFileWithDirs(srcFile, backupFile); err != nil {
|
||||
return fmt.Errorf("failed to backup modified file %s: %w", file, err)
|
||||
return fmt.Errorf("备份修改文件 %s 失败: %w", file, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Backup files that will be deleted
|
||||
// 备份将被删除的文件
|
||||
for _, file := range changes.Deleted {
|
||||
srcFile := filepath.Join(targetPath, file)
|
||||
if _, err := os.Stat(srcFile); os.IsNotExist(err) {
|
||||
continue // File doesn't exist, skip backup
|
||||
continue // 文件不存在,跳过备份
|
||||
}
|
||||
|
||||
backupFile := filepath.Join(backupDir, file)
|
||||
if err := m.copyFileWithDirs(srcFile, backupFile); err != nil {
|
||||
return fmt.Errorf("failed to backup deleted file %s: %w", file, err)
|
||||
return fmt.Errorf("备份删除文件 %s 失败: %w", file, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyUpdateFiles applies the actual file changes
|
||||
// applyUpdateFiles 应用实际的文件更改
|
||||
func (m *Manager) applyUpdateFiles(sourcePath, targetPath string, changes *ChangesInfo) error {
|
||||
// Delete files marked for deletion
|
||||
// 删除标记为删除的文件
|
||||
for _, file := range changes.Deleted {
|
||||
targetFile := filepath.Join(targetPath, file)
|
||||
if err := os.Remove(targetFile); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete file %s: %w", file, err)
|
||||
return fmt.Errorf("删除文件 %s 失败: %w", file, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Copy new and modified files
|
||||
// 复制新文件和修改的文件
|
||||
filesToCopy := append(changes.Added, changes.Modified...)
|
||||
for _, file := range filesToCopy {
|
||||
srcFile := filepath.Join(sourcePath, file)
|
||||
targetFile := filepath.Join(targetPath, file)
|
||||
|
||||
// Check if source file exists
|
||||
// 检查源文件是否存在
|
||||
if _, err := os.Stat(srcFile); os.IsNotExist(err) {
|
||||
continue // Source file doesn't exist, skip
|
||||
continue // 源文件不存在,跳过
|
||||
}
|
||||
|
||||
if err := m.copyFileWithDirs(srcFile, targetFile); err != nil {
|
||||
return fmt.Errorf("failed to copy file %s: %w", file, err)
|
||||
return fmt.Errorf("复制文件 %s 失败: %w", file, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyFileWithDirs copies a file and creates necessary directories
|
||||
// copyFileWithDirs 复制文件并创建必要的目录
|
||||
func (m *Manager) copyFileWithDirs(src, dst string) error {
|
||||
// Create parent directories
|
||||
// 创建父目录
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create parent directories: %w", err)
|
||||
return fmt.Errorf("创建父目录失败: %w", err)
|
||||
}
|
||||
|
||||
// Open source file
|
||||
// 打开源文件
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open source file: %w", err)
|
||||
return fmt.Errorf("打开源文件失败: %w", err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
|
||||
// Get source file info
|
||||
// 获取源文件信息
|
||||
srcInfo, err := srcFile.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get source file info: %w", err)
|
||||
return fmt.Errorf("获取源文件信息失败: %w", err)
|
||||
}
|
||||
|
||||
// Create destination file
|
||||
// 创建目标文件
|
||||
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, srcInfo.Mode())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create destination file: %w", err)
|
||||
return fmt.Errorf("创建目标文件失败: %w", err)
|
||||
}
|
||||
defer dstFile.Close()
|
||||
|
||||
// Copy file contents
|
||||
// 复制文件内容
|
||||
_, err = io.Copy(dstFile, srcFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to copy file contents: %w", err)
|
||||
return fmt.Errorf("复制文件内容失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackUpdate restores files from backup in case of update failure
|
||||
// rollbackUpdate 在更新失败时从备份恢复文件
|
||||
func (m *Manager) rollbackUpdate(targetPath, backupDir string) error {
|
||||
// Walk through backup directory and restore files
|
||||
// 遍历备份目录并恢复文件
|
||||
return filepath.Walk(backupDir, func(backupFile string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil // Skip directories
|
||||
return nil // 跳过目录
|
||||
}
|
||||
|
||||
// Calculate relative path
|
||||
// 计算相对路径
|
||||
relPath, err := filepath.Rel(backupDir, backupFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate relative path: %w", err)
|
||||
return fmt.Errorf("计算相对路径失败: %w", err)
|
||||
}
|
||||
|
||||
// Restore file to target location
|
||||
// 将文件恢复到目标位置
|
||||
targetFile := filepath.Join(targetPath, relPath)
|
||||
if err := m.copyFileWithDirs(backupFile, targetFile); err != nil {
|
||||
return fmt.Errorf("failed to restore file %s: %w", relPath, err)
|
||||
return fmt.Errorf("恢复文件 %s 失败: %w", relPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Integration tests will be implemented here
|
||||
// This file is currently a placeholder
|
||||
// 集成测试将在此处实现
|
||||
// 此文件目前是占位符
|
||||
|
||||
func TestIntegrationPlaceholder(t *testing.T) {
|
||||
t.Skip("Integration tests not yet implemented")
|
||||
t.Skip("集成测试尚未实现")
|
||||
}
|
||||
@@ -64,8 +64,8 @@ type LoggerConfig struct {
|
||||
Level LogLevel
|
||||
MaxSize int64 // 最大文件大小(字节),默认10MB
|
||||
MaxBackups int // 最大备份文件数,默认5
|
||||
LogDir string // 日志目录,默认%APPDATA%/LightweightUpdater/logs
|
||||
Filename string // 日志文件名,默认updater.log
|
||||
LogDir string // 日志目录
|
||||
Filename string // 日志文件名
|
||||
}
|
||||
|
||||
// DefaultLoggerConfig 默认日志配置
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +0,0 @@
|
||||
package utils
|
||||
|
||||
// Package utils provides utility functions for the updater
|
||||
@@ -8,16 +8,16 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"lightweight-updater/logger"
|
||||
"AUTO_MAA_Go_Updater/logger"
|
||||
)
|
||||
|
||||
// VersionInfo represents the version information from version.json
|
||||
// VersionInfo 表示来自 version.json 的版本信息
|
||||
type VersionInfo struct {
|
||||
MainVersion string `json:"main_version"`
|
||||
VersionInfo map[string]map[string][]string `json:"version_info"`
|
||||
}
|
||||
|
||||
// ParsedVersion represents a parsed version with major, minor, patch, and beta components
|
||||
// ParsedVersion 表示解析后的版本,包含主版本号、次版本号、补丁版本号和测试版本号组件
|
||||
type ParsedVersion struct {
|
||||
Major int
|
||||
Minor int
|
||||
@@ -25,13 +25,13 @@ type ParsedVersion struct {
|
||||
Beta int
|
||||
}
|
||||
|
||||
// VersionManager handles version-related operations
|
||||
// VersionManager 处理版本相关操作
|
||||
type VersionManager struct {
|
||||
executableDir string
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
// NewVersionManager creates a new version manager
|
||||
// NewVersionManager 创建新的版本管理器
|
||||
func NewVersionManager() *VersionManager {
|
||||
execPath, _ := os.Executable()
|
||||
execDir := filepath.Dir(execPath)
|
||||
@@ -41,103 +41,93 @@ func NewVersionManager() *VersionManager {
|
||||
}
|
||||
}
|
||||
|
||||
// NewVersionManagerWithLogger creates a new version manager with a custom logger
|
||||
func NewVersionManagerWithLogger(customLogger logger.Logger) *VersionManager {
|
||||
execPath, _ := os.Executable()
|
||||
execDir := filepath.Dir(execPath)
|
||||
return &VersionManager{
|
||||
executableDir: execDir,
|
||||
logger: customLogger,
|
||||
}
|
||||
}
|
||||
|
||||
// createDefaultVersion creates a default version structure with v0.0.0
|
||||
// createDefaultVersion 创建默认版本结构 v0.0.0
|
||||
func (vm *VersionManager) createDefaultVersion() *VersionInfo {
|
||||
return &VersionInfo{
|
||||
MainVersion: "0.0.0.0", // Corresponds to v0.0.0
|
||||
MainVersion: "0.0.0.0", // 对应 v0.0.0
|
||||
VersionInfo: make(map[string]map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadVersionFromFile loads version information from resources/version.json with fallback handling
|
||||
// LoadVersionFromFile 从 resources/version.json 加载版本信息并处理回退
|
||||
func (vm *VersionManager) LoadVersionFromFile() (*VersionInfo, error) {
|
||||
versionPath := filepath.Join(vm.executableDir, "resources", "version.json")
|
||||
|
||||
data, err := os.ReadFile(versionPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
vm.logger.Info("Version file not found at %s, will use default version", versionPath)
|
||||
fmt.Println("未读取到版本信息,使用默认版本进行更新。")
|
||||
return vm.createDefaultVersion(), nil
|
||||
}
|
||||
vm.logger.Warn("Failed to read version file at %s: %v, will use default version", versionPath, err)
|
||||
vm.logger.Warn("读取版本文件 %s 失败: %v,将使用默认版本", versionPath, err)
|
||||
return vm.createDefaultVersion(), nil
|
||||
}
|
||||
|
||||
var versionInfo VersionInfo
|
||||
if err := json.Unmarshal(data, &versionInfo); err != nil {
|
||||
vm.logger.Warn("Failed to parse version file at %s: %v, will use default version", versionPath, err)
|
||||
vm.logger.Warn("解析版本文件 %s 失败: %v,将使用默认版本", versionPath, err)
|
||||
return vm.createDefaultVersion(), nil
|
||||
}
|
||||
|
||||
vm.logger.Debug("Successfully loaded version information from %s", versionPath)
|
||||
vm.logger.Debug("成功从 %s 加载版本信息", versionPath)
|
||||
return &versionInfo, nil
|
||||
}
|
||||
|
||||
// LoadVersionWithDefault loads version information with guaranteed fallback to default
|
||||
// LoadVersionWithDefault 加载版本信息并保证回退到默认版本
|
||||
func (vm *VersionManager) LoadVersionWithDefault() *VersionInfo {
|
||||
versionInfo, err := vm.LoadVersionFromFile()
|
||||
if err != nil {
|
||||
// This should not happen with the updated LoadVersionFromFile, but adding as extra safety
|
||||
vm.logger.Error("Unexpected error loading version file: %v, using default version", err)
|
||||
// 这在更新的 LoadVersionFromFile 中不应该发生,但添加作为额外安全措施
|
||||
vm.logger.Error("加载版本文件时出现意外错误: %v,使用默认版本", err)
|
||||
return vm.createDefaultVersion()
|
||||
}
|
||||
|
||||
// Validate that we have a valid version structure
|
||||
// 验证我们有一个有效的版本结构
|
||||
if versionInfo == nil {
|
||||
vm.logger.Warn("Version info is nil, using default version")
|
||||
vm.logger.Warn("版本信息为空,使用默认版本")
|
||||
return vm.createDefaultVersion()
|
||||
}
|
||||
|
||||
if versionInfo.MainVersion == "" {
|
||||
vm.logger.Warn("Version info has empty main version, using default version")
|
||||
vm.logger.Warn("版本信息主版本为空,使用默认版本")
|
||||
return vm.createDefaultVersion()
|
||||
}
|
||||
|
||||
if versionInfo.VersionInfo == nil {
|
||||
vm.logger.Debug("Version info map is nil, initializing empty map")
|
||||
vm.logger.Debug("版本信息映射为空,初始化空映射")
|
||||
versionInfo.VersionInfo = make(map[string]map[string][]string)
|
||||
}
|
||||
|
||||
return versionInfo
|
||||
}
|
||||
|
||||
// ParseVersion parses a version string like "4.4.1.3" into components
|
||||
// ParseVersion 解析版本字符串如 "4.4.1.3" 为组件
|
||||
func ParseVersion(versionStr string) (*ParsedVersion, error) {
|
||||
parts := strings.Split(versionStr, ".")
|
||||
if len(parts) < 3 || len(parts) > 4 {
|
||||
return nil, fmt.Errorf("invalid version format: %s", versionStr)
|
||||
return nil, fmt.Errorf("无效的版本格式: %s", versionStr)
|
||||
}
|
||||
|
||||
major, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid major version: %s", parts[0])
|
||||
return nil, fmt.Errorf("无效的主版本号: %s", parts[0])
|
||||
}
|
||||
|
||||
minor, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid minor version: %s", parts[1])
|
||||
return nil, fmt.Errorf("无效的次版本号: %s", parts[1])
|
||||
}
|
||||
|
||||
patch, err := strconv.Atoi(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid patch version: %s", parts[2])
|
||||
return nil, fmt.Errorf("无效的补丁版本号: %s", parts[2])
|
||||
}
|
||||
|
||||
beta := 0
|
||||
if len(parts) == 4 {
|
||||
beta, err = strconv.Atoi(parts[3])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid beta version: %s", parts[3])
|
||||
return nil, fmt.Errorf("无效的测试版本号: %s", parts[3])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,7 +139,7 @@ func ParseVersion(versionStr string) (*ParsedVersion, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ToVersionString converts a ParsedVersion back to version string format
|
||||
// ToVersionString 将 ParsedVersion 转换回版本字符串格式
|
||||
func (pv *ParsedVersion) ToVersionString() string {
|
||||
if pv.Beta == 0 {
|
||||
return fmt.Sprintf("%d.%d.%d.0", pv.Major, pv.Minor, pv.Patch)
|
||||
@@ -157,7 +147,7 @@ func (pv *ParsedVersion) ToVersionString() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", pv.Major, pv.Minor, pv.Patch, pv.Beta)
|
||||
}
|
||||
|
||||
// ToDisplayVersion converts version to display format (v4.4.0 or v4.4.1-beta3)
|
||||
// ToDisplayVersion 将版本转换为显示格式 (v4.4.0 或 v4.4.1-beta3)
|
||||
func (pv *ParsedVersion) ToDisplayVersion() string {
|
||||
if pv.Beta == 0 {
|
||||
return fmt.Sprintf("v%d.%d.%d", pv.Major, pv.Minor, pv.Patch)
|
||||
@@ -165,7 +155,7 @@ func (pv *ParsedVersion) ToDisplayVersion() string {
|
||||
return fmt.Sprintf("v%d.%d.%d-beta%d", pv.Major, pv.Minor, pv.Patch, pv.Beta)
|
||||
}
|
||||
|
||||
// GetChannel returns the channel (stable or beta) based on version
|
||||
// GetChannel 根据版本返回渠道 (stable 或 beta)
|
||||
func (pv *ParsedVersion) GetChannel() string {
|
||||
if pv.Beta == 0 {
|
||||
return "stable"
|
||||
@@ -173,12 +163,7 @@ func (pv *ParsedVersion) GetChannel() string {
|
||||
return "beta"
|
||||
}
|
||||
|
||||
// GetDefaultChannel returns the default channel
|
||||
func GetDefaultChannel() string {
|
||||
return "stable"
|
||||
}
|
||||
|
||||
// IsNewer checks if this version is newer than the other version
|
||||
// IsNewer 检查此版本是否比其他版本更新
|
||||
func (pv *ParsedVersion) IsNewer(other *ParsedVersion) bool {
|
||||
if pv.Major != other.Major {
|
||||
return pv.Major > other.Major
|
||||
|
||||
@@ -1,41 +1,19 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var (
|
||||
// Version is the current version of the application
|
||||
// Version 应用程序的当前版本
|
||||
Version = "1.0.0"
|
||||
|
||||
// BuildTime is set during build time
|
||||
// BuildTime 在构建时设置
|
||||
BuildTime = "unknown"
|
||||
|
||||
// GitCommit is set during build time
|
||||
// GitCommit 在构建时设置
|
||||
GitCommit = "unknown"
|
||||
|
||||
// GoVersion is the Go version used to build
|
||||
// GoVersion 用于构建的 Go 版本
|
||||
GoVersion = runtime.Version()
|
||||
)
|
||||
|
||||
// GetVersionInfo returns formatted version information
|
||||
func GetVersionInfo() string {
|
||||
return fmt.Sprintf("Version: %s\nBuild Time: %s\nGit Commit: %s\nGo Version: %s",
|
||||
Version, BuildTime, GitCommit, GoVersion)
|
||||
}
|
||||
|
||||
// GetShortVersion returns just the version number
|
||||
func GetShortVersion() string {
|
||||
return Version
|
||||
}
|
||||
|
||||
// GetBuildInfo returns build-specific information
|
||||
func GetBuildInfo() map[string]string {
|
||||
return map[string]string{
|
||||
"version": Version,
|
||||
"build_time": BuildTime,
|
||||
"git_commit": GitCommit,
|
||||
"go_version": GoVersion,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user