VocabLine / ai_sentence.py
dayuian's picture
Update ai_sentence.py
275d15c verified
raw
history blame
1.03 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_LIST = [
"EleutherAI/pythia-410m",
"EleutherAI/pythia-1b",
"gpt2"
]
model_cache = {} # 緩存模型
def load_model(model_name):
if model_name not in model_cache:
print(f"⏳ 正在載入模型:{model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model_cache[model_name] = (tokenizer, model)
return model_cache[model_name]
def generate_sentence(word, model_name):
tokenizer, model = load_model(model_name)
prompt = f"A simple English sentence using the word '{word}' suitable for beginners. Output only the sentence."
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=30)
sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 清理句子(可以根據需求調整)
sentence = sentence.split(".")[0].strip() + "."
return sentence