Spaces:
Sleeping
Sleeping
File size: 3,207 Bytes
e7f2ddb 3a397cd e7f2ddb 7da3d86 e7f2ddb 7da3d86 e7f2ddb 7da3d86 e7f2ddb 6a0786f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field
from model import CustomDataset, TransformerEncoder, load_model_to_cpu
app = FastAPI()
tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4}
id2tag = {value: key for key, value in tag2id.items()}
device = torch.device('cpu')
def predict_fonk(model, device, example, tokenizer):
model.to(device)
model.eval()
predictions = []
encodings_prdict = tokenizer.encode(example)
predict_texts = [encodings_prdict.tokens]
predict_input_ids = [encodings_prdict.ids]
predict_attention_masks = [encodings_prdict.attention_mask]
predict_token_type_ids = [encodings_prdict.type_ids]
prediction_labels = [encodings_prdict.type_ids]
predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids,
prediction_labels)
predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False)
with torch.no_grad():
for dataset in predict_loader:
batch_input_ids = dataset['input_ids'].to(device)
batch_att_mask = dataset['attention_mask'].to(device)
outputs = model(batch_input_ids, batch_att_mask)
logits = outputs.view(-1, outputs.size(-1)) # Flatten the outputs
_, predicted = torch.max(logits, 1)
# Ignore padding tokens for predictions
predictions.append(predicted)
results_list = []
entity_list = []
results_dict = {}
trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0])
for i, (token, label, attention) in enumerate(trio):
if attention != 0 and label != 0 and label !=4:
for next_ones in predictions[0].tolist()[i+1:]:
i+=1
if next_ones == 4:
token = token +" "+ predict_loader.dataset[0]["text"][i]
else:break
if token not in entity_list:
entity_list.append(token)
results_list.append({"entity":token,"sentiment":id2tag.get(label)})
results_dict["entity_list"] = entity_list
results_dict["results"] = results_list
return results_dict
model = TransformerEncoder()
model = load_model_to_cpu(model, "model.pth")
tokenizer = Tokenizer.from_file("tokenizer.json")
class Item(BaseModel):
text: str = Field(..., example="""Fiber 100mb SuperOnline kullanıcısıyım yaklaşık 2 haftadır @Twitch @Kick_Turkey gibi canlı yayın platformlarında 360p yayın izlerken donmalar yaşıyoruz. Başka hiç bir operatörler bu sorunu yaşamazken ben parasını verip alamadığım hizmeti neden ödeyeyim ? @Turkcell """)
@app.post("/predict/", response_model=dict)
async def predict(item: Item):
predict_list = predict_fonk(model=model, device=device, example=item.text, tokenizer=tokenizer)
#Buraya model'in çıktısı gelecek
#Çıktı formatı aşağıdaki örnek gibi olacak
return predict_list
if __name__=="__main__":
uvicorn.run(app,host="0.0.0.0",port=8000) |