13
13
import sys
14
14
from datetime import datetime
15
15
from pathlib import Path
16
- from typing import Any , Dict , List
16
+ from typing import Any , Dict , List , Optional
17
17
18
18
import torch
19
19
import torch .nn .functional as F
24
24
from torch import Tensor
25
25
from torch import nn
26
26
from torch .utils .tensorboard import SummaryWriter
27
- from tqdm import tqdm
28
- from tqdm import trange
29
27
from transformers import AutoModelForCausalLM
30
28
from transformers import AutoTokenizer
31
29
from whowhatbench import TextEvaluator
32
30
33
31
import nncf
32
+ from nncf .common .logging .track_progress import track
34
33
from nncf .data .dataset import Dataset
35
34
from nncf .parameters import CompressionFormat
36
35
from nncf .parameters import CompressWeightsMode
37
36
from nncf .parameters import StripFormat
37
+ from nncf .quantization .advanced_parameters import AdvancedCompressionParameters
38
38
from nncf .quantization .quantize_model import compress_weights
39
39
from nncf .torch .model_creation import load_from_config
40
40
from nncf .torch .quantization .layers import AsymmetricLoraQuantizer
41
41
from nncf .torch .quantization .layers import SymmetricLoraQuantizer
42
42
43
43
44
- def get_wikitext2 (nsamples : int , seqlen : int , tokenizer : Any , device : torch .device ) -> List [Tensor ]:
44
+ def get_wikitext2 (num_samples : int , seqlen : int , tokenizer : Any , device : torch .device ) -> List [Tensor ]:
45
45
"""
46
46
Loads and processes the Wikitext-2 dataset for training.
47
47
48
- :param nsamples : Number of samples to generate.
48
+ :param num_samples : Number of samples to generate.
49
49
:param seqlen: Sequence length for each sample.
50
50
:param tokenizer: Tokenizer to encode the text.
51
51
:param device: Device to move the tensors to (e.g., 'cpu' or 'cuda').
52
52
:return: A list of tensors containing the tokenized text samples.
53
53
"""
54
54
traindata = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "train" )
55
- limit = nsamples * seqlen // 4 # ~1k for 128 samples with seqlen=32 to be aligned with optimum
55
+ limit = num_samples * seqlen // 4 # ~1k for 128 samples with seqlen=32 to be aligned with optimum
56
56
text = "" .join ([" \n " if s == "" else s for s in traindata ["text" ][:limit ]])
57
57
trainenc = tokenizer (text , return_tensors = "pt" )
58
58
trainloader = []
59
- for _ in range (nsamples ):
59
+ for _ in range (num_samples ):
60
60
# Crop a sequence of tokens of length seqlen starting at a random position
61
61
i = torch .randint (0 , trainenc .input_ids .shape [1 ] - seqlen - 1 , (1 ,)).item ()
62
62
j = i + seqlen
@@ -66,7 +66,7 @@ def get_wikitext2(nsamples: int, seqlen: int, tokenizer: Any, device: torch.devi
66
66
67
67
68
68
@torch .no_grad ()
69
- def save_wwb_ref (model : str , tokenizer : Any , wwb_ref_file : Path ) -> None :
69
+ def save_wwb_ref (model : str , tokenizer : Any , wwb_ref_file : Path , num_samples : Optional [ int ] = None ) -> None :
70
70
"""
71
71
Save the reference answers for the WWB (WhoWhatBenchmark) evaluation.
72
72
@@ -76,12 +76,14 @@ def save_wwb_ref(model: str, tokenizer: Any, wwb_ref_file: Path) -> None:
76
76
"""
77
77
if not wwb_ref_file .exists ():
78
78
print ("#" * 50 + " Collect reference answers for WWB " + "#" * 50 )
79
- wwb_eval = TextEvaluator (base_model = model , tokenizer = tokenizer , use_chat_template = True )
79
+ wwb_eval = TextEvaluator (base_model = model , tokenizer = tokenizer , use_chat_template = True , num_samples = num_samples )
80
80
wwb_eval .dump_gt (str (wwb_ref_file ))
81
81
torch .cuda .empty_cache ()
82
82
83
83
84
- def measure_similarity (model_for_eval : OVModelForCausalLM , tokenizer : Any , wwb_ref_file : Path ) -> float :
84
+ def measure_similarity (
85
+ model_for_eval : OVModelForCausalLM , tokenizer : Any , wwb_ref_file : Path , num_samples : Optional [int ] = None
86
+ ) -> float :
85
87
"""
86
88
Measures the similarity of a model's output to a reference outputs from a given file using WWB evaluation.
87
89
@@ -92,7 +94,11 @@ def measure_similarity(model_for_eval: OVModelForCausalLM, tokenizer: Any, wwb_r
92
94
"""
93
95
print ("#" * 50 + " Evaluate via WWB " + "#" * 50 )
94
96
wwb_eval = TextEvaluator (
95
- tokenizer = tokenizer , gt_data = wwb_ref_file , test_data = str (wwb_ref_file ), use_chat_template = True
97
+ tokenizer = tokenizer ,
98
+ gt_data = wwb_ref_file ,
99
+ test_data = str (wwb_ref_file ),
100
+ use_chat_template = True ,
101
+ num_samples = num_samples ,
96
102
)
97
103
_ , all_metrics = wwb_eval .score (model_for_eval )
98
104
return float (all_metrics ["similarity" ].iloc [0 ])
@@ -108,8 +114,8 @@ def calc_hiddens(model: nn.Module, dataloader: List[Tensor]) -> List[Tensor]:
108
114
:return: A list of hidden states for each input in the dataloader.
109
115
"""
110
116
orig_hiddens = []
111
- for i in trange ( len ( dataloader ), total = len ( dataloader ), desc = "Calculating original hiddens" , leave = False ):
112
- model_input = get_model_input (dataloader [ i ] )
117
+ for data in track ( dataloader , description = "Calculating original hiddens" ):
118
+ model_input = get_model_input (data )
113
119
orig_hiddens .append (model .model (** model_input ).last_hidden_state )
114
120
torch .cuda .empty_cache ()
115
121
return orig_hiddens
@@ -260,10 +266,12 @@ def get_argument_parser() -> argparse.ArgumentParser:
260
266
help = "Whether to start from previously saved checkpoint. If not specified or checkpoint does not exist, "
261
267
"start from scratch by post-training weight compression initialization." ,
262
268
)
269
+ parser .add_argument ("--lora_rank" , type = int , default = 256 , help = "Rank of lora adapters" )
263
270
264
271
# Data params
265
- parser .add_argument ("--nsamples " , type = int , default = 1024 , help = "Number of training samples" )
272
+ parser .add_argument ("--num_train_samples " , type = int , default = 1024 , help = "Number of training samples" )
266
273
parser .add_argument ("--seqlen" , type = int , default = 1024 , help = "Calibration data context length." )
274
+ parser .add_argument ("--num_val_samples" , type = int , default = None , help = "Number of validation samples for WWB." )
267
275
268
276
# Training params
269
277
parser .add_argument (
@@ -286,7 +294,7 @@ def get_argument_parser() -> argparse.ArgumentParser:
286
294
287
295
def main (argv ) -> float :
288
296
"""
289
- Fine-tunes the specified model and returns the best validation similarity score .
297
+ Fine-tunes the specified model and returns the difference between initial and best validation similarity scores .
290
298
"""
291
299
parser = get_argument_parser ()
292
300
args = parser .parse_args (argv )
@@ -295,7 +303,10 @@ def main(argv) -> float:
295
303
device = "cuda"
296
304
torch_dtype = torch .bfloat16
297
305
compression_config = dict (
298
- mode = CompressWeightsMode .INT4_ASYM , group_size = 64 , compression_format = CompressionFormat .FQ_LORA
306
+ mode = CompressWeightsMode .INT4_ASYM ,
307
+ group_size = 64 ,
308
+ compression_format = CompressionFormat .FQ_LORA ,
309
+ advanced_parameters = AdvancedCompressionParameters (lora_adapter_rank = args .lora_rank ),
299
310
)
300
311
301
312
# Configure output and log files.
@@ -320,11 +331,13 @@ def main(argv) -> float:
320
331
# computed by for data generated by two models, original floating-point one and optimized.
321
332
# TODO: (nlyalyus) Use original model for collecting reference, once the bug in WWB resolved.
322
333
wwb_ref_model = AutoModelForCausalLM .from_pretrained (args .pretrained , torch_dtype = torch_dtype , device_map = "cpu" )
323
- save_wwb_ref (wwb_ref_model , tokenizer , wwb_ref_file )
334
+ save_wwb_ref (wwb_ref_model , tokenizer , wwb_ref_file , args . num_val_samples )
324
335
del wwb_ref_model
325
336
326
337
# Prepare training data and pre-compute hiddens of teacher model for distillation loss.
327
- train_loader = get_wikitext2 (nsamples = args .nsamples , seqlen = args .seqlen , tokenizer = tokenizer , device = device )
338
+ train_loader = get_wikitext2 (
339
+ num_samples = args .num_train_samples , seqlen = args .seqlen , tokenizer = tokenizer , device = device
340
+ )
328
341
orig_hiddens = calc_hiddens (model , train_loader )
329
342
330
343
# Create or load model to tune with Fake Quantizers and absorbable LoRA adapters.
@@ -341,9 +354,11 @@ def main(argv) -> float:
341
354
342
355
# Convert torch checkpoint to an OpenVINO model and evaluate it via WWB.
343
356
model_for_eval = export_to_openvino (args .pretrained , train_loader [0 ], ckpt_file , last_dir )
344
- best_similarity = measure_similarity (model_for_eval , tokenizer , wwb_ref_file )
345
- tb .add_scalar ("similarity" , best_similarity , 0 )
346
- print (f"Initial WWB similarity= { best_similarity :.4f} " )
357
+ initial_similarity = best_similarity = measure_similarity (
358
+ model_for_eval , tokenizer , wwb_ref_file , args .num_val_samples
359
+ )
360
+ tb .add_scalar ("similarity" , initial_similarity , 0 )
361
+ print (f"Initial WWB similarity= { initial_similarity :.4f} " )
347
362
348
363
# Run tuning with distillation loss and validation on WWB after each epoch.
349
364
grad_accumulation_steps = args .batch_size // args .microbatch_size
@@ -354,7 +369,7 @@ def main(argv) -> float:
354
369
loss_numerator = grad_steps = total_microbatches = 0
355
370
for epoch in range (args .epochs ):
356
371
batch_indices_epoch = torch .randperm (num_samples )[:epoch_samples ].chunk (microbatches_per_epoch )
357
- for indices in tqdm (batch_indices_epoch , desc = f"Train epoch { epoch } " , leave = [ False ] ):
372
+ for indices in track (batch_indices_epoch , description = f"Train epoch { epoch } " ):
358
373
indices = indices .tolist ()
359
374
total_microbatches += 1
360
375
@@ -373,7 +388,7 @@ def form_batch(inputs: List[Tensor], model_input: bool):
373
388
targets = torch .tanh (targets )
374
389
targets = targets * fls
375
390
outputs = model (** inputs ).logits
376
- loss = kl_div (outputs , targets .to (dtype = torch_dtype ))
391
+ loss = kl_div (outputs , targets .to (device = device , dtype = torch_dtype ))
377
392
378
393
# Perform an optimization step after accumulating gradients over multiple minibatches.
379
394
loss_numerator += loss .item ()
@@ -393,16 +408,16 @@ def form_batch(inputs: List[Tensor], model_input: bool):
393
408
# Save the best checkpoint and OpenVINO IR for the highest similarity score obtained from WWB.
394
409
save_checkpoint (model , ckpt_file )
395
410
model_for_eval = export_to_openvino (args .pretrained , train_loader [0 ], ckpt_file , last_dir )
396
- similarity = measure_similarity (model_for_eval , tokenizer , wwb_ref_file )
397
- print (f"[Epoch { epoch } ], WWB similarity = { similarity :.4f} " )
411
+ similarity = measure_similarity (model_for_eval , tokenizer , wwb_ref_file , args . num_val_samples )
412
+ print (f"[Epoch { epoch + 1 } ], WWB similarity = { similarity :.4f} " )
398
413
tb .add_scalar ("similarity" , similarity , total_microbatches )
399
414
if similarity > best_similarity :
400
415
print (f"New best WWB similarity = { similarity :.4f} " )
401
416
best_similarity = similarity
402
417
shutil .copytree (last_dir , best_dir , dirs_exist_ok = True )
403
418
404
- print (f"The finetuned OV model with the best similarity={ best_similarity } saved to: { best_dir } " )
405
- return best_similarity
419
+ print (f"The finetuned OV model with the best similarity={ best_similarity :.4f } saved to: { best_dir } " )
420
+ return best_similarity - initial_similarity
406
421
407
422
408
423
if __name__ == "__main__" :
0 commit comments