hgext/infinitepush/sqlindexapi.py
changeset 50803 609a3b8058c3
parent 50802 cf0502231d56
child 50806 337bc83c1275
equal deleted inserted replaced
50802:cf0502231d56 50803:609a3b8058c3
     1 # Infinite push
       
     2 #
       
     3 # Copyright 2016 Facebook, Inc.
       
     4 #
       
     5 # This software may be used and distributed according to the terms of the
       
     6 # GNU General Public License version 2 or any later version.
       
     7 
       
     8 
       
     9 import logging
       
    10 import os
       
    11 import time
       
    12 
       
    13 import warnings
       
    14 import mysql.connector
       
    15 
       
    16 from . import indexapi
       
    17 
       
    18 
       
    19 def _convertbookmarkpattern(pattern):
       
    20     pattern = pattern.replace(b'_', b'\\_')
       
    21     pattern = pattern.replace(b'%', b'\\%')
       
    22     if pattern.endswith(b'*'):
       
    23         pattern = pattern[:-1] + b'%'
       
    24     return pattern
       
    25 
       
    26 
       
    27 class sqlindexapi(indexapi.indexapi):
       
    28     """
       
    29     Sql backend for infinitepush index. See schema.sql
       
    30     """
       
    31 
       
    32     def __init__(
       
    33         self,
       
    34         reponame,
       
    35         host,
       
    36         port,
       
    37         database,
       
    38         user,
       
    39         password,
       
    40         logfile,
       
    41         loglevel,
       
    42         waittimeout=300,
       
    43         locktimeout=120,
       
    44     ):
       
    45         super(sqlindexapi, self).__init__()
       
    46         self.reponame = reponame
       
    47         self.sqlargs = {
       
    48             b'host': host,
       
    49             b'port': port,
       
    50             b'database': database,
       
    51             b'user': user,
       
    52             b'password': password,
       
    53         }
       
    54         self.sqlconn = None
       
    55         self.sqlcursor = None
       
    56         if not logfile:
       
    57             logfile = os.devnull
       
    58         logging.basicConfig(filename=logfile)
       
    59         self.log = logging.getLogger()
       
    60         self.log.setLevel(loglevel)
       
    61         self._connected = False
       
    62         self._waittimeout = waittimeout
       
    63         self._locktimeout = locktimeout
       
    64 
       
    65     def sqlconnect(self):
       
    66         if self.sqlconn:
       
    67             raise indexapi.indexexception(b"SQL connection already open")
       
    68         if self.sqlcursor:
       
    69             raise indexapi.indexexception(
       
    70                 b"SQL cursor already open without connection"
       
    71             )
       
    72         retry = 3
       
    73         while True:
       
    74             try:
       
    75                 self.sqlconn = mysql.connector.connect(**self.sqlargs)
       
    76 
       
    77                 # Code is copy-pasted from hgsql. Bug fixes need to be
       
    78                 # back-ported!
       
    79                 # The default behavior is to return byte arrays, when we
       
    80                 # need strings. This custom convert returns strings.
       
    81                 self.sqlconn.set_converter_class(CustomConverter)
       
    82                 self.sqlconn.autocommit = False
       
    83                 break
       
    84             except mysql.connector.errors.Error:
       
    85                 # mysql can be flakey occasionally, so do some minimal
       
    86                 # retrying.
       
    87                 retry -= 1
       
    88                 if retry == 0:
       
    89                     raise
       
    90                 time.sleep(0.2)
       
    91 
       
    92         waittimeout = self.sqlconn.converter.escape(b'%s' % self._waittimeout)
       
    93 
       
    94         self.sqlcursor = self.sqlconn.cursor()
       
    95         self.sqlcursor.execute(b"SET wait_timeout=%s" % waittimeout)
       
    96         self.sqlcursor.execute(
       
    97             b"SET innodb_lock_wait_timeout=%s" % self._locktimeout
       
    98         )
       
    99         self._connected = True
       
   100 
       
   101     def close(self):
       
   102         """Cleans up the metadata store connection."""
       
   103         with warnings.catch_warnings():
       
   104             warnings.simplefilter(b"ignore")
       
   105             self.sqlcursor.close()
       
   106             self.sqlconn.close()
       
   107         self.sqlcursor = None
       
   108         self.sqlconn = None
       
   109 
       
   110     def __enter__(self):
       
   111         if not self._connected:
       
   112             self.sqlconnect()
       
   113         return self
       
   114 
       
   115     def __exit__(self, exc_type, exc_val, exc_tb):
       
   116         if exc_type is None:
       
   117             self.sqlconn.commit()
       
   118         else:
       
   119             self.sqlconn.rollback()
       
   120 
       
   121     def addbundle(self, bundleid, nodesctx):
       
   122         if not self._connected:
       
   123             self.sqlconnect()
       
   124         self.log.info(b"ADD BUNDLE %r %r" % (self.reponame, bundleid))
       
   125         self.sqlcursor.execute(
       
   126             b"INSERT INTO bundles(bundle, reponame) VALUES (%s, %s)",
       
   127             params=(bundleid, self.reponame),
       
   128         )
       
   129         for ctx in nodesctx:
       
   130             self.sqlcursor.execute(
       
   131                 b"INSERT INTO nodestobundle(node, bundle, reponame) "
       
   132                 b"VALUES (%s, %s, %s) ON DUPLICATE KEY UPDATE "
       
   133                 b"bundle=VALUES(bundle)",
       
   134                 params=(ctx.hex(), bundleid, self.reponame),
       
   135             )
       
   136 
       
   137             extra = ctx.extra()
       
   138             author_name = ctx.user()
       
   139             committer_name = extra.get(b'committer', ctx.user())
       
   140             author_date = int(ctx.date()[0])
       
   141             committer_date = int(extra.get(b'committer_date', author_date))
       
   142             self.sqlcursor.execute(
       
   143                 b"INSERT IGNORE INTO nodesmetadata(node, message, p1, p2, "
       
   144                 b"author, committer, author_date, committer_date, "
       
   145                 b"reponame) VALUES "
       
   146                 b"(%s, %s, %s, %s, %s, %s, %s, %s, %s)",
       
   147                 params=(
       
   148                     ctx.hex(),
       
   149                     ctx.description(),
       
   150                     ctx.p1().hex(),
       
   151                     ctx.p2().hex(),
       
   152                     author_name,
       
   153                     committer_name,
       
   154                     author_date,
       
   155                     committer_date,
       
   156                     self.reponame,
       
   157                 ),
       
   158             )
       
   159 
       
   160     def addbookmark(self, bookmark, node):
       
   161         """Takes a bookmark name and hash, and records mapping in the metadata
       
   162         store."""
       
   163         if not self._connected:
       
   164             self.sqlconnect()
       
   165         self.log.info(
       
   166             b"ADD BOOKMARKS %r bookmark: %r node: %r"
       
   167             % (self.reponame, bookmark, node)
       
   168         )
       
   169         self.sqlcursor.execute(
       
   170             b"INSERT INTO bookmarkstonode(bookmark, node, reponame) "
       
   171             b"VALUES (%s, %s, %s) ON DUPLICATE KEY UPDATE node=VALUES(node)",
       
   172             params=(bookmark, node, self.reponame),
       
   173         )
       
   174 
       
   175     def addmanybookmarks(self, bookmarks):
       
   176         if not self._connected:
       
   177             self.sqlconnect()
       
   178         args = []
       
   179         values = []
       
   180         for bookmark, node in bookmarks.items():
       
   181             args.append(b'(%s, %s, %s)')
       
   182             values.extend((bookmark, node, self.reponame))
       
   183         args = b','.join(args)
       
   184 
       
   185         self.sqlcursor.execute(
       
   186             b"INSERT INTO bookmarkstonode(bookmark, node, reponame) "
       
   187             b"VALUES %s ON DUPLICATE KEY UPDATE node=VALUES(node)" % args,
       
   188             params=values,
       
   189         )
       
   190 
       
   191     def deletebookmarks(self, patterns):
       
   192         """Accepts list of bookmark patterns and deletes them.
       
   193         If `commit` is set then bookmark will actually be deleted. Otherwise
       
   194         deletion will be delayed until the end of transaction.
       
   195         """
       
   196         if not self._connected:
       
   197             self.sqlconnect()
       
   198         self.log.info(b"DELETE BOOKMARKS: %s" % patterns)
       
   199         for pattern in patterns:
       
   200             pattern = _convertbookmarkpattern(pattern)
       
   201             self.sqlcursor.execute(
       
   202                 b"DELETE from bookmarkstonode WHERE bookmark LIKE (%s) "
       
   203                 b"and reponame = %s",
       
   204                 params=(pattern, self.reponame),
       
   205             )
       
   206 
       
   207     def getbundle(self, node):
       
   208         """Returns the bundleid for the bundle that contains the given node."""
       
   209         if not self._connected:
       
   210             self.sqlconnect()
       
   211         self.log.info(b"GET BUNDLE %r %r" % (self.reponame, node))
       
   212         self.sqlcursor.execute(
       
   213             b"SELECT bundle from nodestobundle "
       
   214             b"WHERE node = %s AND reponame = %s",
       
   215             params=(node, self.reponame),
       
   216         )
       
   217         result = self.sqlcursor.fetchall()
       
   218         if len(result) != 1 or len(result[0]) != 1:
       
   219             self.log.info(b"No matching node")
       
   220             return None
       
   221         bundle = result[0][0]
       
   222         self.log.info(b"Found bundle %r" % bundle)
       
   223         return bundle
       
   224 
       
   225     def getnode(self, bookmark):
       
   226         """Returns the node for the given bookmark. None if it doesn't exist."""
       
   227         if not self._connected:
       
   228             self.sqlconnect()
       
   229         self.log.info(
       
   230             b"GET NODE reponame: %r bookmark: %r" % (self.reponame, bookmark)
       
   231         )
       
   232         self.sqlcursor.execute(
       
   233             b"SELECT node from bookmarkstonode WHERE "
       
   234             b"bookmark = %s AND reponame = %s",
       
   235             params=(bookmark, self.reponame),
       
   236         )
       
   237         result = self.sqlcursor.fetchall()
       
   238         if len(result) != 1 or len(result[0]) != 1:
       
   239             self.log.info(b"No matching bookmark")
       
   240             return None
       
   241         node = result[0][0]
       
   242         self.log.info(b"Found node %r" % node)
       
   243         return node
       
   244 
       
   245     def getbookmarks(self, query):
       
   246         if not self._connected:
       
   247             self.sqlconnect()
       
   248         self.log.info(
       
   249             b"QUERY BOOKMARKS reponame: %r query: %r" % (self.reponame, query)
       
   250         )
       
   251         query = _convertbookmarkpattern(query)
       
   252         self.sqlcursor.execute(
       
   253             b"SELECT bookmark, node from bookmarkstonode WHERE "
       
   254             b"reponame = %s AND bookmark LIKE %s",
       
   255             params=(self.reponame, query),
       
   256         )
       
   257         result = self.sqlcursor.fetchall()
       
   258         bookmarks = {}
       
   259         for row in result:
       
   260             if len(row) != 2:
       
   261                 self.log.info(b"Bad row returned: %s" % row)
       
   262                 continue
       
   263             bookmarks[row[0]] = row[1]
       
   264         return bookmarks
       
   265 
       
   266     def saveoptionaljsonmetadata(self, node, jsonmetadata):
       
   267         if not self._connected:
       
   268             self.sqlconnect()
       
   269         self.log.info(
       
   270             (
       
   271                 b"INSERT METADATA, QUERY BOOKMARKS reponame: %r "
       
   272                 + b"node: %r, jsonmetadata: %s"
       
   273             )
       
   274             % (self.reponame, node, jsonmetadata)
       
   275         )
       
   276 
       
   277         self.sqlcursor.execute(
       
   278             b"UPDATE nodesmetadata SET optional_json_metadata=%s WHERE "
       
   279             b"reponame=%s AND node=%s",
       
   280             params=(jsonmetadata, self.reponame, node),
       
   281         )
       
   282 
       
   283 
       
   284 class CustomConverter(mysql.connector.conversion.MySQLConverter):
       
   285     """Ensure that all values being returned are returned as python string
       
   286     (versus the default byte arrays)."""
       
   287 
       
   288     def _STRING_to_python(self, value, dsc=None):
       
   289         return str(value)
       
   290 
       
   291     def _VAR_STRING_to_python(self, value, dsc=None):
       
   292         return str(value)
       
   293 
       
   294     def _BLOB_to_python(self, value, dsc=None):
       
   295         return str(value)