Update README.md
Browse files
README.md
CHANGED
@@ -81,26 +81,38 @@ import torch
|
|
81 |
|
82 |
model = "opencsg/opencsg-CodeLlama-7b-v0.1"
|
83 |
|
84 |
-
tokenizer = AutoTokenizer.from_pretrained(model)
|
85 |
pipeline = transformers.pipeline(
|
86 |
"text-generation",
|
87 |
model=model,
|
88 |
torch_dtype=torch.float16,
|
89 |
device_map="auto",
|
90 |
)
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
sequences = pipeline(
|
93 |
-
|
94 |
-
do_sample=
|
95 |
top_k=10,
|
96 |
temperature=0.1,
|
97 |
top_p=0.95,
|
98 |
num_return_sequences=1,
|
99 |
-
eos_token_id=
|
100 |
-
max_length=
|
101 |
)
|
102 |
for seq in sequences:
|
103 |
-
print(
|
104 |
```
|
105 |
# Training
|
106 |
|
|
|
81 |
|
82 |
model = "opencsg/opencsg-CodeLlama-7b-v0.1"
|
83 |
|
84 |
+
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
85 |
pipeline = transformers.pipeline(
|
86 |
"text-generation",
|
87 |
model=model,
|
88 |
torch_dtype=torch.float16,
|
89 |
device_map="auto",
|
90 |
)
|
91 |
+
input_text = """def quick_sort(arr):
|
92 |
+
if len(arr) <= 1:
|
93 |
+
return arr
|
94 |
+
pivot = arr[0]
|
95 |
+
left = []
|
96 |
+
right = []
|
97 |
+
<FILL_ME>
|
98 |
+
if arr[i] < pivot:
|
99 |
+
left.append(arr[i])
|
100 |
+
else:
|
101 |
+
right.append(arr[i])
|
102 |
+
return quick_sort(left) + [pivot] + quick_sort(right)
|
103 |
+
"""
|
104 |
sequences = pipeline(
|
105 |
+
input_text,
|
106 |
+
do_sample=False,
|
107 |
top_k=10,
|
108 |
temperature=0.1,
|
109 |
top_p=0.95,
|
110 |
num_return_sequences=1,
|
111 |
+
eos_token_id=tokenizer1.eos_token_id,
|
112 |
+
max_length=256,
|
113 |
)
|
114 |
for seq in sequences:
|
115 |
+
print(seq['generated_text'][len(input_text):])
|
116 |
```
|
117 |
# Training
|
118 |
|