You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
275 lines
6.2 KiB
275 lines
6.2 KiB
5 years ago
|
package rpc
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"log"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/gorilla/websocket"
|
||
|
)
|
||
|
|
||
|
type caller interface {
|
||
|
// Call sends a request of rpc to aria2 daemon
|
||
|
Call(method string, params, reply interface{}) (err error)
|
||
|
Close() error
|
||
|
}
|
||
|
|
||
|
type httpCaller struct {
|
||
|
uri string
|
||
|
c *http.Client
|
||
|
cancel context.CancelFunc
|
||
|
wg *sync.WaitGroup
|
||
|
once sync.Once
|
||
|
}
|
||
|
|
||
|
func newHTTPCaller(ctx context.Context, u *url.URL, timeout time.Duration, notifer Notifier) *httpCaller {
|
||
|
c := &http.Client{
|
||
|
Transport: &http.Transport{
|
||
|
MaxIdleConnsPerHost: 1,
|
||
|
MaxConnsPerHost: 1,
|
||
|
// TLSClientConfig: tlsConfig,
|
||
|
Dial: (&net.Dialer{
|
||
|
Timeout: timeout,
|
||
|
KeepAlive: 60 * time.Second,
|
||
|
}).Dial,
|
||
|
TLSHandshakeTimeout: 3 * time.Second,
|
||
|
ResponseHeaderTimeout: timeout,
|
||
|
},
|
||
|
}
|
||
|
var wg sync.WaitGroup
|
||
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
h := &httpCaller{uri: u.String(), c: c, cancel: cancel, wg: &wg}
|
||
|
if notifer != nil {
|
||
|
h.setNotifier(ctx, *u, notifer)
|
||
|
}
|
||
|
return h
|
||
|
}
|
||
|
|
||
|
func (h *httpCaller) Close() (err error) {
|
||
|
h.once.Do(func() {
|
||
|
h.cancel()
|
||
|
h.wg.Wait()
|
||
|
})
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (h *httpCaller) setNotifier(ctx context.Context, u url.URL, notifer Notifier) (err error) {
|
||
|
u.Scheme = "ws"
|
||
|
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
h.wg.Add(1)
|
||
|
go func() {
|
||
|
defer h.wg.Done()
|
||
|
defer conn.Close()
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
conn.SetWriteDeadline(time.Now().Add(time.Second))
|
||
|
if err := conn.WriteMessage(websocket.CloseMessage,
|
||
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||
|
log.Printf("sending websocket close message: %v", err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
}()
|
||
|
h.wg.Add(1)
|
||
|
go func() {
|
||
|
defer h.wg.Done()
|
||
|
var request websocketResponse
|
||
|
var err error
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
default:
|
||
|
}
|
||
|
if err = conn.ReadJSON(&request); err != nil {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
default:
|
||
|
}
|
||
|
log.Printf("conn.ReadJSON|err:%v", err.Error())
|
||
|
return
|
||
|
}
|
||
|
switch request.Method {
|
||
|
case "aria2.onDownloadStart":
|
||
|
notifer.OnDownloadStart(request.Params)
|
||
|
case "aria2.onDownloadPause":
|
||
|
notifer.OnDownloadPause(request.Params)
|
||
|
case "aria2.onDownloadStop":
|
||
|
notifer.OnDownloadStop(request.Params)
|
||
|
case "aria2.onDownloadComplete":
|
||
|
notifer.OnDownloadComplete(request.Params)
|
||
|
case "aria2.onDownloadError":
|
||
|
notifer.OnDownloadError(request.Params)
|
||
|
case "aria2.onBtDownloadComplete":
|
||
|
notifer.OnBtDownloadComplete(request.Params)
|
||
|
default:
|
||
|
log.Printf("unexpected notification: %s", request.Method)
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (h httpCaller) Call(method string, params, reply interface{}) (err error) {
|
||
|
payload, err := EncodeClientRequest(method, params)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
r, err := h.c.Post(h.uri, "application/json", payload)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
err = DecodeClientResponse(r.Body, &reply)
|
||
|
r.Body.Close()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
type websocketCaller struct {
|
||
|
conn *websocket.Conn
|
||
|
sendChan chan *sendRequest
|
||
|
cancel context.CancelFunc
|
||
|
wg *sync.WaitGroup
|
||
|
once sync.Once
|
||
|
timeout time.Duration
|
||
|
}
|
||
|
|
||
|
func newWebsocketCaller(ctx context.Context, uri string, timeout time.Duration, notifier Notifier) (*websocketCaller, error) {
|
||
|
var header = http.Header{}
|
||
|
conn, _, err := websocket.DefaultDialer.Dial(uri, header)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
sendChan := make(chan *sendRequest, 16)
|
||
|
var wg sync.WaitGroup
|
||
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
w := &websocketCaller{conn: conn, wg: &wg, cancel: cancel, sendChan: sendChan, timeout: timeout}
|
||
|
processor := NewResponseProcessor()
|
||
|
wg.Add(1)
|
||
|
go func() { // routine:recv
|
||
|
defer wg.Done()
|
||
|
defer cancel()
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
default:
|
||
|
}
|
||
|
var resp websocketResponse
|
||
|
if err := conn.ReadJSON(&resp); err != nil {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
default:
|
||
|
}
|
||
|
log.Printf("conn.ReadJSON|err:%v", err.Error())
|
||
|
return
|
||
|
}
|
||
|
if resp.Id == nil { // RPC notifications
|
||
|
if notifier != nil {
|
||
|
switch resp.Method {
|
||
|
case "aria2.onDownloadStart":
|
||
|
notifier.OnDownloadStart(resp.Params)
|
||
|
case "aria2.onDownloadPause":
|
||
|
notifier.OnDownloadPause(resp.Params)
|
||
|
case "aria2.onDownloadStop":
|
||
|
notifier.OnDownloadStop(resp.Params)
|
||
|
case "aria2.onDownloadComplete":
|
||
|
notifier.OnDownloadComplete(resp.Params)
|
||
|
case "aria2.onDownloadError":
|
||
|
notifier.OnDownloadError(resp.Params)
|
||
|
case "aria2.onBtDownloadComplete":
|
||
|
notifier.OnBtDownloadComplete(resp.Params)
|
||
|
default:
|
||
|
log.Printf("unexpected notification: %s", resp.Method)
|
||
|
}
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
processor.Process(resp.clientResponse)
|
||
|
}
|
||
|
}()
|
||
|
wg.Add(1)
|
||
|
go func() { // routine:send
|
||
|
defer wg.Done()
|
||
|
defer cancel()
|
||
|
defer w.conn.Close()
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
if err := w.conn.WriteMessage(websocket.CloseMessage,
|
||
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||
|
log.Printf("sending websocket close message: %v", err)
|
||
|
}
|
||
|
return
|
||
|
case req := <-sendChan:
|
||
|
processor.Add(req.request.Id, func(resp clientResponse) error {
|
||
|
err := resp.decode(req.reply)
|
||
|
req.cancel()
|
||
|
return err
|
||
|
})
|
||
|
w.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||
|
w.conn.WriteJSON(req.request)
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
return w, nil
|
||
|
}
|
||
|
|
||
|
func (w *websocketCaller) Close() (err error) {
|
||
|
w.once.Do(func() {
|
||
|
w.cancel()
|
||
|
w.wg.Wait()
|
||
|
})
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (w websocketCaller) Call(method string, params, reply interface{}) (err error) {
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), w.timeout)
|
||
|
defer cancel()
|
||
|
select {
|
||
|
case w.sendChan <- &sendRequest{cancel: cancel, request: &clientRequest{
|
||
|
Version: "2.0",
|
||
|
Method: method,
|
||
|
Params: params,
|
||
|
Id: reqid(),
|
||
|
}, reply: reply}:
|
||
|
|
||
|
default:
|
||
|
return errors.New("sending channel blocking")
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
if err := ctx.Err(); err == context.DeadlineExceeded {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
type sendRequest struct {
|
||
|
cancel context.CancelFunc
|
||
|
request *clientRequest
|
||
|
reply interface{}
|
||
|
}
|
||
|
|
||
|
var reqid = func() func() uint64 {
|
||
|
var id = uint64(time.Now().UnixNano())
|
||
|
return func() uint64 {
|
||
|
return atomic.AddUint64(&id, 1)
|
||
|
}
|
||
|
}()
|