面试题答案
一键面试Go 语言中 WaitGroup 的底层实现原理
- 数据结构:
WaitGroup
的核心数据结构定义在src/sync/waitgroup.go
中。其主要结构如下:
type WaitGroup struct { noCopy noCopy state1 [3]uint32 }
noCopy
是一个空结构体,用于防止WaitGroup
被拷贝,因为WaitGroup
内部状态是不可拷贝的。state1
数组包含了WaitGroup
的状态信息,前两个uint32
用于存储计数器的值,第三个uint32
用于存储等待队列的指针等信息。在 64 位系统上,state1
会被视为一个 12 字节的连续内存区域,其中高 32 位和低 32 位共同组成一个 64 位的计数器值,剩下的 32 位用于存储其他状态信息(如等待队列相关信息)。
- 核心方法的实现逻辑:
Add(delta int)
方法:Add
方法用于增加等待组的计数器。它会将传入的delta
值加到计数器上。如果计数器的值小于 0,会直接抛出一个恐慌(panic
),因为计数器不能为负。- 示例代码:
func (wg *WaitGroup) Add(delta int) { statep, semap := wg.state() state := atomic.AddUint64(statep, uint64(delta)<<32) v := int32(state >> 32) if v < 0 { panic("sync: negative WaitGroup counter") } if delta > 0 && v == int32(delta) { return } for ; state.Load() != 0; state.Store(state.Load() - 1) { runtime_Semacquire(semap) } }
Done()
方法:Done
方法实际上是Add(-1)
的快捷方式,它将等待组的计数器减 1。- 示例代码:
func (wg *WaitGroup) Done() { wg.Add(-1) }
Wait()
方法:Wait
方法会阻塞当前 goroutine,直到等待组的计数器变为 0。它通过检查计数器的值,如果计数器不为 0,则将当前 goroutine 放入等待队列,并挂起该 goroutine。当计数器变为 0 时,会唤醒等待队列中的所有 goroutine。- 示例代码:
func (wg *WaitGroup) Wait() { statep, semap := wg.state() for { state := atomic.LoadUint64(statep) v := int32(state >> 32) if v == 0 { return } atomic.AddUint64(statep, uint64(^uint32(0))<<32) runtime_Semacquire(semap) atomic.AddUint64(statep, uint64(1)<<32) } }
扩展 WaitGroup 支持任务优先级的设计思路
-
设计思路:
- 我们可以创建一个新的结构体,例如
PriorityWaitGroup
,它内部包含一个WaitGroup
以及一个用于存储任务优先级信息的数据结构,比如一个优先队列(可以使用堆来实现)。 - 当调用类似
AddWithPriority
方法时,将任务和其优先级信息加入到优先队列中。 Done
方法除了减少WaitGroup
的计数器外,还需要从优先队列中移除对应的任务。Wait
方法在等待WaitGroup
计数器为 0 的同时,要确保优先队列中高优先级任务都已完成。
- 我们可以创建一个新的结构体,例如
-
关键代码片段:
- 首先定义优先队列相关的数据结构和方法:
type Task struct { priority int // 这里可以添加任务相关的其他信息 } type PriorityQueue []Task func (pq PriorityQueue) Len() int { return len(pq) } func (pq PriorityQueue) Less(i, j int) bool { return pq[i].priority > pq[j].priority } func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } func (pq *PriorityQueue) Push(x interface{}) { *pq = append(*pq, x.(Task)) } func (pq *PriorityQueue) Pop() interface{} { old := *pq n := len(old) item := old[n - 1] *pq = old[0 : n - 1] return item }
- 然后定义
PriorityWaitGroup
及其方法:
type PriorityWaitGroup struct { wg sync.WaitGroup pq PriorityQueue taskMap map[*Task]struct{} } func (pwg *PriorityWaitGroup) AddWithPriority(priority int) { task := Task{priority: priority} heap.Init(&pwg.pq) heap.Push(&pwg.pq, task) if pwg.taskMap == nil { pwg.taskMap = make(map[*Task]struct{}) } pwg.taskMap[&task] = struct{}{} pwg.wg.Add(1) } func (pwg *PriorityWaitGroup) Done() { // 找到当前完成的任务并从优先队列和任务映射中移除 for task := range pwg.taskMap { if task.priority == 0 { delete(pwg.taskMap, task) break } } pwg.wg.Done() } func (pwg *PriorityWaitGroup) Wait() { for { heap.Init(&pwg.pq) if len(pwg.pq) == 0 || pwg.pq[0].priority == 0 { break } time.Sleep(time.Millisecond) } pwg.wg.Wait() }
以上代码只是一个简单的示例,实际应用中可能需要根据具体需求进行更多的完善和优化,例如更好的错误处理、任务标识等。