import os
import time
import warnings
from urllib.request import urlretrieve
from dataclasses import dataclass
import requests
import base64
import PIL.Image
from io import BytesIO
[docs]class BackendClient:
"""A backend client class for raw GET/POST requests to the REST API.
.. warning::
This class should not be accessed directly. Instead, use :class:`CSMClient`
to interface with the API.
Parameters
----------
api_key : str, optional
API key for the CSM account you would like to use. If not provided,
the environment variable :envvar:`CSM_API_KEY` is used instead.
base_url : str
Base url for the API. In general this should not be modified; it is
included only for debugging purposes.
"""
def __init__(
self,
api_key=None,
base_url="https://api.csm.ai",
) -> None:
if api_key is None:
api_key = os.environ.get('CSM_API_KEY')
if api_key is None:
raise Exception(
"The argument `api_key` must be provided when env variable "
"CSM_API_KEY is not set."
)
self.api_key = api_key
self.base_url = base_url
@property
def headers(self):
"""Constructs and returns the HTTP headers required for API requests."""
return {
'Content-Type': 'application/json',
'x-api-key': self.api_key,
}
def _check_http_response(self, response: requests.Response) -> None:
"""Validates the HTTP response for successful status codes.
Raises an error if the response indicates a failed request.
Parameters
----------
response : requests.Response
The HTTP response object to check.
"""
if not (200 <= response.status_code < 300):
raise RuntimeError(
f"HTTP request failed with status code {response.status_code} "
f"({response.reason})"
)
# image-to-3d API
# -------------------------------------------
[docs] def create_image_to_3d_session(
self,
image_url,
*,
generate_preview_mesh=False,
auto_refine=False,
preview_model="fast_sculpt",
## refine args
creativity="lowest",
refine_speed="fast",
polygon_count="high_poly",
topology="tris",
texture_resolution=2048,
scaled_bbox=[],
pivot_point=[0.0, 0.0, 0.0]
) -> dict:
"""Creates an image-to-3D conversion session.
Parameters
----------
image_url : str
URL of the image to convert into a 3D model.
Returns
-------
dict
The response from the API containing session details.
"""
assert creativity in ["lowest", "moderate", "highest"]
assert refine_speed in ["slow", "fast"]
assert polygon_count in ["low_poly", "high_poly"]
assert topology in ["tris", "quads"]
assert preview_model in ["fast_sculpt", "turbo"]
assert 128 <= texture_resolution <= 2048
parameters = {
"image_url": image_url,
"generate_preview_mesh": generate_preview_mesh,
"creativity": creativity,
"refine_speed": refine_speed,
"resolution": polygon_count,
"auto_refine": auto_refine,
"topology": topology,
"texture_resolution": texture_resolution,
"manual_segmentation": False, # TODO: implement this option
"pivot_point": [float(s) for s in pivot_point],
"preview_mesh": preview_model,
}
if len(scaled_bbox) == 3:
parameters["scaled_bbox"] = [float(s) for s in scaled_bbox]
response = requests.post(
url=f"{self.base_url}/image-to-3d-sessions",
json=parameters,
headers=self.headers,
)
self._check_http_response(response) # expected=201
return response.json()
[docs] def get_image_to_3d_session_info(self, session_code) -> dict:
"""Fetches information about an existing image-to-3D session.
Parameters
----------
session_code : str
The session code of the image-to-3D session.
Returns
-------
dict
The response from the API containing session details.
"""
response = requests.get(
url=f"{self.base_url}/image-to-3d-sessions/{session_code}",
headers=self.headers,
)
self._check_http_response(response) # expected=200
return response.json()
[docs] def get_3d_refine(self, session_code, scaled_bbox=[], pivot_point=[0.0, 0.0, 0.0]) -> dict:
"""Requests refinement for an existing 3D model.
Parameters
----------
session_code : str
The session code of the 3D model to refine.
Returns
-------
dict
The response from the API with refinement results.
"""
parameters = {"pivot_point": [s for s in pivot_point]}
if len(scaled_bbox) == 3:
parameters["scaled_bbox"] = [float(s) for s in scaled_bbox]
response = requests.post(
url=f"{self.base_url}/image-to-3d-sessions/get-3d/refine/{session_code}",
json=parameters,
headers=self.headers,
)
return response.json()
[docs] def get_3d_preview(self, session_code, spin_url=None, scaled_bbox=[], pivot_point=[0.0, 0.0, 0.0]) -> dict:
"""Fetches a preview of the generated 3D model.
Parameters
----------
session_code : str
The session code of the 3D model to preview.
Returns
-------
dict
The response from the API with preview data.
"""
selected_spin_index = 0
if spin_url is None:
result = self.get_image_to_3d_session_info(session_code)
spin_url = result['data']['spins'][selected_spin_index]["image_url"]
parameters = {
"selected_spin_index": selected_spin_index,
"selected_spin": spin_url,
"pivot_point": [s for s in pivot_point]
}
if len(scaled_bbox) == 3:
parameters["scaled_bbox"] = [float(s) for s in scaled_bbox]
response = requests.post(
url=f"{self.base_url}/image-to-3d-sessions/get-3d/preview/{session_code}",
json=parameters,
headers=self.headers,
timeout=100,
)
self._check_http_response(response) # expected=200
return response.json()
# text-to-image API methods
# -------------------------------------------
[docs] def create_text_to_image_session(
self,
prompt,
style_id="",
guidance=6,
) -> dict:
"""Creates a text-to-image session.
Parameters
----------
prompt : str
The text prompt for generating an image.
Returns
-------
dict
The response from the API with session details.
"""
parameters = {
'prompt': str(prompt),
'style_id': str(style_id),
'guidance': str(guidance),
}
response = requests.post(
url=f"{self.base_url}/tti-sessions",
json=parameters,
headers=self.headers,
)
self._check_http_response(response)
return response.json()
[docs] def get_text_to_image_session_info(self, session_code: str) -> dict:
"""Fetches information for an existing text-to-image session.
Parameters
----------
session_code : str
The session code of the text-to-image session.
Returns
-------
dict
The response from the API with session details.
"""
response = requests.get(
url=f"{self.base_url}/tti-sessions/{session_code}",
headers=self.headers,
)
self._check_http_response(response)
return response.json()
[docs]@dataclass
class ImageTo3DResult:
"""Output class for image-to-3d generation.
Parameters
----------
session_code : str
The image-to-3d session code.
mesh_path : str
Local path of the generated mesh file.
"""
session_code: str
mesh_path: str
[docs]@dataclass
class TextTo3DResult:
"""Output class for text-to-3d generation.
Parameters
----------
session_code : str
The image-to-3d session code.
mesh_path : str
Local path of the generated mesh file.
image_path : str
Local path of the image generated as part of text-to-3d.
"""
session_code: str
mesh_path: str
image_path: str
[docs]class CSMClient:
"""Core client utility for accessing the CSM API.
Parameters
----------
api_key : str, optional
API key for the CSM account you would like to use. If not provided,
the environment variable :envvar:`CSM_API_KEY` is used instead.
base_url : str
Base url for the API. In general this should not be modified; it is
included only for debugging purposes.
"""
def __init__(
self,
api_key=None,
base_url="https://api.csm.ai",
) -> None:
self.backend = BackendClient(api_key=api_key, base_url=base_url)
def _handle_image_input(self, image) -> str:
"""Handles image input by converting it to a base64-encoded string.
Parameters
----------
image : str or PIL.Image.Image
The input image, either as a file path, URL, or PIL Image.
Returns
-------
str or PIL.Image.Image
The base64 encoded string if local or PIL image, else image URL.
"""
if isinstance(image, str):
if os.path.isfile(image): # local file path
image_path = image
pil_image = PIL.Image.open(image_path)
else: # URL for a web file
image_url = image
return image_url # TODO: verify that this is a valid URL
elif isinstance(image, PIL.Image.Image):
pil_image = image
else:
raise ValueError(f"Encountered unexpected type for the input image.")
return pil_image_to_x64(pil_image)
[docs] def image_to_3d(
self,
image,
*,
generate_spin_video=False,
mesh_format='obj',
output='./',
timeout=200,
verbose=True,
scaled_bbox=[],
pivot_point=[0.0, 0.0, 0.0],
refine_speed="fast",
preview_model="fast_sculpt"
) -> ImageTo3DResult:
"""Generate a 3D mesh from an image.
The input image can be provided as a URL, a local path, or a :class:`PIL.Image.Image`.
Parameters
----------
image : str or PIL.Image.Image
The input image. May be provided as a URL, a local file path, or a
:class:`PIL.Image.Image` instance to be converted into a 3D mesh.
generate_spin_video : bool, optional
If True, a spin video of the generated 3D mesh is created. Defaults to False.
mesh_format : str, optional
The format of the output 3D mesh file. Choices are 'obj', 'glb', or 'usdz'.
Defaults to 'obj'.
output : str, optional
The directory path where output files (mesh and video, if generated)
will be saved. Defaults to the current directory.
timeout : int, optional
The maximum time (in seconds) to wait for the 3D mesh generation.
Defaults to 200 seconds.
verbose : bool, optional
If True, outputs detailed progress information. Defaults to True.
scaled_bbox : list, optional
A 3-element list specifying the scaled bounding box for the generated 3D model.
Defaults to an empty list, meaning no custom bounding box.
pivot_point : list, optional
A 3-element list specifying the pivot point for the 3D model's orientation.
Defaults to [0.0, 0.0, 0.0].
refine_speed : str, optional
The refinement speed for the model generation process. Choices are 'fast'
or 'slow'. Defaults to 'fast'.
preview_model : str, optional
The preview model type to use during 3D mesh creation. Choices are
'fast_sculpt' or 'turbo'. Defaults to 'fast_sculpt'.
Returns
-------
ImageTo3DResult
Result object. Contains the local path of the generated mesh file
and session code.
"""
if generate_spin_video:
warnings.warn(
"The option `generate_spin_video=True` is deprecated and will be removed "
"in a future release", DeprecationWarning)
mesh_format = mesh_format.lower()
if mesh_format not in ['obj', 'glb', 'usdz']:
raise ValueError(
f"Unexpected mesh_format value ('{mesh_format}'). Please choose "
f"from options ['obj', 'glb', 'usdz']."
)
image_url = self._handle_image_input(image)
os.makedirs(output, exist_ok=True)
# initialize session
result = self.backend.create_image_to_3d_session(
image_url,
generate_preview_mesh=not generate_spin_video,
auto_refine=False,
preview_model=preview_model,
scaled_bbox=scaled_bbox,
pivot_point=pivot_point,
refine_speed=refine_speed
)
status = result['data']['status']
if (
(generate_spin_video and status not in ["spin_generate_processing", "spin_generate_done"]) or
(not generate_spin_video and status not in ["training_preview", "preview_done"])
):
raise RuntimeError(f"Image-to-3d session creation failed (status='{status}')")
session_code = result['data']['session_code']
step_label = "spin generation" if generate_spin_video else "mesh generation"
if verbose:
print(f'[INFO] Image-to-3d session created ({session_code})')
if generate_spin_video:
if verbose:
print(f'[INFO] Running preview {step_label}...')
# wait for preview spin generation to complete (20-30s)
start_time = time.time()
run_time = 0.
while True:
time.sleep(2)
result = self.backend.get_image_to_3d_session_info(session_code)
status = result['data']['status']
if status == 'spin_generate_done':
break
elif status == 'spin_generate_failed':
raise RuntimeError(f"Preview {step_label} failed")
elif status != 'spin_generate_processing':
raise RuntimeError(f"Unexpected error during preview {step_label} (status='{status}')")
run_time = time.time() - start_time
if run_time >= timeout:
raise RuntimeError(f"Preview {step_label} timed out")
if verbose:
print(f'[INFO] Preview {step_label} completed in {run_time:.1f}s')
spin_url = result['data']['spins'][0]["image_url"]
# launch preview mesh export
result = self.backend.get_3d_preview(
session_code,
spin_url=spin_url,
scaled_bbox=scaled_bbox,
pivot_point=pivot_point
)
step_label = "mesh export"
if verbose:
print(f'[INFO] Running preview {step_label}...')
# wait for preview mesh export to complete (20-30s)
start_time = time.time()
run_time = 0.
while True:
time.sleep(2)
result = self.backend.get_image_to_3d_session_info(session_code)
status = result['data']['status']
if status == 'preview_done':
break
elif status == 'preview_failed':
raise RuntimeError(f"Preview {step_label} failed.")
elif status != 'training_preview':
raise RuntimeError(f"Unexpected error during preview {step_label} (status='{status}')")
run_time = time.time() - start_time
if run_time >= timeout:
raise RuntimeError(f"Preview {step_label} timed out")
if verbose:
print(f'[INFO] Preview {step_label} completed in {run_time:.1f}s')
# download mesh file based on the requested format
if mesh_format == 'obj':
mesh_url = result['data']['preview_mesh_url_zip']
mesh_file = 'mesh.obj' if preview_model == 'fast_sculpt' else 'mesh.zip'
elif mesh_format == 'glb':
mesh_url = result['data']['preview_mesh_url_glb']
mesh_file = 'mesh.glb'
elif mesh_format == 'usdz':
mesh_url = result['data']['preview_mesh_url_usdz']
mesh_file = 'mesh.usdz'
else:
raise ValueError(f"Encountered unexpected mesh_format value ('{mesh_format}').")
mesh_path = os.path.join(output, mesh_file) # TODO: os.path.abspath ?
urlretrieve(mesh_url, mesh_path)
return ImageTo3DResult(session_code=session_code, mesh_path=mesh_path)
[docs] def text_to_3d(
self,
prompt,
*,
style_id="",
guidance=6,
generate_spin_video=False,
mesh_format='obj',
output='./',
timeout=200,
verbose=True,
scaled_bbox=[],
pivot_point=[0.0, 0.0, 0.0],
refine_speed="fast",
preview_model="fast_sculpt"
) -> TextTo3DResult:
"""Generate a 3D mesh from a text prompt.
Parameters
----------
prompt : str
The input text prompt to generate a 3D model based on text description.
style_id : str, optional
The style ID that influences the visual characteristics of the generated model.
Defaults to an empty string, meaning no specific style is applied.
guidance : int, optional
A parameter that adjusts guidance strength, affecting how closely
the generation follows the input text. Default is 6.
generate_spin_video : bool, optional
If True, a spin video of the generated 3D mesh is created. Defaults to False.
mesh_format : str, optional
The format of the output 3D mesh file. Choices are 'obj', 'glb', or 'usdz'.
Defaults to 'obj'.
output : str, optional
The directory path where output files (mesh and video, if generated)
will be saved. Defaults to the current directory.
timeout : int, optional
The maximum time (in seconds) to wait for the 3D mesh generation.
Defaults to 200 seconds.
verbose : bool, optional
If True, outputs detailed progress information. Defaults to True.
scaled_bbox : list, optional
A 3-element list specifying the scaled bounding box for the generated 3D model.
Defaults to an empty list, meaning no custom bounding box.
pivot_point : list, optional
A 3-element list specifying the pivot point for the 3D model's orientation.
Defaults to [0.0, 0.0, 0.0].
refine_speed : str, optional
The refinement speed for the model generation process. Choices are 'fast'
or 'slow'. Defaults to 'fast'.
preview_model : str, optional
The preview model type to use during 3D mesh creation. Choices are
'fast_sculpt' or 'turbo'. Defaults to 'fast_sculpt'.
Returns
-------
TextTo3DResult
Result object. Contains the local path of the generated mesh file,
as well as the image generated as part of the pipeline, and session code.
"""
os.makedirs(output, exist_ok=True)
# initialize text-to-image session
result = self.backend.create_text_to_image_session(
prompt,
style_id=style_id,
guidance=guidance,
)
status = result['data']['status']
if status != "processing" and status != "completed":
raise RuntimeError(f"Text-to-image session creation failed (status='{status}')")
session_code = result['data']['session_code']
if verbose:
print(f'[INFO] Text-to-image session created ({session_code})')
print(f'[INFO] Running text-to-image generation...')
# wait for image generation to complete
start_time = time.time()
run_time = 0.
while True:
time.sleep(2)
result = self.backend.get_text_to_image_session_info(session_code)
status = result['data']['status']
if status == 'completed':
break
elif status != 'processing':
raise RuntimeError(f"Unexpected error during text-to-image generation (status='{status}')")
run_time = time.time() - start_time
if run_time >= timeout:
raise RuntimeError("Text-to-image generation timed out")
if verbose:
print(f'[INFO] Text-to-image generation completed in {run_time:.1f}s')
# access the image URL
image_url = result['data']['image_url']
# download image
image_path = os.path.join(output, 'image.png')
urlretrieve(image_url, image_path)
# launch image-to-3d
i23 = self.image_to_3d(
image_url,
generate_spin_video=generate_spin_video,
mesh_format=mesh_format,
output=output,
timeout=timeout,
verbose=verbose,
scaled_bbox=scaled_bbox,
pivot_point=pivot_point,
refine_speed=refine_speed,
preview_model=preview_model
)
return TextTo3DResult(session_code=i23.session_code, mesh_path=i23.mesh_path, image_path=image_path)
def pil_image_to_x64(image: PIL.Image.Image) -> str:
"""Converts a PIL.Image.Image to a base64-encoded PNG string.
Parameters
----------
image : PIL.Image.Image
The image to convert.
Returns
-------
str
The base64 encoded image string.
"""
buffer = BytesIO()
image.save(buffer, "PNG")
x64 = buffer.getvalue()
return 'data:image/png;base64,' + base64.b64encode(x64).decode("utf-8")