|
from datasets import load_dataset |
|
import torch |
|
import pandas as pd |
|
|
|
if __name__ == "__main__": |
|
imdb = load_dataset("imdb") |
|
|
|
|
|
|
|
|
|
model = None |
|
|
|
|
|
|
|
|
|
submission = pd.DataFrame(columns=list(range(2)), index=range(len(imdb["test"]))) |
|
acc = 0 |
|
for idx, data in enumerate(imdb["test"]): |
|
text = data["text"] |
|
label = data["label"] |
|
pred = model(text) |
|
pred = torch.softmax(pred, dim=0) |
|
submission.loc[idx] = pred.tolist() |
|
acc += int(torch.argmax(pred).item() == label) |
|
print("Accuracy: ", acc/len(imdb["test"])) |
|
|
|
submission.to_csv('submission.csv', index_label='idx') |