Skip to content

Commit

Permalink
llm eval (#535)
Browse files Browse the repository at this point in the history
* llm eval

* add glm

* lm_eval 0.4.2 with oneflow 1.0.0

* lm_eval 0.4.2 with oneflow 1.0.0

* fix

* format

* format

* format

* format

* format

* format
  • Loading branch information
zsw256 authored Apr 10, 2024
1 parent c9bdff1 commit 544cb3a
Show file tree
Hide file tree
Showing 12 changed files with 551 additions and 7 deletions.
4 changes: 2 additions & 2 deletions libai/inference/generator/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,8 @@ def greedy_search(

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = flow.mul(
unfinished_sequences, (next_tokens != eos_token_id).long()
unfinished_sequences = unfinished_sequences.mul(
next_tokens.ne(eos_token_id).prod(dim=0)
)

if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
Expand Down
1 change: 1 addition & 0 deletions projects/BLOOM/configs/bloom_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
cfg = dict(
# model
vocab_size=250880,
max_position_embeddings=512,
hidden_size=64,
hidden_layers=2,
n_head=8,
Expand Down
11 changes: 8 additions & 3 deletions projects/BLOOM/utils/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _convert_state_dict(self, flow_state_dict, cfg):

# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix2 = "transformer." if has_prefix else ""
prefix2 = "transformer." if not has_prefix else ""

# Convert layers.
for key in old_keys:
Expand All @@ -61,8 +61,13 @@ def _load_config_from_json(self, config_file):
cfg_dict = json.load(f)

self._update_cfg("hidden_layers", cfg_dict["n_layer"])
self._update_cfg("hidden_size", cfg_dict["n_embed"])
self._update_cfg("n_head", cfg_dict["num_attention_heads"])

if "n_embed" in cfg_dict.keys():
self._update_cfg("hidden_size", cfg_dict["n_embed"])
self._update_cfg("n_head", cfg_dict["num_attention_heads"])
else:
self._update_cfg("hidden_size", cfg_dict["hidden_size"])
self._update_cfg("n_head", cfg_dict["n_head"])

# update libai_cfg by config.json
for k, v in cfg_dict.items():
Expand Down
1 change: 1 addition & 0 deletions projects/ChatGLM/configs/chatglm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
layernorm_epsilon=1e-05,
multi_query_attention=True,
multi_query_group_num=2,
max_position_embeddings=2048,
num_attention_heads=32,
num_layers=28,
padded_vocab_size=65024,
Expand Down
49 changes: 49 additions & 0 deletions projects/Eval_LLM/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# LLM Evaluation

A tool for evaluating OneFlow models based on [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/)

## Environment

Follow this [Installation Instruction](https://libai.readthedocs.io/en/latest/tutorials/get_started/Installation.html) to install oneflow(1.0.0) and libai first. Conda is recommended.
**Make sure you have python>=3.10 to run evaluation for GLM.**
Then run ```pip install -r ./projects/Eval_LLM/requirements.txt``` to install dependencies.

## Run Eval

### Set the parameters in ./projects/Eval_LLM/config.py

> pretrained_model_path: The path of your model weights, either huggingface weights or libai weights is ok.
> hf_tokenizer_path: The path of huggingface tokenizer.
> model_type: Type of your model, this argument is need for loading model. All choices are listed in ./projects/Eval_LLM/special_arguments.json
> model_weight_type: Whether your weights are huggingface weights or libai weights.
> eval_tasks: Tasks you want to evaluate you model on.
> batch_size_per_gpu: Batch size on a single gpu, if you want to accelerate you evaluation, set it larger. But this may lead to OOM error.
Tasks for Evaluation are listed [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks).

### Run the following command to start eval
```
bash tools/infer.sh projects/Eval_LLM/main.py 1
```
Notice: The number stands for how many gpus you want to use.

If you want to eval GLM(ChatGLM), run this:
```
CHATGLM_HF_DIR=YOUR_MODEL_PATH bash tools/infer.sh projects/Eval_LLM/main.py 1
```

Notice: To run a model with 6B parameters, you are about to have VRAM more than 24GB. You can use tensor or pipeline parallel on multiple devices.

To know more about distributed inference: https://docs.oneflow.org/en/master/parallelism/04_launch.html

## Example of Eval Result
Using Llama2-7b
```
{'sciq':
{'acc,none': 0.794, 'acc_stderr,none': 0.012795613612786583, 'acc_norm,none': 0.707, 'acc_norm_stderr,none': 0.014399942998441271, 'alias': 'sciq'},
'lambada_openai':
{'perplexity,none': 28.778403569948463, 'perplexity_stderr,none': 1.0792474430271395, 'acc,none': 0.33980205705414324, 'acc_stderr,none': 0.006598757339311441, 'alias': 'lambada_openai'},
'gsm8k':
{'exact_match,strict-match': 0.001516300227445034, 'exact_match_stderr,strict-match': 0.0010717793485492675, 'exact_match,flexible-extract': 0.01061410159211524, 'exact_match_stderr,flexible-extract': 0.002822713322387704, 'alias': 'gsm8k'}
}
```
22 changes: 22 additions & 0 deletions projects/Eval_LLM/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from omegaconf import DictConfig

parallel_config = DictConfig(
dict(
data_parallel_size=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
pipeline_num_layers=32,
device_type="cuda",
)
)

eval_config = DictConfig(
dict(
pretrained_model_path="",
hf_tokenizer_path="",
model_type="llama",
model_weight_type="libai", # libai or huggingface
eval_tasks=["lambada_openai", "gsm8k"],
batch_size_per_gpu=1,
)
)
Loading

0 comments on commit 544cb3a

Please sign in to comment.