|
1 // Copyright © 2022 Steve Francia <spf@spf13.com>. |
|
2 // |
|
3 // Licensed under the Apache License, Version 2.0 (the "License"); |
|
4 // you may not use this file except in compliance with the License. |
|
5 // You may obtain a copy of the License at |
|
6 // http://www.apache.org/licenses/LICENSE-2.0 |
|
7 // |
|
8 // Unless required by applicable law or agreed to in writing, software |
|
9 // distributed under the License is distributed on an "AS IS" BASIS, |
|
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
11 // See the License for the specific language governing permissions and |
|
12 // limitations under the License. |
|
13 |
|
14 package cobra |
|
15 |
|
16 import ( |
|
17 "fmt" |
|
18 "sort" |
|
19 "strings" |
|
20 |
|
21 flag "github.com/spf13/pflag" |
|
22 ) |
|
23 |
|
24 const ( |
|
25 requiredAsGroup = "cobra_annotation_required_if_others_set" |
|
26 mutuallyExclusive = "cobra_annotation_mutually_exclusive" |
|
27 ) |
|
28 |
|
29 // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors |
|
30 // if the command is invoked with a subset (but not all) of the given flags. |
|
31 func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) { |
|
32 c.mergePersistentFlags() |
|
33 for _, v := range flagNames { |
|
34 f := c.Flags().Lookup(v) |
|
35 if f == nil { |
|
36 panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v)) |
|
37 } |
|
38 if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil { |
|
39 // Only errs if the flag isn't found. |
|
40 panic(err) |
|
41 } |
|
42 } |
|
43 } |
|
44 |
|
45 // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors |
|
46 // if the command is invoked with more than one flag from the given set of flags. |
|
47 func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) { |
|
48 c.mergePersistentFlags() |
|
49 for _, v := range flagNames { |
|
50 f := c.Flags().Lookup(v) |
|
51 if f == nil { |
|
52 panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v)) |
|
53 } |
|
54 // Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed. |
|
55 if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil { |
|
56 panic(err) |
|
57 } |
|
58 } |
|
59 } |
|
60 |
|
61 // validateFlagGroups validates the mutuallyExclusive/requiredAsGroup logic and returns the |
|
62 // first error encountered. |
|
63 func (c *Command) validateFlagGroups() error { |
|
64 if c.DisableFlagParsing { |
|
65 return nil |
|
66 } |
|
67 |
|
68 flags := c.Flags() |
|
69 |
|
70 // groupStatus format is the list of flags as a unique ID, |
|
71 // then a map of each flag name and whether it is set or not. |
|
72 groupStatus := map[string]map[string]bool{} |
|
73 mutuallyExclusiveGroupStatus := map[string]map[string]bool{} |
|
74 flags.VisitAll(func(pflag *flag.Flag) { |
|
75 processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) |
|
76 processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) |
|
77 }) |
|
78 |
|
79 if err := validateRequiredFlagGroups(groupStatus); err != nil { |
|
80 return err |
|
81 } |
|
82 if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil { |
|
83 return err |
|
84 } |
|
85 return nil |
|
86 } |
|
87 |
|
88 func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool { |
|
89 for _, fname := range flagnames { |
|
90 f := fs.Lookup(fname) |
|
91 if f == nil { |
|
92 return false |
|
93 } |
|
94 } |
|
95 return true |
|
96 } |
|
97 |
|
98 func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) { |
|
99 groupInfo, found := pflag.Annotations[annotation] |
|
100 if found { |
|
101 for _, group := range groupInfo { |
|
102 if groupStatus[group] == nil { |
|
103 flagnames := strings.Split(group, " ") |
|
104 |
|
105 // Only consider this flag group at all if all the flags are defined. |
|
106 if !hasAllFlags(flags, flagnames...) { |
|
107 continue |
|
108 } |
|
109 |
|
110 groupStatus[group] = map[string]bool{} |
|
111 for _, name := range flagnames { |
|
112 groupStatus[group][name] = false |
|
113 } |
|
114 } |
|
115 |
|
116 groupStatus[group][pflag.Name] = pflag.Changed |
|
117 } |
|
118 } |
|
119 } |
|
120 |
|
121 func validateRequiredFlagGroups(data map[string]map[string]bool) error { |
|
122 keys := sortedKeys(data) |
|
123 for _, flagList := range keys { |
|
124 flagnameAndStatus := data[flagList] |
|
125 |
|
126 unset := []string{} |
|
127 for flagname, isSet := range flagnameAndStatus { |
|
128 if !isSet { |
|
129 unset = append(unset, flagname) |
|
130 } |
|
131 } |
|
132 if len(unset) == len(flagnameAndStatus) || len(unset) == 0 { |
|
133 continue |
|
134 } |
|
135 |
|
136 // Sort values, so they can be tested/scripted against consistently. |
|
137 sort.Strings(unset) |
|
138 return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset) |
|
139 } |
|
140 |
|
141 return nil |
|
142 } |
|
143 |
|
144 func validateExclusiveFlagGroups(data map[string]map[string]bool) error { |
|
145 keys := sortedKeys(data) |
|
146 for _, flagList := range keys { |
|
147 flagnameAndStatus := data[flagList] |
|
148 var set []string |
|
149 for flagname, isSet := range flagnameAndStatus { |
|
150 if isSet { |
|
151 set = append(set, flagname) |
|
152 } |
|
153 } |
|
154 if len(set) == 0 || len(set) == 1 { |
|
155 continue |
|
156 } |
|
157 |
|
158 // Sort values, so they can be tested/scripted against consistently. |
|
159 sort.Strings(set) |
|
160 return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set) |
|
161 } |
|
162 return nil |
|
163 } |
|
164 |
|
165 func sortedKeys(m map[string]map[string]bool) []string { |
|
166 keys := make([]string, len(m)) |
|
167 i := 0 |
|
168 for k := range m { |
|
169 keys[i] = k |
|
170 i++ |
|
171 } |
|
172 sort.Strings(keys) |
|
173 return keys |
|
174 } |
|
175 |
|
176 // enforceFlagGroupsForCompletion will do the following: |
|
177 // - when a flag in a group is present, other flags in the group will be marked required |
|
178 // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden |
|
179 // This allows the standard completion logic to behave appropriately for flag groups |
|
180 func (c *Command) enforceFlagGroupsForCompletion() { |
|
181 if c.DisableFlagParsing { |
|
182 return |
|
183 } |
|
184 |
|
185 flags := c.Flags() |
|
186 groupStatus := map[string]map[string]bool{} |
|
187 mutuallyExclusiveGroupStatus := map[string]map[string]bool{} |
|
188 c.Flags().VisitAll(func(pflag *flag.Flag) { |
|
189 processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus) |
|
190 processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus) |
|
191 }) |
|
192 |
|
193 // If a flag that is part of a group is present, we make all the other flags |
|
194 // of that group required so that the shell completion suggests them automatically |
|
195 for flagList, flagnameAndStatus := range groupStatus { |
|
196 for _, isSet := range flagnameAndStatus { |
|
197 if isSet { |
|
198 // One of the flags of the group is set, mark the other ones as required |
|
199 for _, fName := range strings.Split(flagList, " ") { |
|
200 _ = c.MarkFlagRequired(fName) |
|
201 } |
|
202 } |
|
203 } |
|
204 } |
|
205 |
|
206 // If a flag that is mutually exclusive to others is present, we hide the other |
|
207 // flags of that group so the shell completion does not suggest them |
|
208 for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus { |
|
209 for flagName, isSet := range flagnameAndStatus { |
|
210 if isSet { |
|
211 // One of the flags of the mutually exclusive group is set, mark the other ones as hidden |
|
212 // Don't mark the flag that is already set as hidden because it may be an |
|
213 // array or slice flag and therefore must continue being suggested |
|
214 for _, fName := range strings.Split(flagList, " ") { |
|
215 if fName != flagName { |
|
216 flag := c.Flags().Lookup(fName) |
|
217 flag.Hidden = true |
|
218 } |
|
219 } |
|
220 } |
|
221 } |
|
222 } |
|
223 } |