MienOlle commited on
Commit
a44c152
·
1 Parent(s): fe61092

Batch predict fixes

Browse files
Files changed (1) hide show
  1. main.py +6 -4
main.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoModelForSequenceClassification as modelSC, AutoToke
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  import os
 
7
 
8
  app = FastAPI()
9
  os.environ["HF_HOME"] = "/tmp/huggingface"
@@ -23,19 +24,20 @@ model.to(device)
23
  model.eval()
24
 
25
  class TextInput(BaseModel):
26
- text: str
27
 
28
  def predict(input):
29
- inputs = modelToken(input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
 
30
 
31
  with torch.no_grad():
32
  outputs = model(**inputs)
33
 
34
  logits = outputs.logits
35
- ret = logits.argmax().item()
36
 
37
  labels = ["positive", "neutral", "negative"]
38
- return {labels[ret]}
39
 
40
  @app.post("/predict")
41
  def get_sentiment(data: TextInput):
 
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  import os
7
+ from typing import List
8
 
9
  app = FastAPI()
10
  os.environ["HF_HOME"] = "/tmp/huggingface"
 
24
  model.eval()
25
 
26
  class TextInput(BaseModel):
27
+ text: List[str]
28
 
29
  def predict(input):
30
+ inputs = modelToken(input, return_tensors="pt", padding=True, truncation=True, max_length=512)
31
+ inputs = {key: tensor.to(device) for key, tensor in inputs.items()}
32
 
33
  with torch.no_grad():
34
  outputs = model(**inputs)
35
 
36
  logits = outputs.logits
37
+ rets = logits.argmax(dim = 1).tolist()
38
 
39
  labels = ["positive", "neutral", "negative"]
40
+ return {[labels[ret] for ret in rets]}
41
 
42
  @app.post("/predict")
43
  def get_sentiment(data: TextInput):