kamran-r123 commited on
Commit
3fbd422
·
verified ·
1 Parent(s): e553cec

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -17
main.py CHANGED
@@ -1,14 +1,15 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
4
  import uvicorn
 
 
 
 
5
 
6
  # **************************************************
7
  # import transformers
8
  # import torch
9
 
10
- model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
11
-
12
  # pipeline = transformers.pipeline(
13
  # "text-generation",
14
  # model=model_id,
@@ -16,7 +17,7 @@ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
16
  # device_map="auto",
17
  # )
18
 
19
- def generate(item: Item):
20
  messages = [
21
  {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
22
  {"role": "user", "content": "Who are you?"},
@@ -55,13 +56,14 @@ class Item(BaseModel):
55
 
56
  app = FastAPI()
57
 
58
- def format_prompt(message, history):
59
- prompt = "<s>"
60
- for user_prompt, bot_response in history:
61
- prompt += f"[INST] {user_prompt} [/INST]"
62
- prompt += f" {bot_response}</s> "
63
- prompt += f"[INST] {message} [/INST]"
64
- return prompt
 
65
 
66
  def generate(item: Item):
67
  temperature = float(item.temperature)
@@ -78,13 +80,29 @@ def generate(item: Item):
78
  seed=item.seed,
79
  )
80
 
81
- formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
82
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
83
- output = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- for response in stream:
86
- output += response.token.text
87
- return output
88
 
89
  @app.post("/generate/")
90
  async def generate_text(item: Item):
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  import uvicorn
4
+ import prompt_style
5
+
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import torch
8
 
9
  # **************************************************
10
  # import transformers
11
  # import torch
12
 
 
 
13
  # pipeline = transformers.pipeline(
14
  # "text-generation",
15
  # model=model_id,
 
17
  # device_map="auto",
18
  # )
19
 
20
+ def generate_1(item: Item):
21
  messages = [
22
  {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
23
  {"role": "user", "content": "Who are you?"},
 
56
 
57
  app = FastAPI()
58
 
59
+ def format_prompt(item: Item):
60
+ messages = [
61
+ {"role": "system", "content": prompt_style.data},
62
+ ]
63
+ for it in item.history:
64
+ messages.append[{"role" : "user", "content": it[0]}]
65
+ messages.append[{"role" : "assistant", "content": it[1]}]
66
+ return messages
67
 
68
  def generate(item: Item):
69
  temperature = float(item.temperature)
 
80
  seed=item.seed,
81
  )
82
 
83
+ formatted_prompt = format_prompt(item)
84
+
85
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
86
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
87
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto",)
88
+
89
+ input_ids = tokenizer.apply_chat_template(formatted_prompt, add_generation_prompt=True, return_tensors="pt").to(model.device)
90
+
91
+ terminators = [
92
+ tokenizer.eos_token_id,
93
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
94
+ ]
95
+
96
+ outputs = model.generate(input_ids, eos_token_id=terminators, do_sample=True, **generate_kwargs)
97
+ response = outputs[0][input_ids.shape[-1]:]
98
+ return tokenizer.decode(response, skip_special_tokens=True)
99
+
100
+ # stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
101
+ # output = ""
102
 
103
+ # for response in stream:
104
+ # output += response.token.text
105
+ # return output
106
 
107
  @app.post("/generate/")
108
  async def generate_text(item: Item):