#include "./TcpStream.hpp"

using std::byte;
using std::pair;
using std::span;
using std::unexpected;
using std::chrono::duration_cast;
using std::chrono::floor;
using std::chrono::microseconds;
using std::chrono::seconds;

#include <unistd.h>

namespace nixnet {

result<TcpStream> TcpStream::connect(const SocketAddr& addr) {
  TcpStream result_sock;
  int fd = ::socket(addr.data()->sa_family, SOCK_STREAM, 0);
  if (fd < 0) {
    return unexpected(errno_t(errno));
  }
  int res = ::connect(fd, addr.data(), addr.data_size());
  if (res != 0) {
    ::close(fd);
    return unexpected(errno_t(errno));
  }
  result_sock.m_fd = fd;
  return result_sock;
}

TcpStream::TcpStream(TcpStream&& other) : m_fd(other.m_fd) {
  other.m_fd = -1;
}

TcpStream& TcpStream::operator=(TcpStream&& other) {
  if (&other != this) {
    if (m_fd >= 0) {
      ::close(m_fd);
    }
    m_fd = other.m_fd;
    other.m_fd = -1;
  }
  return *this;
}

TcpStream::~TcpStream() {
  if (m_fd >= 0) {
    ::close(m_fd);
  }
}

void TcpStream::close() {
  if (m_fd >= 0) {
    ::close(m_fd);
    m_fd = -1;
  }
}

result<size_t> TcpStream::send(const std::span<const std::byte> bytes) const {
  ssize_t res = ::send(m_fd, bytes.data(), bytes.size(), 0);
  if (res < 0) {
    return unexpected(errno_t(errno));
  }
  return static_cast<size_t>(res);
}

result<size_t> TcpStream::recv(std::span<std::byte> bytes) const {
  ssize_t res = ::recv(m_fd, bytes.data(), bytes.size(), 0);
  if (res < 0) {
    return unexpected(errno_t(errno));
  }
  return static_cast<size_t>(res);
}

static timeval ms_to_timeval(microseconds time) {
  timeval ret;
  if (time.count() < 0) {
    ret.tv_sec = 0;
    ret.tv_usec = 0;
    return ret;
  }

  seconds secs = floor<seconds>(time);
  ret.tv_sec = secs.count();
  ret.tv_usec = time.count() - duration_cast<microseconds>(secs).count();

  return ret;
}

errno_t TcpStream::set_read_timeout(microseconds time) {
  timeval tv = ms_to_timeval(time);
  int res = ::setsockopt(m_fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
  if (res != 0) {
    return errno_t(errno);
  }
  return errno_t::SUCCESS;
}

errno_t TcpStream::set_write_timeout(microseconds time) {
  timeval tv = ms_to_timeval(time);
  int res = ::setsockopt(m_fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv));
  if (res != 0) {
    return errno_t(errno);
  }
  return errno_t::SUCCESS;
}

static microseconds timeval_to_ms(timeval tv) {
  return microseconds(tv.tv_usec) +
         duration_cast<microseconds>(seconds(tv.tv_sec));
}

result<microseconds> TcpStream::read_timeout() const {
  timeval tv;
  socklen_t tv_size = sizeof(tv);
  int res = ::getsockopt(m_fd, SOL_SOCKET, SO_RCVTIMEO, &tv, &tv_size);
  if (res != 0) {
    return unexpected(errno_t(errno));
  }
  return timeval_to_ms(tv);
}

result<microseconds> TcpStream::write_timeout() const {
  timeval tv;
  socklen_t tv_size = sizeof(tv);
  int res = ::getsockopt(m_fd, SOL_SOCKET, SO_SNDTIMEO, &tv, &tv_size);
  if (res != 0) {
    return unexpected(errno_t(errno));
  }
  return timeval_to_ms(tv);
}

int TcpStream::raw_fd() {
  return m_fd;
}

result<SocketAddr> TcpStream::local_addr() const {
  sockaddr_storage ss;
  socklen_t ss_size = sizeof(ss);
  int temp = ::getsockname(m_fd, reinterpret_cast<sockaddr*>(&ss), &ss_size);
  if (temp < 0) {
    return unexpected(errno_t(errno));
  }

  return SocketAddr::from_csocket(reinterpret_cast<sockaddr*>(&ss));
};

result<SocketAddr> TcpStream::peer_addr() const {
  sockaddr_storage ss;
  socklen_t ss_size = sizeof(ss);
  int temp = ::getpeername(m_fd, reinterpret_cast<sockaddr*>(&ss), &ss_size);
  if (temp < 0) {
    return unexpected(errno_t(errno));
  }

  return SocketAddr::from_csocket(reinterpret_cast<sockaddr*>(&ss));
};

};  // namespace nixnet
