Skip to main content
Version: Next

Lora With Quantization Inference

*在线运行 vLLM 入门教程:零基础分步指南

源码 examples/offline_inference/lora_with_quantization_inference.py

# SPDX-License-Identifier: Apache-2.0

"""
此示例显示了如何使用不同的量化技术使用 LoRA
用于离线推理。
需要 HuggingFace 凭证以访问。
"""

import gc
from typing import Optional

import torch
from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


def create_test_prompts(
lora_path: str
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
return [
# 这是不使用 LoRA 的量化的示例
("My name is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128), None),
# 接下来的三个示例使用 LoRA 的量化示例
("my name is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-1", 1, lora_path)),
("The capital of USA is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-2", 1, lora_path)),
("The capital of France is",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128),
LoRARequest("lora-test-3", 1, lora_path)),
]


def process_requests(engine: LLMEngine,
test_prompts: list[tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"持续处理提示列表并处理输出"
request_id = 0

while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1

request_outputs: list[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print("----------------------------------------------------")
print(f"Prompt: {request_output.prompt}")
print(f"Output: {request_output.outputs[0].text}")


def initialize_engine(model: str, quantization: str,
lora_repo: Optional[str]) -> LLMEngine:
"""Initialize the LLMEngine."""

if quantization == "bitsandbytes":
# Qlora (https://arxiv.org/abs/2305.14314) 是一种量化技术。
# 它在加载时量化模型,并带有一些配置信息
# LoRA 适配器存储库。因此,需要设置 load_format 的参数和
# qlora_adapter_name_or_path 如下。
engine_args = EngineArgs(model=model,
quantization=quantization,
qlora_adapter_name_or_path=lora_repo,
enable_lora=True,
max_lora_rank=64)
else:
engine_args = EngineArgs(model=model,
quantization=quantization,
enable_lora=True,
max_loras=4)
return LLMEngine.from_engine_args(engine_args)


def main():
"""Main function that sets up and runs the prompt processing."""

test_configs = [{
"name": "qlora_inference_example",
'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b'
}, {
"name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}, {
"name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}]

for test_config in test_configs:
print(
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
)
engine = initialize_engine(test_config['model'],
test_config['quantization'],
test_config['lora_repo'])
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)

# 清理下一个测试的 GPU 内存
del engine
gc.collect()
torch.cuda.empty_cache()


if __name__ == '__main__':
main()