Skip to content

Commit

Permalink
Fix bug (#538)
Browse files Browse the repository at this point in the history
* fix some bug

* reformat
  • Loading branch information
xiezipeng-ML authored Mar 29, 2024
1 parent 964a349 commit c9bdff1
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 126 deletions.
2 changes: 1 addition & 1 deletion libai/inference/generator/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _get_decoder_start_token_id(
elif bos_token_id is not None:
return bos_token_id
else:
return self.cfg.bos_token_idx
return self.cfg.bos_token_id

@staticmethod
def _expand_inputs_for_generation(
Expand Down
38 changes: 33 additions & 5 deletions libai/models/utils/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import omegaconf
import oneflow as flow
from safetensors import safe_open
from termcolor import colored

import libai.utils.distributed as dist
Expand Down Expand Up @@ -457,17 +458,35 @@ def _load_config_from_json(self, config_file):

raise NotImplementedError("_load_config_from_json not implemented")

def _load_torch_state_dict(self, state_dict_file):
def _load_torch_state_dict(self, state_dict_file, use_safetensors=False):
try:
import torch
except ImportError:
raise ImportError("Load torch state dict need torch.")

if use_safetensors:
if isinstance(state_dict_file, str):
state_dict = {}
with safe_open(state_dict_file, framework="pt", device=0) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
return state_dict

elif isinstance(state_dict_file, list):
merged_state_dict = {}
for file in state_dict_file:
state_dict = {}
with safe_open(file, framework="pt") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k).to(torch.float)
merged_state_dict.update(state_dict)
return merged_state_dict

# load pytorch_model.bin
if isinstance(state_dict_file, str):
return torch.load(state_dict_file, map_location="cpu")

if isinstance(state_dict_file, list):
elif isinstance(state_dict_file, list):
merged_state_dict = {}
for file in state_dict_file:
state_dict = torch.load(file, map_location="cpu")
Expand Down Expand Up @@ -532,6 +551,7 @@ def load(self):
>>> bert = loader.load()
"""
use_safetensors = False
if dist.is_main_process():
if os.path.isdir(self.pretrained_model_path):
# state_dict file pytorch
Expand All @@ -541,10 +561,18 @@ def load(self):
if file.endswith(".bin")
]

if len(model_files) == 0:
use_safetensors = True
model_files = [
os.path.join(self.pretrained_model_path, file)
for file in os.listdir(self.pretrained_model_path)
if file.endswith(".safetensors")
]

if len(model_files) == 0:
raise EnvironmentError(
f"Error: no file named endswith '.bin' found"
f"in directory {self.pretrained_model_path}."
f"Error: no file named endswith '.bin' or '.safetensors' "
f"found in directory {self.pretrained_model_path}."
)

# config file
Expand All @@ -565,7 +593,7 @@ def load(self):
raise EnvironmentError(f"{self.pretrained_model_path} is not a directory.")

logger.info("loading torch model...")
torch_state_dict = self._load_torch_state_dict(model_files)
torch_state_dict = self._load_torch_state_dict(model_files, use_safetensors)
torch_state_dict = self._fix_key(torch_state_dict)
logger.info("transfering torch model into oneflow model...")
flow_state_dict = self._convert_tensors(torch_state_dict)
Expand Down
10 changes: 6 additions & 4 deletions libai/tokenizer/tokenization_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,17 +807,19 @@ def encode(self, text, return_tensors=None, is_global=False, **kwargs):
if isinstance(text, str):
tokens = self.tokenize(text)
token_ids = self.convert_tokens_to_ids(tokens)
token_ids = self.build_inputs_with_special_tokens(token_ids)
if hasattr(self, "build_inputs_with_special_tokens"):
token_ids = self.build_inputs_with_special_tokens(token_ids)
token_ids = self.convert_to_tensors(
token_ids, return_tensors=return_tensors, is_global=is_global, **kwargs
)
return token_ids
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
tokens = [self.tokenize(t) for t in text]
token_ids_list = self.convert_tokens_to_ids(tokens)
token_ids_list = [
self.build_inputs_with_special_tokens(token_ids) for token_ids in token_ids_list
]
if hasattr(self, "build_inputs_with_special_tokens"):
token_ids_list = [
self.build_inputs_with_special_tokens(token_ids) for token_ids in token_ids_list
]
token_ids_list = self.convert_to_tensors(
token_ids_list, return_tensors=return_tensors, is_global=is_global, **kwargs
)
Expand Down
238 changes: 122 additions & 116 deletions projects/mock_transformers/mock_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,125 +16,131 @@
import os

import oneflow as flow
import oneflow.mock_torch as mock

from libai.utils import distributed as dist

flow.mock_torch.enable()

from transformers import BertTokenizer, GPT2Tokenizer, MT5Tokenizer, T5Tokenizer # noqa
from transformers.tokenization_utils_base import * # noqa
from transformers.utils import generic # noqa
from transformers.utils.generic import TensorType # noqa


# ---------------- mock TensorType ------------------
class TensorType(ExplicitEnum): # noqa
PYTORCH = "pt"
TENSORFLOW = "tf"
ONEFLOW = "of"
NUMPY = "np"
JAX = "jax"


generic.TensorType = TensorType


# ---------------- mock convert_to_tensors ------------------
def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
if tensor_type is None:
return self

# Convert to TensorType
if not isinstance(tensor_type, TensorType):
tensor_type = TensorType(tensor_type)
as_tensor = None
is_tensor = None
# Get a function reference for the correct framework
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available(): # noqa
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not "
"installed."
)
import tensorflow as tf

as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available(): # noqa
raise ImportError(
"Unable to convert output to PyTorch tensors format, PyTorch is not installed."
)
import torch

as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.ONEFLOW:
try:
import oneflow # noqa
except ImportError as e:
msg = "Unable to convert output to OneFlow tensors format, OneFlow is not installed."
raise ImportError(msg) from e
as_tensor = flow.tensor
is_tensor = flow.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available(): # noqa
raise ImportError(
"Unable to convert output to JAX tensors format, JAX is not installed."
)
import jax.numpy as jnp # noqa: F811

as_tensor = jnp.array
is_tensor = is_jax_tensor # noqa
else:
as_tensor = np.asarray # noqa
is_tensor = is_numpy_array # noqa

# Do the tensor conversion in batch
for key, value in self.items():
try:
if prepend_batch_axis:
value = [value]

if not is_tensor(value):
tensor = as_tensor(value)

# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]

self[key] = tensor
except Exception as e:
if key == "overflowing_tokens":
raise ValueError(
"Unable to create tensor returning overflowing tokens of different lengths. "
"Please see if a fast version of this tokenizer is available to have this "
"feature available."
) from e
raise ValueError(
"Unable to create tensor, you should probably activate truncation and/or "
"padding with 'padding=True' 'truncation=True' to have batched tensors with "
f"the same length. Perhaps your features (`{key}` in this case) have "
"excessive nesting (inputs type `list` where type `int` is expected)."
) from e
if os.getenv("IS_GLOBAL", True) is True:
size = self["input_ids"].size()
sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])

for k, v in self.items():
if is_tensor != flow.is_tensor:
raise ValueError(
"Unable to create tensor, you should probably set `return_tensors='of'` "
with mock.enable(lazy=True):

from transformers import ( # noqa
BertTokenizer,
GPT2Tokenizer,
MT5Tokenizer,
Qwen2Tokenizer,
T5Tokenizer,
)
from transformers.tokenization_utils_base import * # noqa
from transformers.utils import generic # noqa
from transformers.utils.generic import TensorType # noqa

# ---------------- mock TensorType ------------------
class TensorType(ExplicitEnum): # noqa
PYTORCH = "pt"
TENSORFLOW = "tf"
ONEFLOW = "of"
NUMPY = "np"
JAX = "jax"

generic.TensorType = TensorType

# ---------------- mock convert_to_tensors ------------------
def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
if tensor_type is None:
return self

# Convert to TensorType
if not isinstance(tensor_type, TensorType):
tensor_type = TensorType(tensor_type)
as_tensor = None
is_tensor = None
# Get a function reference for the correct framework
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available(): # noqa
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not "
"installed."
)
if v.size() != size:
raise ValueError(
"Unable to create tensor, you should probably padding with `padding=True` "
import tensorflow as tf

as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available(): # noqa
raise ImportError(
"Unable to convert output to PyTorch tensors format, PyTorch is not installed."
)
self[k] = v.to_global(sbp=sbp, placement=dist.get_layer_placement(0))
return self

import torch

as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.ONEFLOW:
try:
import oneflow # noqa
except ImportError as e:
msg = (
"Unable to convert output to OneFlow tensors format, OneFlow is not installed."
)
raise ImportError(msg) from e
as_tensor = flow.tensor
is_tensor = flow.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available(): # noqa
raise ImportError(
"Unable to convert output to JAX tensors format, JAX is not installed."
)
import jax.numpy as jnp # noqa: F811

as_tensor = jnp.array
is_tensor = is_jax_tensor # noqa
else:
as_tensor = np.asarray # noqa
is_tensor = is_numpy_array # noqa

# Do the tensor conversion in batch
for key, value in self.items():
try:
if prepend_batch_axis:
value = [value]

if not is_tensor(value):
tensor = as_tensor(value)

# Removing this for now in favor of controlling the shape
# with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]

self[key] = tensor
except Exception as e:
if key == "overflowing_tokens":
raise ValueError(
"Unable to create tensor returning overflowing tokens of different "
"lengths. Please see if a fast version of this tokenizer is "
"available to have this feature available."
) from e
raise ValueError(
"Unable to create tensor, you should probably activate truncation and/or "
"padding with 'padding=True' 'truncation=True' to have batched tensors with "
f"the same length. Perhaps your features (`{key}` in this case) have "
"excessive nesting (inputs type `list` where type `int` is expected)."
) from e
if os.getenv("IS_GLOBAL", True) is True:
size = self["input_ids"].size()
sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])

for k, v in self.items():
if is_tensor != flow.is_tensor:
raise ValueError(
"Unable to create tensor, you should probably set `return_tensors='of'` "
)
if v.size() != size:
raise ValueError(
"Unable to create tensor, you should probably padding with `padding=True` "
)
self[k] = v.to_global(sbp=sbp, placement=dist.get_layer_placement(0))
return self

BatchEncoding.convert_to_tensors = flow_convert_to_tensors # noqa
BatchEncoding.convert_to_tensors = flow_convert_to_tensors # noqa
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ black==21.4b2
autoflake
tensorboardX<=2.5.1
pytest
safetensors
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def get_libai_configs() -> List[str]:
"autoflake",
"tensorboardX<=2.5.1",
"pytest",
"safetensors",
],
packages=find_packages(),
package_data={"libai.config": get_libai_configs()},
Expand Down

0 comments on commit c9bdff1

Please sign in to comment.