poltextlab commited on
Commit
ae818da
·
verified ·
1 Parent(s): ce666e1

add war domain

Browse files
Files changed (1) hide show
  1. interfaces/illframes.py +6 -13
interfaces/illframes.py CHANGED
@@ -8,7 +8,7 @@ from transformers import AutoModelForSequenceClassification
8
  from transformers import AutoTokenizer
9
  from huggingface_hub import HfApi
10
 
11
- from label_dicts import ILLFRAMES_MIGRATION_LABEL_NAMES, ILLFRAMES_COVID_LABEL_NAMES
12
 
13
  HF_TOKEN = os.environ["hf_read"]
14
 
@@ -18,7 +18,8 @@ languages = [
18
 
19
  domains = {
20
  "Covid": "covid",
21
- "Migration": "migration"
 
22
  }
23
 
24
 
@@ -55,16 +56,6 @@ def build_huggingface_path(domain: str):
55
 
56
  def predict(text, model_id, tokenizer_id, label_names):
57
  device = torch.device("cpu")
58
-
59
- # --- DEBUG ---
60
-
61
- disk_space = get_disk_space('/data/')
62
- print("Disk Space Info:")
63
- for key, value in disk_space.items():
64
- print(f"{key}: {value}")
65
-
66
- # ---
67
-
68
  try:
69
  model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, offload_folder="offload", device_map="auto", token=HF_TOKEN)
70
  except:
@@ -101,8 +92,10 @@ def predict_illframes(text, language, domain):
101
 
102
  if domain == "migration":
103
  label_names = ILLFRAMES_MIGRATION_LABEL_NAMES
104
- else:
105
  label_names = ILLFRAMES_COVID_LABEL_NAMES
 
 
106
 
107
  return predict(text, model_id, tokenizer_id, label_names)
108
 
 
8
  from transformers import AutoTokenizer
9
  from huggingface_hub import HfApi
10
 
11
+ from label_dicts import ILLFRAMES_MIGRATION_LABEL_NAMES, ILLFRAMES_COVID_LABEL_NAMES, ILLFRAMES_WAR_LABEL_NAMES
12
 
13
  HF_TOKEN = os.environ["hf_read"]
14
 
 
18
 
19
  domains = {
20
  "Covid": "covid",
21
+ "Migration": "migration",
22
+ "War": "war"
23
  }
24
 
25
 
 
56
 
57
  def predict(text, model_id, tokenizer_id, label_names):
58
  device = torch.device("cpu")
 
 
 
 
 
 
 
 
 
 
59
  try:
60
  model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, offload_folder="offload", device_map="auto", token=HF_TOKEN)
61
  except:
 
92
 
93
  if domain == "migration":
94
  label_names = ILLFRAMES_MIGRATION_LABEL_NAMES
95
+ elif domain == "covid":
96
  label_names = ILLFRAMES_COVID_LABEL_NAMES
97
+ elif domain == "war":
98
+ label_names = ILLFRAMES_WAR_LABEL_NAMES
99
 
100
  return predict(text, model_id, tokenizer_id, label_names)
101