resolve a issue
Browse files- generate.py +6 -6
- model.py +4 -5
generate.py
CHANGED
@@ -64,17 +64,17 @@ def generate_code(model_data, prompt, max_new_tokens):
|
|
64 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
65 |
|
66 |
# ایجاد attention mask
|
67 |
-
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
68 |
|
69 |
# تولید کد
|
70 |
outputs = model.generate(
|
71 |
input_ids=input_ids,
|
72 |
-
attention_mask=attention_mask,
|
73 |
max_new_tokens=max_new_tokens,
|
74 |
-
do_sample=False,
|
75 |
-
pad_token_id=tokenizer.eos_token_id,
|
76 |
-
repetition_penalty=1.2,
|
77 |
-
no_repeat_ngram_size=3,
|
78 |
)
|
79 |
|
80 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
64 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
65 |
|
66 |
# ایجاد attention mask
|
67 |
+
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
68 |
|
69 |
# تولید کد
|
70 |
outputs = model.generate(
|
71 |
input_ids=input_ids,
|
72 |
+
attention_mask=attention_mask,
|
73 |
max_new_tokens=max_new_tokens,
|
74 |
+
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
75 |
+
pad_token_id=tokenizer.eos_token_id,
|
76 |
+
repetition_penalty=1.2,
|
77 |
+
no_repeat_ngram_size=3,
|
78 |
)
|
79 |
|
80 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
model.py
CHANGED
@@ -12,7 +12,7 @@ model_dict = {
|
|
12 |
"dialoGPT": {"path": "microsoft/DialoGPT-small", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
13 |
"dialoGPT-medium": {"path": "microsoft/DialoGPT-medium", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
14 |
"dialoGPT-large": {"path": "microsoft/DialoGPT-large", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
15 |
-
"GPT-Neo-125M": {"path": "EleutherAI/gpt-neo-125m", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
16 |
"bert-emotion": {"path": "bhadresh-savani/distilbert-base-uncased-emotion", "library": AutoModelForSequenceClassification, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
17 |
}
|
18 |
|
@@ -31,19 +31,18 @@ def load_model_lazy(model_name):
|
|
31 |
if model_info.get("use_pipeline", False):
|
32 |
print(f"Using pipeline for model: {model_name}")
|
33 |
if model_name == "bert-emotion":
|
34 |
-
# برای مدل bert-emotion از text-classification استفاده کنید
|
35 |
model_pipeline = pipeline(
|
36 |
-
"text-classification",
|
37 |
model=model_info["path"],
|
38 |
truncation=True
|
39 |
)
|
40 |
else:
|
41 |
-
# برای سایر مدلها از text-generation استفاده کنید
|
42 |
model_pipeline = pipeline(
|
43 |
"text-generation",
|
44 |
model=model_info["path"],
|
45 |
truncation=True,
|
46 |
-
pad_token_id=50256
|
|
|
47 |
)
|
48 |
loaded_models[model_name] = {"pipeline": model_pipeline}
|
49 |
return {"pipeline": model_pipeline}
|
|
|
12 |
"dialoGPT": {"path": "microsoft/DialoGPT-small", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
13 |
"dialoGPT-medium": {"path": "microsoft/DialoGPT-medium", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
14 |
"dialoGPT-large": {"path": "microsoft/DialoGPT-large", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
15 |
+
"GPT-Neo-125M": {"path": "EleutherAI/gpt-neo-125m", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
16 |
"bert-emotion": {"path": "bhadresh-savani/distilbert-base-uncased-emotion", "library": AutoModelForSequenceClassification, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
17 |
}
|
18 |
|
|
|
31 |
if model_info.get("use_pipeline", False):
|
32 |
print(f"Using pipeline for model: {model_name}")
|
33 |
if model_name == "bert-emotion":
|
|
|
34 |
model_pipeline = pipeline(
|
35 |
+
"text-classification",
|
36 |
model=model_info["path"],
|
37 |
truncation=True
|
38 |
)
|
39 |
else:
|
|
|
40 |
model_pipeline = pipeline(
|
41 |
"text-generation",
|
42 |
model=model_info["path"],
|
43 |
truncation=True,
|
44 |
+
pad_token_id=50256,
|
45 |
+
do_sample=False # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
46 |
)
|
47 |
loaded_models[model_name] = {"pipeline": model_pipeline}
|
48 |
return {"pipeline": model_pipeline}
|