harsh-manvar commited on
Commit
72c2e54
·
verified ·
1 Parent(s): ef9bb69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -10,9 +10,6 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  vllm_model = LLM(model=model_name, tensor_parallel_size=1, device="cpu")
11
 
12
  def generate_response(prompt, max_tokens, temperature, top_p):
13
- # Tokenize the prompt
14
- inputs = tokenizer(prompt, return_tensors="pt")["input_ids"].tolist()[0]
15
-
16
  # Define sampling parameters
17
  sampling_params = SamplingParams(
18
  max_tokens=max_tokens,
@@ -20,11 +17,11 @@ def generate_response(prompt, max_tokens, temperature, top_p):
20
  top_p=top_p,
21
  )
22
 
23
- # Generate text using vLLM
24
- output = vllm_model.generate(inputs, sampling_params)
25
 
26
- # Decode the generated tokens to text
27
- generated_text = tokenizer.decode(output[0].outputs[0].token_ids, skip_special_tokens=True)
28
  return generated_text
29
 
30
  # Gradio UI
@@ -76,4 +73,4 @@ with gr.Blocks() as demo:
76
  )
77
 
78
  # Launch the app
79
- demo.launch()
 
10
  vllm_model = LLM(model=model_name, tensor_parallel_size=1, device="cpu")
11
 
12
  def generate_response(prompt, max_tokens, temperature, top_p):
 
 
 
13
  # Define sampling parameters
14
  sampling_params = SamplingParams(
15
  max_tokens=max_tokens,
 
17
  top_p=top_p,
18
  )
19
 
20
+ # Generate text using vLLM (input is the raw string `prompt`)
21
+ output = vllm_model.generate(prompt, sampling_params)
22
 
23
+ # Extract and decode the generated tokens
24
+ generated_text = output[0].outputs[0].text
25
  return generated_text
26
 
27
  # Gradio UI
 
73
  )
74
 
75
  # Launch the app
76
+ demo.launch()