account.go
author Mikael Berthe <mikael@lilotux.net>
Sat, 29 Apr 2017 17:27:15 +0200
changeset 156 70aadba26338
parent 155 0c581e0108da
child 159 408aa794d9bb
permissions -rw-r--r--
Add field "All" to LimitParams, change Limit behaviour If All is true, the library will send several requests (if needed) until the API server has sent all the results. If not, and if a Limit is set, the library will try to fetch at least this number of results.

/*
Copyright 2017 Mikael Berthe

Licensed under the MIT license.  Please see the LICENSE file is this directory.
*/

package madon

import (
	"bytes"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"mime/multipart"
	"net/http"
	"os"
	"strconv"
	"strings"

	"github.com/sendgrid/rest"
)

// getAccountsOptions contains option fields for POST and DELETE API calls
type getAccountsOptions struct {
	// The ID is used for most commands
	ID int

	// The Q field (query) is used when searching for accounts
	Q string

	Limit *LimitParams
}

// getSingleAccount returns an account entity
// The operation 'op' can be "account", "verify_credentials", "follow",
// "unfollow", "block", "unblock", "mute", "unmute",
// "follow_requests/authorize" or // "follow_requests/reject".
// The id is optional and depends on the operation.
func (mc *Client) getSingleAccount(op string, id int) (*Account, error) {
	var endPoint string
	method := rest.Get
	strID := strconv.Itoa(id)

	switch op {
	case "account":
		endPoint = "accounts/" + strID
	case "verify_credentials":
		endPoint = "accounts/verify_credentials"
	case "follow", "unfollow", "block", "unblock", "mute", "unmute":
		endPoint = "accounts/" + strID + "/" + op
		method = rest.Post
	case "follow_requests/authorize", "follow_requests/reject":
		// The documentation is incorrect, the endpoint actually
		// is "follow_requests/:id/{authorize|reject}"
		endPoint = op[:16] + strID + "/" + op[16:]
		method = rest.Post
	default:
		return nil, ErrInvalidParameter
	}

	var account Account
	if err := mc.apiCall(endPoint, method, nil, nil, nil, &account); err != nil {
		return nil, err
	}
	return &account, nil
}

// getMultipleAccounts returns a list of account entities
// The operation 'op' can be "followers", "following", "search", "blocks",
// "mutes", "follow_requests".
// The id is optional and depends on the operation.
// If opts.All is true, several requests will be made until the API server
// has nothing to return.
func (mc *Client) getMultipleAccounts(op string, opts *getAccountsOptions) ([]Account, error) {
	var endPoint string
	var lopt *LimitParams

	if opts != nil {
		lopt = opts.Limit
	}

	switch op {
	case "followers", "following":
		if opts == nil || opts.ID < 1 {
			return []Account{}, ErrInvalidID
		}
		endPoint = "accounts/" + strconv.Itoa(opts.ID) + "/" + op
	case "follow_requests", "blocks", "mutes":
		endPoint = op
	case "search":
		if opts == nil || opts.Q == "" {
			return []Account{}, ErrInvalidParameter
		}
		endPoint = "accounts/" + op
	case "reblogged_by", "favourited_by":
		if opts == nil || opts.ID < 1 {
			return []Account{}, ErrInvalidID
		}
		endPoint = "statuses/" + strconv.Itoa(opts.ID) + "/" + op
	default:
		return nil, ErrInvalidParameter
	}

	// Handle target-specific query parameters
	params := make(apiCallParams)
	if op == "search" {
		params["q"] = opts.Q
	}

	var accounts []Account
	var links apiLinks
	if err := mc.apiCall(endPoint, rest.Get, params, lopt, &links, &accounts); err != nil {
		return nil, err
	}
	if lopt != nil { // Fetch more pages to reach our limit
		var accountSlice []Account
		for (lopt.All || lopt.Limit > len(accounts)) && links.next != nil {
			newlopt := links.next
			links = apiLinks{}
			if err := mc.apiCall(endPoint, rest.Get, params, newlopt, &links, &accountSlice); err != nil {
				return nil, err
			}
			accounts = append(accounts, accountSlice...)
			accountSlice = accountSlice[:0] // Clear struct
		}
	}
	return accounts, nil
}

// GetAccount returns an account entity
// The returned value can be nil if there is an error or if the
// requested ID does not exist.
func (mc *Client) GetAccount(accountID int) (*Account, error) {
	account, err := mc.getSingleAccount("account", accountID)
	if err != nil {
		return nil, err
	}
	if account != nil && account.ID == 0 {
		return nil, ErrEntityNotFound
	}
	return account, nil
}

// GetCurrentAccount returns the current user account
func (mc *Client) GetCurrentAccount() (*Account, error) {
	account, err := mc.getSingleAccount("verify_credentials", 0)
	if err != nil {
		return nil, err
	}
	if account != nil && account.ID == 0 {
		return nil, ErrEntityNotFound
	}
	return account, nil
}

