dayuian commited on
Commit
275d15c
·
verified ·
1 Parent(s): 6180cc7

Update ai_sentence.py

Browse files
Files changed (1) hide show
  1. ai_sentence.py +14 -6
ai_sentence.py CHANGED
@@ -3,22 +3,30 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  MODEL_LIST = [
4
  "EleutherAI/pythia-410m",
5
  "EleutherAI/pythia-1b",
6
- "mistralai/Mistral-7B-Instruct"
7
  ]
8
 
9
- MODEL_CACHE = {}
 
10
 
11
  def load_model(model_name):
12
- if model_name not in MODEL_CACHE:
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForCausalLM.from_pretrained(model_name)
15
- MODEL_CACHE[model_name] = (tokenizer, model)
16
- return MODEL_CACHE[model_name]
 
17
 
18
  def generate_sentence(word, model_name):
19
  tokenizer, model = load_model(model_name)
20
- prompt = f"A simple English sentence with the word '{word}':"
 
21
  inputs = tokenizer(prompt, return_tensors="pt")
22
  outputs = model.generate(**inputs, max_new_tokens=30)
23
  sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
24
  return sentence
 
3
  MODEL_LIST = [
4
  "EleutherAI/pythia-410m",
5
  "EleutherAI/pythia-1b",
6
+ "gpt2"
7
  ]
8
 
9
+ model_cache = {} # 緩存模型
10
+
11
 
12
  def load_model(model_name):
13
+ if model_name not in model_cache:
14
+ print(f"⏳ 正在載入模型:{model_name}")
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
  model = AutoModelForCausalLM.from_pretrained(model_name)
17
+ model_cache[model_name] = (tokenizer, model)
18
+ return model_cache[model_name]
19
+
20
 
21
  def generate_sentence(word, model_name):
22
  tokenizer, model = load_model(model_name)
23
+
24
+ prompt = f"A simple English sentence using the word '{word}' suitable for beginners. Output only the sentence."
25
  inputs = tokenizer(prompt, return_tensors="pt")
26
  outputs = model.generate(**inputs, max_new_tokens=30)
27
  sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+
29
+ # 清理句子(可以根據需求調整)
30
+ sentence = sentence.split(".")[0].strip() + "."
31
+
32
  return sentence