diff --git a/internal/routers/api/attachment.go b/internal/routers/api/attachment.go index 6ff07630..447058f6 100644 --- a/internal/routers/api/attachment.go +++ b/internal/routers/api/attachment.go @@ -53,6 +53,21 @@ func GetImageSize(img image.Rectangle) (int, int) { return width, height } +func fileCheck(uploadType string, size int64) error { + if uploadType != "public/video" && + uploadType != "public/image" && + uploadType != "public/avatar" && + uploadType != "attachment" { + return errcode.InvalidParams + } + + if size > 1024*1024*100 { + return errcode.FileInvalidSize.WithDetails("最大允许100MB") + } + + return nil +} + func UploadAttachment(c *gin.Context) { response := app.NewResponse(c) svc := service.New(c) @@ -66,16 +81,9 @@ func UploadAttachment(c *gin.Context) { } defer file.Close() - if uploadType != "public/video" && - uploadType != "public/image" && - uploadType != "public/avatar" && - uploadType != "attachment" { - response.ToErrorResponse(errcode.InvalidParams) - return - } - - if fileHeader.Size > 1024*1024*100 { - response.ToErrorResponse(errcode.FileInvalidSize.WithDetails("最大允许100MB")) + if err = fileCheck(uploadType, fileHeader.Size); err != nil { + cErr, _ := err.(*errcode.Error) + response.ToErrorResponse(cErr) return } @@ -129,20 +137,21 @@ func UploadAttachment(c *gin.Context) { attachment.UserID = userID.(int64) } - if uploadType == "public/image" || uploadType == "public/avatar" { - attachment.Type = model.ATTACHMENT_TYPE_IMAGE + var uploadAttachmentTypeMap = map[string]model.AttachmentType{ + "public/image": model.ATTACHMENT_TYPE_IMAGE, + "public/avatar": model.ATTACHMENT_TYPE_IMAGE, + "public/video": model.ATTACHMENT_TYPE_VIDEO, + "attachment": model.ATTACHMENT_TYPE_OTHER, + } - src, err := imaging.Decode(file) + attachment.Type = uploadAttachmentTypeMap[uploadType] + if attachment.Type == model.ATTACHMENT_TYPE_IMAGE { + var src image.Image + src, err = imaging.Decode(file) if err == nil { attachment.ImgWidth, attachment.ImgHeight = GetImageSize(src.Bounds()) } } - if uploadType == "public/video" { - attachment.Type = model.ATTACHMENT_TYPE_VIDEO - } - if uploadType == "attachment" { - attachment.Type = model.ATTACHMENT_TYPE_OTHER - } attachment, err = svc.CreateAttachment(attachment) if err != nil { diff --git a/internal/service/comment.go b/internal/service/comment.go index cf5bd84b..7986212d 100644 --- a/internal/service/comment.go +++ b/internal/service/comment.go @@ -186,21 +186,21 @@ func (svc *Service) DeletePostComment(comment *model.Comment) error { return svc.dao.DeleteComment(comment) } -func (svc *Service) CreatePostCommentReply(commentID int64, content string, userID, atUserID int64) (*model.CommentReply, error) { +func (svc *Service) createPostPreHandler(commentID int64, userID, atUserID int64) (*model.Post, *model.Comment, error) { // 加载Comment comment, err := svc.dao.GetCommentByID(commentID) if err != nil { - return nil, err + return nil, nil, err } // 加载comment的post post, err := svc.dao.GetPostByID(comment.PostID) if err != nil { - return nil, err + return nil, nil, err } if post.CommentCount >= global.AppSetting.MaxCommentCount { - return nil, errcode.MaxCommentCount + return nil, nil, errcode.MaxCommentCount } if userID == atUserID { @@ -215,6 +215,15 @@ func (svc *Service) CreatePostCommentReply(commentID int64, content string, user } } + return post, comment, nil +} + +func (svc *Service) CreatePostCommentReply(commentID int64, content string, userID, atUserID int64) (*model.CommentReply, error) { + var post, comment, err = svc.createPostPreHandler(commentID, userID, atUserID) + if err != nil { + return nil, err + } + // 创建评论 ip := svc.ctx.ClientIP() reply := &model.CommentReply{ diff --git a/internal/service/post.go b/internal/service/post.go index 00dec7d1..c93c37b4 100644 --- a/internal/service/post.go +++ b/internal/service/post.go @@ -57,6 +57,25 @@ type PostContentItem struct { Sort int64 `json:"sort" binding:"required"` } +// Check 检查PostContentItem属性 +func (p *PostContentItem) Check() error { + // 检查附件是否是本站资源 + if p.Type == model.CONTENT_TYPE_IMAGE || p.Type == model.CONTENT_TYPE_VIDEO || p.Type == model. + CONTENT_TYPE_ATTACHMENT { + if strings.Index(p.Content, "https://"+global.AliossSetting.AliossDomain) != 0 { + return fmt.Errorf("附件非本站资源") + } + } + // 检查链接是否合法 + if p.Type == model.CONTENT_TYPE_LINK { + if strings.Index(p.Content, "http://") != 0 && strings.Index(p.Content, "https://") != 0 { + return fmt.Errorf("链接不合法") + } + } + + return nil +} + func (svc *Service) CreatePost(userID int64, param PostCreationReq) (*model.Post, error) { ip := svc.ctx.ClientIP() @@ -82,18 +101,9 @@ func (svc *Service) CreatePost(userID int64, param PostCreationReq) (*model.Post } for _, item := range param.Contents { - - // 检查附件是否是本站资源 - if item.Type == model.CONTENT_TYPE_IMAGE || item.Type == model.CONTENT_TYPE_VIDEO || item.Type == model.CONTENT_TYPE_ATTACHMENT { - if strings.Index(item.Content, "https://"+global.AliossSetting.AliossDomain) != 0 { - continue - } - } - // 检查链接是否合法 - if item.Type == model.CONTENT_TYPE_LINK { - if strings.Index(item.Content, "http://") != 0 && strings.Index(item.Content, "https://") != 0 { - continue - } + if err = item.Check(); err != nil { + // 属性非法 + continue } if item.Type == model.CONTENT_TYPE_ATTACHMENT && param.AttachmentPrice > 0 {