timxiaohangt commited on
Commit
96f90ba
·
1 Parent(s): 7c1a7bc
Files changed (2) hide show
  1. requirements.txt +4 -0
  2. rspo_mistral_api.py +65 -0
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ fastapi
4
+ uvicorn
rspo_mistral_api.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+ from vllm import LLM, SamplingParams
4
+ import os
5
+ from fastapi import FastAPI
6
+ from pydantic import BaseModel
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
8
+
9
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
10
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
11
+
12
+ MODEL_NAME = "RegularizedSelfPlay/sppo_forward1reverse5-0.1-PromptABC-Mistral-7B-Instruct-SPPO-Iter3" # Example: "meta-llama/Llama-2-7b-chat-hf"
13
+
14
+ # Load model and tokenizer
15
+ tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
16
+ llm = LLM(
17
+ model=MODEL_NAME,
18
+ # revision="1296dc8fd9b21e6424c9c305c06db9ae60c03ace",
19
+ # tokenizer_revision="1296dc8fd9b21e6424c9c305c06db9ae60c03ace",
20
+ tensor_parallel_size=1,
21
+ )
22
+ tokenizer.pad_token = tokenizer.eos_token
23
+ sampling_params = SamplingParams(
24
+ temperature=0.7,
25
+ top_p=0.9,
26
+ seed=2024,
27
+ max_tokens=2048,
28
+ #max_tokens=64, # set it to higher value like 2048 for proper test
29
+ )
30
+
31
+
32
+ def generate_response(prompt):
33
+ # inputs = tokenizer(prompt, return_tensors="pt").to("cuda") # Move to GPU
34
+
35
+ inputs = tokenizer.apply_chat_template(
36
+ [
37
+ {"role": "user", "content": prompt},
38
+ {"role": "assistant", "content": "None"}
39
+ ],
40
+ tokenize=False, add_generate_prompt=True
41
+ ).split("None")[0]
42
+ # outputs = model.generate(**inputs, max_length=512)
43
+
44
+ response = llm.generate(
45
+ inputs,
46
+ sampling_params
47
+ )[0].outputs[0].text
48
+
49
+ # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ return response
51
+
52
+ app = FastAPI()
53
+
54
+ class PromptRequest(BaseModel):
55
+ prompt: str
56
+
57
+ @app.post("/generate")
58
+ def generate_text(request: PromptRequest):
59
+ response = generate_response(request.prompt)
60
+ return {"response": response}
61
+
62
+ if __name__ == "__main__":
63
+ import uvicorn
64
+ uvicorn.run(app, host="0.0.0.0", port=8000)
65
+ # print(generate_response('hi I like u'))