aboutsummaryrefslogblamecommitdiffstats
path: root/vendor/github.com/StackExchange/wmi/wmi.go
blob: a951b1258b04d6645df7bd66569cc48e9cf36183 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486





































































































































































































































































































































































































































































































                                                                                                                    
// +build windows

/*
Package wmi provides a WQL interface for WMI on Windows.

Example code to print names of running processes:

    type Win32_Process struct {
        Name string
    }

    func main() {
        var dst []Win32_Process
        q := wmi.CreateQuery(&dst, "")
        err := wmi.Query(q, &dst)
        if err != nil {
            log.Fatal(err)
        }
        for i, v := range dst {
            println(i, v.Name)
        }
    }

*/
package wmi

import (
    "bytes"
    "errors"
    "fmt"
    "log"
    "os"
    "reflect"
    "runtime"
    "strconv"
    "strings"
    "sync"
    "time"

    "github.com/go-ole/go-ole"
    "github.com/go-ole/go-ole/oleutil"
)

var l = log.New(os.Stdout, "", log.LstdFlags)

var (
    ErrInvalidEntityType = errors.New("wmi: invalid entity type")
    // ErrNilCreateObject is the error returned if CreateObject returns nil even
    // if the error was nil.
    ErrNilCreateObject = errors.New("wmi: create object returned nil")
    lock               sync.Mutex
)

// S_FALSE is returned by CoInitializeEx if it was already called on this thread.
const S_FALSE = 0x00000001

// QueryNamespace invokes Query with the given namespace on the local machine.
func QueryNamespace(query string, dst interface{}, namespace string) error {
    return Query(query, dst, nil, namespace)
}

// Query runs the WQL query and appends the values to dst.
//
// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
// the query must have the same name in dst. Supported types are all signed and
// unsigned integers, time.Time, string, bool, or a pointer to one of those.
// Array types are not supported.
//
// By default, the local machine and default namespace are used. These can be
// changed using connectServerArgs. See
// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
//
// Query is a wrapper around DefaultClient.Query.
func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
    if DefaultClient.SWbemServicesClient == nil {
        return DefaultClient.Query(query, dst, connectServerArgs...)
    }
    return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...)
}

// A Client is an WMI query client.
//
// Its zero value (DefaultClient) is a usable client.
type Client struct {
    // NonePtrZero specifies if nil values for fields which aren't pointers
    // should be returned as the field types zero value.
    //
    // Setting this to true allows stucts without pointer fields to be used
    // without the risk failure should a nil value returned from WMI.
    NonePtrZero bool

    // PtrNil specifies if nil values for pointer fields should be returned
    // as nil.
    //
    // Setting this to true will set pointer fields to nil where WMI
    // returned nil, otherwise the types zero value will be returned.
    PtrNil bool

    // AllowMissingFields specifies that struct fields not present in the
    // query result should not result in an error.
    //
    // Setting this to true allows custom queries to be used with full
    // struct definitions instead of having to define multiple structs.
    AllowMissingFields bool

    // SWbemServiceClient is an optional SWbemServices object that can be
    // initialized and then reused across multiple queries. If it is null
    // then the method will initialize a new temporary client each time.
    SWbemServicesClient *SWbemServices
}

// DefaultClient is the default Client and is used by Query, QueryNamespace
var DefaultClient = &Client{}

