Spaces:
Runtime error
Runtime error
File size: 2,053 Bytes
96f90ba 6c30c0d 96f90ba 0ff8195 96f90ba 6c30c0d 96f90ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from vllm import LLM, SamplingParams
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "2"
MODEL_NAME = "RegularizedSelfPlay/sppo_forward1reverse5-0.1-PromptABC-Mistral-7B-Instruct-SPPO-Iter3" # Example: "meta-llama/Llama-2-7b-chat-hf"
HF_TOKEN = os.getenv("HF_API_TOKEN")
# Load model and tokenizer
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", token=HF_TOKEN)
llm = LLM(
model=MODEL_NAME,
# revision="1296dc8fd9b21e6424c9c305c06db9ae60c03ace",
# tokenizer_revision="1296dc8fd9b21e6424c9c305c06db9ae60c03ace",
tensor_parallel_size=1,
)
tokenizer.pad_token = tokenizer.eos_token
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
seed=2024,
max_tokens=2048,
#max_tokens=64, # set it to higher value like 2048 for proper test
)
def generate_response(prompt):
# inputs = tokenizer(prompt, return_tensors="pt").to("cuda") # Move to GPU
inputs = tokenizer.apply_chat_template(
[
{"role": "user", "content": prompt},
{"role": "assistant", "content": "None"}
],
tokenize=False, add_generate_prompt=True
).split("None")[0]
# outputs = model.generate(**inputs, max_length=512)
response = llm.generate(
inputs,
sampling_params
)[0].outputs[0].text
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
app = FastAPI()
class PromptRequest(BaseModel):
prompt: str
@app.post("/generate")
def generate_text(request: PromptRequest):
response = generate_response(request.prompt)
return {"response": response}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
# print(generate_response('hi I like u')) |