m1k3wn commited on
Commit
10c106d
·
verified ·
1 Parent(s): 19ec348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -21
app.py CHANGED
@@ -28,37 +28,22 @@ async def predict(request: PredictionRequest):
28
  logger.info(f"Loading model: {request.model}")
29
  model_path = MODELS[request.model]
30
 
31
- # Load tokenizer and model
32
- tokenizer = AutoTokenizer.from_pretrained(
33
- model_path,
34
- token=HF_TOKEN,
35
- )
36
-
37
- model = AutoModelForSeq2SeqLM.from_pretrained(
38
- model_path,
39
- token=HF_TOKEN,
40
- device_map="auto"
41
- )
42
 
43
  full_input = "Interpret this dream: " + request.inputs
44
- logger.info(f"Processing: {full_input}")
45
 
46
  inputs = tokenizer(
47
  full_input,
48
  return_tensors="pt",
49
  truncation=True,
50
- max_length=512,
51
- padding=True
52
- )
53
-
54
- outputs = model.generate(
55
- **inputs,
56
- max_length=200,
57
- num_beams=4,
58
- no_repeat_ngram_size=2
59
  )
60
 
 
61
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
62
  return PredictionResponse(generated_text=result)
63
 
64
  except Exception as e:
 
28
  logger.info(f"Loading model: {request.model}")
29
  model_path = MODELS[request.model]
30
 
31
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
33
 
34
  full_input = "Interpret this dream: " + request.inputs
35
+ logger.info(f"Processing input: {full_input}")
36
 
37
  inputs = tokenizer(
38
  full_input,
39
  return_tensors="pt",
40
  truncation=True,
41
+ max_length=512
 
 
 
 
 
 
 
 
42
  )
43
 
44
+ outputs = model.generate(**inputs, max_length=200)
45
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+
47
  return PredictionResponse(generated_text=result)
48
 
49
  except Exception as e: