vendor/github.com/spf13/cobra/flag_groups.go
changeset 260 445e01aede7e
child 265 05c40b36d3b2
equal deleted inserted replaced
259:db4911b0c721 260:445e01aede7e
       
     1 // Copyright © 2022 Steve Francia <spf@spf13.com>.
       
     2 //
       
     3 // Licensed under the Apache License, Version 2.0 (the "License");
       
     4 // you may not use this file except in compliance with the License.
       
     5 // You may obtain a copy of the License at
       
     6 // http://www.apache.org/licenses/LICENSE-2.0
       
     7 //
       
     8 // Unless required by applicable law or agreed to in writing, software
       
     9 // distributed under the License is distributed on an "AS IS" BASIS,
       
    10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
       
    11 // See the License for the specific language governing permissions and
       
    12 // limitations under the License.
       
    13 
       
    14 package cobra
       
    15 
       
    16 import (
       
    17 	"fmt"
       
    18 	"sort"
       
    19 	"strings"
       
    20 
       
    21 	flag "github.com/spf13/pflag"
       
    22 )
       
    23 
       
    24 const (
       
    25 	requiredAsGroup   = "cobra_annotation_required_if_others_set"
       
    26 	mutuallyExclusive = "cobra_annotation_mutually_exclusive"
       
    27 )
       
    28 
       
    29 // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
       
    30 // if the command is invoked with a subset (but not all) of the given flags.
       
    31 func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
       
    32 	c.mergePersistentFlags()
       
    33 	for _, v := range flagNames {
       
    34 		f := c.Flags().Lookup(v)
       
    35 		if f == nil {
       
    36 			panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
       
    37 		}
       
    38 		if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
       
    39 			// Only errs if the flag isn't found.
       
    40 			panic(err)
       
    41 		}
       
    42 	}
       
    43 }
       
    44 
       
    45 // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
       
    46 // if the command is invoked with more than one flag from the given set of flags.
       
    47 func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
       
    48 	c.mergePersistentFlags()
       
    49 	for _, v := range flagNames {
       
    50 		f := c.Flags().Lookup(v)
       
    51 		if f == nil {
       
    52 			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
       
    53 		}
       
    54 		// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
       
    55 		if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
       
    56 			panic(err)
       
    57 		}
       
    58 	}
       
    59 }
       
    60 
       
    61 // validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the
       
    62 // first error encountered.
       
    63 func (c *Command) validateFlagGroups() error {
       
    64 	if c.DisableFlagParsing {
       
    65 		return nil
       
    66 	}
       
    67 
       
    68 	flags := c.Flags()
       
    69 
       
    70 	// groupStatus format is the list of flags as a unique ID,
       
    71 	// then a map of each flag name and whether it is set or not.
       
    72 	groupStatus := map[string]map[string]bool{}
       
    73 	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
       
    74 	flags.VisitAll(func(pflag *flag.Flag) {
       
    75 		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
       
    76 		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
       
    77 	})
       
    78 
       
    79 	if err := validateRequiredFlagGroups(groupStatus); err != nil {
       
    80 		return err
       
    81 	}
       
    82 	if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
       
    83 		return err
       
    84 	}
       
    85 	return nil
       
    86 }
       
    87 
       
    88 func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
       
    89 	for _, fname := range flagnames {
       
    90 		f := fs.Lookup(fname)
       
    91 		if f == nil {
       
    92 			return false
       
    93 		}
       
    94 	}
       
    95 	return true
       
    96 }
       
    97 
       
    98 func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
       
    99 	groupInfo, found := pflag.Annotations[annotation]
       
   100 	if found {
       
   101 		for _, group := range groupInfo {
       
   102 			if groupStatus[group] == nil {
       
   103 				flagnames := strings.Split(group, " ")
       
   104 
       
   105 				// Only consider this flag group at all if all the flags are defined.
       
   106 				if !hasAllFlags(flags, flagnames...) {
       
   107 					continue
       
   108 				}
       
   109 
       
   110 				groupStatus[group] = map[string]bool{}
       
   111 				for _, name := range flagnames {
       
   112 					groupStatus[group][name] = false
       
   113 				}
       
   114 			}
       
   115 
       
   116 			groupStatus[group][pflag.Name] = pflag.Changed
       
   117 		}
       
   118 	}
       
   119 }
       
   120 
       
   121 func validateRequiredFlagGroups(data map[string]map[string]bool) error {
       
   122 	keys := sortedKeys(data)
       
   123 	for _, flagList := range keys {
       
   124 		flagnameAndStatus := data[flagList]
       
   125 
       
   126 		unset := []string{}
       
   127 		for flagname, isSet := range flagnameAndStatus {
       
   128 			if !isSet {
       
   129 				unset = append(unset, flagname)
       
   130 			}
       
   131 		}
       
   132 		if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
       
   133 			continue
       
   134 		}
       
   135 
       
   136 		// Sort values, so they can be tested/scripted against consistently.
       
   137 		sort.Strings(unset)
       
   138 		return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
       
   139 	}
       
   140 
       
   141 	return nil
       
   142 }
       
   143 
       
   144 func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
       
   145 	keys := sortedKeys(data)
       
   146 	for _, flagList := range keys {
       
   147 		flagnameAndStatus := data[flagList]
       
   148 		var set []string
       
   149 		for flagname, isSet := range flagnameAndStatus {
       
   150 			if isSet {
       
   151 				set = append(set, flagname)
       
   152 			}
       
   153 		}
       
   154 		if len(set) == 0 || len(set) == 1 {
       
   155 			continue
       
   156 		}
       
   157 
       
   158 		// Sort values, so they can be tested/scripted against consistently.
       
   159 		sort.Strings(set)
       
   160 		return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
       
   161 	}
       
   162 	return nil
       
   163 }
       
   164 
       
   165 func sortedKeys(m map[string]map[string]bool) []string {
       
   166 	keys := make([]string, len(m))
       
   167 	i := 0
       
   168 	for k := range m {
       
   169 		keys[i] = k
       
   170 		i++
       
   171 	}
       
   172 	sort.Strings(keys)
       
   173 	return keys
       
   174 }
       
   175 
       
   176 // enforceFlagGroupsForCompletion will do the following:
       
   177 // - when a flag in a group is present, other flags in the group will be marked required
       
   178 // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
       
   179 // This allows the standard completion logic to behave appropriately for flag groups
       
   180 func (c *Command) enforceFlagGroupsForCompletion() {
       
   181 	if c.DisableFlagParsing {
       
   182 		return
       
   183 	}
       
   184 
       
   185 	flags := c.Flags()
       
   186 	groupStatus := map[string]map[string]bool{}
       
   187 	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
       
   188 	c.Flags().VisitAll(func(pflag *flag.Flag) {
       
   189 		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
       
   190 		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
       
   191 	})
       
   192 
       
   193 	// If a flag that is part of a group is present, we make all the other flags
       
   194 	// of that group required so that the shell completion suggests them automatically
       
   195 	for flagList, flagnameAndStatus := range groupStatus {
       
   196 		for _, isSet := range flagnameAndStatus {
       
   197 			if isSet {
       
   198 				// One of the flags of the group is set, mark the other ones as required
       
   199 				for _, fName := range strings.Split(flagList, " ") {
       
   200 					_ = c.MarkFlagRequired(fName)
       
   201 				}
       
   202 			}
       
   203 		}
       
   204 	}
       
   205 
       
   206 	// If a flag that is mutually exclusive to others is present, we hide the other
       
   207 	// flags of that group so the shell completion does not suggest them
       
   208 	for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
       
   209 		for flagName, isSet := range flagnameAndStatus {
       
   210 			if isSet {
       
   211 				// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
       
   212 				// Don't mark the flag that is already set as hidden because it may be an
       
   213 				// array or slice flag and therefore must continue being suggested
       
   214 				for _, fName := range strings.Split(flagList, " ") {
       
   215 					if fName != flagName {
       
   216 						flag := c.Flags().Lookup(fName)
       
   217 						flag.Hidden = true
       
   218 					}
       
   219 				}
       
   220 			}
       
   221 		}
       
   222 	}
       
   223 }