// Query runs the WQL query and appends the values to dst.
//
// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
// the query must have the same name in dst. Supported types are all signed and
// unsigned integers, time.Time, string, bool, or a pointer to one of those.
// Array types are not supported.
//
// By default, the local machine and default namespace are used. These can be
// changed using connectServerArgs. See
// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
    dv := reflect.ValueOf(dst)
    if dv.Kind() != reflect.Ptr || dv.IsNil() {
        return ErrInvalidEntityType
    }
    dv = dv.Elem()
    mat, elemType := checkMultiArg(dv)
    if mat == multiArgTypeInvalid {
        return ErrInvalidEntityType
    }

    lock.Lock()
    defer lock.Unlock()
    runtime.LockOSThread()
    defer runtime.UnlockOSThread()

    err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
    if err != nil {
        oleCode := err.(*ole.OleError).Code()
        if oleCode != ole.S_OK && oleCode != S_FALSE {
            return err
        }
    }
    defer ole.CoUninitialize()

    unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
    if err != nil {
        return err
    } else if unknown == nil {
        return ErrNilCreateObject
    }
    defer unknown.Release()

    wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
    if err != nil {
        return err
    }
    defer wmi.Release()

    // service is a SWbemServices
    serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
    if err != nil {
        return err
    }
    service := serviceRaw.ToIDispatch()
    defer serviceRaw.Clear()

    // result is a SWBemObjectSet
    resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
    if err != nil {
        return err
    }
    result := resultRaw.ToIDispatch()
    defer resultRaw.Clear()

    count, err := oleInt64(result, "Count")
    if err != nil {
        return err
    }

    enumProperty, err := result.GetProperty("_NewEnum")
    if err != nil {
        return err
    }
    defer enumProperty.Clear()

    enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
    if err != nil {
        return err
    }
    if enum == nil {
        return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
    }
    defer enum.Release()

    // Initialize a slice with Count capacity
    dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))

    var errFieldMismatch error
    for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
        if err != nil {
            return err
        }

        err := func() error {
            // item is a SWbemObject, but really a Win32_Process
            item := itemRaw.ToIDispatch()
            defer item.Release()

            ev := reflect.New(elemType)
            if err = c.loadEntity(ev.Interface(), item); err != nil {
                if _, ok := err.(*ErrFieldMismatch); ok {
                    // We continue loading entities even in the face of field mismatch errors.
                    // If we encounter any other error, that other error is returned. Otherwise,
                    // an ErrFieldMismatch is returned.
                    errFieldMismatch = err
                } else {
                    return err
                }
            }
            if mat != multiArgTypeStructPtr {
                ev = ev.Elem()
            }
            dv.Set(reflect.Append(dv, ev))
            return nil
        }()
        if err != nil {
            return err
        }
    }
    return errFieldMismatch
}

// ErrFieldMismatch is returned when a field is to be loaded into a different
// type than the one it was stored from, or when a field is missing or
// unexported in the destination struct.
// StructType is the type of the struct pointed to by the destination argument.
type ErrFieldMismatch struct {
    StructType reflect.Type
    FieldName  string
    Reason     string
}

func (e *ErrFieldMismatch) Error() string {
    return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
        e.FieldName, e.StructType, e.Reason)
}

var timeType = reflect.TypeOf(time.Time{})

