由于公司销售的通话数据量比较大,导致数据库存储不够,需要迁移,使用Navicat 迁移总是迁移到一小部分就中断,故此写了次脚本,以共大家参考

  1. 重试机制,当本次数据同步失败时,脚本会自动进行重试同步

  2. 断点续传,如果同步中途失败,脚本会自动从上次失败地方进行同步

  3. 采用多线程进行同步,加快同步速度

package main

import (
	"context"
	"encoding/json"
	"fmt"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/bson/primitive"
	"go.mongodb.org/mongo-driver/mongo"
	"go.mongodb.org/mongo-driver/mongo/options"
	"log"
	"os"
	"runtime"
	"sync"
	"time"
)

// Config 结构体现在包含了所有配置项
var config = struct {
	SourceURI        string
	TargetURI        string
	SourceDatabase   string
	TargetDatabase   string
	SourceCollection string
	TargetCollection string
	BatchSize        int32
	Timeout          time.Duration
	MaxRetries       int
	MaxPoolSize      uint64
	WorkerCount      int
	CursorTimeout    time.Duration
}{
	SourceURI:        "mongodb://root:xxxx@127.0.0.1:3717?replicaSet=mgset-81589849&authSource=admin",
	TargetURI:        "mongodb://root:xxxx@127.0.0.1:3717?replicaSet=mgset-81592023&authSource=admin",
	SourceDatabase:   "admin",
	TargetDatabase:   "admin",
	SourceCollection: "xxx",
	TargetCollection: "xxx",
	BatchSize:        100,
	Timeout:          60 * time.Second,
	MaxRetries:       3,
	MaxPoolSize:      100,
	WorkerCount:      4,
	CursorTimeout:    10 * time.Minute, // 添加游标超时时间
}

type SyncState struct {
	LastProcessedID string
	ProcessedCount  int64
}

var (
	syncState  SyncState
	stateMutex sync.Mutex
)

func main() {
	// 加载同步状态
	if err := loadSyncState("sync_state.json"); err != nil {
		log.Printf("Failed to load sync state, starting from beginning: %v", err)
	}

	// 连接到源MongoDB和目标MongoDB
	sourceClient, err := connectMongoDB(config.SourceURI)
	if err != nil {
		log.Fatalf("Failed to connect to source MongoDB: %v", err)
	}
	defer sourceClient.Disconnect(context.Background())

	targetClient, err := connectMongoDB(config.TargetURI)
	if err != nil {
		log.Fatalf("Failed to connect to target MongoDB: %v", err)
	}
	defer targetClient.Disconnect(context.Background())

	// 获取源集合和目标集合
	sourceColl := sourceClient.Database(config.SourceDatabase).Collection(config.SourceCollection)
	targetColl := targetClient.Database(config.TargetDatabase).Collection(config.TargetCollection)

	// 获取源集合的总文档数
	ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
	defer cancel()
	totalCount, err := sourceColl.CountDocuments(ctx, bson.M{})
	if err != nil {
		log.Fatalf("Failed to count documents in source collection: %v", err)
	}

	log.Printf("Total documents to sync: %d", totalCount)
	log.Printf("Resuming from document ID: %v", syncState.LastProcessedID)

	// 创建工作通道和等待组
	jobs := make(chan bson.M, config.BatchSize)
	var wg sync.WaitGroup

	// 启动工作协程
	for i := 0; i < config.WorkerCount; i++ {
		wg.Add(1)
		go worker(targetColl, jobs, &wg)
	}

	// 主同步循环
	for {
		if err := syncBatch(sourceColl, jobs); err != nil {
			log.Printf("Error during sync batch: %v. Retrying...", err)
			time.Sleep(5 * time.Second) // 等待一段时间后重试
			continue
		}
		if syncState.ProcessedCount >= totalCount {
			break
		}
	}

	close(jobs)
	wg.Wait()

	// 数据验证
	if err := validateSync(sourceColl, targetColl); err != nil {
		log.Printf("Data validation failed: %v", err)
	} else {
		log.Println("Data validation successful")
	}

	log.Println("Data synchronization completed.")
}

