resolve a issue
Browse files- generate.py +6 -6
- model.py +4 -5
generate.py
CHANGED
|
@@ -64,17 +64,17 @@ def generate_code(model_data, prompt, max_new_tokens):
|
|
| 64 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
| 65 |
|
| 66 |
# ایجاد attention mask
|
| 67 |
-
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
| 68 |
|
| 69 |
# تولید کد
|
| 70 |
outputs = model.generate(
|
| 71 |
input_ids=input_ids,
|
| 72 |
-
attention_mask=attention_mask,
|
| 73 |
max_new_tokens=max_new_tokens,
|
| 74 |
-
do_sample=False,
|
| 75 |
-
pad_token_id=tokenizer.eos_token_id,
|
| 76 |
-
repetition_penalty=1.2,
|
| 77 |
-
no_repeat_ngram_size=3,
|
| 78 |
)
|
| 79 |
|
| 80 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 64 |
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
| 65 |
|
| 66 |
# ایجاد attention mask
|
| 67 |
+
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
|
| 68 |
|
| 69 |
# تولید کد
|
| 70 |
outputs = model.generate(
|
| 71 |
input_ids=input_ids,
|
| 72 |
+
attention_mask=attention_mask,
|
| 73 |
max_new_tokens=max_new_tokens,
|
| 74 |
+
do_sample=False, # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
| 75 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 76 |
+
repetition_penalty=1.2,
|
| 77 |
+
no_repeat_ngram_size=3,
|
| 78 |
)
|
| 79 |
|
| 80 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
model.py
CHANGED
|
@@ -12,7 +12,7 @@ model_dict = {
|
|
| 12 |
"dialoGPT": {"path": "microsoft/DialoGPT-small", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
| 13 |
"dialoGPT-medium": {"path": "microsoft/DialoGPT-medium", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
| 14 |
"dialoGPT-large": {"path": "microsoft/DialoGPT-large", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
| 15 |
-
"GPT-Neo-125M": {"path": "EleutherAI/gpt-neo-125m", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
| 16 |
"bert-emotion": {"path": "bhadresh-savani/distilbert-base-uncased-emotion", "library": AutoModelForSequenceClassification, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
| 17 |
}
|
| 18 |
|
|
@@ -31,19 +31,18 @@ def load_model_lazy(model_name):
|
|
| 31 |
if model_info.get("use_pipeline", False):
|
| 32 |
print(f"Using pipeline for model: {model_name}")
|
| 33 |
if model_name == "bert-emotion":
|
| 34 |
-
# برای مدل bert-emotion از text-classification استفاده کنید
|
| 35 |
model_pipeline = pipeline(
|
| 36 |
-
"text-classification",
|
| 37 |
model=model_info["path"],
|
| 38 |
truncation=True
|
| 39 |
)
|
| 40 |
else:
|
| 41 |
-
# برای سایر مدلها از text-generation استفاده کنید
|
| 42 |
model_pipeline = pipeline(
|
| 43 |
"text-generation",
|
| 44 |
model=model_info["path"],
|
| 45 |
truncation=True,
|
| 46 |
-
pad_token_id=50256
|
|
|
|
| 47 |
)
|
| 48 |
loaded_models[model_name] = {"pipeline": model_pipeline}
|
| 49 |
return {"pipeline": model_pipeline}
|
|
|
|
| 12 |
"dialoGPT": {"path": "microsoft/DialoGPT-small", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
| 13 |
"dialoGPT-medium": {"path": "microsoft/DialoGPT-medium", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
| 14 |
"dialoGPT-large": {"path": "microsoft/DialoGPT-large", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": False},
|
| 15 |
+
"GPT-Neo-125M": {"path": "EleutherAI/gpt-neo-125m", "library": AutoModelForCausalLM, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
| 16 |
"bert-emotion": {"path": "bhadresh-savani/distilbert-base-uncased-emotion", "library": AutoModelForSequenceClassification, "tokenizer": AutoTokenizer, "use_pipeline": True},
|
| 17 |
}
|
| 18 |
|
|
|
|
| 31 |
if model_info.get("use_pipeline", False):
|
| 32 |
print(f"Using pipeline for model: {model_name}")
|
| 33 |
if model_name == "bert-emotion":
|
|
|
|
| 34 |
model_pipeline = pipeline(
|
| 35 |
+
"text-classification",
|
| 36 |
model=model_info["path"],
|
| 37 |
truncation=True
|
| 38 |
)
|
| 39 |
else:
|
|
|
|
| 40 |
model_pipeline = pipeline(
|
| 41 |
"text-generation",
|
| 42 |
model=model_info["path"],
|
| 43 |
truncation=True,
|
| 44 |
+
pad_token_id=50256,
|
| 45 |
+
do_sample=False # غیرفعال کردن نمونهگیری (حالت حریصانه)
|
| 46 |
)
|
| 47 |
loaded_models[model_name] = {"pipeline": model_pipeline}
|
| 48 |
return {"pipeline": model_pipeline}
|