Optimize performs the optimization of the parameters.
()
| 35 | |
| 36 | // Optimize performs the optimization of the parameters. |
| 37 | func (o *Optimizer) Optimize() error { |
| 38 | var wg sync.WaitGroup |
| 39 | guard := make(chan struct{}, runtime.NumCPU()*2) |
| 40 | errCh := make(chan error, 1) |
| 41 | |
| 42 | ctx, cancel := context.WithCancel(context.Background()) |
| 43 | defer cancel() |
| 44 | |
| 45 | for param := range o.parameters(ctx) { |
| 46 | select { |
| 47 | case err := <-errCh: |
| 48 | cancel() // As soon as an error occurs, stop the iteration over parameters |
| 49 | wg.Wait() // Wait for running goroutines to finish |
| 50 | return err |
| 51 | default: |
| 52 | param := param |
| 53 | wg.Add(1) |
| 54 | guard <- struct{}{} |
| 55 | go func() { |
| 56 | defer wg.Done() |
| 57 | defer func() { <-guard }() |
| 58 | if !param.HasGrad() { |
| 59 | return |
| 60 | } |
| 61 | if err := o.strategy.OptimizeParams(param); err != nil { |
| 62 | select { |
| 63 | case errCh <- err: |
| 64 | default: |
| 65 | } |
| 66 | } |
| 67 | }() |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | close(errCh) |
| 72 | |
| 73 | if err, ok := <-errCh; ok { |
| 74 | return err |
| 75 | } |
| 76 | |
| 77 | return nil |
| 78 | } |
no test coverage detected