kamran-r123 commited on
Commit
4cbaa02
·
verified ·
1 Parent(s): d465d44

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +30 -11
main.py CHANGED
@@ -1,21 +1,19 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from llama_cpp import Llama
4
  import uvicorn
5
  import prompt_style
6
  import os
7
- from huggingface_hub import hf_hub_download
8
 
9
 
10
- model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3-GGUF"
11
- model_path = hf_hub_download(repo_id=model_id, filename="Meta-Llama-3-8B-Instruct-abliterated-v3_q6.gguf", token=os.environ['HF_TOKEN'])
12
- model = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096, verbose=False)
13
 
14
  class Item(BaseModel):
15
  prompt: str
16
  history: list
17
  system_prompt: str
18
- temperature: float = 0.6
19
  max_new_tokens: int = 1024
20
  top_p: float = 0.95
21
  repetition_penalty: float = 1.0
@@ -34,14 +32,35 @@ def format_prompt(item: Item):
34
  return messages
35
 
36
  def generate(item: Item):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  formatted_prompt = format_prompt(item)
38
- output = model.create_chat_completion(messages=formatted_prompt, seed=item.seed,
39
- temperature=item.temperature,
40
- max_tokens=item.max_new_tokens)
 
 
 
 
 
 
 
41
 
42
 
43
- out = output['choices'][0]['message']['content']
44
- return out
45
 
46
  @app.post("/generate/")
47
  async def generate_text(item: Item):
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  import uvicorn
4
  import prompt_style
5
  import os
6
+ from huggingface_hub import InferenceClient
7
 
8
 
9
+ model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3"
10
+ client = InferenceClient(model_id, token=os.environ['HF_TOKEN'])
 
11
 
12
  class Item(BaseModel):
13
  prompt: str
14
  history: list
15
  system_prompt: str
16
+ temperature: float = 0.8
17
  max_new_tokens: int = 1024
18
  top_p: float = 0.95
19
  repetition_penalty: float = 1.0
 
32
  return messages
33
 
34
  def generate(item: Item):
35
+ temperature = float(item.temperature)
36
+ if temperature < 1e-2:
37
+ temperature = 1e-2
38
+ top_p = float(item.top_p)
39
+
40
+ generate_kwargs = dict(
41
+ temperature=temperature,
42
+ max_new_tokens=item.max_new_tokens,
43
+ top_p=top_p,
44
+ repetition_penalty=item.repetition_penalty,
45
+ do_sample=True,
46
+ seed=item.seed,
47
+ )
48
+
49
  formatted_prompt = format_prompt(item)
50
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
51
+ output = ""
52
+
53
+ for response in stream:
54
+ output += response.token.text
55
+ return output
56
+
57
+ # output = model.create_chat_completion(messages=formatted_prompt, seed=item.seed,
58
+ # temperature=item.temperature,
59
+ # max_tokens=item.max_new_tokens)
60
 
61
 
62
+ # out = output['choices'][0]['message']['content']
63
+ # return out
64
 
65
  @app.post("/generate/")
66
  async def generate_text(item: Item):