File size: 1,686 Bytes
df4d00f
d3c49d5
5488267
 
 
 
d3c49d5
 
5488267
ea1604a
 
5488267
 
 
 
 
 
ea1604a
 
 
 
d3c49d5
 
 
 
 
5488267
ea1604a
d3c49d5
 
 
 
 
 
 
5488267
d3c49d5
5488267
d3c49d5
 
 
5488267
ea1604a
d3c49d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import tempfile
from vllm import LLM
from vllm.sampling_params import SamplingParams
from datetime import datetime, timedelta

# モデル名
model_name = "mistralai/Mistral-7B-Instruct-v0.2"

# SYSTEM_PROMPTのロード関数 (ローカルファイルを読み込む)
def load_system_prompt(file_path: str) -> str:
    with open(file_path, 'r') as file:
        system_prompt = file.read()
    today = datetime.today().strftime('%Y-%m-%d')
    yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d')
    return system_prompt.format(name=model_name, today=today, yesterday=yesterday)

# SYSTEM_PROMPT.txtは現在のディレクトリ内にあることを想定
system_prompt_path = "./SYSTEM_PROMPT.txt"  # 現在のディレクトリ内のファイルパス
SYSTEM_PROMPT = load_system_prompt(system_prompt_path)

# 一時ディレクトリをキャッシュ用に設定
with tempfile.TemporaryDirectory() as tmpdirname:
    os.environ["TRANSFORMERS_CACHE"] = tmpdirname
    os.environ["HF_HOME"] = tmpdirname
    os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")  # 環境変数からトークンを取得

    # メッセージリストの作成
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {
            "role": "user",
            "content": "Which of the depicted countries has the best food?",
        },
    ]

    sampling_params = SamplingParams(max_tokens=512)

    # モデルロードと実行
    llm = LLM(model=model_name, trust_remote_code=True, tensor_parallel_size=1, device="cpu")
    outputs = llm.chat(messages, sampling_params=sampling_params)

    # 出力を表示
    print(outputs[0].outputs[0].text)