Skip to content

Commit

Permalink
[feat] - Support S3 Source Resumption (#3570)
Browse files Browse the repository at this point in the history
* add config option for s3 resumption

* updates

* initial progress tracking logic

* more testing

* revert s3 source file

* UpdateScanProgress tests

* adjust

* updates

* invert

* updates

* updates

* fix

* update

* adjust test

* fix

* remove progress tracking

* cleanup

* cleanup

* remove dupe

* remove context cancellation logic

* fix comment format

* make resumption logic more clear

* rename

* fixes

* update

* add edge case test

* remove dupe mu

* add comment

* fix comment
  • Loading branch information
ahrav authored Nov 22, 2024
1 parent 9a6cad9 commit e495661
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 37 deletions.
215 changes: 180 additions & 35 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package s3

import (
"fmt"
"slices"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -43,8 +44,10 @@ type Source struct {
jobID sources.JobID
verify bool
concurrency int
conn *sourcespb.S3

checkpointer *Checkpointer
sources.Progress
conn *sourcespb.S3

errorCount *sync.Map
jobPool *errgroup.Group
Expand All @@ -67,7 +70,7 @@ func (s *Source) JobID() sources.JobID { return s.jobID }

// Init returns an initialized AWS source
func (s *Source) Init(
_ context.Context,
ctx context.Context,
name string,
jobID sources.JobID,
sourceID sources.SourceID,
Expand All @@ -90,6 +93,8 @@ func (s *Source) Init(
}
s.conn = &conn

s.checkpointer = NewCheckpointer(ctx, conn.GetEnableResumption(), &s.Progress)

s.setMaxObjectSize(conn.GetMaxObjectSize())

if len(conn.GetBuckets()) > 0 && len(conn.GetIgnoreBuckets()) > 0 {
Expand Down Expand Up @@ -173,9 +178,16 @@ func (s *Source) newClient(region, roleArn string) (*s3.S3, error) {
return s3.New(sess), nil
}

// IAM identity needs s3:ListBuckets permission
// getBucketsToScan returns a list of S3 buckets to scan.
// If the connection has a list of buckets specified, those are returned.
// Otherwise, it lists all buckets the client has access to and filters out the ignored ones.
// The list of buckets is sorted lexicographically to ensure consistent ordering,
// which allows resuming scanning from the same place if the scan is interrupted.
//
// Note: The IAM identity needs the s3:ListBuckets permission.
func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
if buckets := s.conn.GetBuckets(); len(buckets) > 0 {
slices.Sort(buckets)
return buckets, nil
}

Expand All @@ -196,32 +208,122 @@ func (s *Source) getBucketsToScan(client *s3.S3) ([]string, error) {
bucketsToScan = append(bucketsToScan, name)
}
}
slices.Sort(bucketsToScan)

return bucketsToScan, nil
}

// pageMetadata contains metadata about a single page of S3 objects being scanned.
type pageMetadata struct {
bucket string // The name of the S3 bucket being scanned
pageNumber int // Current page number in the pagination sequence
client *s3.S3 // AWS S3 client configured for the appropriate region
page *s3.ListObjectsV2Output // Contains the list of S3 objects in this page
}

// processingState tracks the state of concurrent S3 object processing.
type processingState struct {
errorCount *sync.Map // Thread-safe map tracking errors per prefix
objectCount *uint64 // Total number of objects processed
}

// resumePosition tracks where to restart scanning S3 buckets and objects after an interruption.
// It encapsulates all the information needed to resume a scan from its last known position.
type resumePosition struct {
bucket string // The bucket name we were processing
index int // Index in the buckets slice where we should resume
startAfter string // The last processed object key within the bucket
isNewScan bool // True if we're starting a fresh scan
exactMatch bool // True if we found the exact bucket we were previously processing
}

// determineResumePosition calculates where to resume scanning from based on the last saved checkpoint
// and the current list of available buckets to scan. It handles several scenarios:
//
// 1. If getting the resume point fails or there is no previous bucket saved (CurrentBucket is empty),
// we start a new scan from the beginning, this is the safest option.
//
// 2. If the previous bucket exists in our current scan list (exactMatch=true),
// we resume from that exact position and use the StartAfter value
// to continue from the last processed object within that bucket.
//
// 3. If the previous bucket is not found in our current scan list (exactMatch=false), this typically means:
// - The bucket was deleted since our last scan
// - The bucket was explicitly excluded from this scan's configuration
// - The IAM role no longer has access to the bucket
// - The bucket name changed due to a configuration update
// In this case, we use binary search to find the closest position where the bucket would have been,
// allowing us to resume from the nearest available point in our sorted bucket list rather than
// restarting the entire scan.
func determineResumePosition(ctx context.Context, tracker *Checkpointer, buckets []string) resumePosition {
resumePoint, err := tracker.ResumePoint(ctx)
if err != nil {
ctx.Logger().Error(err, "failed to get resume point; starting from the beginning")
return resumePosition{isNewScan: true}
}

if resumePoint.CurrentBucket == "" {
return resumePosition{isNewScan: true}
}

startIdx, found := slices.BinarySearch(buckets, resumePoint.CurrentBucket)
return resumePosition{
bucket: resumePoint.CurrentBucket,
startAfter: resumePoint.StartAfter,
index: startIdx,
exactMatch: found,
}
}

func (s *Source) scanBuckets(
ctx context.Context,
client *s3.S3,
role string,
bucketsToScan []string,
chunksChan chan *sources.Chunk,
) {
var objectCount uint64

if role != "" {
ctx = context.WithValue(ctx, "role", role)
}
var objectCount uint64

for i, bucket := range bucketsToScan {
pos := determineResumePosition(ctx, s.checkpointer, bucketsToScan)
switch {
case pos.isNewScan:
ctx.Logger().Info("Starting new scan from beginning")
case !pos.exactMatch:
ctx.Logger().Info(
"Resume bucket no longer available, starting from closest position",
"original_bucket", pos.bucket,
"position", pos.index,
)
default:
ctx.Logger().Info(
"Resuming scan from previous scan's bucket",
"bucket", pos.bucket,
"position", pos.index,
)
}

bucketsToScanCount := len(bucketsToScan)
for bucketIdx := pos.index; bucketIdx < bucketsToScanCount; bucketIdx++ {
bucket := bucketsToScan[bucketIdx]
ctx := context.WithValue(ctx, "bucket", bucket)

if common.IsDone(ctx) {
ctx.Logger().Error(ctx.Err(), "context done, while scanning bucket")
return
}

s.SetProgressComplete(i, len(bucketsToScan), fmt.Sprintf("Bucket: %s", bucket), "")
ctx.Logger().V(3).Info("Scanning bucket")

s.SetProgressComplete(
bucketIdx,
len(bucketsToScan),
fmt.Sprintf("Bucket: %s", bucket),
s.Progress.EncodedResumeInfo,
)

regionalClient, err := s.getRegionalClientForBucket(ctx, client, role, bucket)
if err != nil {
ctx.Logger().Error(err, "could not get regional client for bucket")
Expand All @@ -230,10 +332,33 @@ func (s *Source) scanBuckets(

errorCount := sync.Map{}

input := &s3.ListObjectsV2Input{Bucket: &bucket}
if bucket == pos.bucket && pos.startAfter != "" {
input.StartAfter = &pos.startAfter
ctx.Logger().V(3).Info(
"Resuming bucket scan",
"start_after", pos.startAfter,
)
}

pageNumber := 1
err = regionalClient.ListObjectsV2PagesWithContext(
ctx, &s3.ListObjectsV2Input{Bucket: &bucket},
ctx,
input,
func(page *s3.ListObjectsV2Output, _ bool) bool {
s.pageChunker(ctx, regionalClient, chunksChan, bucket, page, &errorCount, i+1, &objectCount)
pageMetadata := pageMetadata{
bucket: bucket,
pageNumber: pageNumber,
client: regionalClient,
page: page,
}
processingState := processingState{
errorCount: &errorCount,
objectCount: &objectCount,
}
s.pageChunker(ctx, pageMetadata, processingState, chunksChan)

pageNumber++
return true
})

Expand All @@ -249,6 +374,7 @@ func (s *Source) scanBuckets(
}
}
}

s.SetProgressComplete(
len(bucketsToScan),
len(bucketsToScan),
Expand Down Expand Up @@ -289,29 +415,25 @@ func (s *Source) getRegionalClientForBucket(
return regionalClient, nil
}

// pageChunker emits chunks onto the given channel from a page
// pageChunker emits chunks onto the given channel from a page.
func (s *Source) pageChunker(
ctx context.Context,
client *s3.S3,
metadata pageMetadata,
state processingState,
chunksChan chan *sources.Chunk,
bucket string,
page *s3.ListObjectsV2Output,
errorCount *sync.Map,
pageNumber int,
objectCount *uint64,
) {
for _, obj := range page.Contents {
s.checkpointer.Reset() // Reset the checkpointer for each PAGE
ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber)

for objIdx, obj := range metadata.page.Contents {
if obj == nil {
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for nil object")
}
continue
}

ctx = context.WithValues(
ctx,
"key", *obj.Key,
"bucket", bucket,
"page", pageNumber,
"size", *obj.Size,
)
ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size)

if common.IsDone(ctx) {
return
Expand All @@ -320,29 +442,44 @@ func (s *Source) pageChunker(
// Skip GLACIER and GLACIER_IR objects.
if obj.StorageClass == nil || strings.Contains(*obj.StorageClass, "GLACIER") {
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", *obj.StorageClass)
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for glacier object")
}
continue
}

// Ignore large files.
if *obj.Size > s.maxObjectSize {
ctx.Logger().V(5).Info("Skipping %d byte file (over maxObjectSize limit)")
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for large file")
}
continue
}

// File empty file.
if *obj.Size == 0 {
ctx.Logger().V(5).Info("Skipping empty file")
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for empty file")
}
continue
}

// Skip incompatible extensions.
if common.SkipFile(*obj.Key) {
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for incompatible file")
}
continue
}

s.jobPool.Go(func() error {
defer common.RecoverWithExit(ctx)
if common.IsDone(ctx) {
return ctx.Err()
}

if strings.HasSuffix(*obj.Key, "/") {
ctx.Logger().V(5).Info("Skipping directory")
Expand All @@ -352,7 +489,7 @@ func (s *Source) pageChunker(
path := strings.Split(*obj.Key, "/")
prefix := strings.Join(path[:len(path)-1], "/")

nErr, ok := errorCount.Load(prefix)
nErr, ok := state.errorCount.Load(prefix)
if !ok {
nErr = 0
}
Expand All @@ -366,8 +503,8 @@ func (s *Source) pageChunker(
objCtx, cancel := context.WithTimeout(ctx, getObjectTimeout)
defer cancel()

res, err := client.GetObjectWithContext(objCtx, &s3.GetObjectInput{
Bucket: &bucket,
res, err := metadata.client.GetObjectWithContext(objCtx, &s3.GetObjectInput{
Bucket: &metadata.bucket,
Key: obj.Key,
})
if err != nil {
Expand All @@ -382,7 +519,7 @@ func (s *Source) pageChunker(
res.Body.Close()
}

nErr, ok := errorCount.Load(prefix)
nErr, ok := state.errorCount.Load(prefix)
if !ok {
nErr = 0
}
Expand All @@ -391,7 +528,7 @@ func (s *Source) pageChunker(
return nil
}
nErr = nErr.(int) + 1
errorCount.Store(prefix, nErr)
state.errorCount.Store(prefix, nErr)
// too many consecutive errors on this page
if nErr.(int) > 3 {
ctx.Logger().V(2).Info("Too many consecutive errors, excluding prefix", "prefix", prefix)
Expand All @@ -413,9 +550,9 @@ func (s *Source) pageChunker(
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_S3{
S3: &source_metadatapb.S3{
Bucket: bucket,
Bucket: metadata.bucket,
File: sanitizer.UTF8(*obj.Key),
Link: sanitizer.UTF8(makeS3Link(bucket, *client.Config.Region, *obj.Key)),
Link: sanitizer.UTF8(makeS3Link(metadata.bucket, *metadata.client.Config.Region, *obj.Key)),
Email: sanitizer.UTF8(email),
Timestamp: sanitizer.UTF8(modified),
},
Expand All @@ -429,14 +566,19 @@ func (s *Source) pageChunker(
return nil
}

atomic.AddUint64(objectCount, 1)
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", objectCount)
nErr, ok = errorCount.Load(prefix)
atomic.AddUint64(state.objectCount, 1)
ctx.Logger().V(5).Info("S3 object scanned.", "object_count", state.objectCount)
nErr, ok = state.errorCount.Load(prefix)
if !ok {
nErr = 0
}
if nErr.(int) > 0 {
errorCount.Store(prefix, 0)
state.errorCount.Store(prefix, 0)
}

// Update progress after successful processing.
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for scanned object")
}

return nil
Expand Down Expand Up @@ -485,6 +627,9 @@ func (s *Source) validateBucketAccess(ctx context.Context, client *s3.S3, roleAr
// for each role, passing in the default S3 client, the role ARN, and the list of
// buckets to scan.
//
// The provided function parameter typically implements the core scanning logic
// and must handle context cancellation appropriately.
//
// If no roles are configured, it will call the function with an empty role ARN.
func (s *Source) visitRoles(
ctx context.Context,
Expand Down
Loading

0 comments on commit e495661

Please sign in to comment.