dayuian commited on
Commit
e65b5c7
·
verified ·
1 Parent(s): 792e10f

Update ai_sentence.py

Browse files
Files changed (1) hide show
  1. ai_sentence.py +18 -14
ai_sentence.py CHANGED
@@ -1,32 +1,36 @@
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
 
 
3
  MODEL_LIST = [
4
  "EleutherAI/pythia-410m",
5
- "EleutherAI/pythia-1b",
6
- "gpt2"
7
  ]
8
 
9
- model_cache = {} # 緩存模型
10
-
11
 
 
12
  def load_model(model_name):
13
- if model_name not in model_cache:
14
- print(f"⏳ 正在載入模型:{model_name}")
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
  model = AutoModelForCausalLM.from_pretrained(model_name)
17
- model_cache[model_name] = (tokenizer, model)
18
- return model_cache[model_name]
19
-
20
 
 
21
  def generate_sentence(word, model_name):
22
  tokenizer, model = load_model(model_name)
23
 
24
- prompt = f"A simple English sentence using the word '{word}' suitable for beginners. Output only the sentence."
25
  inputs = tokenizer(prompt, return_tensors="pt")
26
- outputs = model.generate(**inputs, max_new_tokens=30)
27
  sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
 
29
- # 清理句子(可以根據需求調整)
30
- sentence = sentence.split(".")[0].strip() + "."
31
-
 
 
32
  return sentence
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import re
3
 
4
+ # 可用模型清單
5
  MODEL_LIST = [
6
  "EleutherAI/pythia-410m",
7
+ "gpt2",
8
+ "mistralai/Mistral-7B-Instruct"
9
  ]
10
 
11
+ # 模型快取,避免每次重新載入
12
+ MODEL_CACHE = {}
13
 
14
+ # 加載模型
15
  def load_model(model_name):
16
+ if model_name not in MODEL_CACHE:
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForCausalLM.from_pretrained(model_name)
19
+ MODEL_CACHE[model_name] = (tokenizer, model)
20
+ return MODEL_CACHE[model_name]
 
21
 
22
+ # 生成 AI 例句
23
  def generate_sentence(word, model_name):
24
  tokenizer, model = load_model(model_name)
25
 
26
+ prompt = f"Example sentence using '{word}':"
27
  inputs = tokenizer(prompt, return_tensors="pt")
28
+ outputs = model.generate(**inputs, max_new_tokens=20)
29
  sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
 
31
+ # 清理生成句子
32
+ sentence = sentence.split(":")[-1].strip()
33
+ sentence = re.sub(r'[^a-zA-Z0-9, .!?]', '', sentence)
34
+ if not sentence.endswith("."):
35
+ sentence += "."
36
  return sentence