kamran-r123 commited on
Commit
eece387
·
verified ·
1 Parent(s): 4b230ff

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -8
main.py CHANGED
@@ -8,13 +8,20 @@ from llama_cpp import Llama
8
  import time
9
 
10
 
11
- model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3-GGUF"
12
- filename="Meta-Llama-3-8B-Instruct-abliterated-v3_q6.gguf"
13
  # model_path = hf_hub_download(repo_id=model_id, filename="Meta-Llama-3-8B-Instruct-abliterated-v3_q6.gguf", token=os.environ['HF_TOKEN'])
14
  # model = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096, verbose=False)
15
 
16
- model = Llama.from_pretrained(repo_id=model_id, filename=filename, n_gpu_layers=-1, token=os.environ['HF_TOKEN'],
17
- n_ctx=4096, verbose=False, attn_implementation="flash_attention_2")
 
 
 
 
 
 
 
18
 
19
  class Item(BaseModel):
20
  prompt: str
@@ -40,11 +47,36 @@ def format_prompt(item: Item):
40
 
41
  def generate(item: Item):
42
  formatted_prompt = format_prompt(item)
43
- output = model.create_chat_completion(messages=formatted_prompt, seed=item.seed,
44
- temperature=item.temperature, max_tokens=item.max_new_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- out = output['choices'][0]['message']['content']
47
- return out
 
48
 
49
  @app.post("/generate/")
50
  async def generate_text(item: Item):
 
8
  import time
9
 
10
 
11
+ # model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3-GGUF"
12
+ # filename="Meta-Llama-3-8B-Instruct-abliterated-v3_q6.gguf"
13
  # model_path = hf_hub_download(repo_id=model_id, filename="Meta-Llama-3-8B-Instruct-abliterated-v3_q6.gguf", token=os.environ['HF_TOKEN'])
14
  # model = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096, verbose=False)
15
 
16
+ # model = Llama.from_pretrained(repo_id=model_id, filename=filename, n_gpu_layers=-1, token=os.environ['HF_TOKEN'],
17
+ # n_ctx=4096, verbose=False, attn_implementation="flash_attention_2")
18
+
19
+ from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
20
+
21
+ model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3"
22
+ model_8bit = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True),
23
+ token=os.environ['HF_TOKEN'], attn_implementation="flash_attention_2")
24
+
25
 
26
  class Item(BaseModel):
27
  prompt: str
 
47
 
48
  def generate(item: Item):
49
  formatted_prompt = format_prompt(item)
50
+ # output = model.create_chat_completion(messages=formatted_prompt, seed=item.seed,
51
+ # temperature=item.temperature, max_tokens=item.max_new_tokens)
52
+ # out = output['choices'][0]['message']['content']
53
+ # return out
54
+
55
+ input_ids = tokenizer.apply_chat_template(
56
+ formatted_prompt,
57
+ add_generation_prompt=True,
58
+ return_tensors="pt"
59
+ ).to("cuda")
60
+
61
+ terminators = [
62
+ tokenizer.eos_token_id,
63
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
64
+ ]
65
+
66
+ outputs = model_8bit.generate(
67
+ input_ids,
68
+ max_new_tokens=item.max_new_tokens,
69
+ eos_token_id=terminators,
70
+ do_sample=True,
71
+ temperature=item.temperature,
72
+ top_p=item.top_p,
73
+ )
74
+ response = outputs[0][input_ids.shape[-1]:]
75
+ return tokenizer.decode(response, skip_special_tokens=True)
76
 
77
+ # inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
78
+ # generated_ids = model.generate(**inputs)
79
+ # outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
80
 
81
  @app.post("/generate/")
82
  async def generate_text(item: Item):