func syncBatch(sourceColl *mongo.Collection, jobs chan<- bson.M) error {
	ctx, cancel := context.WithTimeout(context.Background(), config.CursorTimeout)
	defer cancel()

	findOptions := options.Find().SetBatchSize(config.BatchSize)
	var filter bson.M
	if syncState.LastProcessedID != "" {
		objectID, err := primitive.ObjectIDFromHex(syncState.LastProcessedID)
		if err != nil {
			log.Printf("Invalid LastProcessedID: %v, starting from beginning", err)
			filter = bson.M{}
		} else {
			filter = bson.M{"_id": bson.M{"$gt": objectID}}
		}
	} else {
		filter = bson.M{}
	}
	findOptions.SetSort(bson.M{"_id": 1})

	cursor, err := sourceColl.Find(ctx, filter, findOptions)
	if err != nil {
		return fmt.Errorf("failed to find documents in source collection: %v", err)
	}
	defer cursor.Close(ctx)

	for cursor.Next(ctx) {
		var doc bson.M
		if err := cursor.Decode(&doc); err != nil {
			log.Printf("Failed to decode document: %v", err)
			continue
		}
		jobs <- doc

		stateMutex.Lock()
		if oid, ok := doc["_id"].(primitive.ObjectID); ok {
			syncState.LastProcessedID = oid.Hex()
		}
		syncState.ProcessedCount++
		stateMutex.Unlock()

		if syncState.ProcessedCount%1000 == 0 {
			saveSyncState("sync_state.json")
			printProgress(syncState.ProcessedCount, int32(cursor.RemainingBatchLength()))
		}
	}

	if err := cursor.Err(); err != nil {
		return fmt.Errorf("cursor error: %v", err)
	}

	return nil
}

func connectMongoDB(uri string) (*mongo.Client, error) {
	ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
	defer cancel()

	clientOptions := options.Client().ApplyURI(uri).SetMaxPoolSize(config.MaxPoolSize)
	client, err := mongo.Connect(ctx, clientOptions)
	if err != nil {
		return nil, fmt.Errorf("failed to create MongoDB client: %v", err)
	}

	err = client.Ping(ctx, nil)
	if err != nil {
		return nil, fmt.Errorf("failed to ping MongoDB: %v", err)
	}

	return client, nil
}

func worker(targetColl *mongo.Collection, jobs <-chan bson.M, wg *sync.WaitGroup) {
	defer wg.Done()

	var batch []interface{}
	for doc := range jobs {
		batch = append(batch, doc)

		if len(batch) >= int(config.BatchSize) {
			if err := retryInsertBatch(context.Background(), targetColl, batch); err != nil {
				log.Printf("Failed to insert batch: %v", err)
			}
			batch = batch[:0]
		}
	}

	if len(batch) > 0 {
		if err := retryInsertBatch(context.Background(), targetColl, batch); err != nil {
			log.Printf("Failed to insert final batch: %v", err)
		}
	}
}

func retryInsertBatch(ctx context.Context, coll *mongo.Collection, batch []interface{}) error {
	var err error
	for i := 0; i < config.MaxRetries; i++ {
		_, err = coll.InsertMany(ctx, batch)
		if err == nil {
			return nil
		}
		log.Printf("Retry %d: Failed to insert batch: %v", i+1, err)
		time.Sleep(time.Second * time.Duration(i+1)) // 指数退避
	}
	return fmt.Errorf("failed to insert batch after %d retries: %v", config.MaxRetries, err)
}

func validateSync(sourceColl, targetColl *mongo.Collection) error {
	ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
	defer cancel()

	sourceCount, err := sourceColl.CountDocuments(ctx, bson.M{})
	if err != nil {
		return fmt.Errorf("failed to count source documents: %v", err)
	}

	targetCount, err := targetColl.CountDocuments(ctx, bson.M{})
	if err != nil {
		return fmt.Errorf("failed to count target documents: %v", err)
	}

	if sourceCount != targetCount {
		return fmt.Errorf("document count mismatch: source=%d, target=%d", sourceCount, targetCount)
	}

	return nil
}

func printProgress(processedCount int64, remainingBatch int32) {
	fmt.Printf("\rProcessed: %d, Remaining in batch: %d", processedCount, remainingBatch)
}

func loadSyncState(filename string) error {
	file, err := os.Open(filename)
	if err != nil {
		if os.IsNotExist(err) {
			return nil // 文件不存在,使用默认状态
		}
		return err
	}
	defer file.Close()

	decoder := json.NewDecoder(file)
	return decoder.Decode(&syncState)
}

func saveSyncState(filename string) error {
	file, err := os.Create(filename)
	if err != nil {
		return err
	}
	defer file.Close()

	encoder := json.NewEncoder(file)
	return encoder.Encode(syncState)
}

func init() {
	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
	runtime.GOMAXPROCS(runtime.NumCPU())
}