kamran-r123 commited on
Commit
f6819c7
·
verified ·
1 Parent(s): 107e792

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +14 -25
main.py CHANGED
@@ -1,12 +1,12 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  import uvicorn
4
  import prompt_style
5
 
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
- import torch
8
 
9
- # client = InferenceClient(model_id)
 
10
 
11
  class Item(BaseModel):
12
  prompt: str
@@ -45,30 +45,19 @@ def generate(item: Item):
45
  )
46
 
47
  formatted_prompt = format_prompt(item)
 
 
48
 
49
- model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3"
50
- tokenizer = AutoTokenizer.from_pretrained(model_id)
51
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto",)
52
-
53
- input_ids = tokenizer.apply_chat_template(formatted_prompt, add_generation_prompt=True, return_tensors="pt").to(model.device)
54
-
55
- terminators = [
56
- tokenizer.eos_token_id,
57
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
58
- ]
59
-
60
- outputs = model.generate(input_ids, eos_token_id=terminators, do_sample=True, **generate_kwargs,)
61
- response = outputs[0][input_ids.shape[-1]:]
62
- return tokenizer.decode(response, skip_special_tokens=True)
63
-
64
- # stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
65
- # output = ""
66
-
67
- # for response in stream:
68
- # output += response.token.text
69
- # return output
70
 
71
  @app.post("/generate/")
72
  async def generate_text(item: Item):
73
  ans = generate(item)
74
- return {"response": ans}
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient
4
  import uvicorn
5
  import prompt_style
6
 
 
 
7
 
8
+ model_id = "failspy/Meta-Llama-3-8B-Instruct-abliterated-v3"
9
+ client = InferenceClient(model_id)
10
 
11
  class Item(BaseModel):
12
  prompt: str
 
45
  )
46
 
47
  formatted_prompt = format_prompt(item)
48
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
49
+ output = ""
50
 
51
+ for response in stream:
52
+ output += response.token.text
53
+ return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  @app.post("/generate/")
56
  async def generate_text(item: Item):
57
  ans = generate(item)
58
+ return {"response": ans}
59
+
60
+
61
+ @app.get("/")
62
+ def read_root():
63
+ return {"Hello": "World!"}