@@ -376,30 +376,32 @@ def test_pipelines_generate_with_streaming(tmp_path, pipeline_type):
376
376
model_id : str = "facebook/opt-125m"
377
377
opt_model , hf_tokenizer = get_hugging_face_models (model_id )
378
378
379
- models_path : Path = tmp_path / "t_streaming" / model_id
379
+ models_path : Path = tmp_path / model_id
380
380
convert_models (opt_model , hf_tokenizer , models_path )
381
381
382
382
generation_config = GenerationConfig ()
383
- pipe , input , gen_config = get_data_by_pipeline_type (models_path , pipeline_type , generation_config )
383
+ pipe , input , generation_config = get_data_by_pipeline_type (models_path , pipeline_type , generation_config )
384
384
385
+ it_cnt = 0
385
386
def py_streamer (py_str : str ):
387
+ nonlocal it_cnt
388
+ it_cnt += 1
386
389
return False
387
390
388
- try :
389
- _ = pipe .generate (input , generation_config = generation_config , streamer = py_streamer )
390
- except Exception :
391
- assert True
391
+ _ = pipe .generate (input , generation_config = generation_config , streamer = py_streamer )
392
392
393
393
del pipe
394
394
rmtree (models_path )
395
395
396
+ assert it_cnt > 0
397
+
396
398
@pytest .mark .parametrize ("pipeline_type" , ["continuous_batching" , "speculative_decoding" , "prompt_lookup_decoding" , "llm_pipeline" ])
397
399
@pytest .mark .precommit
398
400
def test_pipelines_generate_with_streaming_empty_output (tmp_path , pipeline_type ):
399
401
model_id : str = "facebook/opt-125m"
400
402
opt_model , hf_tokenizer = get_hugging_face_models (model_id )
401
403
402
- models_path : Path = tmp_path / "t_streaming" / model_id
404
+ models_path : Path = tmp_path / model_id
403
405
convert_models (opt_model , hf_tokenizer , models_path )
404
406
405
407
generation_config = GenerationConfig ()
@@ -408,13 +410,15 @@ def test_pipelines_generate_with_streaming_empty_output(tmp_path, pipeline_type)
408
410
409
411
pipe , input , generation_config = get_data_by_pipeline_type (models_path , pipeline_type , generation_config )
410
412
413
+ it_cnt = 0
411
414
def py_streamer (py_str : str ):
412
- raise Exception ("Streamer was called" )
415
+ nonlocal it_cnt
416
+ it_cnt += 1
417
+ return False
413
418
414
- try :
415
- _ = pipe .generate (input , generation_config = generation_config , streamer = py_streamer )
416
- except Exception :
417
- assert False
419
+ _ = pipe .generate (input , generation_config = generation_config , streamer = py_streamer )
418
420
419
421
del pipe
420
422
rmtree (models_path )
423
+
424
+ assert it_cnt == 0
0 commit comments