// loadEntity loads a SWbemObject into a struct pointer.
func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
    v := reflect.ValueOf(dst).Elem()
    for i := 0; i < v.NumField(); i++ {
        f := v.Field(i)
        of := f
        isPtr := f.Kind() == reflect.Ptr
        if isPtr {
            ptr := reflect.New(f.Type().Elem())
            f.Set(ptr)
            f = f.Elem()
        }
        n := v.Type().Field(i).Name
        if !f.CanSet() {
            return &ErrFieldMismatch{
                StructType: of.Type(),
                FieldName:  n,
                Reason:     "CanSet() is false",
            }
        }
        prop, err := oleutil.GetProperty(src, n)
        if err != nil {
            if !c.AllowMissingFields {
                errFieldMismatch = &ErrFieldMismatch{
                    StructType: of.Type(),
                    FieldName:  n,
                    Reason:     "no such struct field",
                }
            }
            continue
        }
        defer prop.Clear()

        switch val := prop.Value().(type) {
        case int8, int16, int32, int64, int:
            v := reflect.ValueOf(val).Int()
            switch f.Kind() {
            case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                f.SetInt(v)
            case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
                f.SetUint(uint64(v))
            default:
                return &ErrFieldMismatch{
                    StructType: of.Type(),
                    FieldName:  n,
                    Reason:     "not an integer class",
                }
            }
        case uint8, uint16, uint32, uint64:
            v := reflect.ValueOf(val).Uint()
            switch f.Kind() {
            case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                f.SetInt(int64(v))
            case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
                f.SetUint(v)
            default:
                return &ErrFieldMismatch{
                    StructType: of.Type(),
                    FieldName:  n,
                    Reason:     "not an integer class",
                }
            }
        case string:
            switch f.Kind() {
            case reflect.String:
                f.SetString(val)
            case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                iv, err := strconv.ParseInt(val, 10, 64)
                if err != nil {
                    return err
                }
                f.SetInt(iv)
            case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
                uv, err := strconv.ParseUint(val, 10, 64)
                if err != nil {
                    return err
                }
                f.SetUint(uv)
            case reflect.Struct:
                switch f.Type() {
                case timeType:
                    if len(val) == 25 {
                        mins, err := strconv.Atoi(val[22:])
                        if err != nil {
                            return err
                        }
                        val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
                    }
                    t, err := time.Parse("20060102150405.000000-0700", val)
                    if err != nil {
                        return err
                    }
                    f.Set(reflect.ValueOf(t))
                }
            }
        case bool:
            switch f.Kind() {
            case reflect.Bool:
                f.SetBool(val)
            default:
                return &ErrFieldMismatch{
                    StructType: of.Type(),
                    FieldName:  n,
                    Reason:     "not a bool",
                }
            }
        case float32:
            switch f.Kind() {
            case reflect.Float32:
                f.SetFloat(float64(val))
            default:
                return &ErrFieldMismatch{
                    StructType: of.Type(),
                    FieldName:  n,
                    Reason:     "not a Float32",
                }
            }
        default:
            if f.Kind() == reflect.Slice {
                switch f.Type().Elem().Kind() {
                case reflect.String:
                    safeArray := prop.ToArray()
                    if safeArray != nil {
                        arr := safeArray.ToValueArray()
                        fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
                        for i, v := range arr {
                            s := fArr.Index(i)
                            s.SetString(v.(string))
                        }
                        f.Set(fArr)
                    }
                case reflect.Uint8:
                    safeArray := prop.ToArray()
                    if safeArray != nil {
                        arr := safeArray.ToValueArray()
                        fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
                        for i, v := range arr {
                            s := fArr.Index(i)
                            s.SetUint(reflect.ValueOf(v).Uint())
                        }
                        f.Set(fArr)
                    }
                default:
                    return &ErrFieldMismatch{
                        StructType: of.Type(),
                        FieldName:  n,
                        Reason:     fmt.Sprintf("unsupported slice type (%T)", val),
                    }
                }
            } else {
                typeof := reflect.TypeOf(val)
                if typeof == nil && (isPtr || c.NonePtrZero) {
                    if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) {
                        of.Set(reflect.Zero(of.Type()))
                    }
                    break
                }
                return &ErrFieldMismatch{
                    StructType: of.Type(),
                    FieldName:  n,
                    Reason:     fmt.Sprintf("unsupported type (%T)", val),
                }
            }
        }
    }
    return errFieldMismatch
}

type multiArgType int

const (
    multiArgTypeInvalid multiArgType = iota
    multiArgTypeStruct
    multiArgTypeStructPtr
)

// checkMultiArg checks that v has type []S, []*S for some struct type S.
//
// It returns what category the slice's elements are, and the reflect.Type
// that represents S.
func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
    if v.Kind() != reflect.Slice {
        return multiArgTypeInvalid, nil
    }
    elemType = v.Type().Elem()
    switch elemType.Kind() {
    case reflect.Struct:
        return multiArgTypeStruct, elemType
    case reflect.Ptr:
        elemType = elemType.Elem()
        if elemType.Kind() == reflect.Struct {
            return multiArgTypeStructPtr, elemType
        }
    }
    return multiArgTypeInvalid, nil
}

func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
    v, err := oleutil.GetProperty(item, prop)
    if err != nil {
        return 0, err
    }
    defer v.Clear()

    i := int64(v.Val)
    return i, nil
}

// CreateQuery returns a WQL query string that queries all columns of src. where
// is an optional string that is appended to the query, to be used with WHERE
// clauses. In such a case, the "WHERE" string should appear at the beginning.
func CreateQuery(src interface{}, where string) string {
    var b bytes.Buffer
    b.WriteString("SELECT ")
    s := reflect.Indirect(reflect.ValueOf(src))
    t := s.Type()
    if s.Kind() == reflect.Slice {
        t = t.Elem()
    }
    if t.Kind() != reflect.Struct {
        return ""
    }
    var fields []string
    for i := 0; i < t.NumField(); i++ {
        fields = append(fields, t.Field(i).Name)
    }
    b.WriteString(strings.Join(fields, ", "))
    b.WriteString(" FROM ")
    b.WriteString(t.Name())
    b.WriteString(" " + where)
    return b.String()
}