#include <process.h>

#include <iostream>
using namespace std;

#define _WIN32_WINNT 0x400
#include <winsock2.h>
#pragma comment(lib,"ws2_32.lib")

class CMySocket
{
protected:
  void Release()
  {
    if (InterlockedDecrement(&m_nInUse))
      return;
    SOCKET h = m_sock;
    if (h!=(SOCKET)INVALID_HANDLE_VALUE && h==InterlockedCompareExchange((LONG*)&m_sock,(LONG)INVALID_HANDLE_VALUE,h))
    {
      cout << "closesocket()" << endl;
      closesocket(h);
    }
  }
public:
  SOCKET m_sock;
  LONG m_nInUse;
  LONG m_bClosed;

  class CKeeper
  {
    CMySocket& m_socket;
    bool m_bIncremented;
  public:
    CKeeper(CMySocket& socket);
    ~CKeeper();

    void Release()
    {
      if (m_bIncremented)
      {
        m_bIncremented = false;
        m_socket.Release();
      }
    }
  };

  CMySocket()
    :	m_nInUse(0),
      m_bClosed(1)		  
  {}

  void Connect(const char *name)
  {
    m_sock = socket(PF_INET,SOCK_STREAM,0);
    m_nInUse = 1;
    m_bClosed = 0;
    SOCKADDR_IN sockAddr;
    sockAddr.sin_family = AF_INET;
    hostent *he = gethostbyname(name);
    sockAddr.sin_addr.s_addr = *(long*)gethostbyname(name)->h_addr;
    sockAddr.sin_port = htons(80);    
    connect(m_sock,(SOCKADDR*)&sockAddr,sizeof sockAddr);
  }

  void Recv(char *buf, int size)
  {
    CKeeper hp(*this);
    int r = recv(m_sock,buf,size,0);
    if (r < 0)
      throw WSAGetLastError();
  }  
  
  void Close()
  {
    if (!InterlockedCompareExchange(&m_bClosed,1L,0L))
      Release();
  }

friend class CKeeper;
};

__declspec(thread) CMySocket::CKeeper *t_curSocketKeeper;

CMySocket::CKeeper::CKeeper(CMySocket& socket)
  :	m_socket(socket),
    m_bIncremented(false)
{
  t_curSocketKeeper = this;
  if (m_socket.m_bClosed)
    return;
  InterlockedIncrement(&m_socket.m_nInUse);
  if (!m_socket.m_bClosed)
    m_bIncremented = true;
  else
    m_socket.Release();
}

CMySocket::CKeeper::~CKeeper()
{
  t_curSocketKeeper = 0;
  Release();
}


CMySocket g_mySocket;

unsigned __stdcall SubThread(LPVOID pParam)
{
  Sleep(1000);
  char buf[100];
  try
  {
    g_mySocket.Recv(buf,sizeof buf);
  }
  catch (...)
  {}
  return 0;
}


VOID CALLBACK APCProc(ULONG_PTR dwParam)
{
  cout << "APC" << endl;
  if (t_curSocketKeeper)
    t_curSocketKeeper->Release();
}

void main()
{
  WSAData wsaData;
  WSAStartup(2,&wsaData);

  g_mySocket.Connect("www.microsoft.com");

  HANDLE threads[2] = { (HANDLE)_beginthreadex(0,0,SubThread,0,0,0),
                        (HANDLE)_beginthreadex(0,0,SubThread,0,0,0) };
  Sleep(3000);
  g_mySocket.Close();
  QueueUserAPC(APCProc,threads[0],0);
  QueueUserAPC(APCProc,threads[1],0);
  cout << "Waiting..." << endl;
  WaitForMultipleObjects(2,threads,TRUE,INFINITE);
  cout << "Closing complete" << endl;
}


