harsh-manvar commited on
Commit
9dad4e7
·
verified ·
1 Parent(s): 1e819e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -6,12 +6,12 @@ from vllm import LLM, SamplingParams
6
  model_name = "facebook/opt-125m"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
- # Initialize vLLM with CPU-only configuration
10
- vllm_model = LLM(model=model_name, tensor_parallel_size=1, device="cpu", async_output=False)
11
 
12
  def generate_response(prompt, max_tokens, temperature, top_p):
13
  # Tokenize the prompt
14
- inputs = tokenizer(prompt, return_tensors="pt")
15
 
16
  # Define sampling parameters
17
  sampling_params = SamplingParams(
@@ -20,14 +20,11 @@ def generate_response(prompt, max_tokens, temperature, top_p):
20
  top_p=top_p,
21
  )
22
 
23
- # Generate text using vLLM (synchronous mode)
24
- try:
25
- output = vllm_model.generate(inputs["input_ids"], sampling_params)
26
- except NotImplementedError as e:
27
- return f"Error: {e}. Ensure that async_output is supported or disabled."
28
 
29
  # Decode the generated tokens to text
30
- generated_text = tokenizer.decode(output[0]["token_ids"], skip_special_tokens=True)
31
  return generated_text
32
 
33
  # Gradio UI
 
6
  model_name = "facebook/opt-125m"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
 
9
+ # Initialize vLLM with CPU configuration
10
+ vllm_model = LLM(model=model_name, tensor_parallel_size=1, device="cpu")
11
 
12
  def generate_response(prompt, max_tokens, temperature, top_p):
13
  # Tokenize the prompt
14
+ inputs = tokenizer(prompt, return_tensors="pt")["input_ids"].tolist()[0]
15
 
16
  # Define sampling parameters
17
  sampling_params = SamplingParams(
 
20
  top_p=top_p,
21
  )
22
 
23
+ # Generate text using vLLM
24
+ output = vllm_model.generate(inputs, sampling_params)
 
 
 
25
 
26
  # Decode the generated tokens to text
27
+ generated_text = tokenizer.decode(output[0].outputs[0].token_ids, skip_special_tokens=True)
28
  return generated_text
29
 
30
  # Gradio UI