diff --git a/src/utils.py b/src/utils.py index db6a95c..8a66205 100644 --- a/src/utils.py +++ b/src/utils.py @@ -44,7 +44,11 @@ def __init__(self, job): self.max_batch_size = job.get("max_batch_size") self.apply_chat_template = job.get("apply_chat_template", False) self.use_openai_format = job.get("use_openai_format", False) - self.sampling_params = SamplingParams(max_tokens=100, **job.get("sampling_params", {})) + samp_param = job.get("sampling_params", {}) + if "max_tokens" not in samp_param: + samp_param["max_tokens"] = 100 + self.sampling_params = SamplingParams(**samp_param) + # self.sampling_params = SamplingParams(max_tokens=100, **job.get("sampling_params", {})) self.request_id = random_uuid() batch_size_growth_factor = job.get("batch_size_growth_factor") self.batch_size_growth_factor = float(batch_size_growth_factor) if batch_size_growth_factor else None