From 3b8c431de673529bb3300b871894678f18ab189a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 18 Aug 2024 22:46:04 +0200 Subject: [PATCH] Add :skip_utf8_validation flag to encode --- lib/tokenizers/tokenizer.ex | 4 ++ native/ex_tokenizers/src/tokenizer.rs | 99 ++++++++++++++++++--------- test/tokenizers/tokenizer_test.exs | 7 ++ 3 files changed, 77 insertions(+), 33 deletions(-) diff --git a/lib/tokenizers/tokenizer.ex b/lib/tokenizers/tokenizer.ex index e2a045c..3c403ba 100644 --- a/lib/tokenizers/tokenizer.ex +++ b/lib/tokenizers/tokenizer.ex @@ -452,6 +452,10 @@ defmodule Tokenizers.Tokenizer do * `:add_special_tokens` - whether to add special tokens to the sequence. Defaults to `true` + * `:skip_utf8_validation` - whether to skip utf8 validation. + Defaults to `false`. Disabling this and passing invalid strings + may lead to errors (including segmentation fault) + * `:encoding_transformations` - a list of `t:Tokenizers.Encoding.Transformation.t/0` to apply to the encoding. Check `Tokenizers.Encoding.transform/2` for more information. Defaults to `[]` diff --git a/native/ex_tokenizers/src/tokenizer.rs b/native/ex_tokenizers/src/tokenizer.rs index 807dd05..7e3802e 100644 --- a/native/ex_tokenizers/src/tokenizer.rs +++ b/native/ex_tokenizers/src/tokenizer.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::ops::Deref; use std::panic; -use rustler::{NifTaggedEnum, Term}; +use rustler::{Binary, NifTaggedEnum, Term}; use tokenizers::models::wordpiece::WordPieceTrainerBuilder; use tokenizers::models::TrainerWrapper; @@ -425,36 +425,79 @@ fn term_to_encode_input<'a, 'b>(term: &'a Term<'b>) -> Result, E } } +fn unsafe_term_to_encode_input<'a, 'b>( + term: &'a Term<'b>, +) -> Result, ExTokenizersError> { + if let Ok(bin) = term.decode::() { + let slice: &'b [u8] = bin.as_slice(); + let string = unsafe { std::str::from_utf8_unchecked(slice) }; + Ok(EncodeInput::Single(string.into())) + } else if let Ok((bin1, bin2)) = term.decode::<(Binary, Binary)>() { + let slice1: &'b [u8] = bin1.as_slice(); + let string1 = unsafe { std::str::from_utf8_unchecked(slice1) }; + let slice2: &'b [u8] = bin2.as_slice(); + let string2 = unsafe { std::str::from_utf8_unchecked(slice2) }; + Ok(EncodeInput::Dual(string1.into(), string2.into())) + } else { + Err(ExTokenizersError::Other(String::from( + "input must be either a string (valid UTF-8 encoded) or a tuple", + ))) + } +} + #[derive(NifTaggedEnum)] pub enum EncodeOption { + SkipUTF8Validation(bool), AddSpecialTokens(bool), EncodingTransformations(Vec), } +struct EncodeConfig { + skip_utf8_validation: bool, + add_special_tokens: bool, + encoding_transformations: Vec, +} + +impl From> for EncodeConfig { + fn from(options: Vec) -> Self { + let mut config = EncodeConfig { + skip_utf8_validation: false, + add_special_tokens: true, + encoding_transformations: Vec::new(), + }; + + for option in options { + match option { + EncodeOption::SkipUTF8Validation(val) => { + config.skip_utf8_validation = val; + } + EncodeOption::AddSpecialTokens(val) => { + config.add_special_tokens = val; + } + EncodeOption::EncodingTransformations(transformations) => { + config.encoding_transformations = transformations; + } + } + } + + config + } +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn tokenizer_encode( tokenizer: ExTokenizersTokenizer, input: Term, options: Vec, ) -> Result { - struct Opts { - add_special_tokens: bool, - encoding_transformations: Vec, - } - let mut opts = Opts { - add_special_tokens: true, - encoding_transformations: Vec::new(), + let opts: EncodeConfig = options.into(); + + let input = if opts.skip_utf8_validation { + unsafe_term_to_encode_input(&input)? + } else { + term_to_encode_input(&input)? }; - options.into_iter().for_each(|option| match option { - EncodeOption::AddSpecialTokens(add_special_tokens) => { - opts.add_special_tokens = add_special_tokens - } - EncodeOption::EncodingTransformations(encoding_transformations) => { - opts.encoding_transformations = encoding_transformations - } - }); - let input = term_to_encode_input(&input)?; let mut encoding = tokenizer .resource .0 @@ -470,25 +513,15 @@ pub fn tokenizer_encode_batch( options: Vec, // add_special_tokens: bool, ) -> Result, ExTokenizersError> { - struct Opts { - add_special_tokens: bool, - encoding_transformations: Vec, - } - let mut opts = Opts { - add_special_tokens: true, - encoding_transformations: Vec::new(), + let opts: EncodeConfig = options.into(); + let callback = if opts.skip_utf8_validation { + unsafe_term_to_encode_input + } else { + term_to_encode_input }; - options.into_iter().for_each(|option| match option { - EncodeOption::AddSpecialTokens(add_special_tokens) => { - opts.add_special_tokens = add_special_tokens - } - EncodeOption::EncodingTransformations(encoding_transformations) => { - opts.encoding_transformations = encoding_transformations - } - }); let inputs = inputs .iter() - .map(term_to_encode_input) + .map(callback) .collect::, ExTokenizersError>>()?; let mut encodings = tokenizer .resource diff --git a/test/tokenizers/tokenizer_test.exs b/test/tokenizers/tokenizer_test.exs index 387937b..1db14e9 100644 --- a/test/tokenizers/tokenizer_test.exs +++ b/test/tokenizers/tokenizer_test.exs @@ -142,6 +142,13 @@ defmodule Tokenizers.TokenizerTest do describe "encode/decode" do test "can encode a single string", %{tokenizer: tokenizer} do assert {:ok, %Tokenizers.Encoding{}} = Tokenizer.encode(tokenizer, "This is a test") + + assert {:ok, %Tokenizers.Encoding{}} = + Tokenizer.encode(tokenizer, "This is a test", skip_utf8_validation: true) + end + + test "errors when encoding a binary", %{tokenizer: tokenizer} do + assert {:error, _} = Tokenizer.encode(tokenizer, <<0xFF>>) end test "can apply transformations to encoding", %{tokenizer: tokenizer} do