streams: Use websockets
authorMikael Berthe <mikael@lilotux.net>
Wed, 19 Apr 2017 14:18:02 +0200
changeset 140 7665451f74cc
parent 139 7145e95b4f57
child 141 6068de3675c8
streams: Use websockets
streams.go
--- a/streams.go	Wed Apr 19 11:11:48 2017 +0200
+++ b/streams.go	Wed Apr 19 14:18:02 2017 +0200
@@ -7,18 +7,13 @@
 package madon
 
 import (
-	"bufio"
-	"bytes"
 	"encoding/json"
 	"errors"
 	"fmt"
-	"io"
-	"log"
-	"net/http"
+	"net/url"
 	"strings"
-	"time"
 
-	"github.com/sendgrid/rest"
+	"github.com/gorilla/websocket"
 )
 
 // StreamEvent contains a single event from the streaming API
@@ -31,168 +26,144 @@
 // openStream opens a stream URL and returns an http.Response
 // Note that the caller should close the connection when it's done reading
 // the stream.
-// The stream name can be "user", "public" or "hashtag".
-// For "hashtag", the hashTag argument cannot be empty.
-func (mc *Client) openStream(streamName, hashTag string) (*http.Response, error) {
-	params := make(apiCallParams)
+// The stream name can be "user", "local", "public" or "hashtag".
+// When it is "hashtag", the hashTag argument cannot be empty.
+func (mc *Client) openStream(streamName, hashTag string) (*websocket.Conn, error) {
+	var tag string
 
 	switch streamName {
-	case "user", "public":
+	case "user", "public", "public:local":
 	case "hashtag":
 		if hashTag == "" {
 			return nil, ErrInvalidParameter
 		}
-		params["tag"] = hashTag
+		tag = hashTag
 	default:
 		return nil, ErrInvalidParameter
 	}
 
-	req, err := mc.prepareRequest("streaming/"+streamName, rest.Get, params)
-	if err != nil {
-		return nil, fmt.Errorf("cannot build stream request: %s", err.Error())
+	if !strings.HasPrefix(mc.APIBase, "http") {
+		return nil, errors.New("cannot create Websocket URL: unexpected API base URL")
 	}
 
-	reqObj, err := rest.BuildRequestObject(req)
+	// Build streaming websocket URL
+	u, err := url.Parse("ws" + mc.APIBase[4:] + "/streaming/")
 	if err != nil {
-		return nil, fmt.Errorf("cannot build stream request: %s", err.Error())
+		return nil, errors.New("cannot create Websocket URL: " + err.Error())
 	}
 
-	resp, err := rest.MakeRequest(reqObj)
-	if err != nil {
-		return nil, fmt.Errorf("cannot open stream: %s", err.Error())
+	urlParams := url.Values{}
+	urlParams.Add("stream", streamName)
+	urlParams.Add("access_token", mc.UserToken.AccessToken)
+	if tag != "" {
+		urlParams.Add("tag", tag)
 	}
-	if resp.StatusCode != 200 {
-		resp.Body.Close()
-		return nil, errors.New(resp.Status)
-	}
-	return resp, nil
+	u.RawQuery = urlParams.Encode()
+
+	c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
+	return c, err
 }
 
 // readStream reads from the http.Response and sends events to the events channel
 // It stops when the connection is closed or when the stopCh channel is closed.
 // The foroutine will close the doneCh channel when it terminates.
-func (mc *Client) readStream(events chan<- StreamEvent, stopCh <-chan bool, doneCh chan<- bool, r *http.Response) {
-	defer r.Body.Close()
+func (mc *Client) readStream(events chan<- StreamEvent, stopCh <-chan bool, doneCh chan bool, c *websocket.Conn) {
+	defer c.Close()
+	defer close(doneCh)
 
-	reader := bufio.NewReader(r.Body)
-
-	var line, eventName string
-	for {
+	go func() {
 		select {
 		case <-stopCh:
-			close(doneCh)
-			return
-		default:
+			// Close connection
+			c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
+		case <-doneCh:
+			// Leave
 		}
+	}()
 
-		lineBytes, partial, err := reader.ReadLine()
-		if err != nil {
-			var e error
-			if err == io.EOF {
-				e = fmt.Errorf("connection closed: %s", err.Error())
-			} else {
-				e = fmt.Errorf("read error: %s", err.Error())
-			}
-			log.Printf("Stream Reader: %s", e.Error())
-			events <- StreamEvent{Event: "error", Error: e}
-			close(doneCh)
-			return
+	for {
+		var msg struct {
+			Event   string
+			Payload interface{}
 		}
 
-		if partial {
-			e := fmt.Errorf("received incomplete line; not supported yet")
-			log.Printf("Stream Reader: %s", e.Error())
+		err := c.ReadJSON(&msg)
+		if err != nil {
+			if strings.Contains(err.Error(), "close 1000 (normal)") {
+				break // Connection properly closed
+			}
+			e := fmt.Errorf("read error: %v", err)
 			events <- StreamEvent{Event: "error", Error: e}
-			time.Sleep(5 * time.Second)
-			continue // Skip this
-		}
-
-		line = string(bytes.TrimSpace(lineBytes))
-
-		if line == "" {
-			continue // Skip empty line
-		}
-		if strings.HasPrefix(line, ":") {
-			continue // Skip comment
+			break
 		}
 
-		if strings.HasPrefix(line, "event: ") {
-			eventName = line[7:]
-			continue
-		}
-
-		if !strings.HasPrefix(line, "data: ") {
-			// XXX Needs improvement
-			e := fmt.Errorf("received unhandled event line '%s'", strings.Split(line, ":")[0])
-			log.Printf("Stream Reader: %s", e.Error())
-			events <- StreamEvent{Event: "error", Error: e}
-			continue
-		}
-
-		// This is a data line
-		data := []byte(line[6:])
-
 		var obj interface{}
 
 		// Decode API object
-		switch eventName {
+		switch msg.Event {
 		case "update":
+			strPayload, ok := msg.Payload.(string)
+			if !ok {
+				e := fmt.Errorf("could not decode status: payload isn't a string")
+				events <- StreamEvent{Event: "error", Error: e}
+				continue
+			}
 			var s Status
-			if err := json.Unmarshal(data, &s); err != nil {
-				e := fmt.Errorf("could not unmarshal data: %s", err.Error())
-				log.Printf("Stream Reader: %s", e.Error())
+			if err := json.Unmarshal([]byte(strPayload), &s); err != nil {
+				e := fmt.Errorf("could not decode status: %v", err)
 				events <- StreamEvent{Event: "error", Error: e}
 				continue
 			}
 			obj = s
 		case "notification":
+			strPayload, ok := msg.Payload.(string)
+			if !ok {
+				e := fmt.Errorf("could not decode notification: payload isn't a string")
+				events <- StreamEvent{Event: "error", Error: e}
+				continue
+			}
 			var notif Notification
-			if err := json.Unmarshal(data, &notif); err != nil {
-				e := fmt.Errorf("could not unmarshal data: %s", err.Error())
-				log.Printf("Stream Reader: %s", e.Error())
+			if err := json.Unmarshal([]byte(strPayload), &notif); err != nil {
+				e := fmt.Errorf("could not decode notification: %v", err)
 				events <- StreamEvent{Event: "error", Error: e}
 				continue
 			}
 			obj = notif
 		case "delete":
-			var statusID int
-			if err := json.Unmarshal(data, &statusID); err != nil {
-				e := fmt.Errorf("could not unmarshal data: %s", err.Error())
-				log.Printf("Stream Reader: %s", e.Error())
+			floatPayload, ok := msg.Payload.(float64)
+			if !ok {
+				e := fmt.Errorf("could not decode deletion: payload isn't a number")
 				events <- StreamEvent{Event: "error", Error: e}
 				continue
 			}
-			obj = statusID
-		case "":
-			fallthrough
+			obj = int(floatPayload) // statusID
 		default:
-			e := fmt.Errorf("unhandled event '%s'", eventName)
-			log.Printf("Stream Reader: %s", e.Error())
+			e := fmt.Errorf("unhandled event '%s'", msg.Event)
 			events <- StreamEvent{Event: "error", Error: e}
 			continue
 		}
 
 		// Send event to the channel
-		events <- StreamEvent{Event: eventName, Data: obj}
+		events <- StreamEvent{Event: msg.Event, Data: obj}
 	}
 }
 
 // StreamListener listens to a stream from the Mastodon server
-// The stream 'name' can be "user", "public" or "hashtag".
+// The stream 'name' can be "user", "local", "public" or "hashtag".
 // For 'hashtag', the hashTag argument cannot be empty.
 // The events are sent to the events channel (the errors as well).
 // The streaming is terminated if the 'stopCh' channel is closed.
 // The 'doneCh' channel is closed if the connection is closed by the server.
 // Please note that this method launches a goroutine to listen to the events.
-func (mc *Client) StreamListener(name, hashTag string, events chan<- StreamEvent, stopCh <-chan bool, doneCh chan<- bool) error {
+func (mc *Client) StreamListener(name, hashTag string, events chan<- StreamEvent, stopCh <-chan bool, doneCh chan bool) error {
 	if mc == nil {
 		return ErrUninitializedClient
 	}
 
-	resp, err := mc.openStream(name, hashTag)
+	conn, err := mc.openStream(name, hashTag)
 	if err != nil {
 		return err
 	}
-	go mc.readStream(events, stopCh, doneCh, resp)
+	go mc.readStream(events, stopCh, doneCh, conn)
 	return nil
 }