diff -r 7145e95b4f57 -r 7665451f74cc streams.go --- a/streams.go Wed Apr 19 11:11:48 2017 +0200 +++ b/streams.go Wed Apr 19 14:18:02 2017 +0200 @@ -7,18 +7,13 @@ package madon import ( - "bufio" - "bytes" "encoding/json" "errors" "fmt" - "io" - "log" - "net/http" + "net/url" "strings" - "time" - "github.com/sendgrid/rest" + "github.com/gorilla/websocket" ) // StreamEvent contains a single event from the streaming API @@ -31,168 +26,144 @@ // openStream opens a stream URL and returns an http.Response // Note that the caller should close the connection when it's done reading // the stream. -// The stream name can be "user", "public" or "hashtag". -// For "hashtag", the hashTag argument cannot be empty. -func (mc *Client) openStream(streamName, hashTag string) (*http.Response, error) { - params := make(apiCallParams) +// The stream name can be "user", "local", "public" or "hashtag". +// When it is "hashtag", the hashTag argument cannot be empty. +func (mc *Client) openStream(streamName, hashTag string) (*websocket.Conn, error) { + var tag string switch streamName { - case "user", "public": + case "user", "public", "public:local": case "hashtag": if hashTag == "" { return nil, ErrInvalidParameter } - params["tag"] = hashTag + tag = hashTag default: return nil, ErrInvalidParameter } - req, err := mc.prepareRequest("streaming/"+streamName, rest.Get, params) - if err != nil { - return nil, fmt.Errorf("cannot build stream request: %s", err.Error()) + if !strings.HasPrefix(mc.APIBase, "http") { + return nil, errors.New("cannot create Websocket URL: unexpected API base URL") } - reqObj, err := rest.BuildRequestObject(req) + // Build streaming websocket URL + u, err := url.Parse("ws" + mc.APIBase[4:] + "/streaming/") if err != nil { - return nil, fmt.Errorf("cannot build stream request: %s", err.Error()) + return nil, errors.New("cannot create Websocket URL: " + err.Error()) } - resp, err := rest.MakeRequest(reqObj) - if err != nil { - return nil, fmt.Errorf("cannot open stream: %s", err.Error()) + urlParams := url.Values{} + urlParams.Add("stream", streamName) + urlParams.Add("access_token", mc.UserToken.AccessToken) + if tag != "" { + urlParams.Add("tag", tag) } - if resp.StatusCode != 200 { - resp.Body.Close() - return nil, errors.New(resp.Status) - } - return resp, nil + u.RawQuery = urlParams.Encode() + + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + return c, err } // readStream reads from the http.Response and sends events to the events channel // It stops when the connection is closed or when the stopCh channel is closed. // The foroutine will close the doneCh channel when it terminates. -func (mc *Client) readStream(events chan<- StreamEvent, stopCh <-chan bool, doneCh chan<- bool, r *http.Response) { - defer r.Body.Close() +func (mc *Client) readStream(events chan<- StreamEvent, stopCh <-chan bool, doneCh chan bool, c *websocket.Conn) { + defer c.Close() + defer close(doneCh) - reader := bufio.NewReader(r.Body) - - var line, eventName string - for { + go func() { select { case <-stopCh: - close(doneCh) - return - default: + // Close connection + c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + case <-doneCh: + // Leave } + }() - lineBytes, partial, err := reader.ReadLine() - if err != nil { - var e error - if err == io.EOF { - e = fmt.Errorf("connection closed: %s", err.Error()) - } else { - e = fmt.Errorf("read error: %s", err.Error()) - } - log.Printf("Stream Reader: %s", e.Error()) - events <- StreamEvent{Event: "error", Error: e} - close(doneCh) - return + for { + var msg struct { + Event string + Payload interface{} } - if partial { - e := fmt.Errorf("received incomplete line; not supported yet") - log.Printf("Stream Reader: %s", e.Error()) + err := c.ReadJSON(&msg) + if err != nil { + if strings.Contains(err.Error(), "close 1000 (normal)") { + break // Connection properly closed + } + e := fmt.Errorf("read error: %v", err) events <- StreamEvent{Event: "error", Error: e} - time.Sleep(5 * time.Second) - continue // Skip this - } - - line = string(bytes.TrimSpace(lineBytes)) - - if line == "" { - continue // Skip empty line - } - if strings.HasPrefix(line, ":") { - continue // Skip comment + break } - if strings.HasPrefix(line, "event: ") { - eventName = line[7:] - continue - } - - if !strings.HasPrefix(line, "data: ") { - // XXX Needs improvement - e := fmt.Errorf("received unhandled event line '%s'", strings.Split(line, ":")[0]) - log.Printf("Stream Reader: %s", e.Error()) - events <- StreamEvent{Event: "error", Error: e} - continue - } - - // This is a data line - data := []byte(line[6:]) - var obj interface{} // Decode API object - switch eventName { + switch msg.Event { case "update": + strPayload, ok := msg.Payload.(string) + if !ok { + e := fmt.Errorf("could not decode status: payload isn't a string") + events <- StreamEvent{Event: "error", Error: e} + continue + } var s Status - if err := json.Unmarshal(data, &s); err != nil { - e := fmt.Errorf("could not unmarshal data: %s", err.Error()) - log.Printf("Stream Reader: %s", e.Error()) + if err := json.Unmarshal([]byte(strPayload), &s); err != nil { + e := fmt.Errorf("could not decode status: %v", err) events <- StreamEvent{Event: "error", Error: e} continue } obj = s case "notification": + strPayload, ok := msg.Payload.(string) + if !ok { + e := fmt.Errorf("could not decode notification: payload isn't a string") + events <- StreamEvent{Event: "error", Error: e} + continue + } var notif Notification - if err := json.Unmarshal(data, ¬if); err != nil { - e := fmt.Errorf("could not unmarshal data: %s", err.Error()) - log.Printf("Stream Reader: %s", e.Error()) + if err := json.Unmarshal([]byte(strPayload), ¬if); err != nil { + e := fmt.Errorf("could not decode notification: %v", err) events <- StreamEvent{Event: "error", Error: e} continue } obj = notif case "delete": - var statusID int - if err := json.Unmarshal(data, &statusID); err != nil { - e := fmt.Errorf("could not unmarshal data: %s", err.Error()) - log.Printf("Stream Reader: %s", e.Error()) + floatPayload, ok := msg.Payload.(float64) + if !ok { + e := fmt.Errorf("could not decode deletion: payload isn't a number") events <- StreamEvent{Event: "error", Error: e} continue } - obj = statusID - case "": - fallthrough + obj = int(floatPayload) // statusID default: - e := fmt.Errorf("unhandled event '%s'", eventName) - log.Printf("Stream Reader: %s", e.Error()) + e := fmt.Errorf("unhandled event '%s'", msg.Event) events <- StreamEvent{Event: "error", Error: e} continue } // Send event to the channel - events <- StreamEvent{Event: eventName, Data: obj} + events <- StreamEvent{Event: msg.Event, Data: obj} } } // StreamListener listens to a stream from the Mastodon server -// The stream 'name' can be "user", "public" or "hashtag". +// The stream 'name' can be "user", "local", "public" or "hashtag". // For 'hashtag', the hashTag argument cannot be empty. // The events are sent to the events channel (the errors as well). // The streaming is terminated if the 'stopCh' channel is closed. // The 'doneCh' channel is closed if the connection is closed by the server. // Please note that this method launches a goroutine to listen to the events. -func (mc *Client) StreamListener(name, hashTag string, events chan<- StreamEvent, stopCh <-chan bool, doneCh chan<- bool) error { +func (mc *Client) StreamListener(name, hashTag string, events chan<- StreamEvent, stopCh <-chan bool, doneCh chan bool) error { if mc == nil { return ErrUninitializedClient } - resp, err := mc.openStream(name, hashTag) + conn, err := mc.openStream(name, hashTag) if err != nil { return err } - go mc.readStream(events, stopCh, doneCh, resp) + go mc.readStream(events, stopCh, doneCh, conn) return nil }