GPTfree api
Update app.py
ea1604a verified
raw
history blame
1.69 kB
import os
import tempfile
from vllm import LLM
from vllm.sampling_params import SamplingParams
from datetime import datetime, timedelta
# モデル名
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
# SYSTEM_PROMPTのロード関数 (ローカルファイルを読み込む)
def load_system_prompt(file_path: str) -> str:
with open(file_path, 'r') as file:
system_prompt = file.read()
today = datetime.today().strftime('%Y-%m-%d')
yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d')
return system_prompt.format(name=model_name, today=today, yesterday=yesterday)
# SYSTEM_PROMPT.txtは現在のディレクトリ内にあることを想定
system_prompt_path = "./SYSTEM_PROMPT.txt" # 現在のディレクトリ内のファイルパス
SYSTEM_PROMPT = load_system_prompt(system_prompt_path)
# 一時ディレクトリをキャッシュ用に設定
with tempfile.TemporaryDirectory() as tmpdirname:
os.environ["TRANSFORMERS_CACHE"] = tmpdirname
os.environ["HF_HOME"] = tmpdirname
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") # 環境変数からトークンを取得
# メッセージリストの作成
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": "Which of the depicted countries has the best food?",
},
]
sampling_params = SamplingParams(max_tokens=512)
# モデルロードと実行
llm = LLM(model=model_name, trust_remote_code=True, tensor_parallel_size=1, device="cpu")
outputs = llm.chat(messages, sampling_params=sampling_params)
# 出力を表示
print(outputs[0].outputs[0].text)