You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
386 lines
9.3 KiB
386 lines
9.3 KiB
1 year ago
|
package pkg
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"context"
|
||
|
"crypto/md5"
|
||
|
"encoding/hex"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"log"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/OpenIMSDK/protocol/third"
|
||
|
)
|
||
|
|
||
|
type Upload struct {
|
||
|
URL string `json:"url"`
|
||
|
Name string `json:"name"`
|
||
|
ContentType string `json:"contentType"`
|
||
|
}
|
||
|
|
||
|
type Task struct {
|
||
|
Index int
|
||
|
Upload Upload
|
||
|
}
|
||
|
|
||
|
type PartInfo struct {
|
||
|
ContentType string
|
||
|
PartSize int64
|
||
|
PartNum int
|
||
|
FileMd5 string
|
||
|
PartMd5 string
|
||
|
PartSizes []int64
|
||
|
PartMd5s []string
|
||
|
}
|
||
|
|
||
|
func Run(conf Config) error {
|
||
|
m := &Manage{
|
||
|
prefix: time.Now().Format("20060102150405"),
|
||
|
conf: &conf,
|
||
|
ctx: context.Background(),
|
||
|
}
|
||
|
return m.Run()
|
||
|
}
|
||
|
|
||
|
type Manage struct {
|
||
|
conf *Config
|
||
|
ctx context.Context
|
||
|
api *Api
|
||
|
partLimit *third.PartLimitResp
|
||
|
prefix string
|
||
|
tasks chan Task
|
||
|
id uint64
|
||
|
success int64
|
||
|
failed int64
|
||
|
}
|
||
|
|
||
|
func (m *Manage) tempFilePath() string {
|
||
|
return filepath.Join(m.conf.TempDir, fmt.Sprintf("%s_%d", m.prefix, atomic.AddUint64(&m.id, 1)))
|
||
|
}
|
||
|
|
||
|
func (m *Manage) Run() error {
|
||
|
defer func(start time.Time) {
|
||
|
log.Printf("run time %s\n", time.Since(start))
|
||
|
}(time.Now())
|
||
|
m.api = &Api{
|
||
|
Api: m.conf.Api,
|
||
|
UserID: m.conf.UserID,
|
||
|
Secret: m.conf.Secret,
|
||
|
Client: &http.Client{Timeout: m.conf.Timeout},
|
||
|
}
|
||
|
var err error
|
||
|
ctx := context.WithValue(m.ctx, "operationID", fmt.Sprintf("%s_init", m.prefix))
|
||
|
m.api.Token, err = m.api.GetToken(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
m.partLimit, err = m.api.GetPartLimit(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
progress, err := ReadProgress(m.conf.ProgressPath)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
progressFile, err := os.OpenFile(m.conf.ProgressPath, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0666)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
var mutex sync.Mutex
|
||
|
writeSuccessIndex := func(index int) {
|
||
|
mutex.Lock()
|
||
|
defer mutex.Unlock()
|
||
|
if _, err := progressFile.Write([]byte(strconv.Itoa(index) + "\n")); err != nil {
|
||
|
log.Printf("write progress err: %v\n", err)
|
||
|
}
|
||
|
}
|
||
|
file, err := os.Open(m.conf.TaskPath)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
m.tasks = make(chan Task, m.conf.Concurrency*2)
|
||
|
go func() {
|
||
|
defer file.Close()
|
||
|
defer close(m.tasks)
|
||
|
scanner := bufio.NewScanner(file)
|
||
|
var (
|
||
|
index int
|
||
|
num int
|
||
|
)
|
||
|
for scanner.Scan() {
|
||
|
line := strings.TrimSpace(scanner.Text())
|
||
|
if line == "" {
|
||
|
continue
|
||
|
}
|
||
|
index++
|
||
|
if progress.IsUploaded(index) {
|
||
|
log.Printf("index: %d already uploaded %s\n", index, line)
|
||
|
continue
|
||
|
}
|
||
|
var upload Upload
|
||
|
if err := json.Unmarshal([]byte(line), &upload); err != nil {
|
||
|
log.Printf("index: %d json.Unmarshal(%s) err: %v", index, line, err)
|
||
|
continue
|
||
|
}
|
||
|
num++
|
||
|
m.tasks <- Task{
|
||
|
Index: index,
|
||
|
Upload: upload,
|
||
|
}
|
||
|
}
|
||
|
if num == 0 {
|
||
|
log.Println("mark all completed")
|
||
|
}
|
||
|
}()
|
||
|
var wg sync.WaitGroup
|
||
|
wg.Add(m.conf.Concurrency)
|
||
|
for i := 0; i < m.conf.Concurrency; i++ {
|
||
|
go func(tid int) {
|
||
|
defer wg.Done()
|
||
|
for task := range m.tasks {
|
||
|
var success bool
|
||
|
for n := 0; n < m.conf.Retry; n++ {
|
||
|
ctx := context.WithValue(m.ctx, "operationID", fmt.Sprintf("%s_%d_%d_%d", m.prefix, tid, task.Index, n+1))
|
||
|
if urlRaw, err := m.RunTask(ctx, task); err == nil {
|
||
|
writeSuccessIndex(task.Index)
|
||
|
log.Println("index:", task.Index, "upload success", "urlRaw", urlRaw)
|
||
|
success = true
|
||
|
break
|
||
|
} else {
|
||
|
log.Printf("index: %d upload: %+v err: %v", task.Index, task.Upload, err)
|
||
|
}
|
||
|
}
|
||
|
if success {
|
||
|
atomic.AddInt64(&m.success, 1)
|
||
|
} else {
|
||
|
atomic.AddInt64(&m.failed, 1)
|
||
|
log.Printf("index: %d upload: %+v failed", task.Index, task.Upload)
|
||
|
}
|
||
|
}
|
||
|
}(i + 1)
|
||
|
}
|
||
|
wg.Wait()
|
||
|
log.Printf("execution completed success %d failed %d\n", m.success, m.failed)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *Manage) RunTask(ctx context.Context, task Task) (string, error) {
|
||
|
resp, err := m.HttpGet(ctx, task.Upload.URL)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
reader, err := NewReader(resp.Body, m.conf.CacheSize, m.tempFilePath())
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
defer reader.Close()
|
||
|
part, err := m.getPartInfo(ctx, reader, reader.Size())
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
var contentType string
|
||
|
if task.Upload.ContentType == "" {
|
||
|
contentType = part.ContentType
|
||
|
} else {
|
||
|
contentType = task.Upload.ContentType
|
||
|
}
|
||
|
initiateMultipartUploadResp, err := m.api.InitiateMultipartUpload(ctx, &third.InitiateMultipartUploadReq{
|
||
|
Hash: part.PartMd5,
|
||
|
Size: reader.Size(),
|
||
|
PartSize: part.PartSize,
|
||
|
MaxParts: -1,
|
||
|
Cause: "batch-import",
|
||
|
Name: task.Upload.Name,
|
||
|
ContentType: contentType,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
if initiateMultipartUploadResp.Upload == nil {
|
||
|
return initiateMultipartUploadResp.Url, nil
|
||
|
}
|
||
|
if _, err := reader.Seek(0, io.SeekStart); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
uploadParts := make([]*third.SignPart, part.PartNum)
|
||
|
for _, part := range initiateMultipartUploadResp.Upload.Sign.Parts {
|
||
|
uploadParts[part.PartNumber-1] = part
|
||
|
}
|
||
|
for i, currentPartSize := range part.PartSizes {
|
||
|
md5Reader := NewMd5Reader(io.LimitReader(reader, currentPartSize))
|
||
|
if m.doPut(ctx, m.api.Client, initiateMultipartUploadResp.Upload.Sign, uploadParts[i], md5Reader, currentPartSize); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
if md5val := md5Reader.Md5(); md5val != part.PartMd5s[i] {
|
||
|
return "", fmt.Errorf("upload part %d failed, md5 not match, expect %s, got %s", i, part.PartMd5s[i], md5val)
|
||
|
}
|
||
|
}
|
||
|
urlRaw, err := m.api.CompleteMultipartUpload(ctx, &third.CompleteMultipartUploadReq{
|
||
|
UploadID: initiateMultipartUploadResp.Upload.UploadID,
|
||
|
Parts: part.PartMd5s,
|
||
|
Name: task.Upload.Name,
|
||
|
ContentType: contentType,
|
||
|
Cause: "batch-import",
|
||
|
})
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return urlRaw, nil
|
||
|
}
|
||
|
|
||
|
func (m *Manage) partSize(size int64) (int64, error) {
|
||
|
if size <= 0 {
|
||
|
return 0, errors.New("size must be greater than 0")
|
||
|
}
|
||
|
if size > m.partLimit.MaxPartSize*int64(m.partLimit.MaxNumSize) {
|
||
|
return 0, fmt.Errorf("size must be less than %db", m.partLimit.MaxPartSize*int64(m.partLimit.MaxNumSize))
|
||
|
}
|
||
|
if size <= m.partLimit.MinPartSize*int64(m.partLimit.MaxNumSize) {
|
||
|
return m.partLimit.MinPartSize, nil
|
||
|
}
|
||
|
partSize := size / int64(m.partLimit.MaxNumSize)
|
||
|
if size%int64(m.partLimit.MaxNumSize) != 0 {
|
||
|
partSize++
|
||
|
}
|
||
|
return partSize, nil
|
||
|
}
|
||
|
|
||
|
func (m *Manage) partMD5(parts []string) string {
|
||
|
s := strings.Join(parts, ",")
|
||
|
md5Sum := md5.Sum([]byte(s))
|
||
|
return hex.EncodeToString(md5Sum[:])
|
||
|
}
|
||
|
|
||
|
func (m *Manage) getPartInfo(ctx context.Context, r io.Reader, fileSize int64) (*PartInfo, error) {
|
||
|
partSize, err := m.partSize(fileSize)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
partNum := int(fileSize / partSize)
|
||
|
if fileSize%partSize != 0 {
|
||
|
partNum++
|
||
|
}
|
||
|
partSizes := make([]int64, partNum)
|
||
|
for i := 0; i < partNum; i++ {
|
||
|
partSizes[i] = partSize
|
||
|
}
|
||
|
partSizes[partNum-1] = fileSize - partSize*(int64(partNum)-1)
|
||
|
partMd5s := make([]string, partNum)
|
||
|
buf := make([]byte, 1024*8)
|
||
|
fileMd5 := md5.New()
|
||
|
var contentType string
|
||
|
for i := 0; i < partNum; i++ {
|
||
|
h := md5.New()
|
||
|
r := io.LimitReader(r, partSize)
|
||
|
for {
|
||
|
if n, err := r.Read(buf); err == nil {
|
||
|
if contentType == "" {
|
||
|
contentType = http.DetectContentType(buf[:n])
|
||
|
}
|
||
|
h.Write(buf[:n])
|
||
|
fileMd5.Write(buf[:n])
|
||
|
} else if err == io.EOF {
|
||
|
break
|
||
|
} else {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
partMd5s[i] = hex.EncodeToString(h.Sum(nil))
|
||
|
}
|
||
|
partMd5Val := m.partMD5(partMd5s)
|
||
|
fileMd5val := hex.EncodeToString(fileMd5.Sum(nil))
|
||
|
return &PartInfo{
|
||
|
ContentType: contentType,
|
||
|
PartSize: partSize,
|
||
|
PartNum: partNum,
|
||
|
FileMd5: fileMd5val,
|
||
|
PartMd5: partMd5Val,
|
||
|
PartSizes: partSizes,
|
||
|
PartMd5s: partMd5s,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (m *Manage) doPut(ctx context.Context, client *http.Client, sign *third.AuthSignParts, part *third.SignPart, reader io.Reader, size int64) error {
|
||
|
rawURL := part.Url
|
||
|
if rawURL == "" {
|
||
|
rawURL = sign.Url
|
||
|
}
|
||
|
if len(sign.Query)+len(part.Query) > 0 {
|
||
|
u, err := url.Parse(rawURL)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
query := u.Query()
|
||
|
for i := range sign.Query {
|
||
|
v := sign.Query[i]
|
||
|
query[v.Key] = v.Values
|
||
|
}
|
||
|
for i := range part.Query {
|
||
|
v := part.Query[i]
|
||
|
query[v.Key] = v.Values
|
||
|
}
|
||
|
u.RawQuery = query.Encode()
|
||
|
rawURL = u.String()
|
||
|
}
|
||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, rawURL, reader)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
for i := range sign.Header {
|
||
|
v := sign.Header[i]
|
||
|
req.Header[v.Key] = v.Values
|
||
|
}
|
||
|
for i := range part.Header {
|
||
|
v := part.Header[i]
|
||
|
req.Header[v.Key] = v.Values
|
||
|
}
|
||
|
req.ContentLength = size
|
||
|
resp, err := client.Do(req)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer func() {
|
||
|
_ = resp.Body.Close()
|
||
|
}()
|
||
|
body, err := io.ReadAll(resp.Body)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if resp.StatusCode/200 != 1 {
|
||
|
return fmt.Errorf("PUT %s part %d failed, status code %d, body %s", rawURL, part.PartNumber, resp.StatusCode, string(body))
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *Manage) HttpGet(ctx context.Context, url string) (*http.Response, error) {
|
||
|
reqUrl := url
|
||
|
for {
|
||
|
request, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, nil)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
DefaultRequestHeader(request.Header)
|
||
|
response, err := m.api.Client.Do(request)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if response.StatusCode != http.StatusOK {
|
||
|
_ = response.Body.Close()
|
||
|
return nil, fmt.Errorf("http get %s status %s", url, response.Status)
|
||
|
}
|
||
|
return response, nil
|
||
|
}
|
||
|
}
|