m1k3wn commited on
Commit
19ec348
·
verified ·
1 Parent(s): 77ad07b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -32
app.py CHANGED
@@ -1,9 +1,8 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  import logging
5
  import os
6
- import json
7
 
8
  logging.basicConfig(level=logging.DEBUG)
9
  logger = logging.getLogger(__name__)
@@ -16,17 +15,6 @@ MODELS = {
16
  "nidra-v2": "m1k3wn/nidra-v2"
17
  }
18
 
19
- # Define the tokenizer configuration explicitly
20
- TOKENIZER_CONFIG = {
21
- "model_max_length": 512,
22
- "clean_up_tokenization_spaces": False,
23
- "tokenizer_class": "T5Tokenizer",
24
- "pad_token": "<pad>",
25
- "eos_token": "</s>",
26
- "unk_token": "<unk>",
27
- "extra_ids": 100
28
- }
29
-
30
  class PredictionRequest(BaseModel):
31
  inputs: str
32
  model: str = "nidra-v1"
@@ -40,56 +28,43 @@ async def predict(request: PredictionRequest):
40
  logger.info(f"Loading model: {request.model}")
41
  model_path = MODELS[request.model]
42
 
43
- # Initialize tokenizer with explicit config
44
- tokenizer = T5Tokenizer.from_pretrained(
45
  model_path,
46
  token=HF_TOKEN,
47
- model_max_length=TOKENIZER_CONFIG["model_max_length"],
48
- clean_up_tokenization_spaces=TOKENIZER_CONFIG["clean_up_tokenization_spaces"],
49
- pad_token=TOKENIZER_CONFIG["pad_token"],
50
- eos_token=TOKENIZER_CONFIG["eos_token"],
51
- unk_token=TOKENIZER_CONFIG["unk_token"],
52
- extra_ids=TOKENIZER_CONFIG["extra_ids"],
53
- use_fast=True # Try forcing the fast tokenizer
54
  )
55
 
56
- model = T5ForConditionalGeneration.from_pretrained(
57
  model_path,
58
  token=HF_TOKEN,
59
- torch_dtype="auto"
60
  )
61
 
62
  full_input = "Interpret this dream: " + request.inputs
63
  logger.info(f"Processing: {full_input}")
64
 
65
- # Add explicit encoding parameters
66
  inputs = tokenizer(
67
  full_input,
68
  return_tensors="pt",
69
  truncation=True,
70
  max_length=512,
71
- padding=True,
72
- add_special_tokens=True
73
  )
74
 
75
  outputs = model.generate(
76
  **inputs,
77
  max_length=200,
78
  num_beams=4,
79
- no_repeat_ngram_size=2,
80
- length_penalty=1.0
81
  )
82
 
83
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
- logger.info(f"Generated result: {result}")
85
-
86
  return PredictionResponse(generated_text=result)
87
 
88
  except Exception as e:
89
  logger.error(f"Error: {str(e)}")
90
  raise HTTPException(status_code=500, detail=str(e))
91
 
92
- # Add health check endpoint
93
  @app.get("/health")
94
  async def health():
95
  return {"status": "healthy"}
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import logging
5
  import os
 
6
 
7
  logging.basicConfig(level=logging.DEBUG)
8
  logger = logging.getLogger(__name__)
 
15
  "nidra-v2": "m1k3wn/nidra-v2"
16
  }
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  class PredictionRequest(BaseModel):
19
  inputs: str
20
  model: str = "nidra-v1"
 
28
  logger.info(f"Loading model: {request.model}")
29
  model_path = MODELS[request.model]
30
 
31
+ # Load tokenizer and model
32
+ tokenizer = AutoTokenizer.from_pretrained(
33
  model_path,
34
  token=HF_TOKEN,
 
 
 
 
 
 
 
35
  )
36
 
37
+ model = AutoModelForSeq2SeqLM.from_pretrained(
38
  model_path,
39
  token=HF_TOKEN,
40
+ device_map="auto"
41
  )
42
 
43
  full_input = "Interpret this dream: " + request.inputs
44
  logger.info(f"Processing: {full_input}")
45
 
 
46
  inputs = tokenizer(
47
  full_input,
48
  return_tensors="pt",
49
  truncation=True,
50
  max_length=512,
51
+ padding=True
 
52
  )
53
 
54
  outputs = model.generate(
55
  **inputs,
56
  max_length=200,
57
  num_beams=4,
58
+ no_repeat_ngram_size=2
 
59
  )
60
 
61
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
62
  return PredictionResponse(generated_text=result)
63
 
64
  except Exception as e:
65
  logger.error(f"Error: {str(e)}")
66
  raise HTTPException(status_code=500, detail=str(e))
67
 
 
68
  @app.get("/health")
69
  async def health():
70
  return {"status": "healthy"}