| 634 | } |
| 635 | |
| 636 | func (s *Socket) Accept(flags int) (*Socket, syscall.Errno, error) { |
| 637 | if !s.FD.IncRef() { |
| 638 | return nil, unix.EBADF, nil |
| 639 | } |
| 640 | defer s.FD.DecRef() |
| 641 | |
| 642 | cur := s.Inode.state.Load() |
| 643 | switch cur.state { |
| 644 | case StatePassive, StateConnected, StateConnecting: |
| 645 | return nil, unix.EINVAL, nil |
| 646 | case StateListening: |
| 647 | if !cur.listening.active.Load() { |
| 648 | return nil, unix.EINVAL, nil // TODO: right errno? |
| 649 | } |
| 650 | case StateClosed: |
| 651 | return nil, unix.EBADF, nil |
| 652 | } |
| 653 | |
| 654 | ret, sa, err := unix.Accept4(s.FD.FD(), flags|unix.SOCK_CLOEXEC) |
| 655 | if err != nil { |
| 656 | var errno syscall.Errno |
| 657 | if !errors.As(err, &errno) { |
| 658 | return nil, 0, fmt.Errorf("failed to interpret accept error as errno: %w", err) |
| 659 | } |
| 660 | // If accept(2) fails, Linux does not put the socket in an error state. |
| 661 | return nil, errno, nil |
| 662 | } |
| 663 | |
| 664 | var addr netip.AddrPort |
| 665 | switch sa := sa.(type) { |
| 666 | case *unix.SockaddrInet4: |
| 667 | addr = netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)) |
| 668 | case *unix.SockaddrInet6: |
| 669 | addr = netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)) |
| 670 | } |
| 671 | if addr.Addr().Is4In6() { |
| 672 | addr = netip.AddrPortFrom(netip.AddrFrom4(addr.Addr().As4()), addr.Port()) |
| 673 | } |
| 674 | |
| 675 | ch := make(chan *proxy, 1) |
| 676 | if found, loaded := cur.listening.backlog.LoadOrStore(addr, ch); loaded { |
| 677 | ch = found.(chan *proxy) |
| 678 | cur.listening.backlog.Delete(addr) |
| 679 | } |
| 680 | |
| 681 | p := <-ch |
| 682 | if p.process.LocalAddr().String() != addr.String() { |
| 683 | panic(fmt.Sprintf("dialed process-side local does not match accepted connection: %s != %s", p.process.LocalAddr(), addr)) |
| 684 | } |
| 685 | slog.Debug("accepter dequeued accepted connection", "sock", s, "addr", addr) |
| 686 | |
| 687 | var stat unix.Stat_t |
| 688 | if err := unix.Fstat(ret, &stat); err != nil { |
| 689 | unix.Close(ret) |
| 690 | return nil, 0, fmt.Errorf("stat after channel receive: %w", err) |
| 691 | } |
| 692 | |
| 693 | fd := fd.NewFD(ret) |