Spaces:
Runtime error
Runtime error
Commit
·
2e1968d
1
Parent(s):
490c46c
gpu inference
Browse files
app.py
CHANGED
@@ -16,6 +16,9 @@ DEBIASING_KEYWORDS = [
|
|
16 |
"(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) "
|
17 |
]
|
18 |
|
|
|
|
|
|
|
19 |
def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
|
20 |
"""
|
21 |
Debiasing inference function.
|
@@ -24,7 +27,7 @@ def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
|
|
24 |
:param max_length: The maximum length of the output sentence.
|
25 |
:return: The debiased output sentence.
|
26 |
"""
|
27 |
-
wrapper = GPT2Wrapper(model_name=str(model), use_cuda=
|
28 |
if use_prefix == 'Prefixes':
|
29 |
debiasing_prefixes = DEBIASING_PREFIXES
|
30 |
else:
|
|
|
16 |
"(rude) ", "(sexually explicit) ", "(hateful) ", "(aggressive) ", "(racist) ", "(threat) ", "(violent) ", "(sexist) "
|
17 |
]
|
18 |
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
use_cuda = True
|
21 |
+
|
22 |
def debias(prompt, model,use_prefix, max_length=50, num_beam=3):
|
23 |
"""
|
24 |
Debiasing inference function.
|
|
|
27 |
:param max_length: The maximum length of the output sentence.
|
28 |
:return: The debiased output sentence.
|
29 |
"""
|
30 |
+
wrapper = GPT2Wrapper(model_name=str(model), use_cuda=use_cuda)
|
31 |
if use_prefix == 'Prefixes':
|
32 |
debiasing_prefixes = DEBIASING_PREFIXES
|
33 |
else:
|