mercurial/simplemerge.py
changeset 6002 abd66eb0889e
parent 5081 ea7b982b6c08
child 6212 e75aab656f46
equal deleted inserted replaced
6001:30d2fecaab76 6002:abd66eb0889e
       
     1 #!/usr/bin/env python
       
     2 # Copyright (C) 2004, 2005 Canonical Ltd
       
     3 #
       
     4 # This program is free software; you can redistribute it and/or modify
       
     5 # it under the terms of the GNU General Public License as published by
       
     6 # the Free Software Foundation; either version 2 of the License, or
       
     7 # (at your option) any later version.
       
     8 #
       
     9 # This program is distributed in the hope that it will be useful,
       
    10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
       
    11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
       
    12 # GNU General Public License for more details.
       
    13 #
       
    14 # You should have received a copy of the GNU General Public License
       
    15 # along with this program; if not, write to the Free Software
       
    16 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
       
    17 
       
    18 # mbp: "you know that thing where cvs gives you conflict markers?"
       
    19 # s: "i hate that."
       
    20 
       
    21 from i18n import _
       
    22 import util, mdiff, fancyopts, sys, os
       
    23 
       
    24 class CantReprocessAndShowBase(Exception):
       
    25     pass
       
    26 
       
    27 def warn(message):
       
    28     sys.stdout.flush()
       
    29     sys.stderr.write(message)
       
    30     sys.stderr.flush()
       
    31 
       
    32 def intersect(ra, rb):
       
    33     """Given two ranges return the range where they intersect or None.
       
    34 
       
    35     >>> intersect((0, 10), (0, 6))
       
    36     (0, 6)
       
    37     >>> intersect((0, 10), (5, 15))
       
    38     (5, 10)
       
    39     >>> intersect((0, 10), (10, 15))
       
    40     >>> intersect((0, 9), (10, 15))
       
    41     >>> intersect((0, 9), (7, 15))
       
    42     (7, 9)
       
    43     """
       
    44     assert ra[0] <= ra[1]
       
    45     assert rb[0] <= rb[1]
       
    46 
       
    47     sa = max(ra[0], rb[0])
       
    48     sb = min(ra[1], rb[1])
       
    49     if sa < sb:
       
    50         return sa, sb
       
    51     else:
       
    52         return None
       
    53 
       
    54 def compare_range(a, astart, aend, b, bstart, bend):
       
    55     """Compare a[astart:aend] == b[bstart:bend], without slicing.
       
    56     """
       
    57     if (aend-astart) != (bend-bstart):
       
    58         return False
       
    59     for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
       
    60         if a[ia] != b[ib]:
       
    61             return False
       
    62     else:
       
    63         return True
       
    64 
       
    65 class Merge3Text(object):
       
    66     """3-way merge of texts.
       
    67 
       
    68     Given strings BASE, OTHER, THIS, tries to produce a combined text
       
    69     incorporating the changes from both BASE->OTHER and BASE->THIS."""
       
    70     def __init__(self, basetext, atext, btext, base=None, a=None, b=None):
       
    71         self.basetext = basetext
       
    72         self.atext = atext
       
    73         self.btext = btext
       
    74         if base is None:
       
    75             base = mdiff.splitnewlines(basetext)
       
    76         if a is None:
       
    77             a = mdiff.splitnewlines(atext)
       
    78         if b is None:
       
    79             b = mdiff.splitnewlines(btext)
       
    80         self.base = base
       
    81         self.a = a
       
    82         self.b = b
       
    83 
       
    84     def merge_lines(self,
       
    85                     name_a=None,
       
    86                     name_b=None,
       
    87                     name_base=None,
       
    88                     start_marker='<<<<<<<',
       
    89                     mid_marker='=======',
       
    90                     end_marker='>>>>>>>',
       
    91                     base_marker=None,
       
    92                     reprocess=False):
       
    93         """Return merge in cvs-like form.
       
    94         """
       
    95         self.conflicts = False
       
    96         newline = '\n'
       
    97         if len(self.a) > 0:
       
    98             if self.a[0].endswith('\r\n'):
       
    99                 newline = '\r\n'
       
   100             elif self.a[0].endswith('\r'):
       
   101                 newline = '\r'
       
   102         if base_marker and reprocess:
       
   103             raise CantReprocessAndShowBase()
       
   104         if name_a:
       
   105             start_marker = start_marker + ' ' + name_a
       
   106         if name_b:
       
   107             end_marker = end_marker + ' ' + name_b
       
   108         if name_base and base_marker:
       
   109             base_marker = base_marker + ' ' + name_base
       
   110         merge_regions = self.merge_regions()
       
   111         if reprocess is True:
       
   112             merge_regions = self.reprocess_merge_regions(merge_regions)
       
   113         for t in merge_regions:
       
   114             what = t[0]
       
   115             if what == 'unchanged':
       
   116                 for i in range(t[1], t[2]):
       
   117                     yield self.base[i]
       
   118             elif what == 'a' or what == 'same':
       
   119                 for i in range(t[1], t[2]):
       
   120                     yield self.a[i]
       
   121             elif what == 'b':
       
   122                 for i in range(t[1], t[2]):
       
   123                     yield self.b[i]
       
   124             elif what == 'conflict':
       
   125                 self.conflicts = True
       
   126                 yield start_marker + newline
       
   127                 for i in range(t[3], t[4]):
       
   128                     yield self.a[i]
       
   129                 if base_marker is not None:
       
   130                     yield base_marker + newline
       
   131                     for i in range(t[1], t[2]):
       
   132                         yield self.base[i]
       
   133                 yield mid_marker + newline
       
   134                 for i in range(t[5], t[6]):
       
   135                     yield self.b[i]
       
   136                 yield end_marker + newline
       
   137             else:
       
   138                 raise ValueError(what)
       
   139 
       
   140     def merge_annotated(self):
       
   141         """Return merge with conflicts, showing origin of lines.
       
   142 
       
   143         Most useful for debugging merge.
       
   144         """
       
   145         for t in self.merge_regions():
       
   146             what = t[0]
       
   147             if what == 'unchanged':
       
   148                 for i in range(t[1], t[2]):
       
   149                     yield 'u | ' + self.base[i]
       
   150             elif what == 'a' or what == 'same':
       
   151                 for i in range(t[1], t[2]):
       
   152                     yield what[0] + ' | ' + self.a[i]
       
   153             elif what == 'b':
       
   154                 for i in range(t[1], t[2]):
       
   155                     yield 'b | ' + self.b[i]
       
   156             elif what == 'conflict':
       
   157                 yield '<<<<\n'
       
   158                 for i in range(t[3], t[4]):
       
   159                     yield 'A | ' + self.a[i]
       
   160                 yield '----\n'
       
   161                 for i in range(t[5], t[6]):
       
   162                     yield 'B | ' + self.b[i]
       
   163                 yield '>>>>\n'
       
   164             else:
       
   165                 raise ValueError(what)
       
   166 
       
   167     def merge_groups(self):
       
   168         """Yield sequence of line groups.  Each one is a tuple:
       
   169 
       
   170         'unchanged', lines
       
   171              Lines unchanged from base
       
   172 
       
   173         'a', lines
       
   174              Lines taken from a
       
   175 
       
   176         'same', lines
       
   177              Lines taken from a (and equal to b)
       
   178 
       
   179         'b', lines
       
   180              Lines taken from b
       
   181 
       
   182         'conflict', base_lines, a_lines, b_lines
       
   183              Lines from base were changed to either a or b and conflict.
       
   184         """
       
   185         for t in self.merge_regions():
       
   186             what = t[0]
       
   187             if what == 'unchanged':
       
   188                 yield what, self.base[t[1]:t[2]]
       
   189             elif what == 'a' or what == 'same':
       
   190                 yield what, self.a[t[1]:t[2]]
       
   191             elif what == 'b':
       
   192                 yield what, self.b[t[1]:t[2]]
       
   193             elif what == 'conflict':
       
   194                 yield (what,
       
   195                        self.base[t[1]:t[2]],
       
   196                        self.a[t[3]:t[4]],
       
   197                        self.b[t[5]:t[6]])
       
   198             else:
       
   199                 raise ValueError(what)
       
   200 
       
   201     def merge_regions(self):
       
   202         """Return sequences of matching and conflicting regions.
       
   203 
       
   204         This returns tuples, where the first value says what kind we
       
   205         have:
       
   206 
       
   207         'unchanged', start, end
       
   208              Take a region of base[start:end]
       
   209 
       
   210         'same', astart, aend
       
   211              b and a are different from base but give the same result
       
   212 
       
   213         'a', start, end
       
   214              Non-clashing insertion from a[start:end]
       
   215 
       
   216         Method is as follows:
       
   217 
       
   218         The two sequences align only on regions which match the base
       
   219         and both descendents.  These are found by doing a two-way diff
       
   220         of each one against the base, and then finding the
       
   221         intersections between those regions.  These "sync regions"
       
   222         are by definition unchanged in both and easily dealt with.
       
   223 
       
   224         The regions in between can be in any of three cases:
       
   225         conflicted, or changed on only one side.
       
   226         """
       
   227 
       
   228         # section a[0:ia] has been disposed of, etc
       
   229         iz = ia = ib = 0
       
   230 
       
   231         for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
       
   232             #print 'match base [%d:%d]' % (zmatch, zend)
       
   233 
       
   234             matchlen = zend - zmatch
       
   235             assert matchlen >= 0
       
   236             assert matchlen == (aend - amatch)
       
   237             assert matchlen == (bend - bmatch)
       
   238 
       
   239             len_a = amatch - ia
       
   240             len_b = bmatch - ib
       
   241             len_base = zmatch - iz
       
   242             assert len_a >= 0
       
   243             assert len_b >= 0
       
   244             assert len_base >= 0
       
   245 
       
   246             #print 'unmatched a=%d, b=%d' % (len_a, len_b)
       
   247 
       
   248             if len_a or len_b:
       
   249                 # try to avoid actually slicing the lists
       
   250                 equal_a = compare_range(self.a, ia, amatch,
       
   251                                         self.base, iz, zmatch)
       
   252                 equal_b = compare_range(self.b, ib, bmatch,
       
   253                                         self.base, iz, zmatch)
       
   254                 same = compare_range(self.a, ia, amatch,
       
   255                                      self.b, ib, bmatch)
       
   256 
       
   257                 if same:
       
   258                     yield 'same', ia, amatch
       
   259                 elif equal_a and not equal_b:
       
   260                     yield 'b', ib, bmatch
       
   261                 elif equal_b and not equal_a:
       
   262                     yield 'a', ia, amatch
       
   263                 elif not equal_a and not equal_b:
       
   264                     yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
       
   265                 else:
       
   266                     raise AssertionError("can't handle a=b=base but unmatched")
       
   267 
       
   268                 ia = amatch
       
   269                 ib = bmatch
       
   270             iz = zmatch
       
   271 
       
   272             # if the same part of the base was deleted on both sides
       
   273             # that's OK, we can just skip it.
       
   274 
       
   275 
       
   276             if matchlen > 0:
       
   277                 assert ia == amatch
       
   278                 assert ib == bmatch
       
   279                 assert iz == zmatch
       
   280 
       
   281                 yield 'unchanged', zmatch, zend
       
   282                 iz = zend
       
   283                 ia = aend
       
   284                 ib = bend
       
   285 
       
   286     def reprocess_merge_regions(self, merge_regions):
       
   287         """Where there are conflict regions, remove the agreed lines.
       
   288 
       
   289         Lines where both A and B have made the same changes are
       
   290         eliminated.
       
   291         """
       
   292         for region in merge_regions:
       
   293             if region[0] != "conflict":
       
   294                 yield region
       
   295                 continue
       
   296             type, iz, zmatch, ia, amatch, ib, bmatch = region
       
   297             a_region = self.a[ia:amatch]
       
   298             b_region = self.b[ib:bmatch]
       
   299             matches = mdiff.get_matching_blocks(''.join(a_region),
       
   300                                                 ''.join(b_region))
       
   301             next_a = ia
       
   302             next_b = ib
       
   303             for region_ia, region_ib, region_len in matches[:-1]:
       
   304                 region_ia += ia
       
   305                 region_ib += ib
       
   306                 reg = self.mismatch_region(next_a, region_ia, next_b,
       
   307                                            region_ib)
       
   308                 if reg is not None:
       
   309                     yield reg
       
   310                 yield 'same', region_ia, region_len+region_ia
       
   311                 next_a = region_ia + region_len
       
   312                 next_b = region_ib + region_len
       
   313             reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
       
   314             if reg is not None:
       
   315                 yield reg
       
   316 
       
   317     def mismatch_region(next_a, region_ia,  next_b, region_ib):
       
   318         if next_a < region_ia or next_b < region_ib:
       
   319             return 'conflict', None, None, next_a, region_ia, next_b, region_ib
       
   320     mismatch_region = staticmethod(mismatch_region)
       
   321 
       
   322     def find_sync_regions(self):
       
   323         """Return a list of sync regions, where both descendents match the base.
       
   324 
       
   325         Generates a list of (base1, base2, a1, a2, b1, b2).  There is
       
   326         always a zero-length sync region at the end of all the files.
       
   327         """
       
   328 
       
   329         ia = ib = 0
       
   330         amatches = mdiff.get_matching_blocks(self.basetext, self.atext)
       
   331         bmatches = mdiff.get_matching_blocks(self.basetext, self.btext)
       
   332         len_a = len(amatches)
       
   333         len_b = len(bmatches)
       
   334 
       
   335         sl = []
       
   336 
       
   337         while ia < len_a and ib < len_b:
       
   338             abase, amatch, alen = amatches[ia]
       
   339             bbase, bmatch, blen = bmatches[ib]
       
   340 
       
   341             # there is an unconflicted block at i; how long does it
       
   342             # extend?  until whichever one ends earlier.
       
   343             i = intersect((abase, abase+alen), (bbase, bbase+blen))
       
   344             if i:
       
   345                 intbase = i[0]
       
   346                 intend = i[1]
       
   347                 intlen = intend - intbase
       
   348 
       
   349                 # found a match of base[i[0], i[1]]; this may be less than
       
   350                 # the region that matches in either one
       
   351                 assert intlen <= alen
       
   352                 assert intlen <= blen
       
   353                 assert abase <= intbase
       
   354                 assert bbase <= intbase
       
   355 
       
   356                 asub = amatch + (intbase - abase)
       
   357                 bsub = bmatch + (intbase - bbase)
       
   358                 aend = asub + intlen
       
   359                 bend = bsub + intlen
       
   360 
       
   361                 assert self.base[intbase:intend] == self.a[asub:aend], \
       
   362                        (self.base[intbase:intend], self.a[asub:aend])
       
   363 
       
   364                 assert self.base[intbase:intend] == self.b[bsub:bend]
       
   365 
       
   366                 sl.append((intbase, intend,
       
   367                            asub, aend,
       
   368                            bsub, bend))
       
   369 
       
   370             # advance whichever one ends first in the base text
       
   371             if (abase + alen) < (bbase + blen):
       
   372                 ia += 1
       
   373             else:
       
   374                 ib += 1
       
   375 
       
   376         intbase = len(self.base)
       
   377         abase = len(self.a)
       
   378         bbase = len(self.b)
       
   379         sl.append((intbase, intbase, abase, abase, bbase, bbase))
       
   380 
       
   381         return sl
       
   382 
       
   383     def find_unconflicted(self):
       
   384         """Return a list of ranges in base that are not conflicted."""
       
   385         am = mdiff.get_matching_blocks(self.basetext, self.atext)
       
   386         bm = mdiff.get_matching_blocks(self.basetext, self.btext)
       
   387 
       
   388         unc = []
       
   389 
       
   390         while am and bm:
       
   391             # there is an unconflicted block at i; how long does it
       
   392             # extend?  until whichever one ends earlier.
       
   393             a1 = am[0][0]
       
   394             a2 = a1 + am[0][2]
       
   395             b1 = bm[0][0]
       
   396             b2 = b1 + bm[0][2]
       
   397             i = intersect((a1, a2), (b1, b2))
       
   398             if i:
       
   399                 unc.append(i)
       
   400 
       
   401             if a2 < b2:
       
   402                 del am[0]
       
   403             else:
       
   404                 del bm[0]
       
   405 
       
   406         return unc
       
   407 
       
   408 def simplemerge(local, base, other, **opts):
       
   409     def readfile(filename):
       
   410         f = open(filename, "rb")
       
   411         text = f.read()
       
   412         f.close()
       
   413         if util.binary(text):
       
   414             msg = _("%s looks like a binary file.") % filename
       
   415             if not opts.get('text'):
       
   416                 raise util.Abort(msg)
       
   417             elif not opts.get('quiet'):
       
   418                 warn(_('warning: %s\n') % msg)
       
   419         return text
       
   420 
       
   421     name_a = local
       
   422     name_b = other
       
   423     labels = opts.get('label', [])
       
   424     if labels:
       
   425         name_a = labels.pop(0)
       
   426     if labels:
       
   427         name_b = labels.pop(0)
       
   428     if labels:
       
   429         raise util.Abort(_("can only specify two labels."))
       
   430 
       
   431     localtext = readfile(local)
       
   432     basetext = readfile(base)
       
   433     othertext = readfile(other)
       
   434 
       
   435     orig = local
       
   436     local = os.path.realpath(local)
       
   437     if not opts.get('print'):
       
   438         opener = util.opener(os.path.dirname(local))
       
   439         out = opener(os.path.basename(local), "w", atomictemp=True)
       
   440     else:
       
   441         out = sys.stdout
       
   442 
       
   443     reprocess = not opts.get('no_minimal')
       
   444 
       
   445     m3 = Merge3Text(basetext, localtext, othertext)
       
   446     for line in m3.merge_lines(name_a=name_a, name_b=name_b,
       
   447                                reprocess=reprocess):
       
   448         out.write(line)
       
   449 
       
   450     if not opts.get('print'):
       
   451         out.rename()
       
   452 
       
   453     if m3.conflicts:
       
   454         if not opts.get('quiet'):
       
   455             warn(_("warning: conflicts during merge.\n"))
       
   456         return 1