// Copyright © 2023 OpenIM. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pkg import ( "bufio" "context" "crypto/md5" "encoding/hex" "encoding/json" "fmt" "io" "log" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/openimsdk/tools/errs" "github.com/openimsdk/protocol/third" ) type Upload struct { URL string `json:"url"` Name string `json:"name"` ContentType string `json:"contentType"` } type Task struct { Index int Upload Upload } type PartInfo struct { ContentType string PartSize int64 PartNum int FileMd5 string PartMd5 string PartSizes []int64 PartMd5s []string } func Run(conf Config) error { m := &Manage{ prefix: time.Now().Format("20060102150405"), conf: &conf, ctx: context.Background(), } return m.Run() } type Manage struct { conf *Config ctx context.Context api *Api partLimit *third.PartLimitResp prefix string tasks chan Task id uint64 success int64 failed int64 } func (m *Manage) tempFilePath() string { return filepath.Join(m.conf.TempDir, fmt.Sprintf("%s_%d", m.prefix, atomic.AddUint64(&m.id, 1))) } func (m *Manage) Run() error { defer func(start time.Time) { log.Printf("run time %s\n", time.Since(start)) }(time.Now()) m.api = &Api{ Api: m.conf.Api, UserID: m.conf.UserID, Secret: m.conf.Secret, Client: &http.Client{Timeout: m.conf.Timeout}, } var err error ctx := context.WithValue(m.ctx, "operationID", fmt.Sprintf("%s_init", m.prefix)) m.api.Token, err = m.api.GetAdminToken(ctx) if err != nil { return err } m.partLimit, err = m.api.GetPartLimit(ctx) if err != nil { return err } progress, err := ReadProgress(m.conf.ProgressPath) if err != nil { return err } progressFile, err := os.OpenFile(m.conf.ProgressPath, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0666) if err != nil { return err } var mutex sync.Mutex writeSuccessIndex := func(index int) { mutex.Lock() defer mutex.Unlock() if _, err := progressFile.Write([]byte(strconv.Itoa(index) + "\n")); err != nil { log.Printf("write progress err: %v\n", err) } } file, err := os.Open(m.conf.TaskPath) if err != nil { return err } m.tasks = make(chan Task, m.conf.Concurrency*2) go func() { defer file.Close() defer close(m.tasks) scanner := bufio.NewScanner(file) var ( index int num int ) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" { continue } index++ if progress.IsUploaded(index) { log.Printf("index: %d already uploaded %s\n", index, line) continue } var upload Upload if err := json.Unmarshal([]byte(line), &upload); err != nil { log.Printf("index: %d json.Unmarshal(%s) err: %v", index, line, err) continue } num++ m.tasks <- Task{ Index: index, Upload: upload, } } if num == 0 { log.Println("mark all completed") } }() var wg sync.WaitGroup wg.Add(m.conf.Concurrency) for i := 0; i < m.conf.Concurrency; i++ { go func(tid int) { defer wg.Done() for task := range m.tasks { var success bool for n := 0; n < m.conf.Retry; n++ { ctx := context.WithValue(m.ctx, "operationID", fmt.Sprintf("%s_%d_%d_%d", m.prefix, tid, task.Index, n+1)) if urlRaw, err := m.RunTask(ctx, task); err == nil { writeSuccessIndex(task.Index) log.Println("index:", task.Index, "upload success", "urlRaw", urlRaw) success = true break } else { log.Printf("index: %d upload: %+v err: %v", task.Index, task.Upload, err) } } if success { atomic.AddInt64(&m.success, 1) } else { atomic.AddInt64(&m.failed, 1) log.Printf("index: %d upload: %+v failed", task.Index, task.Upload) } } }(i + 1) } wg.Wait() log.Printf("execution completed success %d failed %d\n", m.success, m.failed) return nil } func (m *Manage) RunTask(ctx context.Context, task Task) (string, error) { resp, err := m.HttpGet(ctx, task.Upload.URL) if err != nil { return "", err } defer resp.Body.Close() reader, err := NewReader(resp.Body, m.conf.CacheSize, m.tempFilePath()) if err != nil { return "", err } defer reader.Close() part, err := m.getPartInfo(ctx, reader, reader.Size()) if err != nil { return "", err } var contentType string if task.Upload.ContentType == "" { contentType = part.ContentType } else { contentType = task.Upload.ContentType } initiateMultipartUploadResp, err := m.api.InitiateMultipartUpload(ctx, &third.InitiateMultipartUploadReq{ Hash: part.PartMd5, Size: reader.Size(), PartSize: part.PartSize, MaxParts: -1, Cause: "batch-import", Name: task.Upload.Name, ContentType: contentType, }) if err != nil { return "", err } if initiateMultipartUploadResp.Upload == nil { return initiateMultipartUploadResp.Url, nil } if _, err := reader.Seek(0, io.SeekStart); err != nil { return "", err } uploadParts := make([]*third.SignPart, part.PartNum) for _, part := range initiateMultipartUploadResp.Upload.Sign.Parts { uploadParts[part.PartNumber-1] = part } for i, currentPartSize := range part.PartSizes { md5Reader := NewMd5Reader(io.LimitReader(reader, currentPartSize)) if err := m.doPut(ctx, m.api.Client, initiateMultipartUploadResp.Upload.Sign, uploadParts[i], md5Reader, currentPartSize); err != nil { return "", err } if md5val := md5Reader.Md5(); md5val != part.PartMd5s[i] { return "", fmt.Errorf("upload part %d failed, md5 not match, expect %s, got %s", i, part.PartMd5s[i], md5val) } } urlRaw, err := m.api.CompleteMultipartUpload(ctx, &third.CompleteMultipartUploadReq{ UploadID: initiateMultipartUploadResp.Upload.UploadID, Parts: part.PartMd5s, Name: task.Upload.Name, ContentType: contentType, Cause: "batch-import", }) if err != nil { return "", err } return urlRaw, nil } func (m *Manage) partSize(size int64) (int64, error) { if size <= 0 { return 0, errs.New("size must be greater than 0") } if size > m.partLimit.MaxPartSize*int64(m.partLimit.MaxNumSize) { return 0, errs.New("size must be less than", "size", m.partLimit.MaxPartSize*int64(m.partLimit.MaxNumSize)) } if size <= m.partLimit.MinPartSize*int64(m.partLimit.MaxNumSize) { return m.partLimit.MinPartSize, nil } partSize := size / int64(m.partLimit.MaxNumSize) if size%int64(m.partLimit.MaxNumSize) != 0 { partSize++ } return partSize, nil } func (m *Manage) partMD5(parts []string) string { s := strings.Join(parts, ",") md5Sum := md5.Sum([]byte(s)) return hex.EncodeToString(md5Sum[:]) } func (m *Manage) getPartInfo(ctx context.Context, r io.Reader, fileSize int64) (*PartInfo, error) { partSize, err := m.partSize(fileSize) if err != nil { return nil, err } partNum := int(fileSize / partSize) if fileSize%partSize != 0 { partNum++ } partSizes := make([]int64, partNum) for i := 0; i < partNum; i++ { partSizes[i] = partSize } partSizes[partNum-1] = fileSize - partSize*(int64(partNum)-1) partMd5s := make([]string, partNum) buf := make([]byte, 1024*8) fileMd5 := md5.New() var contentType string for i := 0; i < partNum; i++ { h := md5.New() r := io.LimitReader(r, partSize) for { if n, err := r.Read(buf); err == nil { if contentType == "" { contentType = http.DetectContentType(buf[:n]) } h.Write(buf[:n]) fileMd5.Write(buf[:n]) } else if err == io.EOF { break } else { return nil, err } } partMd5s[i] = hex.EncodeToString(h.Sum(nil)) } partMd5Val := m.partMD5(partMd5s) fileMd5val := hex.EncodeToString(fileMd5.Sum(nil)) return &PartInfo{ ContentType: contentType, PartSize: partSize, PartNum: partNum, FileMd5: fileMd5val, PartMd5: partMd5Val, PartSizes: partSizes, PartMd5s: partMd5s, }, nil } func (m *Manage) doPut(ctx context.Context, client *http.Client, sign *third.AuthSignParts, part *third.SignPart, reader io.Reader, size int64) error { rawURL := part.Url if rawURL == "" { rawURL = sign.Url } if len(sign.Query)+len(part.Query) > 0 { u, err := url.Parse(rawURL) if err != nil { return err } query := u.Query() for i := range sign.Query { v := sign.Query[i] query[v.Key] = v.Values } for i := range part.Query { v := part.Query[i] query[v.Key] = v.Values } u.RawQuery = query.Encode() rawURL = u.String() } req, err := http.NewRequestWithContext(ctx, http.MethodPut, rawURL, reader) if err != nil { return err } for i := range sign.Header { v := sign.Header[i] req.Header[v.Key] = v.Values } for i := range part.Header { v := part.Header[i] req.Header[v.Key] = v.Values } req.ContentLength = size resp, err := client.Do(req) if err != nil { return err } defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return err } if resp.StatusCode/200 != 1 { return fmt.Errorf("PUT %s part %d failed, status code %d, body %s", rawURL, part.PartNumber, resp.StatusCode, string(body)) } return nil } func (m *Manage) HttpGet(ctx context.Context, url string) (*http.Response, error) { reqUrl := url for { request, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, nil) if err != nil { return nil, err } DefaultRequestHeader(request.Header) response, err := m.api.Client.Do(request) if err != nil { return nil, err } if response.StatusCode != http.StatusOK { _ = response.Body.Close() return nil, fmt.Errorf("webhook get %s status %s", url, response.Status) } return response, nil } }