Text Generation
Transformers
PyTorch
English
llama
text-generation-inference
Conrad Lippert-Zajaczkowski commited on
Commit
ba9cbaf
·
1 Parent(s): 26b366e

refined input

Browse files
Files changed (1) hide show
  1. handler.py +12 -5
handler.py CHANGED
@@ -17,11 +17,11 @@ class EndpointHandler:
17
  def __init__(self, path=""):
18
  # load the model
19
  print('starting to load tokenizer')
20
- tokenizer = LlamaTokenizer.from_pretrained("/repository/orca_tokenizer", local_files_only=True)
21
  print('loaded tokenizer')
22
  gpu_info1 = nvmlDeviceGetMemoryInfo(gpu_h1)
23
  print(f'vram {gpu_info1.total} used {gpu_info1.used} free {gpu_info1.free}')
24
- model = LlamaForCausalLM.from_pretrained(
25
  "/repository",
26
  device_map="auto",
27
  torch_dtype=dtype,
@@ -33,10 +33,10 @@ class EndpointHandler:
33
 
34
  print('loaded model')
35
  # create inference pipeline
36
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
37
  print('created pipeline')
38
 
39
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
40
  print('starting to call')
41
  inputs = data.pop("inputs", data)
42
  print('inputs: ', inputs)
@@ -46,6 +46,13 @@ class EndpointHandler:
46
  if parameters is not None:
47
  prediction = self.pipeline(inputs, **parameters)
48
  else:
49
- prediction = self.pipeline(inputs)
 
 
 
 
 
 
 
50
  # postprocess the prediction
51
  return prediction
 
17
  def __init__(self, path=""):
18
  # load the model
19
  print('starting to load tokenizer')
20
+ self.tokenizer = LlamaTokenizer.from_pretrained("/repository/orca_tokenizer", local_files_only=True)
21
  print('loaded tokenizer')
22
  gpu_info1 = nvmlDeviceGetMemoryInfo(gpu_h1)
23
  print(f'vram {gpu_info1.total} used {gpu_info1.used} free {gpu_info1.free}')
24
+ self.model = LlamaForCausalLM.from_pretrained(
25
  "/repository",
26
  device_map="auto",
27
  torch_dtype=dtype,
 
33
 
34
  print('loaded model')
35
  # create inference pipeline
36
+ self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
37
  print('created pipeline')
38
 
39
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
40
  print('starting to call')
41
  inputs = data.pop("inputs", data)
42
  print('inputs: ', inputs)
 
46
  if parameters is not None:
47
  prediction = self.pipeline(inputs, **parameters)
48
  else:
49
+ prediction = self.pipeline(
50
+ inputs,
51
+ do_sample=True,
52
+ top_k=10,
53
+ num_return_sequences=1,
54
+ eos_token_id=self.tokenizer.eos_token_id,
55
+ max_length=200,
56
+ )
57
  # postprocess the prediction
58
  return prediction