|
import os |
|
import tempfile |
|
from vllm import LLM |
|
from vllm.sampling_params import SamplingParams |
|
from datetime import datetime, timedelta |
|
|
|
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
|
|
|
def load_system_prompt(file_path: str) -> str: |
|
with open(file_path, 'r') as file: |
|
system_prompt = file.read() |
|
today = datetime.today().strftime('%Y-%m-%d') |
|
yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d') |
|
return system_prompt.format(name=model_name, today=today, yesterday=yesterday) |
|
|
|
|
|
system_prompt_path = "./SYSTEM_PROMPT.txt" |
|
SYSTEM_PROMPT = load_system_prompt(system_prompt_path) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
os.environ["TRANSFORMERS_CACHE"] = tmpdirname |
|
os.environ["HF_HOME"] = tmpdirname |
|
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": SYSTEM_PROMPT}, |
|
{ |
|
"role": "user", |
|
"content": "Which of the depicted countries has the best food?", |
|
}, |
|
] |
|
|
|
sampling_params = SamplingParams(max_tokens=512) |
|
|
|
|
|
llm = LLM(model=model_name, trust_remote_code=True, tensor_parallel_size=1, device="cpu") |
|
outputs = llm.chat(messages, sampling_params=sampling_params) |
|
|
|
|
|
print(outputs[0].outputs[0].text) |
|
|