Adapting commited on
Commit
675214a
·
1 Parent(s): 6085381
Files changed (1) hide show
  1. inference_hf/_inference.py +24 -0
inference_hf/_inference.py CHANGED
@@ -1,6 +1,8 @@
1
  import json
2
  import requests
3
  from typing import Union,List
 
 
4
 
5
  class InferenceHF:
6
  headers = {"Authorization": f"Bearer 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"}
@@ -19,6 +21,21 @@ class InferenceHF:
19
  response = requests.request("POST", cls.API_URL+model_name, headers=cls.headers, data=data)
20
  return json.loads(response.content.decode("utf-8"))
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  if __name__ == '__main__':
24
  print(InferenceHF.inference(
@@ -26,4 +43,11 @@ if __name__ == '__main__':
26
  model_name= 't5-small'
27
  ))
28
 
 
 
 
 
 
 
 
29
 
 
1
  import json
2
  import requests
3
  from typing import Union,List
4
+ import aiohttp
5
+ from asyncio import run
6
 
7
  class InferenceHF:
8
  headers = {"Authorization": f"Bearer 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"}
 
21
  response = requests.request("POST", cls.API_URL+model_name, headers=cls.headers, data=data)
22
  return json.loads(response.content.decode("utf-8"))
23
 
24
+ @classmethod
25
+ async def async_inference(cls, inputs: Union[List[str], str], model_name: str) -> dict:
26
+ payload = dict(
27
+ inputs=inputs,
28
+ options=dict(
29
+ wait_for_model=True
30
+ )
31
+ )
32
+
33
+ data = json.dumps(payload)
34
+
35
+ async with aiohttp.ClientSession() as session:
36
+ async with session.post(cls.API_URL + model_name, data=data, headers=cls.headers) as response:
37
+ return await response.json()
38
+
39
 
40
  if __name__ == '__main__':
41
  print(InferenceHF.inference(
 
43
  model_name= 't5-small'
44
  ))
45
 
46
+ print(
47
+ run(InferenceHF.async_inference(
48
+ inputs='hi how are you?',
49
+ model_name='t5-small'
50
+ ))
51
+ )
52
+
53