// GetAccountFollowers returns the list of accounts following a given account
func (mc *Client) GetAccountFollowers(accountID int, lopt *LimitParams) ([]Account, error) {
	o := &getAccountsOptions{ID: accountID, Limit: lopt}
	return mc.getMultipleAccounts("followers", o)
}

// GetAccountFollowing returns the list of accounts a given account is following
func (mc *Client) GetAccountFollowing(accountID int, lopt *LimitParams) ([]Account, error) {
	o := &getAccountsOptions{ID: accountID, Limit: lopt}
	return mc.getMultipleAccounts("following", o)
}

// FollowAccount follows an account
func (mc *Client) FollowAccount(accountID int) error {
	account, err := mc.getSingleAccount("follow", accountID)
	if err != nil {
		return err
	}
	if account != nil && account.ID != accountID {
		return ErrEntityNotFound
	}
	return nil
}

// UnfollowAccount unfollows an account
func (mc *Client) UnfollowAccount(accountID int) error {
	account, err := mc.getSingleAccount("unfollow", accountID)
	if err != nil {
		return err
	}
	if account != nil && account.ID != accountID {
		return ErrEntityNotFound
	}
	return nil
}

// FollowRemoteAccount follows a remote account
// The parameter 'uri' is a URI (e.mc. "username@domain").
func (mc *Client) FollowRemoteAccount(uri string) (*Account, error) {
	if uri == "" {
		return nil, ErrInvalidID
	}

	params := make(apiCallParams)
	params["uri"] = uri

	var account Account
	if err := mc.apiCall("follows", rest.Post, params, nil, nil, &account); err != nil {
		return nil, err
	}
	if account.ID == 0 {
		return nil, ErrEntityNotFound
	}
	return &account, nil
}

// BlockAccount blocks an account
func (mc *Client) BlockAccount(accountID int) error {
	account, err := mc.getSingleAccount("block", accountID)
	if err != nil {
		return err
	}
	if account != nil && account.ID != accountID {
		return ErrEntityNotFound
	}
	return nil
}

// UnblockAccount unblocks an account
func (mc *Client) UnblockAccount(accountID int) error {
	account, err := mc.getSingleAccount("unblock", accountID)
	if err != nil {
		return err
	}
	if account != nil && account.ID != accountID {
		return ErrEntityNotFound
	}
	return nil
}

// MuteAccount mutes an account
func (mc *Client) MuteAccount(accountID int) error {
	account, err := mc.getSingleAccount("mute", accountID)
	if err != nil {
		return err
	}
	if account != nil && account.ID != accountID {
		return ErrEntityNotFound
	}
	return nil
}

// UnmuteAccount unmutes an account
func (mc *Client) UnmuteAccount(accountID int) error {
	account, err := mc.getSingleAccount("unmute", accountID)
	if err != nil {
		return err
	}
	if account != nil && account.ID != accountID {
		return ErrEntityNotFound
	}
	return nil
}

// SearchAccounts returns a list of accounts matching the query string
// The lopt parameter is optional (can be nil) or can be used to set a limit.
func (mc *Client) SearchAccounts(query string, lopt *LimitParams) ([]Account, error) {
	o := &getAccountsOptions{Q: query, Limit: lopt}
	return mc.getMultipleAccounts("search", o)
}

// GetBlockedAccounts returns the list of blocked accounts
// The lopt parameter is optional (can be nil).
func (mc *Client) GetBlockedAccounts(lopt *LimitParams) ([]Account, error) {
	o := &getAccountsOptions{Limit: lopt}
	return mc.getMultipleAccounts("blocks", o)
}

// GetMutedAccounts returns the list of muted accounts
// The lopt parameter is optional (can be nil).
func (mc *Client) GetMutedAccounts(lopt *LimitParams) ([]Account, error) {
	o := &getAccountsOptions{Limit: lopt}
	return mc.getMultipleAccounts("mutes", o)
}

// GetAccountFollowRequests returns the list of follow requests accounts
// The lopt parameter is optional (can be nil).
func (mc *Client) GetAccountFollowRequests(lopt *LimitParams) ([]Account, error) {
	o := &getAccountsOptions{Limit: lopt}
	return mc.getMultipleAccounts("follow_requests", o)
}

// GetAccountRelationships returns a list of relationship entities for the given accounts
func (mc *Client) GetAccountRelationships(accountIDs []int) ([]Relationship, error) {
	if len(accountIDs) < 1 {
		return nil, ErrInvalidID
	}

	params := make(apiCallParams)
	for i, id := range accountIDs {
		if id < 1 {
			return nil, ErrInvalidID
		}
		qID := fmt.Sprintf("id[%d]", i+1)
		params[qID] = strconv.Itoa(id)
	}

	var rl []Relationship
	if err := mc.apiCall("accounts/relationships", rest.Get, params, nil, nil, &rl); err != nil {
		return nil, err
	}
	return rl, nil
}

