103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import Quartz
|
|
import Vision
|
|
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
|
|
app = FastAPI(title="FastAPI Vision OCR", version="0.1.0")
|
|
|
|
|
|
def recognize_text(
|
|
image_path: Path,
|
|
recognition_level: str = "accurate",
|
|
languages: list[str] | None = None,
|
|
) -> dict[str, Any]:
|
|
image_url = Quartz.CFURLCreateFromFileSystemRepresentation(
|
|
None,
|
|
str(image_path).encode("utf-8"),
|
|
len(str(image_path)),
|
|
False,
|
|
)
|
|
image_source = Quartz.CGImageSourceCreateWithURL(image_url, None)
|
|
if image_source is None:
|
|
raise ValueError("Unsupported image format.")
|
|
|
|
cg_image = Quartz.CGImageSourceCreateImageAtIndex(image_source, 0, None)
|
|
if cg_image is None:
|
|
raise ValueError("Failed to decode image.")
|
|
|
|
request = Vision.VNRecognizeTextRequest.alloc().init()
|
|
request.setRecognitionLevel_(
|
|
Vision.VNRequestTextRecognitionLevelFast
|
|
if recognition_level == "fast"
|
|
else Vision.VNRequestTextRecognitionLevelAccurate
|
|
)
|
|
request.setUsesLanguageCorrection_(True)
|
|
|
|
if languages:
|
|
request.setRecognitionLanguages_(languages)
|
|
|
|
handler = Vision.VNImageRequestHandler.alloc().initWithCGImage_options_(
|
|
cg_image, None
|
|
)
|
|
success, error = handler.performRequests_error_([request], None)
|
|
if not success:
|
|
message = str(error) if error else "Vision OCR failed."
|
|
raise RuntimeError(message)
|
|
|
|
results = request.results() or []
|
|
lines: list[dict[str, Any]] = []
|
|
for observation in results:
|
|
candidates = observation.topCandidates_(1)
|
|
if not candidates:
|
|
continue
|
|
candidate = candidates[0]
|
|
lines.append(
|
|
{
|
|
"text": str(candidate.string()),
|
|
"confidence": float(candidate.confidence()),
|
|
}
|
|
)
|
|
|
|
return {
|
|
"text": "\n".join(line["text"] for line in lines),
|
|
"lines": lines,
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.post("/ocr")
|
|
async def ocr(
|
|
image: UploadFile = File(...),
|
|
recognition_level: str = Form("accurate"),
|
|
languages: list[str] | None = Form(None),
|
|
) -> dict[str, Any]:
|
|
if recognition_level not in {"fast", "accurate"}:
|
|
raise HTTPException(status_code=400, detail="recognition_level must be fast or accurate")
|
|
|
|
suffix = Path(image.filename or "upload.bin").suffix or ".bin"
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
|
tmp.write(await image.read())
|
|
tmp_path = Path(tmp.name)
|
|
|
|
try:
|
|
return recognize_text(
|
|
image_path=tmp_path,
|
|
recognition_level=recognition_level,
|
|
languages=languages,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
finally:
|
|
tmp_path.unlink(missing_ok=True)
|