|
2 | 2 |
|
3 | 3 | #[cfg(unix)] |
4 | 4 | use std::path::Path; |
5 | | -use std::{io, net::SocketAddr, sync::Arc, time::Duration}; |
| 5 | +use std::{ |
| 6 | + io::{self, ErrorKind}, |
| 7 | + net::SocketAddr, |
| 8 | + sync::Arc, |
| 9 | + time::Duration, |
| 10 | +}; |
6 | 11 |
|
7 | 12 | use byteorder::{BigEndian, ByteOrder}; |
8 | 13 | use bytes::{BufMut, BytesMut}; |
@@ -145,6 +150,92 @@ impl DnsClient { |
145 | 150 | } |
146 | 151 | } |
147 | 152 | } |
| 153 | + |
| 154 | + /// Check if the underlying connection is still connecting |
| 155 | + /// |
| 156 | + /// This will only work for TCP and UNIX Stream connections. |
| 157 | + /// UDP clients will always return `true`. |
| 158 | + pub async fn check_connected(&mut self) -> bool { |
| 159 | + #[cfg(unix)] |
| 160 | + fn check_peekable<F: std::os::unix::io::AsRawFd>(fd: &mut F) -> bool { |
| 161 | + let fd = fd.as_raw_fd(); |
| 162 | + |
| 163 | + unsafe { |
| 164 | + let mut peek_buf = [0u8; 1]; |
| 165 | + |
| 166 | + let ret = libc::recv( |
| 167 | + fd, |
| 168 | + peek_buf.as_mut_ptr() as *mut libc::c_void, |
| 169 | + peek_buf.len(), |
| 170 | + libc::MSG_PEEK | libc::MSG_DONTWAIT, |
| 171 | + ); |
| 172 | + |
| 173 | + if ret == 0 { |
| 174 | + // EOF, connection lost |
| 175 | + false |
| 176 | + } else if ret > 0 { |
| 177 | + // Data in buffer |
| 178 | + true |
| 179 | + } else { |
| 180 | + let err = io::Error::last_os_error(); |
| 181 | + if err.kind() == ErrorKind::WouldBlock { |
| 182 | + // EAGAIN, EWOULDBLOCK |
| 183 | + // Still connected. |
| 184 | + true |
| 185 | + } else { |
| 186 | + false |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + #[cfg(windows)] |
| 193 | + fn check_peekable<F: std::os::windows::io::AsRawSocket>(s: &mut F) -> bool { |
| 194 | + use winapi::{ |
| 195 | + ctypes::{c_char, c_int}, |
| 196 | + um::winsock2::{recv, MSG_PEEK}, |
| 197 | + }; |
| 198 | + |
| 199 | + let sock = s.as_raw_socket(); |
| 200 | + |
| 201 | + unsafe { |
| 202 | + let mut peek_buf = [0u8; 1]; |
| 203 | + |
| 204 | + let ret = recv( |
| 205 | + sock, |
| 206 | + peek_buf.as_mut_ptr() as *mut c_char, |
| 207 | + peek_buf.len() as c_int, |
| 208 | + MSG_PEEK, |
| 209 | + ); |
| 210 | + |
| 211 | + if ret == 0 { |
| 212 | + // EOF, connection lost |
| 213 | + false |
| 214 | + } else if ret > 0 { |
| 215 | + // Data in buffer |
| 216 | + true |
| 217 | + } else { |
| 218 | + let err = io::Error::last_os_error(); |
| 219 | + if err.kind() == ErrorKind::WouldBlock { |
| 220 | + // I have to trust the `s` have already set to non-blocking mode |
| 221 | + // Becuase windows doesn't have MSG_DONTWAIT |
| 222 | + true |
| 223 | + } else { |
| 224 | + false |
| 225 | + } |
| 226 | + } |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | + match *self { |
| 231 | + DnsClient::TcpLocal { ref mut stream } => check_peekable(stream), |
| 232 | + DnsClient::UdpLocal { .. } => true, |
| 233 | + #[cfg(unix)] |
| 234 | + DnsClient::UnixStream { ref mut stream } => check_peekable(stream), |
| 235 | + DnsClient::TcpRemote { ref mut stream } => check_peekable(stream.get_mut().get_mut()), |
| 236 | + DnsClient::UdpRemote { .. } => true, |
| 237 | + } |
| 238 | + } |
148 | 239 | } |
149 | 240 |
|
150 | 241 | pub async fn stream_query<S>(stream: &mut S, r: &Message) -> Result<Message, ProtoError> |
|
0 commit comments