|
const std = @import("../../std.zig"); const tls = std.crypto.tls; const Client = @This(); const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; const max_ciphertext_len = tls.max_ciphertext_len; const hkdfExpandLabel = tls.hkdfExpandLabel; const int2 = tls.int2; const int3 = tls.int3; const array = tls.array; const enum_array = tls.enum_array; read_seq: u64, write_seq: u64, partial_cleartext_idx: u15, partial_ciphertext_idx: u15, partial_ciphertext_end: u15, received_close_notify: bool, allow_truncation_attacks: bool = false, application_cipher: tls.ApplicationCipher, partially_read_buffer: [tls.max_ciphertext_record_len]u8, |
StreamInterface The starting index of cleartext bytes inside |
pub const StreamInterface = struct { pub const ReadError = error{}; |
readv() Can be any error set. Returns the number of bytes read. The number read may be less than the buffer space provided. End-of-stream is indicated by a return value of 0. |
pub fn readv(this: @This(), iovecs: []std.os.iovec) ReadError!usize { _ = .{ this, iovecs }; @panic("unimplemented"); } pub const WriteError = error{}; |
writev()Can be any error set. Returns the number of bytes read, which may be less than the buffer space provided. A short read does not indicate end-of-stream. |
pub fn writev(this: @This(), iovecs: []const std.os.iovec_const) WriteError!usize { _ = .{ this, iovecs }; @panic("unimplemented"); } |
writevAll() Returns the number of bytes read, which may be less than the buffer space provided, indicating end-of-stream. The |
pub fn writevAll(this: @This(), iovecs: []std.os.iovec_const) WriteError!usize { // This can be implemented in terms of writev, or specialized if desired. _ = .{ this, iovecs }; @panic("unimplemented"); } }; |
InitError() |
pub fn InitError(comptime Stream: type) type { return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ InsufficientEntropy, DiskQuota, LockViolation, NotOpenForWriting, TlsUnexpectedMessage, TlsIllegalParameter, TlsDecryptFailure, TlsRecordOverflow, TlsBadRecordMac, CertificateFieldHasInvalidLength, CertificateHostMismatch, CertificatePublicKeyInvalid, CertificateExpired, CertificateFieldHasWrongDataType, CertificateIssuerMismatch, CertificateNotYetValid, CertificateSignatureAlgorithmMismatch, CertificateSignatureAlgorithmUnsupported, CertificateSignatureInvalid, CertificateSignatureInvalidLength, CertificateSignatureNamedCurveUnsupported, CertificateSignatureUnsupportedBitCount, TlsCertificateNotVerified, TlsBadSignatureScheme, TlsBadRsaSignatureBitCount, InvalidEncoding, IdentityElement, SignatureVerificationFailed, TlsDecryptError, TlsConnectionTruncated, TlsDecodeError, UnsupportedCertificateVersion, CertificateTimeInvalid, CertificateHasUnrecognizedObjectId, CertificateHasInvalidBitString, MessageTooLong, NegativeIntoUnsigned, TargetTooSmall, BufferTooSmall, InvalidSignature, NotSquare, NonCanonical, }; } |
init() Initiates a TLS handshake and establishes a TLSv1.3 session with |
pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client { const host_len: u16 = @intCast(host.len); var random_buffer: [128]u8 = undefined; crypto.random.bytes(&random_buffer); const hello_rand = random_buffer[0..32].*; const legacy_session_id = random_buffer[32..64].*; const x25519_kp_seed = random_buffer[64..96].*; const secp256r1_kp_seed = random_buffer[96..128].*; const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) { // Only possible to happen if the private key is all zeroes. error.IdentityElement => return error.InsufficientEntropy, }; const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) { // Only possible to happen if the private key is all zeroes. error.IdentityElement => return error.InsufficientEntropy, }; const kyber768_kp = crypto.kem.kyber_d00.Kyber768.KeyPair.create(null) catch {}; const extensions_payload = tls.extension(.supported_versions, [_]u8{ 0x02, // byte length of supported versions 0x03, 0x04, // TLS 1.3 }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, .ecdsa_secp521r1_sha512, .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, .rsa_pkcs1_sha256, .rsa_pkcs1_sha384, .rsa_pkcs1_sha512, .ed25519, })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ .x25519_kyber768d00, .secp256r1, .x25519, })) ++ tls.extension( .key_share, array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ array(1, x25519_kp.public_key) ++ int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ array(1, secp256r1_kp.public_key.toUncompressedSec1()) ++ int2(@intFromEnum(tls.NamedGroup.x25519_kyber768d00)) ++ array(1, x25519_kp.public_key ++ kyber768_kp.public_key.toBytes())), ) ++ int2(@intFromEnum(tls.ExtensionType.server_name)) ++ int2(host_len + 5) ++ // byte length of this extension payload int2(host_len + 3) ++ // server_name_list byte count [1]u8{0x00} ++ // name_type int2(host_len); const extensions_header = int2(@intCast(extensions_payload.len + host_len)) ++ extensions_payload; const legacy_compression_methods = 0x0100; const client_hello = int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ hello_rand ++ [1]u8{32} ++ legacy_session_id ++ cipher_suites ++ int2(legacy_compression_methods) ++ extensions_header; const out_handshake = [_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++ int3(@intCast(client_hello.len + host_len)) ++ client_hello; const plaintext_header = [_]u8{ @intFromEnum(tls.ContentType.handshake), 0x03, 0x01, // legacy_record_version } ++ int2(@intCast(out_handshake.len + host_len)) ++ out_handshake; { var iovecs = [_]std.os.iovec_const{ .{ .iov_base = &plaintext_header, .iov_len = plaintext_header.len, }, .{ .iov_base = host.ptr, .iov_len = host.len, }, }; try stream.writevAll(&iovecs); } const client_hello_bytes1 = plaintext_header[5..]; var handshake_cipher: tls.HandshakeCipher = undefined; var handshake_buffer: [8000]u8 = undefined; var d: tls.Decoder = .{ .buf = &handshake_buffer }; { try d.readAtLeastOurAmt(stream, tls.record_header_len); const ct = d.decode(tls.ContentType); d.skip(2); // legacy_record_version const record_len = d.decode(u16); try d.readAtLeast(stream, record_len); const server_hello_fragment = d.buf[d.idx..][0..record_len]; var ptd = try d.sub(record_len); switch (ct) { .alert => { try ptd.ensure(2); const level = ptd.decode(tls.AlertLevel); const desc = ptd.decode(tls.AlertDescription); _ = level; // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake try desc.toError(); // TODO: handle server-side closures return error.TlsUnexpectedMessage; }, .handshake => { try ptd.ensure(4); const handshake_type = ptd.decode(tls.HandshakeType); if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; const length = ptd.decode(u24); var hsd = try ptd.sub(length); try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); const legacy_version = hsd.decode(u16); const random = hsd.array(32); if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { // This is a HelloRetryRequest message. This client implementation // does not expect to get one. return error.TlsUnexpectedMessage; } const legacy_session_id_echo_len = hsd.decode(u8); if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; const legacy_session_id_echo = hsd.array(32); if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter; const cipher_suite_tag = hsd.decode(tls.CipherSuite); hsd.skip(1); // legacy_compression_method const extensions_size = hsd.decode(u16); var all_extd = try hsd.sub(extensions_size); var supported_version: u16 = 0; var shared_key: []const u8 = undefined; var have_shared_key = false; while (!all_extd.eof()) { try all_extd.ensure(2 + 2); const et = all_extd.decode(tls.ExtensionType); const ext_size = all_extd.decode(u16); var extd = try all_extd.sub(ext_size); switch (et) { .supported_versions => { if (supported_version != 0) return error.TlsIllegalParameter; try extd.ensure(2); supported_version = extd.decode(u16); }, .key_share => { if (have_shared_key) return error.TlsIllegalParameter; have_shared_key = true; try extd.ensure(4); const named_group = extd.decode(tls.NamedGroup); const key_size = extd.decode(u16); try extd.ensure(key_size); switch (named_group) { .x25519_kyber768d00 => { const xksl = crypto.dh.X25519.public_length; const hksl = xksl + crypto.kem.kyber_d00.Kyber768.ciphertext_length; if (key_size != hksl) return error.TlsIllegalParameter; const server_ks = extd.array(hksl); shared_key = &((crypto.dh.X25519.scalarmult( x25519_kp.secret_key, server_ks[0..xksl].*, ) catch return error.TlsDecryptFailure) ++ (kyber768_kp.secret_key.decaps( server_ks[xksl..hksl], ) catch return error.TlsDecryptFailure)); }, .x25519 => { const ksl = crypto.dh.X25519.public_length; if (key_size != ksl) return error.TlsIllegalParameter; const server_pub_key = extd.array(ksl); shared_key = &(crypto.dh.X25519.scalarmult( x25519_kp.secret_key, server_pub_key.*, ) catch return error.TlsDecryptFailure); }, .secp256r1 => { const server_pub_key = extd.slice(key_size); const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; const pk = PublicKey.fromSec1(server_pub_key) catch { return error.TlsDecryptFailure; }; const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .big) catch { return error.TlsDecryptFailure; }; shared_key = &mul.affineCoordinates().x.toBytes(.big); }, else => { return error.TlsIllegalParameter; }, } }, else => {}, } } if (!have_shared_key) return error.TlsIllegalParameter; const tls_version = if (supported_version == 0) legacy_version else supported_version; if (tls_version != @intFromEnum(tls.ProtocolVersion.tls_1_3)) return error.TlsIllegalParameter; switch (cipher_suite_tag) { inline .AES_128_GCM_SHA256, .AES_256_GCM_SHA384, .CHACHA20_POLY1305_SHA256, .AEGIS_256_SHA384, .AEGIS_128L_SHA256, => |tag| { const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag)); handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{ .handshake_secret = undefined, .master_secret = undefined, .client_handshake_key = undefined, .server_handshake_key = undefined, .client_finished_key = undefined, .server_finished_key = undefined, .client_handshake_iv = undefined, .server_handshake_iv = undefined, .transcript_hash = P.Hash.init(.{}), }); const p = &@field(handshake_cipher, @tagName(tag)); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 p.transcript_hash.update(host); // Client Hello part 2 p.transcript_hash.update(server_hello_fragment); const hello_hash = p.transcript_hash.peek(); const zeroes = [1]u8{0} ** P.Hash.digest_length; const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); const empty_hash = tls.emptyHash(P.Hash); const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, shared_key); const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); }, else => { return error.TlsIllegalParameter; }, } }, else => return error.TlsUnexpectedMessage, } } // This is used for two purposes: // * Detect whether a certificate is the first one presented, in which case // we need to verify the host name. // * Flip back and forth between the two cleartext buffers in order to keep // the previous certificate in memory so that it can be verified by the // next one. var cert_index: usize = 0; var read_seq: u64 = 0; var prev_cert: Certificate.Parsed = undefined; // Set to true once a trust chain has been established from the first // certificate to a root CA. const HandshakeState = enum { encrypted_extensions, certificate, trust_chain_established, finished, }; var handshake_state: HandshakeState = .encrypted_extensions; var cleartext_bufs: [2][8000]u8 = undefined; var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; var main_cert_pub_key_buf: [600]u8 = undefined; var main_cert_pub_key_len: u16 = undefined; const now_sec = std.time.timestamp(); while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); const record_header = d.buf[d.idx..][0..5]; const ct = d.decode(tls.ContentType); d.skip(2); // legacy_version const record_len = d.decode(u16); try d.readAtLeast(stream, record_len); var record_decoder = try d.sub(record_len); switch (ct) { .change_cipher_spec => { try record_decoder.ensure(1); if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; }, .application_data => { const cleartext_buf = &cleartext_bufs[cert_index % 2]; const cleartext = switch (handshake_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); const ciphertext_len = record_len - P.AEAD.tag_length; try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); const ciphertext = record_decoder.slice(ciphertext_len); if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; const cleartext = cleartext_buf[0..ciphertext.len]; const auth_tag = record_decoder.array(P.AEAD.tag_length).*; const V = @Vector(P.AEAD.nonce_length, u8); const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @as([8]u8, @bitCast(big(read_seq))); read_seq += 1; const nonce = @as(V, p.server_handshake_iv) ^ operand; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch return error.TlsBadRecordMac; break :c cleartext; }, }; const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); if (inner_ct != .handshake) return error.TlsUnexpectedMessage; var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); while (true) { try ctd.ensure(4); const handshake_type = ctd.decode(tls.HandshakeType); const handshake_len = ctd.decode(u24); var hsd = try ctd.sub(handshake_len); const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; switch (handshake_type) { .encrypted_extensions => { if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; handshake_state = .certificate; switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } try hsd.ensure(2); const total_ext_size = hsd.decode(u16); var all_extd = try hsd.sub(total_ext_size); while (!all_extd.eof()) { try all_extd.ensure(4); const et = all_extd.decode(tls.ExtensionType); const ext_size = all_extd.decode(u16); var extd = try all_extd.sub(ext_size); _ = extd; switch (et) { .server_name => {}, else => {}, } } }, .certificate => cert: { switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } switch (handshake_state) { .certificate => {}, .trust_chain_established => break :cert, else => return error.TlsUnexpectedMessage, } try hsd.ensure(1 + 4); const cert_req_ctx_len = hsd.decode(u8); if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; const certs_size = hsd.decode(u24); var certs_decoder = try hsd.sub(certs_size); while (!certs_decoder.eof()) { try certs_decoder.ensure(3); const cert_size = certs_decoder.decode(u24); var certd = try certs_decoder.sub(cert_size); const subject_cert: Certificate = .{ .buffer = certd.buf, .index = @intCast(certd.idx), }; const subject = try subject_cert.parse(); if (cert_index == 0) { // Verify the host on the first certificate. try subject.verifyHostName(host); // Keep track of the public key for the // certificate_verify message later. main_cert_pub_key_algo = subject.pub_key_algo; const pub_key = subject.pubKey(); if (pub_key.len > main_cert_pub_key_buf.len) return error.CertificatePublicKeyInvalid; @memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key); main_cert_pub_key_len = @intCast(pub_key.len); } else { try prev_cert.verify(subject, now_sec); } if (ca_bundle.verify(subject, now_sec)) |_| { handshake_state = .trust_chain_established; break :cert; } else |err| switch (err) { error.CertificateIssuerNotFound => {}, else => |e| return e, } prev_cert = subject; cert_index += 1; try certs_decoder.ensure(2); const total_ext_size = certs_decoder.decode(u16); var all_extd = try certs_decoder.sub(total_ext_size); _ = all_extd; } }, .certificate_verify => { switch (handshake_state) { .trust_chain_established => handshake_state = .finished, .certificate => return error.TlsCertificateNotVerified, else => return error.TlsUnexpectedMessage, } try hsd.ensure(4); const scheme = hsd.decode(tls.SignatureScheme); const sig_len = hsd.decode(u16); try hsd.ensure(sig_len); const encoded_sig = hsd.slice(sig_len); const max_digest_len = 64; var verify_buffer = ([1]u8{0x20} ** 64) ++ "TLS 1.3, server CertificateVerify\x00".* ++ @as([max_digest_len]u8, undefined); const verify_bytes = switch (handshake_cipher) { inline else => |*p| v: { const transcript_digest = p.transcript_hash.peek(); verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; p.transcript_hash.update(wrapped_handshake); break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; }, }; const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; switch (scheme) { inline .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, => |comptime_scheme| { if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) return error.TlsBadSignatureScheme; const Ecdsa = SchemeEcdsa(comptime_scheme); const sig = try Ecdsa.Signature.fromDer(encoded_sig); const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); try sig.verify(verify_bytes, key); }, inline .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, => |comptime_scheme| { if (main_cert_pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; const Hash = SchemeHash(comptime_scheme); const rsa = Certificate.rsa; const components = try rsa.PublicKey.parseDer(main_cert_pub_key); const exponent = components.exponent; const modulus = components.modulus; switch (modulus.len) { inline 128, 256, 512 => |modulus_len| { const key = try rsa.PublicKey.fromBytes(exponent, modulus); const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); }, else => { return error.TlsBadRsaSignatureBitCount; }, } }, else => { return error.TlsBadSignatureScheme; }, } }, .finished => { if (handshake_state != .finished) return error.TlsUnexpectedMessage; // This message is to trick buggy proxies into behaving correctly. const client_change_cipher_spec_msg = [_]u8{ @intFromEnum(tls.ContentType.change_cipher_spec), 0x03, 0x03, // legacy protocol version 0x00, 0x01, // length 0x01, }; const app_cipher = switch (handshake_cipher) { inline else => |*p, tag| c: { const P = @TypeOf(p.*); const finished_digest = p.transcript_hash.peek(); p.transcript_hash.update(wrapped_handshake); const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); if (!mem.eql(u8, &expected_server_verify_data, handshake)) return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); const out_cleartext = [_]u8{ @intFromEnum(tls.HandshakeType.finished), 0, 0, verify_data.len, // length } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; const wrapped_len = out_cleartext.len + P.AEAD.tag_length; var finished_msg = [_]u8{ @intFromEnum(tls.ContentType.application_data), 0x03, 0x03, // legacy protocol version 0, wrapped_len, // byte length of encrypted record } ++ @as([wrapped_len]u8, undefined); const ad = finished_msg[0..5]; const ciphertext = finished_msg[5..][0..out_cleartext.len]; const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; const nonce = p.client_handshake_iv; P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); const both_msgs = client_change_cipher_spec_msg ++ finished_msg; try stream.writeAll(&both_msgs); const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .client_secret = client_secret, .server_secret = server_secret, .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), }); }, }; const leftover = d.rest(); var client: Client = .{ .read_seq = 0, .write_seq = 0, .partial_cleartext_idx = 0, .partial_ciphertext_idx = 0, .partial_ciphertext_end = @intCast(leftover.len), .received_close_notify = false, .application_cipher = app_cipher, .partially_read_buffer = undefined, }; @memcpy(client.partially_read_buffer[0..leftover.len], leftover); return client; }, else => { return error.TlsUnexpectedMessage; }, } if (ctd.eof()) break; } }, else => { return error.TlsUnexpectedMessage; }, } } } |
write() In this state we expect only an encrypted_extensions message. In this state we expect certificate messages. In this state we expect certificate or certificate_verify messages. certificate messages are ignored since the trust chain is already established. In this state, we expect only the finished message. Sends TLS-encrypted data to |
pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { return writeEnd(c, stream, bytes, false); } |
writeAll() Sends TLS-encrypted data to |
pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { var index: usize = 0; while (index < bytes.len) { index += try c.write(stream, bytes[index..]); } } |
writeAllEnd() Sends TLS-encrypted data to |
pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { var index: usize = 0; while (index < bytes.len) { index += try c.writeEnd(stream, bytes[index..], end); } } |
writeEnd() Sends TLS-encrypted data to |
pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; var iovecs_buf: [6]std.os.iovec_const = undefined; var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); if (end) { prepared.iovec_end += prepareCiphertextRecord( c, iovecs_buf[prepared.iovec_end..], ciphertext_buf[prepared.ciphertext_end..], &tls.close_notify_alert, .alert, ).iovec_end; } const iovec_end = prepared.iovec_end; const overhead_len = prepared.overhead_len; // Ideally we would call writev exactly once here, however, we must ensure // that we don't return with a record partially written. var i: usize = 0; var total_amt: usize = 0; while (true) { var amt = try stream.writev(iovecs_buf[i..iovec_end]); while (amt >= iovecs_buf[i].iov_len) { const encrypted_amt = iovecs_buf[i].iov_len; total_amt += encrypted_amt - overhead_len; amt -= encrypted_amt; i += 1; // Rely on the property that iovecs delineate records, meaning that // if amt equals zero here, we have fortunately found ourselves // with a short read that aligns at the record boundary. if (i >= iovec_end) return total_amt; // We also cannot return on a vector boundary if the final close_notify is // not sent; otherwise the caller would not know to retry the call. if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; } iovecs_buf[i].iov_base += amt; iovecs_buf[i].iov_len -= amt; } } fn prepareCiphertextRecord( c: *Client, iovecs: []std.os.iovec_const, ciphertext_buf: []u8, bytes: []const u8, inner_content_type: tls.ContentType, ) struct { iovec_end: usize, ciphertext_end: usize, overhead_len: usize, } { // Due to the trailing inner content type byte in the ciphertext, we need // an additional buffer for storing the cleartext into before encrypting. var cleartext_buf: [max_ciphertext_len]u8 = undefined; var ciphertext_end: usize = 0; var iovec_end: usize = 0; var bytes_i: usize = 0; switch (c.application_cipher) { inline else => |*p| { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const encrypted_content_len: u16 = @intCast(@min( @min(bytes.len - bytes_i, max_ciphertext_len - 1), ciphertext_buf.len - close_notify_alert_reserved - overhead_len - ciphertext_end, )); if (encrypted_content_len == 0) return .{ .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, .overhead_len = overhead_len, }; @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); bytes_i += encrypted_content_len; const ciphertext_len = encrypted_content_len + 1; const cleartext = cleartext_buf[0..ciphertext_len]; const record_start = ciphertext_end; const ad = ciphertext_buf[ciphertext_end..][0..5]; ad.* = [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ int2(ciphertext_len + P.AEAD.tag_length); ciphertext_end += ad.len; const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; ciphertext_end += ciphertext_len; const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; ciphertext_end += auth_tag.len; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); c.write_seq += 1; // TODO send key_update on overflow const nonce = @as(V, p.client_iv) ^ operand; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); const record = ciphertext_buf[record_start..ciphertext_end]; iovecs[iovec_end] = .{ .iov_base = record.ptr, .iov_len = record.len, }; iovec_end += 1; } }, } } |
eof()How many bytes are taken up by overhead per record. |
pub fn eof(c: Client) bool { return c.received_close_notify and c.partial_cleartext_idx >= c.partial_ciphertext_idx and c.partial_ciphertext_idx >= c.partial_ciphertext_end; } |
readAtLeast() Receives TLS-encrypted data from |
pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { var iovecs = [1]std.os.iovec{.{ .iov_base = buffer.ptr, .iov_len = buffer.len }}; return readvAtLeast(c, stream, &iovecs, len); } |
read() Receives TLS-encrypted data from |
pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { return readAtLeast(c, stream, buffer, 1); } |
readAll() Receives TLS-encrypted data from |
pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { return readAtLeast(c, stream, buffer, buffer.len); } |
readv() Receives TLS-encrypted data from |
pub fn readv(c: *Client, stream: anytype, iovecs: []std.os.iovec) !usize { return readvAtLeast(c, stream, iovecs, 1); } |
readvAtLeast() Receives TLS-encrypted data from |
pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.os.iovec, len: usize) !usize { if (c.eof()) return 0; var off_i: usize = 0; var vec_i: usize = 0; while (true) { var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); off_i += amt; if (c.eof() or off_i >= len) return off_i; while (amt >= iovecs[vec_i].iov_len) { amt -= iovecs[vec_i].iov_len; vec_i += 1; } iovecs[vec_i].iov_base += amt; iovecs[vec_i].iov_len -= amt; } } |
readvAdvanced() Receives TLS-encrypted data from |
pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.os.iovec) !usize { var vp: VecPut = .{ .iovecs = iovecs }; // Give away the buffered cleartext we have, if any. const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; if (partial_cleartext.len > 0) { const amt: u15 = @intCast(vp.put(partial_cleartext)); c.partial_cleartext_idx += amt; if (c.partial_cleartext_idx == c.partial_ciphertext_idx and c.partial_ciphertext_end == c.partial_ciphertext_idx) { // The buffer is now empty. c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = 0; c.partial_ciphertext_end = 0; } if (c.received_close_notify) { c.partial_ciphertext_end = 0; assert(vp.total == amt); return amt; } else if (amt > 0) { // We don't need more data, so don't call read. assert(vp.total == amt); return amt; } } assert(!c.received_close_notify); // Ideally, this buffer would never be used. It is needed when `iovecs` are // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; // How many bytes left in the user's buffer. const free_size = vp.freeSize(); // The amount of the user's buffer that we need to repurpose for storing // ciphertext. The end of the buffer will be used for such purposes. const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; // The amount of the user's buffer that will be used to give cleartext. The // beginning of the buffer will be used for such purposes. const cleartext_buf_len = free_size - ciphertext_buf_len; // Recoup `partially_read_buffer space`. This is necessary because it is assumed // below that `frag0` is big enough to hold at least one record. limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); c.partial_ciphertext_end -= c.partial_ciphertext_idx; c.partial_ciphertext_idx = 0; c.partial_cleartext_idx = 0; const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; var ask_iovecs_buf: [2]std.os.iovec = .{ .{ .iov_base = first_iov.ptr, .iov_len = first_iov.len, }, .{ .iov_base = &in_stack_buffer, .iov_len = in_stack_buffer.len, }, }; // Cleartext capacity of output buffer, in records. Minimum one full record. const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); const actual_read_len = try stream.readv(ask_iovecs); if (actual_read_len == 0) { // This is either a truncation attack, a bug in the server, or an // intentional omission of the close_notify message due to truncation // detection handled above the TLS layer. if (c.allow_truncation_attacks) { c.received_close_notify = true; } else { return error.TlsConnectionTruncated; } } // There might be more bytes inside `in_stack_buffer` that need to be processed, // but at least frag0 will have one complete ciphertext record. const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; // We need to decipher frag0 and frag1 but there may be a ciphertext record // straddling the boundary. We can handle this with two memcpy() calls to // assemble the straddling record in between handling the two sides. var frag = frag0; var in: usize = 0; while (true) { if (in == frag.len) { // Perfect split. if (frag.ptr == frag1.ptr) { c.partial_ciphertext_end = c.partial_ciphertext_idx; return vp.total; } frag = frag1; in = 0; continue; } if (in + tls.record_header_len > frag.len) { if (frag.ptr == frag1.ptr) return finishRead(c, frag, in, vp.total); const first = frag[in..]; if (frag1.len < tls.record_header_len) return finishRead2(c, first, frag1, vp.total); // A record straddles the two fragments. Copy into the now-empty first fragment. const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); const record_len = (record_len_byte_0 << 8) | record_len_byte_1; if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; const full_record_len = record_len + tls.record_header_len; const second_len = full_record_len - first.len; if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); limitedOverlapCopy(frag, in); @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); frag = frag[0..full_record_len]; frag1 = frag1[second_len..]; in = 0; continue; } const ct: tls.ContentType = @enumFromInt(frag[in]); in += 1; const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); in += 2; _ = legacy_version; const record_len = mem.readInt(u16, frag[in..][0..2], .big); if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; in += 2; const end = in + record_len; if (end > frag.len) { // We need the record header on the next iteration of the loop. in -= tls.record_header_len; if (frag.ptr == frag1.ptr) return finishRead(c, frag, in, vp.total); // A record straddles the two fragments. Copy into the now-empty first fragment. const first = frag[in..]; const full_record_len = record_len + tls.record_header_len; const second_len = full_record_len - first.len; if (frag1.len < second_len) return finishRead2(c, first, frag1, vp.total); limitedOverlapCopy(frag, in); @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); frag = frag[0..full_record_len]; frag1 = frag1[second_len..]; in = 0; continue; } switch (ct) { .alert => { if (in + 2 > frag.len) return error.TlsDecodeError; const level: tls.AlertLevel = @enumFromInt(frag[in]); const desc: tls.AlertDescription = @enumFromInt(frag[in + 1]); _ = level; try desc.toError(); // TODO: handle server-side closures return error.TlsUnexpectedMessage; }, .application_data => { const cleartext = switch (c.application_cipher) { inline else => |*p| c: { const P = @TypeOf(p.*); const V = @Vector(P.AEAD.nonce_length, u8); const ad = frag[in - 5 ..][0..5]; const ciphertext_len = record_len - P.AEAD.tag_length; const ciphertext = frag[in..][0..ciphertext_len]; in += ciphertext_len; const auth_tag = frag[in..][0..P.AEAD.tag_length].*; const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @as([8]u8, @bitCast(big(c.read_seq))); const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; const out_buf = vp.peek(); const cleartext_buf = if (ciphertext.len <= out_buf.len) out_buf else &cleartext_stack_buffer; const cleartext = cleartext_buf[0..ciphertext.len]; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch return error.TlsBadRecordMac; break :c cleartext; }, }; c.read_seq = try std.math.add(u64, c.read_seq, 1); const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); switch (inner_ct) { .alert => { const level: tls.AlertLevel = @enumFromInt(cleartext[0]); const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); if (desc == .close_notify) { c.received_close_notify = true; c.partial_ciphertext_end = c.partial_ciphertext_idx; return vp.total; } _ = level; try desc.toError(); // TODO: handle server-side closures return error.TlsUnexpectedMessage; }, .handshake => { var ct_i: usize = 0; while (true) { const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); ct_i += 1; const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); ct_i += 3; const next_handshake_i = ct_i + handshake_len; if (next_handshake_i > cleartext.len - 1) return error.TlsBadLength; const handshake = cleartext[ct_i..next_handshake_i]; switch (handshake_type) { .new_session_ticket => { // This client implementation ignores new session tickets. }, .key_update => { switch (c.application_cipher) { inline else => |*p| { const P = @TypeOf(p.*); const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); p.server_secret = server_secret; p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); }, } c.read_seq = 0; switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { .update_requested => { switch (c.application_cipher) { inline else => |*p| { const P = @TypeOf(p.*); const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); p.client_secret = client_secret; p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); }, } c.write_seq = 0; }, .update_not_requested => {}, _ => return error.TlsIllegalParameter, } }, else => { return error.TlsUnexpectedMessage; }, } ct_i = next_handshake_i; if (ct_i >= cleartext.len - 1) break; } }, .application_data => { // Determine whether the output buffer or a stack // buffer was used for storing the cleartext. if (cleartext.ptr == &cleartext_stack_buffer) { // Stack buffer was used, so we must copy to the output buffer. const msg = cleartext[0 .. cleartext.len - 1]; if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { // We have already run out of room in iovecs. Continue // appending to `partially_read_buffer`. @memcpy( c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len], msg, ); c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + msg.len); } else { const amt = vp.put(msg); if (amt < msg.len) { const rest = msg[amt..]; c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = @intCast(rest.len); @memcpy(c.partially_read_buffer[0..rest.len], rest); } } } else { // Output buffer was used directly which means no // memory copying needs to occur, and we can move // on to the next ciphertext record. vp.next(cleartext.len - 1); } }, else => { return error.TlsUnexpectedMessage; }, } }, else => { return error.TlsUnexpectedMessage; }, } in = end; } } fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { const saved_buf = frag[in..]; if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { // There is cleartext at the beginning already which we need to preserve. c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len); @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf); } else { c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = 0; c.partial_ciphertext_end = @intCast(saved_buf.len); @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); } return out; } fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { // There is cleartext at the beginning already which we need to preserve. c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); // TODO: eliminate this call to copyForwards std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first); @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1); } else { c.partial_cleartext_idx = 0; c.partial_ciphertext_idx = 0; c.partial_ciphertext_end = @intCast(first.len + frag1.len); // TODO: eliminate this call to copyForwards std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); } return out; } fn limitedOverlapCopy(frag: []u8, in: usize) void { const first = frag[in..]; if (first.len <= in) { // A single, non-overlapping memcpy suffices. @memcpy(frag[0..first.len], first); } else { // One memcpy call would overlap, so just do this instead. std.mem.copyForwards(u8, frag, first); } } fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { if (index < s1.len) { return s1[index]; } else { return s2[index - s1.len]; } } const builtin = @import("builtin"); const native_endian = builtin.cpu.arch.endian(); inline fn big(x: anytype) @TypeOf(x) { return switch (native_endian) { .big => x, .little => @byteSwap(x), }; } fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { return switch (scheme) { .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, .ecdsa_secp521r1_sha512 => crypto.sign.ecdsa.EcdsaP512Sha512, else => @compileError("bad scheme"), }; } fn SchemeHash(comptime scheme: tls.SignatureScheme) type { return switch (scheme) { .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, else => @compileError("bad scheme"), }; } const VecPut = struct { iovecs: []const std.os.iovec, idx: usize = 0, off: usize = 0, total: usize = 0, fn put(vp: *VecPut, bytes: []const u8) usize { if (vp.idx >= vp.iovecs.len) return 0; var bytes_i: usize = 0; while (true) { const v = vp.iovecs[vp.idx]; const dest = v.iov_base[vp.off..v.iov_len]; const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; @memcpy(dest[0..src.len], src); bytes_i += src.len; vp.off += src.len; if (vp.off >= v.iov_len) { vp.off = 0; vp.idx += 1; if (vp.idx >= vp.iovecs.len) { vp.total += bytes_i; return bytes_i; } } if (bytes_i >= bytes.len) { vp.total += bytes_i; return bytes_i; } } } fn peek(vp: VecPut) []u8 { if (vp.idx >= vp.iovecs.len) return &.{}; const v = vp.iovecs[vp.idx]; return v.iov_base[vp.off..v.iov_len]; } // After writing to the result of peek(), one can call next() to // advance the cursor. fn next(vp: *VecPut, len: usize) void { vp.total += len; vp.off += len; if (vp.off >= vp.iovecs[vp.idx].iov_len) { vp.off = 0; vp.idx += 1; } } fn freeSize(vp: VecPut) usize { if (vp.idx >= vp.iovecs.len) return 0; var total: usize = 0; total += vp.iovecs[vp.idx].iov_len - vp.off; if (vp.idx + 1 >= vp.iovecs.len) return total; for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.iov_len; return total; } }; fn limitVecs(iovecs: []std.os.iovec, len: usize) []std.os.iovec { var bytes_left: usize = len; for (iovecs, 0..) |*iovec, vec_i| { if (bytes_left <= iovec.iov_len) { iovec.iov_len = bytes_left; return iovecs[0 .. vec_i + 1]; } bytes_left -= iovec.iov_len; } return iovecs; } const cipher_suites = if (crypto.core.aes.has_hardware_support) enum_array(tls.CipherSuite, &.{ .AEGIS_128L_SHA256, .AEGIS_256_SHA384, .AES_128_GCM_SHA256, .AES_256_GCM_SHA384, .CHACHA20_POLY1305_SHA256, }) else enum_array(tls.CipherSuite, &.{ .CHACHA20_POLY1305_SHA256, .AEGIS_128L_SHA256, .AEGIS_256_SHA384, .AES_128_GCM_SHA256, .AES_256_GCM_SHA384, }); test { _ = StreamInterface; } |
Generated by zstd-browse2 on 2023-11-04 14:12:24 -0400. |