Skip to content

Commit

Permalink
fix: remove indirection for protocols::tls::TlsStream.tls
Browse files Browse the repository at this point in the history
  • Loading branch information
hargut committed Sep 6, 2024
1 parent 022252e commit 97a3df7
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 37 deletions.
17 changes: 4 additions & 13 deletions pingora-core/src/protocols/tls/boringssl_openssl/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@

//! BoringSSL & OpenSSL TLS stream specific implementation

use async_trait::async_trait;
use log::warn;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};

use pingora_error::{Error, ErrorType::*, OrErr, Result};

use crate::listeners::ALPN;
use crate::protocols::digest::{GetSocketDigest, SocketDigest, TimingDigest};
use crate::protocols::raw_connect::ProxyDigest;
use crate::protocols::tls::InnerTlsStream;
use crate::protocols::tls::SslDigest;
use crate::protocols::{GetProxyDigest, GetTimingDigest};
use crate::tls::error::ErrorStack;
Expand Down Expand Up @@ -58,10 +55,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> InnerStream<T> {
}
}

#[async_trait]
impl<T: AsyncRead + AsyncWrite + Unpin + Send> InnerTlsStream for InnerStream<T> {
impl<T: AsyncRead + AsyncWrite + Unpin + Send> InnerStream<T> {
/// Connect to the remote TLS server as a client
async fn connect(&mut self) -> Result<()> {
pub(crate) async fn connect(&mut self) -> Result<()> {
Self::clear_error();
match Pin::new(&mut self.0).connect().await {
Ok(_) => Ok(()),
Expand All @@ -70,22 +66,17 @@ impl<T: AsyncRead + AsyncWrite + Unpin + Send> InnerTlsStream for InnerStream<T>
}

/// Finish the TLS handshake from client as a server
async fn accept(&mut self) -> Result<()> {
pub(crate) async fn accept(&mut self) -> Result<()> {
Self::clear_error();
match Pin::new(&mut self.0).accept().await {
Ok(_) => Ok(()),
Err(err) => self.transform_ssl_error(err),
}
}

fn digest(&mut self) -> Option<Arc<SslDigest>> {
pub(crate) fn digest(&mut self) -> Option<Arc<SslDigest>> {
Some(Arc::new(SslDigest::from_ssl(self.0.ssl())))
}

fn selected_alpn_proto(&mut self) -> Option<ALPN> {
let ssl = self.0.ssl();
ALPN::from_wire_selected(ssl.selected_alpn_protocol()?)
}
}

impl<T: AsyncRead + AsyncWrite + Unpin> InnerStream<T> {
Expand Down
6 changes: 3 additions & 3 deletions pingora-core/src/protocols/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ pub struct TlsStream<T> {
timing: TimingDigest,
}

// NOTE: keeping trait for documentation purpose
// switched to direct implementations to eliminate redirections in within the call-graph
// the below trait is required for InnerStream<T> to be implemented
#[async_trait]
pub trait InnerTlsStream {
async fn connect(&mut self) -> Result<()>;
async fn accept(&mut self) -> Result<()>;

/// Return the [`ssl::SslDigest`] for logging
fn digest(&mut self) -> Option<Arc<SslDigest>>;

/// Return selected ALPN if any
fn selected_alpn_proto(&mut self) -> Option<ALPN>;
}

/// The protocol for Application-Layer Protocol Negotiation
Expand Down
25 changes: 4 additions & 21 deletions pingora-core/src/protocols/tls/rustls/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use async_trait::async_trait;
use pingora_error::ErrorType::{AcceptError, ConnectError, TLSHandshakeFailure};
use pingora_error::{Error, ImmutStr, OrErr, Result};
use pingora_rustls::NoDebug;
Expand All @@ -24,10 +23,8 @@ use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};

use crate::listeners::tls::Acceptor;
use crate::listeners::ALPN;
use crate::protocols::digest::{GetSocketDigest, SocketDigest, TimingDigest};
use crate::protocols::raw_connect::ProxyDigest;
use crate::protocols::tls::InnerTlsStream;
use crate::protocols::tls::SslDigest;
use crate::protocols::{GetProxyDigest, GetTimingDigest};

Expand Down Expand Up @@ -67,11 +64,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin> InnerStream<T> {
})
}
}

#[async_trait]
impl<T: AsyncRead + AsyncWrite + Unpin + Send> InnerTlsStream for InnerStream<T> {
impl<T: AsyncRead + AsyncWrite + Unpin + Send> InnerStream<T> {
/// Connect to the remote TLS server as a client
async fn connect(&mut self) -> Result<()> {
pub(crate) async fn connect(&mut self) -> Result<()> {
let connect = &mut (*self.connect);

if let Some(ref mut connect) = connect {
Expand All @@ -92,7 +87,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin + Send> InnerTlsStream for InnerStream<T>

/// Finish the TLS handshake from client as a server
/// no-op implementation within Rustls, handshake is performed during creation of stream.
async fn accept(&mut self) -> Result<()> {
pub(crate) async fn accept(&mut self) -> Result<()> {
let accept = &mut (*self.accept);

if let Some(ref mut accept) = accept {
Expand All @@ -111,21 +106,9 @@ impl<T: AsyncRead + AsyncWrite + Unpin + Send> InnerTlsStream for InnerStream<T>
}
}

fn digest(&mut self) -> Option<Arc<SslDigest>> {
pub(crate) fn digest(&mut self) -> Option<Arc<SslDigest>> {
Some(Arc::new(SslDigest::from_stream(&self.stream)))
}

fn selected_alpn_proto(&mut self) -> Option<ALPN> {
if let Some(stream) = self.stream.as_ref() {
let proto = stream.get_ref().1.alpn_protocol();
match proto {
None => None,
Some(raw) => ALPN::from_wire_selected(raw),
}
} else {
None
}
}
}

impl<S> GetSocketDigest for InnerStream<S>
Expand Down

0 comments on commit 97a3df7

Please sign in to comment.