// GetAccountStatuses returns a list of status entities for the given account
// If onlyMedia is true, returns only statuses that have media attachments.
// If excludeReplies is true, skip statuses that reply to other statuses.
// If lopt.All is true, several requests will be made until the API server
// has nothing to return.
// If lopt.Limit is set (and not All), several queries can be made until the
// limit is reached.
func (mc *Client) GetAccountStatuses(accountID int, onlyMedia, excludeReplies bool, lopt *LimitParams) ([]Status, error) {
	if accountID < 1 {
		return nil, ErrInvalidID
	}

	endPoint := "accounts/" + strconv.Itoa(accountID) + "/" + "statuses"
	params := make(apiCallParams)
	if onlyMedia {
		params["only_media"] = "true"
	}
	if excludeReplies {
		params["exclude_replies"] = "true"
	}

	var sl []Status
	var links apiLinks
	if err := mc.apiCall(endPoint, rest.Get, params, lopt, &links, &sl); err != nil {
		return nil, err
	}
	if lopt != nil { // Fetch more pages to reach our limit
		var statusSlice []Status
		for (lopt.All || lopt.Limit > len(sl)) && links.next != nil {
			newlopt := links.next
			links = apiLinks{}
			if err := mc.apiCall(endPoint, rest.Get, params, newlopt, &links, &statusSlice); err != nil {
				return nil, err
			}
			sl = append(sl, statusSlice...)
			statusSlice = statusSlice[:0] // Clear struct
		}
	}
	return sl, nil
}

// FollowRequestAuthorize authorizes or rejects an account follow-request
func (mc *Client) FollowRequestAuthorize(accountID int, authorize bool) error {
	endPoint := "follow_requests/reject"
	if authorize {
		endPoint = "follow_requests/authorize"
	}
	_, err := mc.getSingleAccount(endPoint, accountID)
	return err
}

// UpdateAccount updates the connected user's account data
// The fields avatar & headerImage can contain base64-encoded images; if
// they do not (that is; if they don't contain ";base64,"), they are considered
// as file paths and their content will be encoded.
// All fields can be nil, in which case they are not updated.
// displayName and note can be set to "" to delete previous values;
// I'm not sure images can be deleted -- only replaced AFAICS.
func (mc *Client) UpdateAccount(displayName, note, avatar, headerImage *string) (*Account, error) {
	const endPoint = "accounts/update_credentials"
	params := make(apiCallParams)

	if displayName != nil {
		params["display_name"] = *displayName
	}
	if note != nil {
		params["note"] = *note
	}

	var err error
	avatar, err = fileToBase64(avatar, nil)
	if err != nil {
		return nil, err
	}
	headerImage, err = fileToBase64(headerImage, nil)
	if err != nil {
		return nil, err
	}

	var formBuf bytes.Buffer
	w := multipart.NewWriter(&formBuf)

	if avatar != nil {
		w.WriteField("avatar", *avatar)
	}
	if headerImage != nil {
		w.WriteField("header", *headerImage)
	}
	w.Close()

	// Prepare the request
	req, err := mc.prepareRequest(endPoint, rest.Patch, params)
	if err != nil {
		return nil, fmt.Errorf("prepareRequest failed: %s", err.Error())
	}
	req.Headers["Content-Type"] = w.FormDataContentType()
	req.Body = formBuf.Bytes()

	// Make API call
	r, err := restAPI(req)
	if err != nil {
		return nil, fmt.Errorf("account update failed: %s", err.Error())
	}

	// Check for error reply
	var errorResult Error
	if err := json.Unmarshal([]byte(r.Body), &errorResult); err == nil {
		// The empty object is not an error
		if errorResult.Text != "" {
			return nil, fmt.Errorf("%s", errorResult.Text)
		}
	}

	// Not an error reply; let's unmarshal the data
	var account Account
	if err := json.Unmarshal([]byte(r.Body), &account); err != nil {
		return nil, fmt.Errorf("cannot decode API response: %s", err.Error())
	}
	return &account, nil
}

// fileToBase64 is a helper function to convert a file's contents to
// base64-encoded data.  Is the data string already contains base64 data, it
// is not modified.
// If contentType is nil, it is detected.
func fileToBase64(data, contentType *string) (*string, error) {
	if data == nil {
		return nil, nil
	}

	if *data == "" {
		return data, nil
	}

	if strings.Contains(*data, ";base64,") {
		return data, nil
	}

	// We need to convert the file and file name to base64

	file, err := os.Open(*data)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	fStat, err := file.Stat()
	if err != nil {
		return nil, err
	}

	buffer := make([]byte, fStat.Size())
	_, err = file.Read(buffer)
	if err != nil {
		return nil, err
	}

	var cType string
	if contentType == nil || *contentType == "" {
		cType = http.DetectContentType(buffer[:512])
	} else {
		cType = *contentType
	}
	contentData := base64.StdEncoding.EncodeToString(buffer)
	newData := "data:" + cType + ";base64," + contentData
	return &newData, nil
}